gitextract_qha1vuxj/ ├── .devcontainer/ │ ├── README.md │ ├── devcontainer.json │ └── setup.sh ├── .gemini/ │ ├── config.yaml │ └── styleguide.md ├── .github/ │ ├── dependabot.yml │ └── workflows/ │ ├── actions.yml │ ├── auto-assignment.yaml │ ├── config/ │ │ ├── jax/ │ │ │ └── keras.json │ │ ├── numpy/ │ │ │ └── keras.json │ │ ├── openvino/ │ │ │ └── keras.json │ │ ├── tensorflow/ │ │ │ └── keras.json │ │ └── torch/ │ │ └── keras.json │ ├── gpu_tests.yml │ ├── labeler.yaml │ ├── nightly.yml │ ├── scorecard.yml │ ├── scripts/ │ │ ├── auto-assignment.js │ │ └── labeler.js │ ├── stale-issue-pr.yaml │ └── tpu_tests.yml ├── .gitignore ├── .kokoro/ │ ├── README.md │ └── github/ │ └── ubuntu/ │ └── gpu/ │ ├── build.sh │ ├── jax/ │ │ ├── continuous.cfg │ │ └── presubmit.cfg │ └── tensorflow/ │ ├── continuous.cfg │ └── presubmit.cfg ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── SECURITY.md ├── api_gen.py ├── benchmarks/ │ ├── __init__.py │ ├── layer_benchmark/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── activation_benchmark.py │ │ ├── attention_benchmark.py │ │ ├── base_benchmark.py │ │ ├── conv_benchmark.py │ │ ├── core_benchmark.py │ │ ├── merge_benchmark.py │ │ ├── normalization_benchmark.py │ │ ├── pooling_benchmark.py │ │ ├── random_rotation_benchmark.py │ │ ├── regularization_benchmark.py │ │ ├── reshaping_benchmark.py │ │ └── rnn_benchmark.py │ ├── model_benchmark/ │ │ ├── __init__.py │ │ ├── benchmark_utils.py │ │ ├── bert_benchmark.py │ │ └── image_classification_benchmark.py │ └── torch_ctl_benchmark/ │ ├── README.md │ ├── __init__.py │ ├── benchmark_utils.py │ ├── conv_model_benchmark.py │ └── dense_model_benchmark.py ├── codecov.yml ├── conftest.py ├── examples/ │ ├── demo_custom_jax_workflow.py │ ├── demo_custom_layer_backend_agnostic.py │ ├── demo_custom_tf_workflow.py │ ├── demo_custom_torch_workflow.py │ ├── demo_functional.py │ ├── demo_jax_distributed.py │ ├── demo_mnist_convnet.py │ ├── demo_subclass.py │ └── demo_torch_multi_gpu.py ├── guides/ │ ├── custom_train_step_in_jax.py │ ├── custom_train_step_in_tensorflow.py │ ├── custom_train_step_in_torch.py │ ├── distributed_training_with_jax.py │ ├── distributed_training_with_tensorflow.py │ ├── distributed_training_with_torch.py │ ├── functional_api.py │ ├── making_new_layers_and_models_via_subclassing.py │ ├── sequential_model.py │ ├── training_with_built_in_methods.py │ ├── transfer_learning.py │ ├── understanding_masking_and_padding.py │ ├── writing_a_custom_training_loop_in_jax.py │ ├── writing_a_custom_training_loop_in_tensorflow.py │ ├── writing_a_custom_training_loop_in_torch.py │ └── writing_your_own_callbacks.py ├── integration_tests/ │ ├── basic_full_flow.py │ ├── dataset_tests/ │ │ ├── boston_housing_test.py │ │ ├── california_housing_test.py │ │ ├── cifar100_test.py │ │ ├── cifar10_test.py │ │ ├── fashion_mnist_test.py │ │ ├── imdb_test.py │ │ ├── mnist_test.py │ │ └── reuters_test.py │ ├── import_test.py │ ├── jax_custom_fit_test.py │ ├── model_visualization_test.py │ ├── numerical_test.py │ ├── pytorch_export_test.py │ ├── tf_custom_fit_test.py │ ├── tf_distribute_training_test.py │ ├── torch_custom_fit_test.py │ └── torch_workflow_test.py ├── keras/ │ ├── __init__.py │ ├── api/ │ │ ├── __init__.py │ │ ├── _tf_keras/ │ │ │ ├── __init__.py │ │ │ └── keras/ │ │ │ ├── __init__.py │ │ │ ├── activations/ │ │ │ │ └── __init__.py │ │ │ ├── applications/ │ │ │ │ ├── __init__.py │ │ │ │ ├── convnext/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── densenet/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── efficientnet/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── efficientnet_v2/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── imagenet_utils/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── inception_resnet_v2/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── inception_v3/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── mobilenet/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── mobilenet_v2/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── mobilenet_v3/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── nasnet/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── resnet/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── resnet50/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── resnet_v2/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── vgg16/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── vgg19/ │ │ │ │ │ └── __init__.py │ │ │ │ └── xception/ │ │ │ │ └── __init__.py │ │ │ ├── backend/ │ │ │ │ └── __init__.py │ │ │ ├── callbacks/ │ │ │ │ └── __init__.py │ │ │ ├── config/ │ │ │ │ └── __init__.py │ │ │ ├── constraints/ │ │ │ │ └── __init__.py │ │ │ ├── datasets/ │ │ │ │ ├── __init__.py │ │ │ │ ├── boston_housing/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── california_housing/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── cifar10/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── cifar100/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── fashion_mnist/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── imdb/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── mnist/ │ │ │ │ │ └── __init__.py │ │ │ │ └── reuters/ │ │ │ │ └── __init__.py │ │ │ ├── distillation/ │ │ │ │ └── __init__.py │ │ │ ├── distribution/ │ │ │ │ └── __init__.py │ │ │ ├── dtype_policies/ │ │ │ │ └── __init__.py │ │ │ ├── export/ │ │ │ │ └── __init__.py │ │ │ ├── initializers/ │ │ │ │ └── __init__.py │ │ │ ├── layers/ │ │ │ │ └── __init__.py │ │ │ ├── legacy/ │ │ │ │ ├── __init__.py │ │ │ │ └── saving/ │ │ │ │ └── __init__.py │ │ │ ├── losses/ │ │ │ │ └── __init__.py │ │ │ ├── metrics/ │ │ │ │ └── __init__.py │ │ │ ├── mixed_precision/ │ │ │ │ └── __init__.py │ │ │ ├── models/ │ │ │ │ └── __init__.py │ │ │ ├── ops/ │ │ │ │ ├── __init__.py │ │ │ │ ├── image/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── linalg/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── nn/ │ │ │ │ │ └── __init__.py │ │ │ │ └── numpy/ │ │ │ │ └── __init__.py │ │ │ ├── optimizers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── legacy/ │ │ │ │ │ └── __init__.py │ │ │ │ └── schedules/ │ │ │ │ └── __init__.py │ │ │ ├── preprocessing/ │ │ │ │ ├── __init__.py │ │ │ │ ├── image/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── sequence/ │ │ │ │ │ └── __init__.py │ │ │ │ └── text/ │ │ │ │ └── __init__.py │ │ │ ├── quantizers/ │ │ │ │ └── __init__.py │ │ │ ├── random/ │ │ │ │ └── __init__.py │ │ │ ├── regularizers/ │ │ │ │ └── __init__.py │ │ │ ├── saving/ │ │ │ │ └── __init__.py │ │ │ ├── tree/ │ │ │ │ └── __init__.py │ │ │ ├── utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── bounding_boxes/ │ │ │ │ │ └── __init__.py │ │ │ │ └── legacy/ │ │ │ │ └── __init__.py │ │ │ ├── visualization/ │ │ │ │ └── __init__.py │ │ │ └── wrappers/ │ │ │ └── __init__.py │ │ ├── activations/ │ │ │ └── __init__.py │ │ ├── applications/ │ │ │ ├── __init__.py │ │ │ ├── convnext/ │ │ │ │ └── __init__.py │ │ │ ├── densenet/ │ │ │ │ └── __init__.py │ │ │ ├── efficientnet/ │ │ │ │ └── __init__.py │ │ │ ├── efficientnet_v2/ │ │ │ │ └── __init__.py │ │ │ ├── imagenet_utils/ │ │ │ │ └── __init__.py │ │ │ ├── inception_resnet_v2/ │ │ │ │ └── __init__.py │ │ │ ├── inception_v3/ │ │ │ │ └── __init__.py │ │ │ ├── mobilenet/ │ │ │ │ └── __init__.py │ │ │ ├── mobilenet_v2/ │ │ │ │ └── __init__.py │ │ │ ├── mobilenet_v3/ │ │ │ │ └── __init__.py │ │ │ ├── nasnet/ │ │ │ │ └── __init__.py │ │ │ ├── resnet/ │ │ │ │ └── __init__.py │ │ │ ├── resnet50/ │ │ │ │ └── __init__.py │ │ │ ├── resnet_v2/ │ │ │ │ └── __init__.py │ │ │ ├── vgg16/ │ │ │ │ └── __init__.py │ │ │ ├── vgg19/ │ │ │ │ └── __init__.py │ │ │ └── xception/ │ │ │ └── __init__.py │ │ ├── backend/ │ │ │ └── __init__.py │ │ ├── callbacks/ │ │ │ └── __init__.py │ │ ├── config/ │ │ │ └── __init__.py │ │ ├── constraints/ │ │ │ └── __init__.py │ │ ├── datasets/ │ │ │ ├── __init__.py │ │ │ ├── boston_housing/ │ │ │ │ └── __init__.py │ │ │ ├── california_housing/ │ │ │ │ └── __init__.py │ │ │ ├── cifar10/ │ │ │ │ └── __init__.py │ │ │ ├── cifar100/ │ │ │ │ └── __init__.py │ │ │ ├── fashion_mnist/ │ │ │ │ └── __init__.py │ │ │ ├── imdb/ │ │ │ │ └── __init__.py │ │ │ ├── mnist/ │ │ │ │ └── __init__.py │ │ │ └── reuters/ │ │ │ └── __init__.py │ │ ├── distillation/ │ │ │ └── __init__.py │ │ ├── distribution/ │ │ │ └── __init__.py │ │ ├── dtype_policies/ │ │ │ └── __init__.py │ │ ├── export/ │ │ │ └── __init__.py │ │ ├── initializers/ │ │ │ └── __init__.py │ │ ├── layers/ │ │ │ └── __init__.py │ │ ├── legacy/ │ │ │ ├── __init__.py │ │ │ └── saving/ │ │ │ └── __init__.py │ │ ├── losses/ │ │ │ └── __init__.py │ │ ├── metrics/ │ │ │ └── __init__.py │ │ ├── mixed_precision/ │ │ │ └── __init__.py │ │ ├── models/ │ │ │ └── __init__.py │ │ ├── ops/ │ │ │ ├── __init__.py │ │ │ ├── image/ │ │ │ │ └── __init__.py │ │ │ ├── linalg/ │ │ │ │ └── __init__.py │ │ │ ├── nn/ │ │ │ │ └── __init__.py │ │ │ └── numpy/ │ │ │ └── __init__.py │ │ ├── optimizers/ │ │ │ ├── __init__.py │ │ │ ├── legacy/ │ │ │ │ └── __init__.py │ │ │ └── schedules/ │ │ │ └── __init__.py │ │ ├── preprocessing/ │ │ │ ├── __init__.py │ │ │ ├── image/ │ │ │ │ └── __init__.py │ │ │ └── sequence/ │ │ │ └── __init__.py │ │ ├── quantizers/ │ │ │ └── __init__.py │ │ ├── random/ │ │ │ └── __init__.py │ │ ├── regularizers/ │ │ │ └── __init__.py │ │ ├── saving/ │ │ │ └── __init__.py │ │ ├── tree/ │ │ │ └── __init__.py │ │ ├── utils/ │ │ │ ├── __init__.py │ │ │ ├── bounding_boxes/ │ │ │ │ └── __init__.py │ │ │ └── legacy/ │ │ │ └── __init__.py │ │ ├── visualization/ │ │ │ └── __init__.py │ │ └── wrappers/ │ │ └── __init__.py │ └── src/ │ ├── __init__.py │ ├── activations/ │ │ ├── __init__.py │ │ ├── activations.py │ │ └── activations_test.py │ ├── api_export.py │ ├── applications/ │ │ ├── __init__.py │ │ ├── applications_test.py │ │ ├── convnext.py │ │ ├── densenet.py │ │ ├── efficientnet.py │ │ ├── efficientnet_v2.py │ │ ├── imagenet_utils.py │ │ ├── imagenet_utils_test.py │ │ ├── inception_resnet_v2.py │ │ ├── inception_v3.py │ │ ├── mobilenet.py │ │ ├── mobilenet_v2.py │ │ ├── mobilenet_v3.py │ │ ├── nasnet.py │ │ ├── resnet.py │ │ ├── resnet_v2.py │ │ ├── vgg16.py │ │ ├── vgg19.py │ │ └── xception.py │ ├── backend/ │ │ ├── __init__.py │ │ ├── common/ │ │ │ ├── __init__.py │ │ │ ├── backend_utils.py │ │ │ ├── backend_utils_test.py │ │ │ ├── compute_output_spec_test.py │ │ │ ├── dtypes.py │ │ │ ├── dtypes_test.py │ │ │ ├── global_state.py │ │ │ ├── global_state_test.py │ │ │ ├── keras_tensor.py │ │ │ ├── keras_tensor_test.py │ │ │ ├── masking.py │ │ │ ├── masking_test.py │ │ │ ├── name_scope.py │ │ │ ├── name_scope_test.py │ │ │ ├── remat.py │ │ │ ├── remat_test.py │ │ │ ├── stateless_scope.py │ │ │ ├── stateless_scope_test.py │ │ │ ├── symbolic_scope.py │ │ │ ├── symbolic_scope_test.py │ │ │ ├── tensor_attributes.py │ │ │ ├── thread_safe_test.py │ │ │ ├── variables.py │ │ │ └── variables_test.py │ │ ├── config.py │ │ ├── jax/ │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── core_test.py │ │ │ ├── distribution_lib.py │ │ │ ├── distribution_lib_test.py │ │ │ ├── excluded_tpu_tests.txt │ │ │ ├── export.py │ │ │ ├── image.py │ │ │ ├── layer.py │ │ │ ├── linalg.py │ │ │ ├── math.py │ │ │ ├── nn.py │ │ │ ├── numpy.py │ │ │ ├── optimizer.py │ │ │ ├── random.py │ │ │ ├── rnn.py │ │ │ ├── sparse.py │ │ │ ├── tensorboard.py │ │ │ ├── trainer.py │ │ │ └── trainer_test.py │ │ ├── numpy/ │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── export.py │ │ │ ├── image.py │ │ │ ├── layer.py │ │ │ ├── linalg.py │ │ │ ├── math.py │ │ │ ├── nn.py │ │ │ ├── numpy.py │ │ │ ├── random.py │ │ │ ├── rnn.py │ │ │ └── trainer.py │ │ ├── openvino/ │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── excluded_concrete_tests.txt │ │ │ ├── excluded_tests.txt │ │ │ ├── export.py │ │ │ ├── image.py │ │ │ ├── layer.py │ │ │ ├── linalg.py │ │ │ ├── math.py │ │ │ ├── nn.py │ │ │ ├── numpy.py │ │ │ ├── random.py │ │ │ ├── rnn.py │ │ │ └── trainer.py │ │ ├── tensorflow/ │ │ │ ├── __init__.py │ │ │ ├── core.py │ │ │ ├── distribute_test.py │ │ │ ├── distribution_lib.py │ │ │ ├── export.py │ │ │ ├── image.py │ │ │ ├── layer.py │ │ │ ├── linalg.py │ │ │ ├── math.py │ │ │ ├── name_scope_test.py │ │ │ ├── nn.py │ │ │ ├── numpy.py │ │ │ ├── optimizer.py │ │ │ ├── optimizer_distribute_test.py │ │ │ ├── random.py │ │ │ ├── rnn.py │ │ │ ├── saved_model_test.py │ │ │ ├── sparse.py │ │ │ ├── tensorboard.py │ │ │ ├── trackable.py │ │ │ └── trainer.py │ │ ├── tests/ │ │ │ ├── compute_output_spec_test.py │ │ │ └── device_scope_test.py │ │ └── torch/ │ │ ├── __init__.py │ │ ├── core.py │ │ ├── export.py │ │ ├── image.py │ │ ├── layer.py │ │ ├── linalg.py │ │ ├── math.py │ │ ├── nn.py │ │ ├── numpy.py │ │ ├── optimizers/ │ │ │ ├── __init__.py │ │ │ ├── torch_adadelta.py │ │ │ ├── torch_adagrad.py │ │ │ ├── torch_adam.py │ │ │ ├── torch_adamax.py │ │ │ ├── torch_adamw.py │ │ │ ├── torch_lion.py │ │ │ ├── torch_nadam.py │ │ │ ├── torch_optimizer.py │ │ │ ├── torch_parallel_optimizer.py │ │ │ ├── torch_rmsprop.py │ │ │ └── torch_sgd.py │ │ ├── random.py │ │ ├── rnn.py │ │ └── trainer.py │ ├── callbacks/ │ │ ├── __init__.py │ │ ├── backup_and_restore.py │ │ ├── backup_and_restore_test.py │ │ ├── callback.py │ │ ├── callback_list.py │ │ ├── callback_test.py │ │ ├── csv_logger.py │ │ ├── csv_logger_test.py │ │ ├── early_stopping.py │ │ ├── early_stopping_test.py │ │ ├── history.py │ │ ├── lambda_callback.py │ │ ├── lambda_callback_test.py │ │ ├── learning_rate_scheduler.py │ │ ├── learning_rate_scheduler_test.py │ │ ├── model_checkpoint.py │ │ ├── model_checkpoint_test.py │ │ ├── monitor_callback.py │ │ ├── monitor_callback_test.py │ │ ├── orbax_checkpoint.py │ │ ├── orbax_checkpoint_test.py │ │ ├── progbar_logger.py │ │ ├── reduce_lr_on_plateau.py │ │ ├── reduce_lr_on_plateau_test.py │ │ ├── remote_monitor.py │ │ ├── remote_monitor_test.py │ │ ├── swap_ema_weights.py │ │ ├── swap_ema_weights_test.py │ │ ├── tensorboard.py │ │ ├── tensorboard_test.py │ │ ├── terminate_on_nan.py │ │ └── terminate_on_nan_test.py │ ├── constraints/ │ │ ├── __init__.py │ │ ├── constraints.py │ │ └── constraints_test.py │ ├── datasets/ │ │ ├── __init__.py │ │ ├── boston_housing.py │ │ ├── california_housing.py │ │ ├── cifar.py │ │ ├── cifar10.py │ │ ├── cifar100.py │ │ ├── fashion_mnist.py │ │ ├── imdb.py │ │ ├── mnist.py │ │ └── reuters.py │ ├── distillation/ │ │ ├── __init__.py │ │ ├── distillation_loss.py │ │ ├── distillation_loss_test.py │ │ ├── distiller.py │ │ └── distiller_test.py │ ├── distribution/ │ │ ├── __init__.py │ │ ├── distribution_lib.py │ │ └── distribution_lib_test.py │ ├── dtype_policies/ │ │ ├── __init__.py │ │ ├── dtype_policy.py │ │ ├── dtype_policy_map.py │ │ ├── dtype_policy_map_test.py │ │ └── dtype_policy_test.py │ ├── export/ │ │ ├── __init__.py │ │ ├── export_utils.py │ │ ├── litert.py │ │ ├── litert_test.py │ │ ├── neptune_model_export_archive.py │ │ ├── onnx.py │ │ ├── onnx_test.py │ │ ├── openvino.py │ │ ├── openvino_test.py │ │ ├── saved_model.py │ │ ├── saved_model_export_archive.py │ │ ├── saved_model_test.py │ │ ├── tf2onnx_lib.py │ │ ├── tfsm_layer.py │ │ └── tfsm_layer_test.py │ ├── initializers/ │ │ ├── __init__.py │ │ ├── constant_initializers.py │ │ ├── constant_initializers_test.py │ │ ├── initializer.py │ │ ├── random_initializers.py │ │ └── random_initializers_test.py │ ├── layers/ │ │ ├── __init__.py │ │ ├── activations/ │ │ │ ├── __init__.py │ │ │ ├── activation.py │ │ │ ├── activation_test.py │ │ │ ├── elu.py │ │ │ ├── elu_test.py │ │ │ ├── leaky_relu.py │ │ │ ├── leaky_relu_test.py │ │ │ ├── prelu.py │ │ │ ├── prelu_test.py │ │ │ ├── relu.py │ │ │ ├── relu_test.py │ │ │ ├── softmax.py │ │ │ └── softmax_test.py │ │ ├── attention/ │ │ │ ├── __init__.py │ │ │ ├── additive_attention.py │ │ │ ├── additive_attention_test.py │ │ │ ├── attention.py │ │ │ ├── attention_test.py │ │ │ ├── grouped_query_attention.py │ │ │ ├── grouped_query_attention_test.py │ │ │ ├── multi_head_attention.py │ │ │ └── multi_head_attention_test.py │ │ ├── convolutional/ │ │ │ ├── __init__.py │ │ │ ├── base_conv.py │ │ │ ├── base_conv_transpose.py │ │ │ ├── base_depthwise_conv.py │ │ │ ├── base_separable_conv.py │ │ │ ├── conv1d.py │ │ │ ├── conv1d_transpose.py │ │ │ ├── conv2d.py │ │ │ ├── conv2d_transpose.py │ │ │ ├── conv3d.py │ │ │ ├── conv3d_transpose.py │ │ │ ├── conv_test.py │ │ │ ├── conv_transpose_test.py │ │ │ ├── depthwise_conv1d.py │ │ │ ├── depthwise_conv2d.py │ │ │ ├── depthwise_conv_test.py │ │ │ ├── separable_conv1d.py │ │ │ ├── separable_conv2d.py │ │ │ └── separable_conv_test.py │ │ ├── core/ │ │ │ ├── __init__.py │ │ │ ├── dense.py │ │ │ ├── dense_test.py │ │ │ ├── einsum_dense.py │ │ │ ├── einsum_dense_test.py │ │ │ ├── embedding.py │ │ │ ├── embedding_test.py │ │ │ ├── identity.py │ │ │ ├── identity_test.py │ │ │ ├── input_layer.py │ │ │ ├── input_layer_test.py │ │ │ ├── lambda_layer.py │ │ │ ├── lambda_layer_test.py │ │ │ ├── masking.py │ │ │ ├── masking_test.py │ │ │ ├── reversible_embedding.py │ │ │ ├── reversible_embedding_test.py │ │ │ ├── wrapper.py │ │ │ └── wrapper_test.py │ │ ├── input_spec.py │ │ ├── layer.py │ │ ├── layer_test.py │ │ ├── merging/ │ │ │ ├── __init__.py │ │ │ ├── add.py │ │ │ ├── average.py │ │ │ ├── base_merge.py │ │ │ ├── concatenate.py │ │ │ ├── dot.py │ │ │ ├── maximum.py │ │ │ ├── merging_test.py │ │ │ ├── minimum.py │ │ │ ├── multiply.py │ │ │ └── subtract.py │ │ ├── normalization/ │ │ │ ├── __init__.py │ │ │ ├── batch_normalization.py │ │ │ ├── batch_normalization_test.py │ │ │ ├── group_normalization.py │ │ │ ├── group_normalization_test.py │ │ │ ├── layer_normalization.py │ │ │ ├── layer_normalization_test.py │ │ │ ├── rms_normalization.py │ │ │ ├── rms_normalization_test.py │ │ │ ├── spectral_normalization.py │ │ │ ├── spectral_normalization_test.py │ │ │ ├── unit_normalization.py │ │ │ └── unit_normalization_test.py │ │ ├── pooling/ │ │ │ ├── __init__.py │ │ │ ├── adaptive_average_pooling1d.py │ │ │ ├── adaptive_average_pooling2d.py │ │ │ ├── adaptive_average_pooling3d.py │ │ │ ├── adaptive_max_pooling1d.py │ │ │ ├── adaptive_max_pooling2d.py │ │ │ ├── adaptive_max_pooling3d.py │ │ │ ├── adaptive_pooling1d_test.py │ │ │ ├── adaptive_pooling2d_test.py │ │ │ ├── adaptive_pooling3d_test.py │ │ │ ├── average_pooling1d.py │ │ │ ├── average_pooling2d.py │ │ │ ├── average_pooling3d.py │ │ │ ├── average_pooling_test.py │ │ │ ├── base_adaptive_pooling.py │ │ │ ├── base_global_pooling.py │ │ │ ├── base_pooling.py │ │ │ ├── global_average_pooling1d.py │ │ │ ├── global_average_pooling2d.py │ │ │ ├── global_average_pooling3d.py │ │ │ ├── global_average_pooling_test.py │ │ │ ├── global_max_pooling1d.py │ │ │ ├── global_max_pooling2d.py │ │ │ ├── global_max_pooling3d.py │ │ │ ├── global_max_pooling_test.py │ │ │ ├── max_pooling1d.py │ │ │ ├── max_pooling2d.py │ │ │ ├── max_pooling3d.py │ │ │ └── max_pooling_test.py │ │ ├── preprocessing/ │ │ │ ├── __init__.py │ │ │ ├── category_encoding.py │ │ │ ├── category_encoding_test.py │ │ │ ├── data_layer.py │ │ │ ├── data_layer_test.py │ │ │ ├── discretization.py │ │ │ ├── discretization_test.py │ │ │ ├── feature_space.py │ │ │ ├── feature_space_test.py │ │ │ ├── hashed_crossing.py │ │ │ ├── hashed_crossing_test.py │ │ │ ├── hashing.py │ │ │ ├── hashing_test.py │ │ │ ├── image_preprocessing/ │ │ │ │ ├── __init__.py │ │ │ │ ├── aug_mix.py │ │ │ │ ├── aug_mix_test.py │ │ │ │ ├── auto_contrast.py │ │ │ │ ├── auto_contrast_test.py │ │ │ │ ├── base_image_preprocessing_layer.py │ │ │ │ ├── bounding_boxes/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── bounding_box.py │ │ │ │ │ ├── converters.py │ │ │ │ │ ├── converters_test.py │ │ │ │ │ ├── formats.py │ │ │ │ │ ├── iou.py │ │ │ │ │ ├── iou_test.py │ │ │ │ │ ├── validation.py │ │ │ │ │ └── validation_test.py │ │ │ │ ├── center_crop.py │ │ │ │ ├── center_crop_test.py │ │ │ │ ├── clahe.py │ │ │ │ ├── clahe_test.py │ │ │ │ ├── cut_mix.py │ │ │ │ ├── cut_mix_test.py │ │ │ │ ├── equalization.py │ │ │ │ ├── equalization_test.py │ │ │ │ ├── max_num_bounding_box.py │ │ │ │ ├── max_num_bounding_box_test.py │ │ │ │ ├── mix_up.py │ │ │ │ ├── mix_up_test.py │ │ │ │ ├── rand_augment.py │ │ │ │ ├── rand_augment_test.py │ │ │ │ ├── random_brightness.py │ │ │ │ ├── random_brightness_test.py │ │ │ │ ├── random_color_degeneration.py │ │ │ │ ├── random_color_degeneration_test.py │ │ │ │ ├── random_color_jitter.py │ │ │ │ ├── random_color_jitter_test.py │ │ │ │ ├── random_contrast.py │ │ │ │ ├── random_contrast_test.py │ │ │ │ ├── random_crop.py │ │ │ │ ├── random_crop_test.py │ │ │ │ ├── random_elastic_transform.py │ │ │ │ ├── random_elastic_transform_test.py │ │ │ │ ├── random_erasing.py │ │ │ │ ├── random_erasing_test.py │ │ │ │ ├── random_flip.py │ │ │ │ ├── random_flip_test.py │ │ │ │ ├── random_gaussian_blur.py │ │ │ │ ├── random_gaussian_blur_test.py │ │ │ │ ├── random_grayscale.py │ │ │ │ ├── random_grayscale_test.py │ │ │ │ ├── random_hue.py │ │ │ │ ├── random_hue_test.py │ │ │ │ ├── random_invert.py │ │ │ │ ├── random_invert_test.py │ │ │ │ ├── random_perspective.py │ │ │ │ ├── random_perspective_test.py │ │ │ │ ├── random_posterization.py │ │ │ │ ├── random_posterization_test.py │ │ │ │ ├── random_rotation.py │ │ │ │ ├── random_rotation_test.py │ │ │ │ ├── random_saturation.py │ │ │ │ ├── random_saturation_test.py │ │ │ │ ├── random_sharpness.py │ │ │ │ ├── random_sharpness_test.py │ │ │ │ ├── random_shear.py │ │ │ │ ├── random_shear_test.py │ │ │ │ ├── random_translation.py │ │ │ │ ├── random_translation_test.py │ │ │ │ ├── random_zoom.py │ │ │ │ ├── random_zoom_test.py │ │ │ │ ├── resizing.py │ │ │ │ ├── resizing_test.py │ │ │ │ ├── solarization.py │ │ │ │ └── solarization_test.py │ │ │ ├── index_lookup.py │ │ │ ├── index_lookup_test.py │ │ │ ├── integer_lookup.py │ │ │ ├── integer_lookup_test.py │ │ │ ├── mel_spectrogram.py │ │ │ ├── mel_spectrogram_test.py │ │ │ ├── normalization.py │ │ │ ├── normalization_test.py │ │ │ ├── pipeline.py │ │ │ ├── pipeline_test.py │ │ │ ├── rescaling.py │ │ │ ├── rescaling_test.py │ │ │ ├── stft_spectrogram.py │ │ │ ├── stft_spectrogram_test.py │ │ │ ├── string_lookup.py │ │ │ ├── string_lookup_test.py │ │ │ ├── text_vectorization.py │ │ │ └── text_vectorization_test.py │ │ ├── regularization/ │ │ │ ├── __init__.py │ │ │ ├── activity_regularization.py │ │ │ ├── activity_regularization_test.py │ │ │ ├── alpha_dropout.py │ │ │ ├── alpha_dropout_test.py │ │ │ ├── dropout.py │ │ │ ├── dropout_test.py │ │ │ ├── gaussian_dropout.py │ │ │ ├── gaussian_dropout_test.py │ │ │ ├── gaussian_noise.py │ │ │ ├── gaussian_noise_test.py │ │ │ ├── spatial_dropout.py │ │ │ └── spatial_dropout_test.py │ │ ├── reshaping/ │ │ │ ├── __init__.py │ │ │ ├── cropping1d.py │ │ │ ├── cropping1d_test.py │ │ │ ├── cropping2d.py │ │ │ ├── cropping2d_test.py │ │ │ ├── cropping3d.py │ │ │ ├── cropping3d_test.py │ │ │ ├── flatten.py │ │ │ ├── flatten_test.py │ │ │ ├── permute.py │ │ │ ├── permute_test.py │ │ │ ├── repeat_vector.py │ │ │ ├── repeat_vector_test.py │ │ │ ├── reshape.py │ │ │ ├── reshape_test.py │ │ │ ├── up_sampling1d.py │ │ │ ├── up_sampling1d_test.py │ │ │ ├── up_sampling2d.py │ │ │ ├── up_sampling2d_test.py │ │ │ ├── up_sampling3d.py │ │ │ ├── up_sampling3d_test.py │ │ │ ├── zero_padding1d.py │ │ │ ├── zero_padding1d_test.py │ │ │ ├── zero_padding2d.py │ │ │ ├── zero_padding2d_test.py │ │ │ ├── zero_padding3d.py │ │ │ └── zero_padding3d_test.py │ │ └── rnn/ │ │ ├── __init__.py │ │ ├── bidirectional.py │ │ ├── bidirectional_test.py │ │ ├── conv_lstm.py │ │ ├── conv_lstm1d.py │ │ ├── conv_lstm1d_test.py │ │ ├── conv_lstm2d.py │ │ ├── conv_lstm2d_test.py │ │ ├── conv_lstm3d.py │ │ ├── conv_lstm3d_test.py │ │ ├── conv_lstm_test.py │ │ ├── dropout_rnn_cell.py │ │ ├── dropout_rnn_cell_test.py │ │ ├── gru.py │ │ ├── gru_test.py │ │ ├── lstm.py │ │ ├── lstm_test.py │ │ ├── rnn.py │ │ ├── rnn_test.py │ │ ├── simple_rnn.py │ │ ├── simple_rnn_test.py │ │ ├── stacked_rnn_cells.py │ │ ├── stacked_rnn_cells_test.py │ │ ├── time_distributed.py │ │ └── time_distributed_test.py │ ├── legacy/ │ │ ├── __init__.py │ │ ├── backend.py │ │ ├── layers.py │ │ ├── losses.py │ │ ├── preprocessing/ │ │ │ ├── __init__.py │ │ │ ├── image.py │ │ │ ├── sequence.py │ │ │ └── text.py │ │ └── saving/ │ │ ├── __init__.py │ │ ├── json_utils.py │ │ ├── json_utils_test.py │ │ ├── legacy_h5_format.py │ │ ├── legacy_h5_format_test.py │ │ ├── saving_options.py │ │ ├── saving_utils.py │ │ └── serialization.py │ ├── losses/ │ │ ├── __init__.py │ │ ├── loss.py │ │ ├── loss_test.py │ │ ├── losses.py │ │ └── losses_test.py │ ├── metrics/ │ │ ├── __init__.py │ │ ├── accuracy_metrics.py │ │ ├── accuracy_metrics_test.py │ │ ├── confusion_metrics.py │ │ ├── confusion_metrics_test.py │ │ ├── correlation_metrics.py │ │ ├── correlation_metrics_test.py │ │ ├── f_score_metrics.py │ │ ├── f_score_metrics_test.py │ │ ├── hinge_metrics.py │ │ ├── hinge_metrics_test.py │ │ ├── iou_metrics.py │ │ ├── iou_metrics_test.py │ │ ├── metric.py │ │ ├── metric_test.py │ │ ├── metrics_utils.py │ │ ├── probabilistic_metrics.py │ │ ├── probabilistic_metrics_test.py │ │ ├── reduction_metrics.py │ │ ├── reduction_metrics_test.py │ │ ├── regression_metrics.py │ │ └── regression_metrics_test.py │ ├── models/ │ │ ├── __init__.py │ │ ├── cloning.py │ │ ├── cloning_test.py │ │ ├── functional.py │ │ ├── functional_test.py │ │ ├── model.py │ │ ├── model_test.py │ │ ├── sequential.py │ │ ├── sequential_test.py │ │ ├── variable_mapping.py │ │ └── variable_mapping_test.py │ ├── ops/ │ │ ├── __init__.py │ │ ├── core.py │ │ ├── core_test.py │ │ ├── einops.py │ │ ├── einops_test.py │ │ ├── function.py │ │ ├── function_test.py │ │ ├── image.py │ │ ├── image_test.py │ │ ├── linalg.py │ │ ├── linalg_test.py │ │ ├── math.py │ │ ├── math_test.py │ │ ├── nn.py │ │ ├── nn_test.py │ │ ├── node.py │ │ ├── node_test.py │ │ ├── numpy.py │ │ ├── numpy_test.py │ │ ├── operation.py │ │ ├── operation_test.py │ │ ├── operation_utils.py │ │ ├── operation_utils_test.py │ │ ├── ops_test.py │ │ ├── symbolic_arguments.py │ │ └── symbolic_arguments_test.py │ ├── optimizers/ │ │ ├── __init__.py │ │ ├── adadelta.py │ │ ├── adadelta_test.py │ │ ├── adafactor.py │ │ ├── adafactor_test.py │ │ ├── adagrad.py │ │ ├── adagrad_test.py │ │ ├── adam.py │ │ ├── adam_test.py │ │ ├── adamax.py │ │ ├── adamax_test.py │ │ ├── adamw.py │ │ ├── adamw_test.py │ │ ├── base_optimizer.py │ │ ├── ftrl.py │ │ ├── ftrl_test.py │ │ ├── lamb.py │ │ ├── lamb_test.py │ │ ├── lion.py │ │ ├── lion_test.py │ │ ├── loss_scale_optimizer.py │ │ ├── loss_scale_optimizer_test.py │ │ ├── muon.py │ │ ├── muon_test.py │ │ ├── nadam.py │ │ ├── nadam_test.py │ │ ├── optimizer.py │ │ ├── optimizer_sparse_test.py │ │ ├── optimizer_test.py │ │ ├── rmsprop.py │ │ ├── rmsprop_test.py │ │ ├── schedule_free_adamw.py │ │ ├── schedule_free_adamw_test.py │ │ ├── schedules/ │ │ │ ├── __init__.py │ │ │ ├── learning_rate_schedule.py │ │ │ └── learning_rate_schedule_test.py │ │ ├── sgd.py │ │ └── sgd_test.py │ ├── quantizers/ │ │ ├── __init__.py │ │ ├── awq.py │ │ ├── awq_config.py │ │ ├── awq_config_test.py │ │ ├── awq_core.py │ │ ├── awq_test.py │ │ ├── gptq.py │ │ ├── gptq_config.py │ │ ├── gptq_config_test.py │ │ ├── gptq_core.py │ │ ├── gptq_core_test.py │ │ ├── gptq_test.py │ │ ├── quantization_config.py │ │ ├── quantization_config_test.py │ │ ├── quantizers.py │ │ ├── quantizers_test.py │ │ ├── utils.py │ │ └── utils_test.py │ ├── random/ │ │ ├── __init__.py │ │ ├── random.py │ │ ├── random_test.py │ │ ├── seed_generator.py │ │ └── seed_generator_test.py │ ├── regularizers/ │ │ ├── __init__.py │ │ ├── regularizers.py │ │ └── regularizers_test.py │ ├── saving/ │ │ ├── __init__.py │ │ ├── file_editor.py │ │ ├── file_editor_test.py │ │ ├── keras_saveable.py │ │ ├── object_registration.py │ │ ├── object_registration_test.py │ │ ├── orbax_util.py │ │ ├── saving_api.py │ │ ├── saving_api_test.py │ │ ├── saving_lib.py │ │ ├── saving_lib_test.py │ │ ├── serialization_lib.py │ │ └── serialization_lib_test.py │ ├── testing/ │ │ ├── __init__.py │ │ ├── test_case.py │ │ ├── test_utils.py │ │ └── test_utils_test.py │ ├── trainers/ │ │ ├── __init__.py │ │ ├── compile_utils.py │ │ ├── compile_utils_test.py │ │ ├── data_adapters/ │ │ │ ├── __init__.py │ │ │ ├── array_data_adapter.py │ │ │ ├── array_data_adapter_test.py │ │ │ ├── array_slicing.py │ │ │ ├── data_adapter.py │ │ │ ├── data_adapter_utils.py │ │ │ ├── data_adapter_utils_test.py │ │ │ ├── generator_data_adapter.py │ │ │ ├── generator_data_adapter_test.py │ │ │ ├── grain_dataset_adapter.py │ │ │ ├── grain_dataset_adapter_test.py │ │ │ ├── py_dataset_adapter.py │ │ │ ├── py_dataset_adapter_test.py │ │ │ ├── tf_dataset_adapter.py │ │ │ ├── tf_dataset_adapter_test.py │ │ │ ├── torch_data_loader_adapter.py │ │ │ └── torch_data_loader_adapter_test.py │ │ ├── epoch_iterator.py │ │ ├── epoch_iterator_test.py │ │ ├── trainer.py │ │ └── trainer_test.py │ ├── tree/ │ │ ├── __init__.py │ │ ├── dmtree_impl.py │ │ ├── optree_impl.py │ │ ├── torchtree_impl.py │ │ ├── tree_api.py │ │ └── tree_test.py │ ├── utils/ │ │ ├── __init__.py │ │ ├── argument_validation.py │ │ ├── audio_dataset_utils.py │ │ ├── audio_dataset_utils_test.py │ │ ├── backend_utils.py │ │ ├── backend_utils_test.py │ │ ├── code_stats.py │ │ ├── code_stats_test.py │ │ ├── config.py │ │ ├── dataset_utils.py │ │ ├── dataset_utils_test.py │ │ ├── dtype_utils.py │ │ ├── dtype_utils_test.py │ │ ├── file_utils.py │ │ ├── file_utils_test.py │ │ ├── grain_utils.py │ │ ├── image_dataset_utils.py │ │ ├── image_dataset_utils_test.py │ │ ├── image_utils.py │ │ ├── image_utils_test.py │ │ ├── io_utils.py │ │ ├── io_utils_test.py │ │ ├── jax_layer.py │ │ ├── jax_layer_test.py │ │ ├── jax_utils.py │ │ ├── model_visualization.py │ │ ├── module_utils.py │ │ ├── naming.py │ │ ├── naming_test.py │ │ ├── numerical_utils.py │ │ ├── numerical_utils_test.py │ │ ├── progbar.py │ │ ├── progbar_test.py │ │ ├── python_utils.py │ │ ├── python_utils_test.py │ │ ├── rng_utils.py │ │ ├── rng_utils_test.py │ │ ├── sequence_utils.py │ │ ├── sequence_utils_test.py │ │ ├── summary_utils.py │ │ ├── summary_utils_test.py │ │ ├── text_dataset_utils.py │ │ ├── text_dataset_utils_test.py │ │ ├── tf_utils.py │ │ ├── timeseries_dataset_utils.py │ │ ├── timeseries_dataset_utils_test.py │ │ ├── torch_utils.py │ │ ├── torch_utils_test.py │ │ ├── traceback_utils.py │ │ ├── tracking.py │ │ └── tracking_test.py │ ├── version.py │ ├── visualization/ │ │ ├── __init__.py │ │ ├── draw_bounding_boxes.py │ │ ├── draw_segmentation_masks.py │ │ ├── plot_bounding_box_gallery.py │ │ ├── plot_image_gallery.py │ │ └── plot_segmentation_mask_gallery.py │ └── wrappers/ │ ├── __init__.py │ ├── fixes.py │ ├── sklearn_test.py │ ├── sklearn_wrapper.py │ └── utils.py ├── pip_build.py ├── pyproject.toml ├── requirements-common.txt ├── requirements-jax-cuda.txt ├── requirements-jax-tpu.txt ├── requirements-tensorflow-cuda.txt ├── requirements-tensorflow-tpu.txt ├── requirements-torch-cuda.txt ├── requirements.txt └── shell/ ├── api_gen.sh └── format.sh