Copy disabled (too large)
Download .txt
Showing preview only (11,013K chars total). Download the full file to get everything.
Repository: keras-team/keras
Branch: master
Commit: 4d047a81be4e
Files: 1011
Total size: 10.3 MB
Directory structure:
gitextract_qha1vuxj/
├── .devcontainer/
│ ├── README.md
│ ├── devcontainer.json
│ └── setup.sh
├── .gemini/
│ ├── config.yaml
│ └── styleguide.md
├── .github/
│ ├── dependabot.yml
│ └── workflows/
│ ├── actions.yml
│ ├── auto-assignment.yaml
│ ├── config/
│ │ ├── jax/
│ │ │ └── keras.json
│ │ ├── numpy/
│ │ │ └── keras.json
│ │ ├── openvino/
│ │ │ └── keras.json
│ │ ├── tensorflow/
│ │ │ └── keras.json
│ │ └── torch/
│ │ └── keras.json
│ ├── gpu_tests.yml
│ ├── labeler.yaml
│ ├── nightly.yml
│ ├── scorecard.yml
│ ├── scripts/
│ │ ├── auto-assignment.js
│ │ └── labeler.js
│ ├── stale-issue-pr.yaml
│ └── tpu_tests.yml
├── .gitignore
├── .kokoro/
│ ├── README.md
│ └── github/
│ └── ubuntu/
│ └── gpu/
│ ├── build.sh
│ ├── jax/
│ │ ├── continuous.cfg
│ │ └── presubmit.cfg
│ └── tensorflow/
│ ├── continuous.cfg
│ └── presubmit.cfg
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── SECURITY.md
├── api_gen.py
├── benchmarks/
│ ├── __init__.py
│ ├── layer_benchmark/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── activation_benchmark.py
│ │ ├── attention_benchmark.py
│ │ ├── base_benchmark.py
│ │ ├── conv_benchmark.py
│ │ ├── core_benchmark.py
│ │ ├── merge_benchmark.py
│ │ ├── normalization_benchmark.py
│ │ ├── pooling_benchmark.py
│ │ ├── random_rotation_benchmark.py
│ │ ├── regularization_benchmark.py
│ │ ├── reshaping_benchmark.py
│ │ └── rnn_benchmark.py
│ ├── model_benchmark/
│ │ ├── __init__.py
│ │ ├── benchmark_utils.py
│ │ ├── bert_benchmark.py
│ │ └── image_classification_benchmark.py
│ └── torch_ctl_benchmark/
│ ├── README.md
│ ├── __init__.py
│ ├── benchmark_utils.py
│ ├── conv_model_benchmark.py
│ └── dense_model_benchmark.py
├── codecov.yml
├── conftest.py
├── examples/
│ ├── demo_custom_jax_workflow.py
│ ├── demo_custom_layer_backend_agnostic.py
│ ├── demo_custom_tf_workflow.py
│ ├── demo_custom_torch_workflow.py
│ ├── demo_functional.py
│ ├── demo_jax_distributed.py
│ ├── demo_mnist_convnet.py
│ ├── demo_subclass.py
│ └── demo_torch_multi_gpu.py
├── guides/
│ ├── custom_train_step_in_jax.py
│ ├── custom_train_step_in_tensorflow.py
│ ├── custom_train_step_in_torch.py
│ ├── distributed_training_with_jax.py
│ ├── distributed_training_with_tensorflow.py
│ ├── distributed_training_with_torch.py
│ ├── functional_api.py
│ ├── making_new_layers_and_models_via_subclassing.py
│ ├── sequential_model.py
│ ├── training_with_built_in_methods.py
│ ├── transfer_learning.py
│ ├── understanding_masking_and_padding.py
│ ├── writing_a_custom_training_loop_in_jax.py
│ ├── writing_a_custom_training_loop_in_tensorflow.py
│ ├── writing_a_custom_training_loop_in_torch.py
│ └── writing_your_own_callbacks.py
├── integration_tests/
│ ├── basic_full_flow.py
│ ├── dataset_tests/
│ │ ├── boston_housing_test.py
│ │ ├── california_housing_test.py
│ │ ├── cifar100_test.py
│ │ ├── cifar10_test.py
│ │ ├── fashion_mnist_test.py
│ │ ├── imdb_test.py
│ │ ├── mnist_test.py
│ │ └── reuters_test.py
│ ├── import_test.py
│ ├── jax_custom_fit_test.py
│ ├── model_visualization_test.py
│ ├── numerical_test.py
│ ├── pytorch_export_test.py
│ ├── tf_custom_fit_test.py
│ ├── tf_distribute_training_test.py
│ ├── torch_custom_fit_test.py
│ └── torch_workflow_test.py
├── keras/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── _tf_keras/
│ │ │ ├── __init__.py
│ │ │ └── keras/
│ │ │ ├── __init__.py
│ │ │ ├── activations/
│ │ │ │ └── __init__.py
│ │ │ ├── applications/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── convnext/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── densenet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── efficientnet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── efficientnet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── imagenet_utils/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── inception_resnet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── inception_v3/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mobilenet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mobilenet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mobilenet_v3/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── nasnet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── resnet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── resnet50/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── resnet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── vgg16/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── vgg19/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── xception/
│ │ │ │ └── __init__.py
│ │ │ ├── backend/
│ │ │ │ └── __init__.py
│ │ │ ├── callbacks/
│ │ │ │ └── __init__.py
│ │ │ ├── config/
│ │ │ │ └── __init__.py
│ │ │ ├── constraints/
│ │ │ │ └── __init__.py
│ │ │ ├── datasets/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── boston_housing/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── california_housing/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── cifar10/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── cifar100/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── fashion_mnist/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── imdb/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mnist/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── reuters/
│ │ │ │ └── __init__.py
│ │ │ ├── distillation/
│ │ │ │ └── __init__.py
│ │ │ ├── distribution/
│ │ │ │ └── __init__.py
│ │ │ ├── dtype_policies/
│ │ │ │ └── __init__.py
│ │ │ ├── export/
│ │ │ │ └── __init__.py
│ │ │ ├── initializers/
│ │ │ │ └── __init__.py
│ │ │ ├── layers/
│ │ │ │ └── __init__.py
│ │ │ ├── legacy/
│ │ │ │ ├── __init__.py
│ │ │ │ └── saving/
│ │ │ │ └── __init__.py
│ │ │ ├── losses/
│ │ │ │ └── __init__.py
│ │ │ ├── metrics/
│ │ │ │ └── __init__.py
│ │ │ ├── mixed_precision/
│ │ │ │ └── __init__.py
│ │ │ ├── models/
│ │ │ │ └── __init__.py
│ │ │ ├── ops/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── image/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── linalg/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── nn/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── numpy/
│ │ │ │ └── __init__.py
│ │ │ ├── optimizers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── legacy/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── schedules/
│ │ │ │ └── __init__.py
│ │ │ ├── preprocessing/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── image/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── sequence/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── text/
│ │ │ │ └── __init__.py
│ │ │ ├── quantizers/
│ │ │ │ └── __init__.py
│ │ │ ├── random/
│ │ │ │ └── __init__.py
│ │ │ ├── regularizers/
│ │ │ │ └── __init__.py
│ │ │ ├── saving/
│ │ │ │ └── __init__.py
│ │ │ ├── tree/
│ │ │ │ └── __init__.py
│ │ │ ├── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bounding_boxes/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── legacy/
│ │ │ │ └── __init__.py
│ │ │ ├── visualization/
│ │ │ │ └── __init__.py
│ │ │ └── wrappers/
│ │ │ └── __init__.py
│ │ ├── activations/
│ │ │ └── __init__.py
│ │ ├── applications/
│ │ │ ├── __init__.py
│ │ │ ├── convnext/
│ │ │ │ └── __init__.py
│ │ │ ├── densenet/
│ │ │ │ └── __init__.py
│ │ │ ├── efficientnet/
│ │ │ │ └── __init__.py
│ │ │ ├── efficientnet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── imagenet_utils/
│ │ │ │ └── __init__.py
│ │ │ ├── inception_resnet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── inception_v3/
│ │ │ │ └── __init__.py
│ │ │ ├── mobilenet/
│ │ │ │ └── __init__.py
│ │ │ ├── mobilenet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── mobilenet_v3/
│ │ │ │ └── __init__.py
│ │ │ ├── nasnet/
│ │ │ │ └── __init__.py
│ │ │ ├── resnet/
│ │ │ │ └── __init__.py
│ │ │ ├── resnet50/
│ │ │ │ └── __init__.py
│ │ │ ├── resnet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── vgg16/
│ │ │ │ └── __init__.py
│ │ │ ├── vgg19/
│ │ │ │ └── __init__.py
│ │ │ └── xception/
│ │ │ └── __init__.py
│ │ ├── backend/
│ │ │ └── __init__.py
│ │ ├── callbacks/
│ │ │ └── __init__.py
│ │ ├── config/
│ │ │ └── __init__.py
│ │ ├── constraints/
│ │ │ └── __init__.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── boston_housing/
│ │ │ │ └── __init__.py
│ │ │ ├── california_housing/
│ │ │ │ └── __init__.py
│ │ │ ├── cifar10/
│ │ │ │ └── __init__.py
│ │ │ ├── cifar100/
│ │ │ │ └── __init__.py
│ │ │ ├── fashion_mnist/
│ │ │ │ └── __init__.py
│ │ │ ├── imdb/
│ │ │ │ └── __init__.py
│ │ │ ├── mnist/
│ │ │ │ └── __init__.py
│ │ │ └── reuters/
│ │ │ └── __init__.py
│ │ ├── distillation/
│ │ │ └── __init__.py
│ │ ├── distribution/
│ │ │ └── __init__.py
│ │ ├── dtype_policies/
│ │ │ └── __init__.py
│ │ ├── export/
│ │ │ └── __init__.py
│ │ ├── initializers/
│ │ │ └── __init__.py
│ │ ├── layers/
│ │ │ └── __init__.py
│ │ ├── legacy/
│ │ │ ├── __init__.py
│ │ │ └── saving/
│ │ │ └── __init__.py
│ │ ├── losses/
│ │ │ └── __init__.py
│ │ ├── metrics/
│ │ │ └── __init__.py
│ │ ├── mixed_precision/
│ │ │ └── __init__.py
│ │ ├── models/
│ │ │ └── __init__.py
│ │ ├── ops/
│ │ │ ├── __init__.py
│ │ │ ├── image/
│ │ │ │ └── __init__.py
│ │ │ ├── linalg/
│ │ │ │ └── __init__.py
│ │ │ ├── nn/
│ │ │ │ └── __init__.py
│ │ │ └── numpy/
│ │ │ └── __init__.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── legacy/
│ │ │ │ └── __init__.py
│ │ │ └── schedules/
│ │ │ └── __init__.py
│ │ ├── preprocessing/
│ │ │ ├── __init__.py
│ │ │ ├── image/
│ │ │ │ └── __init__.py
│ │ │ └── sequence/
│ │ │ └── __init__.py
│ │ ├── quantizers/
│ │ │ └── __init__.py
│ │ ├── random/
│ │ │ └── __init__.py
│ │ ├── regularizers/
│ │ │ └── __init__.py
│ │ ├── saving/
│ │ │ └── __init__.py
│ │ ├── tree/
│ │ │ └── __init__.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ ├── bounding_boxes/
│ │ │ │ └── __init__.py
│ │ │ └── legacy/
│ │ │ └── __init__.py
│ │ ├── visualization/
│ │ │ └── __init__.py
│ │ └── wrappers/
│ │ └── __init__.py
│ └── src/
│ ├── __init__.py
│ ├── activations/
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ └── activations_test.py
│ ├── api_export.py
│ ├── applications/
│ │ ├── __init__.py
│ │ ├── applications_test.py
│ │ ├── convnext.py
│ │ ├── densenet.py
│ │ ├── efficientnet.py
│ │ ├── efficientnet_v2.py
│ │ ├── imagenet_utils.py
│ │ ├── imagenet_utils_test.py
│ │ ├── inception_resnet_v2.py
│ │ ├── inception_v3.py
│ │ ├── mobilenet.py
│ │ ├── mobilenet_v2.py
│ │ ├── mobilenet_v3.py
│ │ ├── nasnet.py
│ │ ├── resnet.py
│ │ ├── resnet_v2.py
│ │ ├── vgg16.py
│ │ ├── vgg19.py
│ │ └── xception.py
│ ├── backend/
│ │ ├── __init__.py
│ │ ├── common/
│ │ │ ├── __init__.py
│ │ │ ├── backend_utils.py
│ │ │ ├── backend_utils_test.py
│ │ │ ├── compute_output_spec_test.py
│ │ │ ├── dtypes.py
│ │ │ ├── dtypes_test.py
│ │ │ ├── global_state.py
│ │ │ ├── global_state_test.py
│ │ │ ├── keras_tensor.py
│ │ │ ├── keras_tensor_test.py
│ │ │ ├── masking.py
│ │ │ ├── masking_test.py
│ │ │ ├── name_scope.py
│ │ │ ├── name_scope_test.py
│ │ │ ├── remat.py
│ │ │ ├── remat_test.py
│ │ │ ├── stateless_scope.py
│ │ │ ├── stateless_scope_test.py
│ │ │ ├── symbolic_scope.py
│ │ │ ├── symbolic_scope_test.py
│ │ │ ├── tensor_attributes.py
│ │ │ ├── thread_safe_test.py
│ │ │ ├── variables.py
│ │ │ └── variables_test.py
│ │ ├── config.py
│ │ ├── jax/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── core_test.py
│ │ │ ├── distribution_lib.py
│ │ │ ├── distribution_lib_test.py
│ │ │ ├── excluded_tpu_tests.txt
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── optimizer.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ ├── sparse.py
│ │ │ ├── tensorboard.py
│ │ │ ├── trainer.py
│ │ │ └── trainer_test.py
│ │ ├── numpy/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ └── trainer.py
│ │ ├── openvino/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── excluded_concrete_tests.txt
│ │ │ ├── excluded_tests.txt
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ └── trainer.py
│ │ ├── tensorflow/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── distribute_test.py
│ │ │ ├── distribution_lib.py
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── name_scope_test.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── optimizer.py
│ │ │ ├── optimizer_distribute_test.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ ├── saved_model_test.py
│ │ │ ├── sparse.py
│ │ │ ├── tensorboard.py
│ │ │ ├── trackable.py
│ │ │ └── trainer.py
│ │ ├── tests/
│ │ │ ├── compute_output_spec_test.py
│ │ │ └── device_scope_test.py
│ │ └── torch/
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── export.py
│ │ ├── image.py
│ │ ├── layer.py
│ │ ├── linalg.py
│ │ ├── math.py
│ │ ├── nn.py
│ │ ├── numpy.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── torch_adadelta.py
│ │ │ ├── torch_adagrad.py
│ │ │ ├── torch_adam.py
│ │ │ ├── torch_adamax.py
│ │ │ ├── torch_adamw.py
│ │ │ ├── torch_lion.py
│ │ │ ├── torch_nadam.py
│ │ │ ├── torch_optimizer.py
│ │ │ ├── torch_parallel_optimizer.py
│ │ │ ├── torch_rmsprop.py
│ │ │ └── torch_sgd.py
│ │ ├── random.py
│ │ ├── rnn.py
│ │ └── trainer.py
│ ├── callbacks/
│ │ ├── __init__.py
│ │ ├── backup_and_restore.py
│ │ ├── backup_and_restore_test.py
│ │ ├── callback.py
│ │ ├── callback_list.py
│ │ ├── callback_test.py
│ │ ├── csv_logger.py
│ │ ├── csv_logger_test.py
│ │ ├── early_stopping.py
│ │ ├── early_stopping_test.py
│ │ ├── history.py
│ │ ├── lambda_callback.py
│ │ ├── lambda_callback_test.py
│ │ ├── learning_rate_scheduler.py
│ │ ├── learning_rate_scheduler_test.py
│ │ ├── model_checkpoint.py
│ │ ├── model_checkpoint_test.py
│ │ ├── monitor_callback.py
│ │ ├── monitor_callback_test.py
│ │ ├── orbax_checkpoint.py
│ │ ├── orbax_checkpoint_test.py
│ │ ├── progbar_logger.py
│ │ ├── reduce_lr_on_plateau.py
│ │ ├── reduce_lr_on_plateau_test.py
│ │ ├── remote_monitor.py
│ │ ├── remote_monitor_test.py
│ │ ├── swap_ema_weights.py
│ │ ├── swap_ema_weights_test.py
│ │ ├── tensorboard.py
│ │ ├── tensorboard_test.py
│ │ ├── terminate_on_nan.py
│ │ └── terminate_on_nan_test.py
│ ├── constraints/
│ │ ├── __init__.py
│ │ ├── constraints.py
│ │ └── constraints_test.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── boston_housing.py
│ │ ├── california_housing.py
│ │ ├── cifar.py
│ │ ├── cifar10.py
│ │ ├── cifar100.py
│ │ ├── fashion_mnist.py
│ │ ├── imdb.py
│ │ ├── mnist.py
│ │ └── reuters.py
│ ├── distillation/
│ │ ├── __init__.py
│ │ ├── distillation_loss.py
│ │ ├── distillation_loss_test.py
│ │ ├── distiller.py
│ │ └── distiller_test.py
│ ├── distribution/
│ │ ├── __init__.py
│ │ ├── distribution_lib.py
│ │ └── distribution_lib_test.py
│ ├── dtype_policies/
│ │ ├── __init__.py
│ │ ├── dtype_policy.py
│ │ ├── dtype_policy_map.py
│ │ ├── dtype_policy_map_test.py
│ │ └── dtype_policy_test.py
│ ├── export/
│ │ ├── __init__.py
│ │ ├── export_utils.py
│ │ ├── litert.py
│ │ ├── litert_test.py
│ │ ├── neptune_model_export_archive.py
│ │ ├── onnx.py
│ │ ├── onnx_test.py
│ │ ├── openvino.py
│ │ ├── openvino_test.py
│ │ ├── saved_model.py
│ │ ├── saved_model_export_archive.py
│ │ ├── saved_model_test.py
│ │ ├── tf2onnx_lib.py
│ │ ├── tfsm_layer.py
│ │ └── tfsm_layer_test.py
│ ├── initializers/
│ │ ├── __init__.py
│ │ ├── constant_initializers.py
│ │ ├── constant_initializers_test.py
│ │ ├── initializer.py
│ │ ├── random_initializers.py
│ │ └── random_initializers_test.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── activations/
│ │ │ ├── __init__.py
│ │ │ ├── activation.py
│ │ │ ├── activation_test.py
│ │ │ ├── elu.py
│ │ │ ├── elu_test.py
│ │ │ ├── leaky_relu.py
│ │ │ ├── leaky_relu_test.py
│ │ │ ├── prelu.py
│ │ │ ├── prelu_test.py
│ │ │ ├── relu.py
│ │ │ ├── relu_test.py
│ │ │ ├── softmax.py
│ │ │ └── softmax_test.py
│ │ ├── attention/
│ │ │ ├── __init__.py
│ │ │ ├── additive_attention.py
│ │ │ ├── additive_attention_test.py
│ │ │ ├── attention.py
│ │ │ ├── attention_test.py
│ │ │ ├── grouped_query_attention.py
│ │ │ ├── grouped_query_attention_test.py
│ │ │ ├── multi_head_attention.py
│ │ │ └── multi_head_attention_test.py
│ │ ├── convolutional/
│ │ │ ├── __init__.py
│ │ │ ├── base_conv.py
│ │ │ ├── base_conv_transpose.py
│ │ │ ├── base_depthwise_conv.py
│ │ │ ├── base_separable_conv.py
│ │ │ ├── conv1d.py
│ │ │ ├── conv1d_transpose.py
│ │ │ ├── conv2d.py
│ │ │ ├── conv2d_transpose.py
│ │ │ ├── conv3d.py
│ │ │ ├── conv3d_transpose.py
│ │ │ ├── conv_test.py
│ │ │ ├── conv_transpose_test.py
│ │ │ ├── depthwise_conv1d.py
│ │ │ ├── depthwise_conv2d.py
│ │ │ ├── depthwise_conv_test.py
│ │ │ ├── separable_conv1d.py
│ │ │ ├── separable_conv2d.py
│ │ │ └── separable_conv_test.py
│ │ ├── core/
│ │ │ ├── __init__.py
│ │ │ ├── dense.py
│ │ │ ├── dense_test.py
│ │ │ ├── einsum_dense.py
│ │ │ ├── einsum_dense_test.py
│ │ │ ├── embedding.py
│ │ │ ├── embedding_test.py
│ │ │ ├── identity.py
│ │ │ ├── identity_test.py
│ │ │ ├── input_layer.py
│ │ │ ├── input_layer_test.py
│ │ │ ├── lambda_layer.py
│ │ │ ├── lambda_layer_test.py
│ │ │ ├── masking.py
│ │ │ ├── masking_test.py
│ │ │ ├── reversible_embedding.py
│ │ │ ├── reversible_embedding_test.py
│ │ │ ├── wrapper.py
│ │ │ └── wrapper_test.py
│ │ ├── input_spec.py
│ │ ├── layer.py
│ │ ├── layer_test.py
│ │ ├── merging/
│ │ │ ├── __init__.py
│ │ │ ├── add.py
│ │ │ ├── average.py
│ │ │ ├── base_merge.py
│ │ │ ├── concatenate.py
│ │ │ ├── dot.py
│ │ │ ├── maximum.py
│ │ │ ├── merging_test.py
│ │ │ ├── minimum.py
│ │ │ ├── multiply.py
│ │ │ └── subtract.py
│ │ ├── normalization/
│ │ │ ├── __init__.py
│ │ │ ├── batch_normalization.py
│ │ │ ├── batch_normalization_test.py
│ │ │ ├── group_normalization.py
│ │ │ ├── group_normalization_test.py
│ │ │ ├── layer_normalization.py
│ │ │ ├── layer_normalization_test.py
│ │ │ ├── rms_normalization.py
│ │ │ ├── rms_normalization_test.py
│ │ │ ├── spectral_normalization.py
│ │ │ ├── spectral_normalization_test.py
│ │ │ ├── unit_normalization.py
│ │ │ └── unit_normalization_test.py
│ │ ├── pooling/
│ │ │ ├── __init__.py
│ │ │ ├── adaptive_average_pooling1d.py
│ │ │ ├── adaptive_average_pooling2d.py
│ │ │ ├── adaptive_average_pooling3d.py
│ │ │ ├── adaptive_max_pooling1d.py
│ │ │ ├── adaptive_max_pooling2d.py
│ │ │ ├── adaptive_max_pooling3d.py
│ │ │ ├── adaptive_pooling1d_test.py
│ │ │ ├── adaptive_pooling2d_test.py
│ │ │ ├── adaptive_pooling3d_test.py
│ │ │ ├── average_pooling1d.py
│ │ │ ├── average_pooling2d.py
│ │ │ ├── average_pooling3d.py
│ │ │ ├── average_pooling_test.py
│ │ │ ├── base_adaptive_pooling.py
│ │ │ ├── base_global_pooling.py
│ │ │ ├── base_pooling.py
│ │ │ ├── global_average_pooling1d.py
│ │ │ ├── global_average_pooling2d.py
│ │ │ ├── global_average_pooling3d.py
│ │ │ ├── global_average_pooling_test.py
│ │ │ ├── global_max_pooling1d.py
│ │ │ ├── global_max_pooling2d.py
│ │ │ ├── global_max_pooling3d.py
│ │ │ ├── global_max_pooling_test.py
│ │ │ ├── max_pooling1d.py
│ │ │ ├── max_pooling2d.py
│ │ │ ├── max_pooling3d.py
│ │ │ └── max_pooling_test.py
│ │ ├── preprocessing/
│ │ │ ├── __init__.py
│ │ │ ├── category_encoding.py
│ │ │ ├── category_encoding_test.py
│ │ │ ├── data_layer.py
│ │ │ ├── data_layer_test.py
│ │ │ ├── discretization.py
│ │ │ ├── discretization_test.py
│ │ │ ├── feature_space.py
│ │ │ ├── feature_space_test.py
│ │ │ ├── hashed_crossing.py
│ │ │ ├── hashed_crossing_test.py
│ │ │ ├── hashing.py
│ │ │ ├── hashing_test.py
│ │ │ ├── image_preprocessing/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aug_mix.py
│ │ │ │ ├── aug_mix_test.py
│ │ │ │ ├── auto_contrast.py
│ │ │ │ ├── auto_contrast_test.py
│ │ │ │ ├── base_image_preprocessing_layer.py
│ │ │ │ ├── bounding_boxes/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── bounding_box.py
│ │ │ │ │ ├── converters.py
│ │ │ │ │ ├── converters_test.py
│ │ │ │ │ ├── formats.py
│ │ │ │ │ ├── iou.py
│ │ │ │ │ ├── iou_test.py
│ │ │ │ │ ├── validation.py
│ │ │ │ │ └── validation_test.py
│ │ │ │ ├── center_crop.py
│ │ │ │ ├── center_crop_test.py
│ │ │ │ ├── clahe.py
│ │ │ │ ├── clahe_test.py
│ │ │ │ ├── cut_mix.py
│ │ │ │ ├── cut_mix_test.py
│ │ │ │ ├── equalization.py
│ │ │ │ ├── equalization_test.py
│ │ │ │ ├── max_num_bounding_box.py
│ │ │ │ ├── max_num_bounding_box_test.py
│ │ │ │ ├── mix_up.py
│ │ │ │ ├── mix_up_test.py
│ │ │ │ ├── rand_augment.py
│ │ │ │ ├── rand_augment_test.py
│ │ │ │ ├── random_brightness.py
│ │ │ │ ├── random_brightness_test.py
│ │ │ │ ├── random_color_degeneration.py
│ │ │ │ ├── random_color_degeneration_test.py
│ │ │ │ ├── random_color_jitter.py
│ │ │ │ ├── random_color_jitter_test.py
│ │ │ │ ├── random_contrast.py
│ │ │ │ ├── random_contrast_test.py
│ │ │ │ ├── random_crop.py
│ │ │ │ ├── random_crop_test.py
│ │ │ │ ├── random_elastic_transform.py
│ │ │ │ ├── random_elastic_transform_test.py
│ │ │ │ ├── random_erasing.py
│ │ │ │ ├── random_erasing_test.py
│ │ │ │ ├── random_flip.py
│ │ │ │ ├── random_flip_test.py
│ │ │ │ ├── random_gaussian_blur.py
│ │ │ │ ├── random_gaussian_blur_test.py
│ │ │ │ ├── random_grayscale.py
│ │ │ │ ├── random_grayscale_test.py
│ │ │ │ ├── random_hue.py
│ │ │ │ ├── random_hue_test.py
│ │ │ │ ├── random_invert.py
│ │ │ │ ├── random_invert_test.py
│ │ │ │ ├── random_perspective.py
│ │ │ │ ├── random_perspective_test.py
│ │ │ │ ├── random_posterization.py
│ │ │ │ ├── random_posterization_test.py
│ │ │ │ ├── random_rotation.py
│ │ │ │ ├── random_rotation_test.py
│ │ │ │ ├── random_saturation.py
│ │ │ │ ├── random_saturation_test.py
│ │ │ │ ├── random_sharpness.py
│ │ │ │ ├── random_sharpness_test.py
│ │ │ │ ├── random_shear.py
│ │ │ │ ├── random_shear_test.py
│ │ │ │ ├── random_translation.py
│ │ │ │ ├── random_translation_test.py
│ │ │ │ ├── random_zoom.py
│ │ │ │ ├── random_zoom_test.py
│ │ │ │ ├── resizing.py
│ │ │ │ ├── resizing_test.py
│ │ │ │ ├── solarization.py
│ │ │ │ └── solarization_test.py
│ │ │ ├── index_lookup.py
│ │ │ ├── index_lookup_test.py
│ │ │ ├── integer_lookup.py
│ │ │ ├── integer_lookup_test.py
│ │ │ ├── mel_spectrogram.py
│ │ │ ├── mel_spectrogram_test.py
│ │ │ ├── normalization.py
│ │ │ ├── normalization_test.py
│ │ │ ├── pipeline.py
│ │ │ ├── pipeline_test.py
│ │ │ ├── rescaling.py
│ │ │ ├── rescaling_test.py
│ │ │ ├── stft_spectrogram.py
│ │ │ ├── stft_spectrogram_test.py
│ │ │ ├── string_lookup.py
│ │ │ ├── string_lookup_test.py
│ │ │ ├── text_vectorization.py
│ │ │ └── text_vectorization_test.py
│ │ ├── regularization/
│ │ │ ├── __init__.py
│ │ │ ├── activity_regularization.py
│ │ │ ├── activity_regularization_test.py
│ │ │ ├── alpha_dropout.py
│ │ │ ├── alpha_dropout_test.py
│ │ │ ├── dropout.py
│ │ │ ├── dropout_test.py
│ │ │ ├── gaussian_dropout.py
│ │ │ ├── gaussian_dropout_test.py
│ │ │ ├── gaussian_noise.py
│ │ │ ├── gaussian_noise_test.py
│ │ │ ├── spatial_dropout.py
│ │ │ └── spatial_dropout_test.py
│ │ ├── reshaping/
│ │ │ ├── __init__.py
│ │ │ ├── cropping1d.py
│ │ │ ├── cropping1d_test.py
│ │ │ ├── cropping2d.py
│ │ │ ├── cropping2d_test.py
│ │ │ ├── cropping3d.py
│ │ │ ├── cropping3d_test.py
│ │ │ ├── flatten.py
│ │ │ ├── flatten_test.py
│ │ │ ├── permute.py
│ │ │ ├── permute_test.py
│ │ │ ├── repeat_vector.py
│ │ │ ├── repeat_vector_test.py
│ │ │ ├── reshape.py
│ │ │ ├── reshape_test.py
│ │ │ ├── up_sampling1d.py
│ │ │ ├── up_sampling1d_test.py
│ │ │ ├── up_sampling2d.py
│ │ │ ├── up_sampling2d_test.py
│ │ │ ├── up_sampling3d.py
│ │ │ ├── up_sampling3d_test.py
│ │ │ ├── zero_padding1d.py
│ │ │ ├── zero_padding1d_test.py
│ │ │ ├── zero_padding2d.py
│ │ │ ├── zero_padding2d_test.py
│ │ │ ├── zero_padding3d.py
│ │ │ └── zero_padding3d_test.py
│ │ └── rnn/
│ │ ├── __init__.py
│ │ ├── bidirectional.py
│ │ ├── bidirectional_test.py
│ │ ├── conv_lstm.py
│ │ ├── conv_lstm1d.py
│ │ ├── conv_lstm1d_test.py
│ │ ├── conv_lstm2d.py
│ │ ├── conv_lstm2d_test.py
│ │ ├── conv_lstm3d.py
│ │ ├── conv_lstm3d_test.py
│ │ ├── conv_lstm_test.py
│ │ ├── dropout_rnn_cell.py
│ │ ├── dropout_rnn_cell_test.py
│ │ ├── gru.py
│ │ ├── gru_test.py
│ │ ├── lstm.py
│ │ ├── lstm_test.py
│ │ ├── rnn.py
│ │ ├── rnn_test.py
│ │ ├── simple_rnn.py
│ │ ├── simple_rnn_test.py
│ │ ├── stacked_rnn_cells.py
│ │ ├── stacked_rnn_cells_test.py
│ │ ├── time_distributed.py
│ │ └── time_distributed_test.py
│ ├── legacy/
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── layers.py
│ │ ├── losses.py
│ │ ├── preprocessing/
│ │ │ ├── __init__.py
│ │ │ ├── image.py
│ │ │ ├── sequence.py
│ │ │ └── text.py
│ │ └── saving/
│ │ ├── __init__.py
│ │ ├── json_utils.py
│ │ ├── json_utils_test.py
│ │ ├── legacy_h5_format.py
│ │ ├── legacy_h5_format_test.py
│ │ ├── saving_options.py
│ │ ├── saving_utils.py
│ │ └── serialization.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ ├── loss_test.py
│ │ ├── losses.py
│ │ └── losses_test.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── accuracy_metrics.py
│ │ ├── accuracy_metrics_test.py
│ │ ├── confusion_metrics.py
│ │ ├── confusion_metrics_test.py
│ │ ├── correlation_metrics.py
│ │ ├── correlation_metrics_test.py
│ │ ├── f_score_metrics.py
│ │ ├── f_score_metrics_test.py
│ │ ├── hinge_metrics.py
│ │ ├── hinge_metrics_test.py
│ │ ├── iou_metrics.py
│ │ ├── iou_metrics_test.py
│ │ ├── metric.py
│ │ ├── metric_test.py
│ │ ├── metrics_utils.py
│ │ ├── probabilistic_metrics.py
│ │ ├── probabilistic_metrics_test.py
│ │ ├── reduction_metrics.py
│ │ ├── reduction_metrics_test.py
│ │ ├── regression_metrics.py
│ │ └── regression_metrics_test.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── cloning.py
│ │ ├── cloning_test.py
│ │ ├── functional.py
│ │ ├── functional_test.py
│ │ ├── model.py
│ │ ├── model_test.py
│ │ ├── sequential.py
│ │ ├── sequential_test.py
│ │ ├── variable_mapping.py
│ │ └── variable_mapping_test.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── core_test.py
│ │ ├── einops.py
│ │ ├── einops_test.py
│ │ ├── function.py
│ │ ├── function_test.py
│ │ ├── image.py
│ │ ├── image_test.py
│ │ ├── linalg.py
│ │ ├── linalg_test.py
│ │ ├── math.py
│ │ ├── math_test.py
│ │ ├── nn.py
│ │ ├── nn_test.py
│ │ ├── node.py
│ │ ├── node_test.py
│ │ ├── numpy.py
│ │ ├── numpy_test.py
│ │ ├── operation.py
│ │ ├── operation_test.py
│ │ ├── operation_utils.py
│ │ ├── operation_utils_test.py
│ │ ├── ops_test.py
│ │ ├── symbolic_arguments.py
│ │ └── symbolic_arguments_test.py
│ ├── optimizers/
│ │ ├── __init__.py
│ │ ├── adadelta.py
│ │ ├── adadelta_test.py
│ │ ├── adafactor.py
│ │ ├── adafactor_test.py
│ │ ├── adagrad.py
│ │ ├── adagrad_test.py
│ │ ├── adam.py
│ │ ├── adam_test.py
│ │ ├── adamax.py
│ │ ├── adamax_test.py
│ │ ├── adamw.py
│ │ ├── adamw_test.py
│ │ ├── base_optimizer.py
│ │ ├── ftrl.py
│ │ ├── ftrl_test.py
│ │ ├── lamb.py
│ │ ├── lamb_test.py
│ │ ├── lion.py
│ │ ├── lion_test.py
│ │ ├── loss_scale_optimizer.py
│ │ ├── loss_scale_optimizer_test.py
│ │ ├── muon.py
│ │ ├── muon_test.py
│ │ ├── nadam.py
│ │ ├── nadam_test.py
│ │ ├── optimizer.py
│ │ ├── optimizer_sparse_test.py
│ │ ├── optimizer_test.py
│ │ ├── rmsprop.py
│ │ ├── rmsprop_test.py
│ │ ├── schedule_free_adamw.py
│ │ ├── schedule_free_adamw_test.py
│ │ ├── schedules/
│ │ │ ├── __init__.py
│ │ │ ├── learning_rate_schedule.py
│ │ │ └── learning_rate_schedule_test.py
│ │ ├── sgd.py
│ │ └── sgd_test.py
│ ├── quantizers/
│ │ ├── __init__.py
│ │ ├── awq.py
│ │ ├── awq_config.py
│ │ ├── awq_config_test.py
│ │ ├── awq_core.py
│ │ ├── awq_test.py
│ │ ├── gptq.py
│ │ ├── gptq_config.py
│ │ ├── gptq_config_test.py
│ │ ├── gptq_core.py
│ │ ├── gptq_core_test.py
│ │ ├── gptq_test.py
│ │ ├── quantization_config.py
│ │ ├── quantization_config_test.py
│ │ ├── quantizers.py
│ │ ├── quantizers_test.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── random/
│ │ ├── __init__.py
│ │ ├── random.py
│ │ ├── random_test.py
│ │ ├── seed_generator.py
│ │ └── seed_generator_test.py
│ ├── regularizers/
│ │ ├── __init__.py
│ │ ├── regularizers.py
│ │ └── regularizers_test.py
│ ├── saving/
│ │ ├── __init__.py
│ │ ├── file_editor.py
│ │ ├── file_editor_test.py
│ │ ├── keras_saveable.py
│ │ ├── object_registration.py
│ │ ├── object_registration_test.py
│ │ ├── orbax_util.py
│ │ ├── saving_api.py
│ │ ├── saving_api_test.py
│ │ ├── saving_lib.py
│ │ ├── saving_lib_test.py
│ │ ├── serialization_lib.py
│ │ └── serialization_lib_test.py
│ ├── testing/
│ │ ├── __init__.py
│ │ ├── test_case.py
│ │ ├── test_utils.py
│ │ └── test_utils_test.py
│ ├── trainers/
│ │ ├── __init__.py
│ │ ├── compile_utils.py
│ │ ├── compile_utils_test.py
│ │ ├── data_adapters/
│ │ │ ├── __init__.py
│ │ │ ├── array_data_adapter.py
│ │ │ ├── array_data_adapter_test.py
│ │ │ ├── array_slicing.py
│ │ │ ├── data_adapter.py
│ │ │ ├── data_adapter_utils.py
│ │ │ ├── data_adapter_utils_test.py
│ │ │ ├── generator_data_adapter.py
│ │ │ ├── generator_data_adapter_test.py
│ │ │ ├── grain_dataset_adapter.py
│ │ │ ├── grain_dataset_adapter_test.py
│ │ │ ├── py_dataset_adapter.py
│ │ │ ├── py_dataset_adapter_test.py
│ │ │ ├── tf_dataset_adapter.py
│ │ │ ├── tf_dataset_adapter_test.py
│ │ │ ├── torch_data_loader_adapter.py
│ │ │ └── torch_data_loader_adapter_test.py
│ │ ├── epoch_iterator.py
│ │ ├── epoch_iterator_test.py
│ │ ├── trainer.py
│ │ └── trainer_test.py
│ ├── tree/
│ │ ├── __init__.py
│ │ ├── dmtree_impl.py
│ │ ├── optree_impl.py
│ │ ├── torchtree_impl.py
│ │ ├── tree_api.py
│ │ └── tree_test.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── argument_validation.py
│ │ ├── audio_dataset_utils.py
│ │ ├── audio_dataset_utils_test.py
│ │ ├── backend_utils.py
│ │ ├── backend_utils_test.py
│ │ ├── code_stats.py
│ │ ├── code_stats_test.py
│ │ ├── config.py
│ │ ├── dataset_utils.py
│ │ ├── dataset_utils_test.py
│ │ ├── dtype_utils.py
│ │ ├── dtype_utils_test.py
│ │ ├── file_utils.py
│ │ ├── file_utils_test.py
│ │ ├── grain_utils.py
│ │ ├── image_dataset_utils.py
│ │ ├── image_dataset_utils_test.py
│ │ ├── image_utils.py
│ │ ├── image_utils_test.py
│ │ ├── io_utils.py
│ │ ├── io_utils_test.py
│ │ ├── jax_layer.py
│ │ ├── jax_layer_test.py
│ │ ├── jax_utils.py
│ │ ├── model_visualization.py
│ │ ├── module_utils.py
│ │ ├── naming.py
│ │ ├── naming_test.py
│ │ ├── numerical_utils.py
│ │ ├── numerical_utils_test.py
│ │ ├── progbar.py
│ │ ├── progbar_test.py
│ │ ├── python_utils.py
│ │ ├── python_utils_test.py
│ │ ├── rng_utils.py
│ │ ├── rng_utils_test.py
│ │ ├── sequence_utils.py
│ │ ├── sequence_utils_test.py
│ │ ├── summary_utils.py
│ │ ├── summary_utils_test.py
│ │ ├── text_dataset_utils.py
│ │ ├── text_dataset_utils_test.py
│ │ ├── tf_utils.py
│ │ ├── timeseries_dataset_utils.py
│ │ ├── timeseries_dataset_utils_test.py
│ │ ├── torch_utils.py
│ │ ├── torch_utils_test.py
│ │ ├── traceback_utils.py
│ │ ├── tracking.py
│ │ └── tracking_test.py
│ ├── version.py
│ ├── visualization/
│ │ ├── __init__.py
│ │ ├── draw_bounding_boxes.py
│ │ ├── draw_segmentation_masks.py
│ │ ├── plot_bounding_box_gallery.py
│ │ ├── plot_image_gallery.py
│ │ └── plot_segmentation_mask_gallery.py
│ └── wrappers/
│ ├── __init__.py
│ ├── fixes.py
│ ├── sklearn_test.py
│ ├── sklearn_wrapper.py
│ └── utils.py
├── pip_build.py
├── pyproject.toml
├── requirements-common.txt
├── requirements-jax-cuda.txt
├── requirements-jax-tpu.txt
├── requirements-tensorflow-cuda.txt
├── requirements-tensorflow-tpu.txt
├── requirements-torch-cuda.txt
├── requirements.txt
└── shell/
├── api_gen.sh
└── format.sh
================================================
FILE CONTENTS
================================================
================================================
FILE: .devcontainer/README.md
================================================
# Dev container configurations
This directory contains the configuration for dev containers, which is used to
initialize the development environment in **Codespaces**, **Visual Studio
Code**, and **JetBrains IDEs**. The environment is installed with all the
necessary dependencies for development and is ready for linting, formatting, and
running tests.
* **GitHub Codespaces**. Create a codespace for the repo by clicking
the "Code" button on the main page of the repo, selecting the "Codespaces"
tab, and clicking the "+". The configurations will automatically be used.
Follow
[this guide](https://docs.github.com/en/codespaces/developing-in-a-codespace/creating-a-codespace-for-a-repository)
for more details.
* **Visual Studio Code**. Open the root folder of the repo in VS Code. A
notification will pop up to open it in a dev container with the
configuration. Follow
[this guide](https://code.visualstudio.com/docs/devcontainers/tutorial)
for more details.
* **JetBrains IDEs**. Open the `.devcontainer/devcontainer.json` in your
JetBrains IDE. Click the docker icon to create a dev container.
Follow
[this guide](https://www.jetbrains.com/help/idea/connect-to-devcontainer.html)
for more details.
================================================
FILE: .devcontainer/devcontainer.json
================================================
{
"image": "mcr.microsoft.com/vscode/devcontainers/python:3.10",
"postCreateCommand": "sh ./.devcontainer/setup.sh && pip install -r requirements.txt",
"customizations": {
"vscode": {
"settings": {
"python.testing.pytestEnabled": true,
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.organizeImports": true
},
"[python]": {
"editor.defaultFormatter": "charliermarsh.ruff"
},
"editor.rulers": [
80
]
},
"extensions": [
"charliermarsh.ruff",
"ms-python.python"
]
}
},
"features": {
"ghcr.io/devcontainers/features/github-cli:1": {}
}
}
================================================
FILE: .devcontainer/setup.sh
================================================
sudo pip install --upgrade pip
sudo pip install -r requirements.txt
echo "bash shell/lint.sh" > .git/hooks/pre-commit
chmod a+x .git/hooks/pre-commit
================================================
FILE: .gemini/config.yaml
================================================
have_fun: false
memory_config:
disabled: false
code_review:
disable: false
comment_severity_threshold: MEDIUM
max_review_comments: -1
pull_request_opened:
help: true
summary: true
code_review: true
include_drafts: true
ignore_patterns: []
================================================
FILE: .gemini/styleguide.md
================================================
# Keras API design guidelines
These guidelines are meant to help focus design discussions and help us create delightful developer experiences.
These are meant as guidelines, not rules: each decision should be debated in its own unique context.
Some text remixed from external references:
- [User experience design for APIs](https://blog.keras.io/user-experience-design-for-apis.html)
- [Notes to Myself on Software Engineering](https://medium.com/s/story/notes-to-myself-on-software-engineering-c890f16f4e4d)
---
## Design end-to-end workflows, not individual functions and classes.
When developing APIs, start by designing end-to-end workflows, and only sketch out specific function/class signatures at the end.
- The goal is to arrive at workflows that feel like they are purposefully designed and well-optimized, rather than cobbled together to route around the features provided by the API. The workflows should come first, before atomic features. **Features only exist to support a workflow.** No feature should exist to provide a capability "just in case", "because we can".
- **Every design review document should prominently feature a code example of one or two end-to-end workflows showing the canonical use-case for the new API.**
- Every time we discuss choices surrounding a specific API feature, we should start by asking: **in what workflows will this be used?** Then we should make the choice that makes the most sense with respect to these workflows. We should not make API design decisions about features in isolation.
- This implies that we will often ask the question: **do users really need to configure this parameter?**, and in many cases, the answer will be "no", rather than being "yes" by default.
---
## Carefully weigh whether a new feature should be included.
It's okay to say no: just because someone asks for a feature doesn't mean we should do it. Every feature has a cost that goes beyond the initial CL: maintenance cost, documentation cost, and cognitive cost for our users (a sprawling API surface is a major usability issue).
In particular, in the Keras API, every new feature has to be maintained in perpetuity.
As such, our criteria for adding a new feature in the API is the following:
- **It should be broadly useful to our users**, rather than a niche feature that is only relevant to a specific vertical of researchers. Niche features should be maintained independently by those who need them (e.g. by extending the API via subclassing), as third-party add-on packages.
- **It should be widely recognized as a machine learning best practice.** We will not add new layers/etc that were recently published to ArXiv.org, even in case of claims of increased accuracy/etc. We only add new objects that are already commonly used in the machine learning community. Presumably, a new technique that does result in meaningful gains would be broadly adopted after a few months anyway (like ResNet), and that's when we would be adding it to the core API. SIG-addons maintains a repository of significantly more volatile and independently maintained code to which the barriers to entry are lower.
- **It should have an owner committed to maintaining it in the long term.** In particular, the code should be maintainable by multiple people on the team, not just by one technical guru.
In addition, when saying yes to a request for supporting a new use case, remember that **literally adding what the user/team requested is often not the optimal choice**. Users are focused on their own specific use case, and we must counter this with a holistic and principled vision of the whole project (see: designing end-to-end workflows, not atomic functions/classes). Often, the right answer is to extend an existing feature. **Find the natural place to integrate the new feature in existing APIs.**
### Examples:
- We should not have added the self-normalizing activation function to the API. It was added before passing the test of time, and that technique has shown later not to reach broad adoption. **Note that citation count is not a good metric of adoption**; that paper has a high citation count.
- We should not move to core an API that has debuted somewhere on GitHub or TF-Addons but has failed to gain more than a few users after a few months.
---
## Seek to minimize cognitive load for our users.
Always seek to minimize the cognitive load imposed on our users in the course of using our APIs.
At a high level:
- **Automate everything that can be automated.**
- **Minimize the actions & choices required from the user.** Make sure default values for arguments are sensible and reflect best practices (so that users usually wouldn't have to manually configure these). Don't expose options that are not important or do not match real use cases, "just in case".
- **Design simple and consistent workflows that reflect simple and consistent mental models.**
Here are a few practical rules:
- **No API should deal with internal implementation details.** An API is a language for our users to talk about the problem they care about -- and they don't care about our internal hacks. For instance, an option like `use_locking` in an optimizer should be avoided. If an argument requires users to understand the implementation (not just what the code is supposed to implement, like SGD in this case), then the argument should not be included in the public API. **An API is all about the problem it solves, not about how the code works in the background.**
- **Introduce as few new concepts as possible.** It's not just that additional data structures require more effort in order to learn about their methods and properties, it's that they multiply the number of **mental models** that are necessary to grok your API. Ideally, you should only need **a single universal mental model around which everything is organized** (in Keras, that's the `Layer`). Definitely avoid having more than 2 or 3 mental models underlying the workflows you design. Likewise, avoid having concepts that are mostly overlapping but subtly different, since the difference will be difficult to convey clearly and will confuse our users (like, say, `Network` and `Model` -- this is why we don't export `Network` as a public API).
- **Objects that do interchangeable things should have identical or very close APIs.** In particular they should have the same positional arguments. For example, it should be possible to swap one optimizer for another in user code (when leaving all arguments to their default value) without editing the arguments.
- **If you find yourself proposing a signature with more than 6-7 arguments, consider whether all of these arguments are useful.** How many people and use cases would be affected if you removed one argument? How much would they be affected -- would they be able to easily extend the API (e.g. via subclassing) to support their use case without that built-in argument? Could this API be broken up into smaller, modular objects?
- **Best-practices should come baked into your API.** The simplest way to use your API (leaving all arguments to their default value, using the most obvious tool for the task, etc) should be as close as possible to the best way of solving the problem. In particular, all arguments that can be given a default value should be given a default value, and that default should match the most common use case.
- **Plain Python types are preferable to custom types.** Use tuples, strings, ints... A custom type requires more knowledge and effort on the part of the user (e.g. `TensorShape`, which is also breaking established conventions of scientific Python). **When using enums, make sure that their values are strings**, so as to make it possible for users to pass plain strings (example: `data_format="channels_last"`, `padding="valid"`).
- **Explicit, single-level configuration arguments are preferable to nested, hidden configuration arguments.** Avoid something like: `MyLayer(hyperparameter_dict)`, instead use `MyLayer(units, activation=None, ...)`.
In particular, naming is important and difficult:
- **The meaning of an argument should be clear from its name and should not require knowledge that only the implementers have.** In particular, argument names should only involve recognized terms of art ("L1 norm" is a term of art), and should not involve implementation-related vocabulary (e.g. "fused batchnorm").
- **Avoid `OverlyLongAndSpecificNamingPatterns`.** If you find yourself with argument names with involve more than 3 subparts (e.g. "squared_operator_norm"), reconsider. Argument names should be intuitive and easy to remember.
- Avoid overly generic names (`x`, `variable`, `parameter`).
- **Make sure you are consistent in your naming choices.** Naming consistency means both **internal naming consistency** (don't call `dim` what is called `axis` in other places, don't call `ndims` what is called `ndim` elsewhere) and **consistency with established conventions for the problem domain (terms of art)**. Before settling on a name, make sure to look up existing names used by domain experts (or other APIs). In our case, argument names should be consistent with the broader scientific Python conventions, in particular NumPy.
Note that Keras uses the following naming rules:
- We use the convention `num_*` for counters, though omitting an explicit counter is nicer when there is no ambiguity (e.g. `units`, `epochs`, `filters`).
- The rank of a tensor is its `ndim`. A specific dimension index is an `axis`. The number of dimensions in a linear projection (or similar) is `units`.
- By convention Keras layers are named with nouns rather than verbs (e.g. `Normalization` and not `Normalize`, `Convolution` and not `Convolve`).
- Following Python conventions, classes use capitalized parts (e.g. `ClassName`) and functions and methods use snake case (e.g. `function_name`).
- If an argument name has a numerical suffix (e.g. `alpha_1`), we put an underscore before the suffix in snake case. The capitalized equivalent would be e.g. `Alpha1`.
- We used fully spelled-out names, e.g. `attention_scores` and not `attn_scores`. There are a couple standardized exceptions to this rule, in particular `dim` for "dimension" and `num` for "number". These are sufficiently common that they are not ambiguous to a first-time reader.
### Example:
```python
MyConstructor(
per_variable_sparsity_config=[
'layer_1/kernel:0.8', 'layer_2/kernel:1.5'])
```
What's wrong with this?
- Overly long argument name
- Too much cognitive load involved in preparing an appropriate argument value
- Preparing an argument value requires internal implementation knowledge
- Reliance on TF variable names (subject to changes at any time, thus breaking this code)
- Nested config adding indirection
- Incorrect typing (float values being passing as strings)
Possible alternative:
```
obj = MyConstructor()
obj.configure_sparsity(some_layer.kernel, value=0.8)
obj.configure_sparsity(some_other_layer.kernel, value=1.5)
```
What's nice about this?
- Object-based variable references.
- Modular, simple action, with a clear name.
- Plain Python types.
---
## Balance expressivity vs. user-friendliness.
### Simple use cases should be simple, advanced use cases should be possible:
**Don't increase the cognitive load of common use cases for the sake of niche use cases**, even minimally.
**Make sure that advanced users have a path to support their use case**, even if this path requires the users to roll out plugins or other API extensions (in particular via subclassing). **It is ok for advanced use cases not to be directly supported in the built-in API options.**
### Keep our APIs modular.
**Complex objects should be achievable by composing simple objects with few arguments, that do one thing reliably.** There is a balance to strike between having complex signatures on fewer objects, and having more objects with simpler signatures. A good API has a reasonable number of objects, with reasonably simple signatures (see also: avoiding signatures with more than 6-7 arguments).
**Things that create state or side-effects should be classes. Functions should be stateless.**
For instance, layers that create weights should not be cast as functions, since it makes the weights (and other elements of state) hard to access, impossible to update, and forces reliance on a global state capturing the side effects of layer-functions.
### APIs should be strictly compartmentalized.
For instance, the optimizer API or the layers API should not contain arguments for configuring distributed training. That should go into the distribution API.
---
## Don't neglect error messages, docstrings, and documentation.
Documentation and error messages are an integral part of the API. Good docs and helpful error messages are key to a delightful user experience.
- **Catch user errors early and anticipate common mistakes.** Do user input validation as soon as possible. Actively keep track of common mistakes that people make (by screening GitHub and StackOverflow), and either solve them by simplifying our API, adding targeted error messages for these mistakes, or having a "solutions to common issues" page in our docs. Consider adding automated fallback behaviors (e.g. casting a wrongly-typed input) instead of raising errors, when applicable. Be nice to our users.
- **Provide detailed feedback messages upon user error.** Error messages should be contextual, informative, and actionable. Every error message that transparently provides the user with the solution to their problem means one less support ticket, multiplied by how many times users run into the same issue. A good error message should answer:
- What happened, in what context?
- What did the software expect?
- How can the user fix it?
- **A docstring should answer the question: what is this about, and why & how should I use it?** It should assume as little context as possible, and it shouldn't mention specialized terms without first introducing them (for example, "num_blocks: Number of blocks in the kernel" is not a good argument description if this is the first time you mention "blocks" in your docstring).
- **Show, don't tell: your documentation should not talk about how the software works, it should show how to use it.** Show code examples for end-to-end workflows; show code examples for each and every common use case and key feature of your API. **All docstrings should include code examples.**
- **Deliberately design the user onboarding process for your feature.** How are complete newcomers going to find out the best way to solve their use case with your tool? Have an answer ready. Make sure your onboarding material closely maps to what your users care about: don't teach newcomers how your framework is implemented, teach them how they can use it to solve their own problems. After shipping a CL and writing good docstrings, make sure to create a Colab guide / tutorial showcasing the target workflow, and post it on the docs website.
- The feature is not ready until:
- 1) Users know about it
- 2) They know how to use it
- 3) They're actually using it to solve the corresponding problem.
Note that Keras uses the following rules for writing docstrings:
- For class docstrings, document arguments in a `Arguments:` section in the class docstring, not in `__init__`.
- When a user creates a class, they are not calling the `MyLayer.__init__()` method as if it were a regular method, they are calling `MyLayer`. We don't want to generate documentation for the `__init__()` method as a standalone method that needs to be called directly, that would be confusing. We also don't need `__init__()` docstrings that always start with "Initializes a MyLayer class.", which is useless information. Leaving `__init__()` without a docstring is the best practice.
- If constructor arguments are documented in `__init__`, it forces us to programmatically copy the `__init__` docstring when generating docs and concatenate it to the class docstring. This means that the Arguments section becomes the last thing in the docstring, which is bad.
- The order of information in a class docstring should be:
- One-line description of the class, that gives initial context to the user. e.g. `Applies Dropout to the input.` Make sure the one-line description is useful. No `Intantiates an ObscureName class instance.`
- Paragraph(s) of more detailed information that tells the user what the object is for and when they need to use it. e.g. `The Dropout layer randomly sets input units to 0 with a frequency of "rate" at each step during training time, which helps prevent overfitting. Inputs not set to 0 are scaled up by "1/(1 - rate)" such that the sum over all inputs is unchanged. [...]`
- If there is a reference paper, cite it here.
- `Arguments` section.
- If it's a layer that has arguments in `call`, the `Call arguments` section.
- If it's a `Layer`, `Input shape` and `Output shape` sections.
- Example(s).
- Lastly, addendum. Information that isn't very important and that most users don't need, but that should be documented somewhere.
- e.g. the section "About the layer's `dtype` attribute" in the base Layer class.
- e.g. warnings about edge cases or compatibility issues.
- e.g. pointers to further guides and tutorials.
### Error messages: a case study
The following would be a very poor error message:
```
AssertionError: '1 != 3'
```
In general, to validate user input, always use `ValueError` and avoid `assert`.
Also bad:
```
ValueError: 'Invalid target shape (600, 1).'
```
The following is better, but still not sufficient, because it does not tell the user what they passed, and does not quite say how to fix it:
```
ValueError: 'categorical_crossentropy requires target.shape[1] == classes'
```
Now, here's a good example, that says **what was passed**, **what was expected**, and **how to fix the issue**:
```
ValueError: '''You are passing a target array of shape (600, 1) while using as loss `categorical_crossentropy`.
`categorical_crossentropy` expects targets to be binary matrices (1s and 0s) of shape (samples, classes).
If your targets are integer classes, you can convert them to the expected format via:
---
from keras.utils import to_categorical
y_binary = to_categorical(y_int)
---
Alternatively, you can use the loss function `sparse_categorical_crossentropy` instead, which does expect integer targets.
```
---
When performing code reviews on pull requests, you must strictly adhere to the following principles in addition to the API design guidelines above:
1. **Question the Necessity of Changes**: Do not assume that the pull request changes are strictly necessary. Critically review the proposed changes to ensure they add real value. Point out any code that is solving a non-existent problem or adding unnecessary complexity.
2. **Call out "AI Slop"**: Actively look for and identify "AI slop"—generic, overly verbose, or hallucinated code that lacks context or violates best practices. If you suspect the code is AI slop, explicitly call it out.
3. **Poke Holes in the Implementation**: Your goal is to critically test the logic. Actively search for and point out failing edge cases, race conditions, or unhandled exceptions in the implementation.
4. **Demand Robustness**: Do not accept fragile code. If the proposed code is not robust enough or lacks proper error handling, explicitly tell the author why the current approach is brittle and what must be done to reinforce it.
5. **Respect Existing Repo Patterns**: Before suggesting review comments (like asking users to add boilerplate or specific patterns), actively check for existing design patterns across the repository. Do not suggest adding useless code or structures that contradict or fall outside the established Keras repo coding style.
================================================
FILE: .github/dependabot.yml
================================================
# To get started with Dependabot version updates, you'll need to specify which
# package ecosystems to update and where the package manifests are located.
# Please see the documentation for all configuration options:
# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "monthly"
groups:
github-actions:
patterns:
- "*"
- package-ecosystem: "pip"
directory: "/"
schedule:
interval: "monthly"
groups:
python:
patterns:
- "*"
ignore:
# 2.19.1 is the last version of the TensorFlow that supports TPUs.
- dependency-name: "tensorflow-tpu"
# TODO: ignore all updates for JAX GPU due to cuda version issue
- dependency-name: "jax[cuda12_pip]"
# TODO(#21914): Update this version when TF is updated
- dependency-name: "ai-edge-litert"
================================================
FILE: .github/workflows/actions.yml
================================================
name: Tests
# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future
# Currently only basic flow tests run with NNX enabled
on:
push:
branches: [ master ]
pull_request:
release:
types: [created]
permissions:
contents: read
jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: ['3.11']
backend: [tensorflow, jax, torch, numpy, openvino]
nnx_enabled: [false]
include:
- python-version: '3.11'
backend: jax
nnx_enabled: true
name: ${{ matrix.backend == 'jax' && format('Run tests ({0}, {1}, nnx_enabled = {2})', matrix.python-version, matrix.backend, matrix.nnx_enabled) || format('Run tests ({0}, {1})', matrix.python-version, matrix.backend) }}
runs-on: ubuntu-latest
env:
PYTHON: ${{ matrix.python-version }}
KERAS_HOME: .github/workflows/config/${{ matrix.backend }}
steps:
- uses: actions/checkout@v6.0.2
- name: Check for changes in keras/src/applications
uses: dorny/paths-filter@v3
id: filter
with:
filters: |
applications:
- 'keras/src/applications/**'
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@v5
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
pip install --upgrade flax>=0.11.1
fi
pip install --no-deps tf_keras==2.20.0
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Test applications with pytest
if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}
run: |
pytest keras/src/applications --cov=keras/src/applications --cov-config=pyproject.toml
coverage xml --include='keras/src/applications/*' -o apps-coverage.xml
- name: Codecov keras.applications
if: ${{ steps.filter.outputs.applications == 'true' && matrix.nnx_enabled == false }}
uses: codecov/codecov-action@v5
with:
env_vars: PYTHON,KERAS_HOME
flags: keras.applications,keras.applications-${{ matrix.backend }}
files: apps-coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
- name: Test integrations
if: ${{ matrix.backend != 'numpy' && matrix.nnx_enabled == false }}
run: |
python integration_tests/import_test.py
python integration_tests/numerical_test.py
- name: Test JAX-specific integrations
if: ${{ matrix.backend == 'jax' && matrix.nnx_enabled == false }}
run: |
python integration_tests/jax_custom_fit_test.py
- name: Test basic flow with NNX
if: ${{ matrix.nnx_enabled == true }}
env:
KERAS_NNX_ENABLED: true
run: |
python integration_tests/import_test.py
python integration_tests/basic_full_flow.py
- name: Test TF-specific integrations
if: ${{ matrix.backend == 'tensorflow'}}
run: |
python integration_tests/tf_distribute_training_test.py
python integration_tests/tf_custom_fit_test.py
- name: Test Torch-specific integrations
if: ${{ matrix.backend == 'torch'}}
run: |
pytest integration_tests/torch_workflow_test.py
python integration_tests/torch_custom_fit_test.py
- name: Test with pytest
if: ${{ matrix.nnx_enabled == false }}
run: |
if [ "${{ matrix.backend }}" == "openvino" ]; then
IGNORE_FILE="keras/src/backend/openvino/excluded_tests.txt"
IGNORE_ARGS=$(awk '{print "--ignore=" $0}' "$IGNORE_FILE")
else
IGNORE_ARGS=""
fi
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml
- name: Codecov keras
if: ${{ matrix.nnx_enabled == false }}
uses: codecov/codecov-action@v5
with:
env_vars: PYTHON,KERAS_HOME,KERAS_NNX_ENABLED
flags: keras,keras-${{ matrix.backend }}${{ matrix.nnx_enabled == 'true' && '-nnx' || '' }}
files: core-coverage.xml
token: ${{ secrets.CODECOV_TOKEN }}
fail_ci_if_error: false
format:
name: Check the code format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6.0.2
- name: Set up Python 3.11
uses: actions/setup-python@v6
with:
python-version: '3.11'
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@v5
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Run pre-commit
run: pre-commit run --all-files --hook-stage manual
================================================
FILE: .github/workflows/auto-assignment.yaml
================================================
name: auto-assignment
on:
issues:
types:
- opened
permissions:
contents: read
issues: write
pull-requests: write
jobs:
welcome:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6.0.2
- uses: actions/github-script@v8
with:
script: |
const script = require('./\.github/workflows/scripts/auto-assignment.js')
script({github, context})
================================================
FILE: .github/workflows/config/jax/keras.json
================================================
{
"floatx": "float32",
"epsilon": 1e-07,
"backend": "jax",
"image_data_format": "channels_last",
"nnx_enabled": false
}
================================================
FILE: .github/workflows/config/numpy/keras.json
================================================
{
"floatx": "float32",
"epsilon": 1e-07,
"backend": "numpy",
"image_data_format": "channels_last"
}
================================================
FILE: .github/workflows/config/openvino/keras.json
================================================
{
"floatx": "float32",
"epsilon": 1e-07,
"backend": "openvino",
"image_data_format": "channels_last"
}
================================================
FILE: .github/workflows/config/tensorflow/keras.json
================================================
{
"floatx": "float32",
"epsilon": 1e-07,
"backend": "tensorflow",
"image_data_format": "channels_last"
}
================================================
FILE: .github/workflows/config/torch/keras.json
================================================
{
"floatx": "float32",
"epsilon": 1e-07,
"backend": "torch",
"image_data_format": "channels_first"
}
================================================
FILE: .github/workflows/gpu_tests.yml
================================================
name: Keras GPU Tests
on:
push:
branches: [master]
pull_request:
types: [unlabeled]
release:
types: [created]
permissions:
contents: read
jobs:
test-in-container:
name: Run tests on GPU
runs-on: linux-x86-g2-16-l4-1gpu
# Only run on pushes to master, releases or "kokoro:force-run" unlabel
if: |
github.event_name == 'push' ||
github.event_name == 'release' ||
(github.event_name == 'pull_request' && github.event.action == 'unlabeled' && github.event.label.name == 'kokoro:force-run')
strategy:
fail-fast: false
matrix:
backend: [torch]
container:
image: python:3.11-slim
options: --privileged --network host
steps:
- name: Checkout Repository
uses: actions/checkout@v4
- name: Check CUDA Version
run: nvidia-smi
- name: Install Dependencies
run: pip install --no-cache-dir -r requirements-${{ matrix.backend }}-cuda.txt
- name: Set Keras Backend
run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV
- name: Verify TF Installation
if: ${{ matrix.backend == 'tensorflow'}}
run: python3 -c "import tensorflow as tf; print('Tensorflow devices:', tf.config.list_logical_devices()); assert len(tf.config.list_physical_devices('GPU')) > 0"
- name: Verify JAX Installation
if: ${{ matrix.backend == 'jax'}}
run: python3 -c "import jax; print('JAX devices:', jax.devices()); assert jax.default_backend() == 'gpu'"
- name: Verify Torch Installation
if: ${{ matrix.backend == 'torch'}}
run: python3 -c "import torch; print('Torch devices:', [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]); assert torch.cuda.device_count() > 0"
- name: Run Tests
run: pytest -s keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml
- name: Run Distribution Tests
if: ${{ matrix.backend == 'jax'}}
run: pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml
================================================
FILE: .github/workflows/labeler.yaml
================================================
# Copyright 2024 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This workflow automatically identifies issues and pull requests (PRs) and add the
# appropriate label as per defined rules.
# First Labeler workflow: It searches for the keyword "Gemma" (case-insensitive) in both the title
# and description of the issue/PR. If a match is found, the workflow adds the label 'Gemma' to the issue/PR.
name: 'Labeler'
on:
issues:
types: [edited,opened]
pull_request_target:
types: [opened, edited]
permissions:
contents: read
issues: write
pull-requests: write
jobs:
add_labels:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6.0.2
- uses: actions/github-script@v8
with:
script: |
const script = require('./\.github/workflows/scripts/labeler.js')
script({github, context})
================================================
FILE: .github/workflows/nightly.yml
================================================
name: Nightly
on:
workflow_dispatch: # To Generate wheels on demand outside of schedule.
schedule:
- cron: "0 3 * * *" # run at 3 AM UTC / 8 PM PDT
permissions:
contents: read
jobs:
build:
strategy:
fail-fast: false
matrix:
python-version: ["3.11"]
backend: [tensorflow, jax, torch, numpy]
name: Run tests (Python ${{ matrix.python-version }})
runs-on: ubuntu-latest
env:
PYTHON: ${{ matrix.python-version }}
KERAS_BACKEND: ${{ matrix.backend }}
steps:
- uses: actions/checkout@v6.0.2
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@v5
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Test integrations
if: ${{ matrix.backend != 'numpy'}}
run: |
python integration_tests/import_test.py
- name: Test TF-specific integrations
if: ${{ matrix.backend == 'tensorflow'}}
run: |
python integration_tests/tf_distribute_training_test.py
- name: Test Torch-specific integrations
if: ${{ matrix.backend == 'torch'}}
run: |
pytest integration_tests/torch_workflow_test.py
- name: Test with pytest
run: |
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml
format:
name: Check the code format
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6.0.2
- name: Set up Python 3.11
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip setuptools
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@v5
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Run pre-commit
run: pre-commit run --all-files --hook-stage manual
nightly:
name: Build Wheel file and upload
needs: [build, format]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6.0.2
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip setuptools
pip install twine
pip install -r requirements.txt --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
- name: Build wheel file
run: |
python pip_build.py --nightly
- name: Publish to PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_NIGHTLY_API_TOKEN }}
packages-dir: dist/
verbose: true
================================================
FILE: .github/workflows/scorecard.yml
================================================
name: Scorecard supply-chain security
on:
# For Branch-Protection check. Only the default branch is supported. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection
branch_protection_rule:
# To guarantee Maintained check is occasionally updated. See
# https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained
schedule:
- cron: '42 8 * * 2'
push:
branches: [ "master" ]
# Declare default permissions as read only.
permissions: read-all
jobs:
analysis:
name: Scorecard analysis
runs-on: ubuntu-latest
permissions:
# Needed to upload the results to code-scanning dashboard.
security-events: write
# Needed to publish results and get a badge (see publish_results below).
id-token: write
steps:
- name: "Checkout code"
uses: actions/checkout@0c366fd6a839edf440554fa01a7085ccba70ac98 # v4.1.1
with:
persist-credentials: false
- name: "Run analysis"
uses: ossf/scorecard-action@4eaacf0543bb3f2c246792bd56e8cdeffafb205a # v2.4.3
with:
results_file: results.sarif
results_format: sarif
# (Optional) "write" PAT token. Uncomment the `repo_token` line below if:
# - you want to enable the Branch-Protection check on a *public* repository, or
# - you are installing Scorecard on a *private* repository
# To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat.
# repo_token: ${{ secrets.SCORECARD_TOKEN }}
# Publish results to OpenSSF REST API for easy access by consumers
# Allows the repository to include the Scorecard badge.
# See https://github.com/ossf/scorecard-action#publishing-results.
publish_results: true
# Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF
# format to the repository Actions tab.
- name: "Upload artifact"
uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0
with:
name: SARIF file
path: results.sarif
retention-days: 5
# Upload the results to GitHub's code scanning dashboard.
- name: "Upload to code-scanning"
uses: github/codeql-action/upload-sarif@89a39a4e59826350b863aa6b6252a07ad50cf83e # v3.29.5
with:
sarif_file: results.sarif
================================================
FILE: .github/workflows/scripts/auto-assignment.js
================================================
/**
* @license
* Copyright 2023 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/
/** Automatically assign issues and PRs to users in the `assigneesList`
* on a rotating basis.
@param {!object}
GitHub objects can call GitHub APIs using their built-in library functions.
The context object contains issue and PR details.
*/
module.exports = async ({ github, context }) => {
let issueNumber;
let assigneesList;
// Is this an issue? If so, assign the issue number. Otherwise, assign the PR number.
if (context.payload.issue) {
//assignee List for issues.
assigneesList = ["mehtamansi29", "sachinprasadhs"];
issueNumber = context.payload.issue.number;
} else {
//assignee List for PRs.
assigneesList = [];
issueNumber = context.payload.number;
}
console.log("assignee list", assigneesList);
console.log("entered auto assignment for this issue: ", issueNumber);
if (!assigneesList.length) {
console.log("No assignees found for this repo.");
return;
}
let noOfAssignees = assigneesList.length;
let selection = issueNumber % noOfAssignees;
let assigneeForIssue = assigneesList[selection];
console.log(
"issue Number = ",
issueNumber + " , assigning to: ",
assigneeForIssue
);
return github.rest.issues.addAssignees({
issue_number: context.issue.number,
owner: context.repo.owner,
repo: context.repo.repo,
assignees: [assigneeForIssue],
});
};
================================================
FILE: .github/workflows/scripts/labeler.js
================================================
/*
Copyright 2024 Google LLC. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
/**
* Invoked from labeler.yaml file to add
* label 'Gemma' to the issue and PR for which have gemma keyword present.
* @param {!Object.<string,!Object>} github contains pre defined functions.
* context Information about the workflow run.
*/
module.exports = async ({ github, context }) => {
const issue_title = context.payload.issue ? context.payload.issue.title : context.payload.pull_request.title
const issue_description = context.payload.issue ? context.payload.issue.body : context.payload.pull_request.body
const issue_number = context.payload.issue ? context.payload.issue.number : context.payload.pull_request.number
const keyword_label = {
gemma:'Gemma'
}
const labelsToAdd = []
console.log(issue_title,issue_description,issue_number)
for(const [keyword, label] of Object.entries(keyword_label)){
if(issue_title.toLowerCase().indexOf(keyword) !=-1 || issue_description.toLowerCase().indexOf(keyword) !=-1 ){
console.log(`'${keyword}'keyword is present inside the title or description. Pushing label '${label}' to row.`)
labelsToAdd.push(label)
}
}
if(labelsToAdd.length > 0){
console.log(`Adding labels ${labelsToAdd} to the issue '#${issue_number}'.`)
github.rest.issues.addLabels({
owner: context.repo.owner,
repo: context.repo.repo,
issue_number: context.issue.number,
labels: labelsToAdd
})
}
};
================================================
FILE: .github/workflows/stale-issue-pr.yaml
================================================
name: Close inactive issues
on:
schedule:
- cron: "30 1 * * *"
jobs:
close-issues:
# Don't do this in forks
if: github.repository == 'keras-team/keras'
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
actions: write
steps:
- name: Awaiting response issues
uses: actions/stale@v10
with:
operations-per-run: 500
days-before-issue-stale: 14
days-before-issue-close: 14
stale-issue-label: "stale"
# reason for closed the issue default value is not_planned
close-issue-reason: completed
only-labels: "stat:awaiting response from contributor"
stale-issue-message: >
This issue is stale because it has been open for 14 days with no activity.
It will be closed if no further activity occurs. Thank you.
# List of labels to remove when issues/PRs unstale.
labels-to-remove-when-unstale: "stat:awaiting response from contributor"
close-issue-message: >
This issue was closed because it has been inactive for 28 days.
Please reopen if you'd like to work on this further.
days-before-pr-stale: 14
days-before-pr-close: 14
stale-pr-message: "This PR is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you."
close-pr-message: "This PR was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further."
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: Contribution issues
uses: actions/stale@v10
with:
operations-per-run: 500
days-before-issue-stale: 180
days-before-issue-close: 365
stale-issue-label: "stale"
# reason for closed the issue default value is not_planned
close-issue-reason: not_planned
any-of-labels: "stat:contributions welcome,good first issue"
# List of labels to remove when issues/PRs unstale.
labels-to-remove-when-unstale: "stat:contributions welcome,good first issue"
stale-issue-message: >
This issue is stale because it has been open for 180 days with no activity.
It will be closed if no further activity occurs. Thank you.
close-issue-message: >
This issue was closed because it has been inactive for more than 1 year.
repo-token: ${{ secrets.GITHUB_TOKEN }}
================================================
FILE: .github/workflows/tpu_tests.yml
================================================
name: Keras TPU Tests
on:
push:
branches: [master]
pull_request:
types: [unlabeled]
release:
types: [created]
permissions:
contents: read
jobs:
test-in-container:
name: Run tests on TPU
runs-on: linux-x86-ct6e-44-1tpu
# Only run on pushes to master, releases or "kokoro:force-run" unlabel
if: |
github.event_name == 'push' ||
github.event_name == 'release' ||
(github.event_name == 'pull_request' && github.event.action == 'unlabeled' && github.event.label.name == 'kokoro:force-run')
strategy:
fail-fast: false
matrix:
backend: [jax]
container:
image: python:3.11-slim
options: --privileged --network host
steps:
- name: Checkout Repository
uses: actions/checkout@v6.0.2
- name: Install Dependencies
run: pip install --no-cache-dir -r requirements-${{ matrix.backend }}-tpu.txt
- name: Set Keras Backend
run: echo "KERAS_BACKEND=jax" >> $GITHUB_ENV
- name: Verify JAX Installation
run: python3 -c "import jax; print('JAX devices:', jax.devices()); assert jax.default_backend() == 'tpu'"
- name: Run Tests
run: pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml
================================================
FILE: .gitignore
================================================
.DS_Store
*.pyc
.vscode-test
__pycache__
**/.vscode-test/**
**/.vscode test/**
**/.vscode-smoke/**
**/.venv*/
venv
bin/**
build/**
obj/**
.pytest_cache
tmp/**
.vs/
dist/**
**/*.egg-info/*
.vscode
examples/**/*.jpg
.python-version
.coverage
*coverage.xml
.ruff_cache
pytest.ini
venv/
================================================
FILE: .kokoro/README.md
================================================
CI to run on PR and merge to Master.
================================================
FILE: .kokoro/github/ubuntu/gpu/build.sh
================================================
set -e
set -x
cd "${KOKORO_ROOT}/"
sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
PYTHON_BINARY="/usr/bin/python3.10"
"${PYTHON_BINARY}" -m venv venv
source venv/bin/activate
# Check the python version
python --version
python3 --version
# setting the LD_LIBRARY_PATH manually is causing segmentation fault
#export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:"
# Check cuda
nvidia-smi
nvcc --version
cd "src/github/keras"
pip install -U pip setuptools
# psutil is used by background log reader
pip install -U psutil
if [ "$KERAS_BACKEND" == "tensorflow" ]
then
echo "TensorFlow backend detected."
pip install -r requirements-tensorflow-cuda.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
echo "Check that TensorFlow uses GPU"
python3 -c 'import tensorflow as tf;print(tf.__version__);print(tf.config.list_physical_devices("GPU"))'
# Raise error if GPU is not detected.
python3 -c 'import tensorflow as tf;assert len(tf.config.list_physical_devices("GPU")) > 0'
# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
pytest keras --ignore keras/src/applications \
--ignore keras/src/layers/merging/merging_test.py \
--cov=keras \
--cov-config=pyproject.toml
fi
if [ "$KERAS_BACKEND" == "jax" ]
then
echo "JAX backend detected."
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
python3 -c 'import jax;print(jax.__version__);print(jax.default_backend())'
# Raise error if GPU is not detected.
python3 -c 'import jax;assert jax.default_backend().lower() == "gpu"'
# TODO: keras/layers/merging/merging_test.py::MergingLayersTest::test_sparse_dot_2d Fatal Python error: Aborted
# TODO: keras/trainers/data_adapters/py_dataset_adapter_test.py::PyDatasetAdapterTest::test_basic_flow0 Fatal Python error: Aborted
# keras/backend/jax/distribution_lib_test.py is configured for CPU test for now.
pytest keras --ignore keras/src/applications \
--ignore keras/src/layers/merging/merging_test.py \
--ignore keras/src/trainers/data_adapters/py_dataset_adapter_test.py \
--ignore keras/src/backend/jax/distribution_lib_test.py \
--ignore keras/src/distribution/distribution_lib_test.py \
--cov=keras \
--cov-config=pyproject.toml
pytest keras/src/distribution/distribution_lib_test.py --cov=keras --cov-config=pyproject.toml
fi
if [ "$KERAS_BACKEND" == "torch" ]
then
echo "PyTorch backend detected."
pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000
pip uninstall -y keras keras-nightly
python3 -c 'import torch;print(torch.__version__);print(torch.cuda.is_available())'
# Raise error if GPU is not detected.
python3 -c 'import torch;assert torch.cuda.is_available()'
pytest keras --ignore keras/src/applications \
--cov=keras \
--cov-config=pyproject.toml
fi
================================================
FILE: .kokoro/github/ubuntu/gpu/jax/continuous.cfg
================================================
build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh"
action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}
env_vars: {
key: "KERAS_BACKEND"
value: "jax"
}
# Set timeout to 120 mins from default 180 mins
timeout_mins: 120
================================================
FILE: .kokoro/github/ubuntu/gpu/jax/presubmit.cfg
================================================
build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh"
action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}
env_vars: {
key: "KERAS_BACKEND"
value: "jax"
}
# Set timeout to 120 mins from default 180 mins
timeout_mins: 120
================================================
FILE: .kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg
================================================
build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh"
action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}
env_vars: {
key: "KERAS_BACKEND"
value: "tensorflow"
}
# Set timeout to 60 mins from default 180 mins
timeout_mins: 60
================================================
FILE: .kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg
================================================
build_file: "keras/.kokoro/github/ubuntu/gpu/build.sh"
action {
define_artifacts {
regex: "**/sponge_log.log"
regex: "**/sponge_log.xml"
}
}
env_vars: {
key: "KERAS_BACKEND"
value: "tensorflow"
}
# Set timeout to 60 mins from default 180 mins
timeout_mins: 60
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: local
hooks:
- id: api-gen
name: api_gen
entry: |
bash shell/api_gen.sh
git status
clean=$(git status | grep "nothing to commit")
if [ -z "$clean" ]; then
echo "Please run shell/api_gen.sh to generate API."
exit 1
fi
language: system
stages: [pre-commit, manual]
require_serial: true
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.2
hooks:
- id: ruff
args: [--config, pyproject.toml, --fix, .]
stages: [pre-commit]
- id: ruff-format
args: [--config, pyproject.toml, .]
stages: [pre-commit]
- id: ruff
args: [--config, pyproject.toml, .]
stages: [manual]
- id: ruff-format
args: ["--check", --config, pyproject.toml, .]
stages: [manual]
================================================
FILE: CONTRIBUTING.md
================================================
Keras 3 is a high-velocity open-source project. We welcome contributions!
Contributions can be made in a variety of ways, including coding, enriching documentation, refining docstrings, and providing code examples.
## Current items open for contributions
At [this link](https://github.com/keras-team/keras/issues/18442), you'll find a list of items where your help is needed!
## How to contribute code
Follow these steps to submit your code contribution.
### Step 1. Open an issue
Before making any changes, we recommend opening an issue (if one doesn't already
exist) and discussing your proposed changes. This way, we can give you feedback
and validate the proposed changes.
If the changes are minor (simple bug fix or documentation fix), then feel free
to open a Pull Request (PR) without discussion.
### Step 2. Make code changes
To make code changes, you need to fork the repository. You will need to setup a
development environment and run the unit tests. This is covered in the section
"Setup environment".
### Step 3. Create a pull request
Once the change is ready, open a pull request from your branch in your fork to
the master branch in [keras-team/keras](https://github.com/keras-team/keras).
### Step 4. Sign the Contributor License Agreement
After creating the pull request, the `cla/google` check will be performed and,
if you haven't signed the Contributor License Agreement (CLA), it will fail with
instructions on how to do so. Please follow the instructions to sign the CLA and
the check will pass.

### Step 5. Code review
If the tests fail, look into the error messages and try to fix them.

A reviewer will review the pull request and provide comments. There may be
several rounds of comments and code changes before the pull request gets
approved by the reviewer.

### Step 6. Merging
Once the pull request is approved, a `ready to pull` tag will be added to the
pull request. A team member will take care of the merging.

Here is an [example pull request](https://github.com/keras-team/keras/pull/18848)
for your reference.
## Setup environment
We provide two ways of setting up a development environment. One is to use a
dev container, and the other one is to set up a local environment by installing
the dev tools needed.
### Option 1: GitHub Codespace or dev container
We support GitHub Codespaces, Visual Studio Code dev containers and JetBrain dev
containers. Please see the
[Dev container documentation](https://github.com/keras-team/keras/tree/master/.devcontainer).
### Option 2: Set up a local environment
To set up your local dev environment, you will need the following tools.
1. [git](https://github.com/) for code repository management.
2. [python](https://www.python.org/) to build and code in Keras.
The following commands check the tools above are successfully installed. Note
that Keras requires at least Python 3.10 to run.
```shell
git --version
python --version
```
Clone your forked repo to your local machine. Go to the cloned directory to
install the dependencies.
```shell
git clone https://github.com/YOUR_GITHUB_USERNAME/keras.git
cd keras
pip install -r requirements.txt
```
You then need to configure the backend to use, see the
[Configuring your backend](https://github.com/keras-team/keras/blob/master/README.md#configuring-your-backend)
section of the README.
You can also add GPU support to your environment, see the
[Adding GPU support](https://github.com/keras-team/keras/blob/master/README.md#adding-gpu-support)
section of the README.
## Generating public API and formatting the code
For the first time you are setting up the repo, please run `pre-commit install`.
Note that this needs to be done only once at the beginning.
Now, whenever you run `git commit -m "<message>"`, three things are
automatically done:
- Public API generation
- Code formatting
- Code linting
If there's any error, the commit will not go through. Please fix the error (
most of the times, the error is fixed automatically by the formatter/linter) and
re-run the following:
```
git add .
git commit -m "<message>" # This will not get logged as a duplicate commit.
```
In case you want to run the above manually on all files, you can do the
following:
```
pre-commit run --all-files
```
KerasHub uses [Ruff](https://docs.astral.sh/ruff/) to format the code.
### Docstrings
We do not have an automated way to check docstring style, so if you write
or edit any docstring, please make sure to check them manually.
Keras docstrings follow the conventions below:
A **class docstring** may contain the following items:
* A one-line description of the class.
* Paragraph(s) of more detailed information.
* Optional `Examples` section.
* `Args` section for arguments in `__init__()`.
* If it's a layer:
* `Call arguments` section for arguments in `Layer.call()`.
* `Returns` section for the return values of `Layer.call()`.
* Optional `Raises` section for possible errors.
You can check out `MultiHeadAttention` as an example
[(link)](https://github.com/keras-team/keras/blob/v3.0.0/keras/layers/attention/multi_head_attention.py#L20).
A **function docstring** may contain the following items:
* One-line description of the function.
* Paragraph(s) of more detailed information.
* Optional `Examples` section.
* `Args` section for the function arguments.
* `Returns` section for the return values.
* Optional `Raises` section for possible errors.
You can check out `text_dataset_from_directory` as an example
[(link)](https://github.com/keras-team/keras/blob/v3.0.0/keras/utils/text_dataset_utils.py#L27).
## Run tests
We use [pytest](https://pytest.org/) to run the tests.
### Run a test file
To run the tests in `keras/src/losses/losses_test.py`, use the following command
at the root directory of the repo.
```shell
pytest keras/src/losses/losses_test.py
```
### Run a single test case
You can specify a single test class to run within a file.
```shell
pytest keras/src/losses/losses_test.py::MeanSquaredErrorTest
```
You can also specify a single test method to run within a class.
```shell
pytest keras/src/losses/losses_test.py::MeanSquaredErrorTest::test_sample_weighted
```
### Run all tests
You can run all the tests locally by running the following command in the repo
root directory.
```shell
pytest keras
```
Note that you can skip the Keras applications tests using the
`SKIP_APPLICATIONS_TESTS` environment variable. This will cut down the testing
time significantly.
```shell
SKIP_APPLICATIONS_TESTS=True pytest keras
```
To run all tests using a different backend, you can simply specify it on the
command line.
```shell
KERAS_BACKEND=jax SKIP_APPLICATIONS_TESTS=True pytest keras
```
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# Keras 3: Deep Learning for Humans
Keras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlow, PyTorch, and OpenVINO (for inference-only).
Effortlessly build and train models for computer vision, natural language processing, audio processing,
timeseries forecasting, recommender systems, etc.
- **Accelerated model development**: Ship deep learning solutions faster thanks to the high-level UX of Keras
and the availability of easy-to-debug runtimes like PyTorch or JAX eager execution.
- **State-of-the-art performance**: By picking the backend that is the fastest for your model architecture (often JAX!),
leverage speedups ranging from 20% to 350% compared to other frameworks. [Benchmark here](https://keras.io/getting_started/benchmarks/).
- **Datacenter-scale training**: Scale confidently from your laptop to large clusters of GPUs or TPUs.
Join nearly three million developers, from burgeoning startups to global enterprises, in harnessing the power of Keras 3.
## Installation
### Install with pip
Keras 3 is available on PyPI as `keras`. Note that Keras 2 remains available as the `tf-keras` package.
1. Install `keras`:
```
pip install keras --upgrade
```
2. Install backend package(s).
To use `keras`, you should also install the backend of choice: `tensorflow`, `jax`, or `torch`. Additionally,
The `openvino` backend is available with support for model inference only.
### Local installation
#### Minimal installation
Keras 3 is compatible with Linux and macOS systems. For Windows users, we recommend using WSL2 to run Keras.
To install a local development version:
1. Install dependencies:
```
pip install -r requirements.txt
```
2. Run installation command from the root directory.
```
python pip_build.py --install
```
3. Run API generation script when creating PRs that update `keras_export` public APIs:
```
./shell/api_gen.sh
```
## Backend Compatibility Table
The following table lists the minimum supported versions of each backend for the latest stable release of Keras (v3.x):
| Backend | Minimum Supported Version |
|------------|---------------------------|
| TensorFlow | 2.16.1 |
| JAX | 0.4.20 |
| PyTorch | 2.1.0 |
| OpenVINO | 2025.3.0 |
#### Adding GPU support
The `requirements.txt` file will install a CPU-only version of TensorFlow, JAX, and PyTorch. For GPU support, we also
provide a separate `requirements-{backend}-cuda.txt` for TensorFlow, JAX, and PyTorch. These install all CUDA
dependencies via `pip` and expect a NVIDIA driver to be pre-installed. We recommend a clean Python environment for each
backend to avoid CUDA version mismatches. As an example, here is how to create a JAX GPU environment with `conda`:
```shell
conda create -y -n keras-jax python=3.10
conda activate keras-jax
pip install -r requirements-jax-cuda.txt
python pip_build.py --install
```
## Configuring your backend
You can export the environment variable `KERAS_BACKEND` or you can edit your local config file at `~/.keras/keras.json`
to configure your backend. Available backend options are: `"tensorflow"`, `"jax"`, `"torch"`, `"openvino"`. Example:
```
export KERAS_BACKEND="jax"
```
In Colab, you can do:
```python
import os
os.environ["KERAS_BACKEND"] = "jax"
import keras
```
**Note:** The backend must be configured before importing `keras`, and the backend cannot be changed after
the package has been imported.
**Note:** The OpenVINO backend is an inference-only backend, meaning it is designed only for running model
predictions using `model.predict()` method.
## Backwards compatibility
Keras 3 is intended to work as a drop-in replacement for `tf.keras` (when using the TensorFlow backend). Just take your
existing `tf.keras` code, make sure that your calls to `model.save()` are using the up-to-date `.keras` format, and you're
done.
If your `tf.keras` model does not include custom components, you can start running it on top of JAX or PyTorch immediately.
If it does include custom components (e.g. custom layers or a custom `train_step()`), it is usually possible to convert it
to a backend-agnostic implementation in just a few minutes.
In addition, Keras models can consume datasets in any format, regardless of the backend you're using:
you can train your models with your existing `tf.data.Dataset` pipelines or PyTorch `DataLoaders`.
## Why use Keras 3?
- Run your high-level Keras workflows on top of any framework -- benefiting at will from the advantages of each framework,
e.g. the scalability and performance of JAX or the production ecosystem options of TensorFlow.
- Write custom components (e.g. layers, models, metrics) that you can use in low-level workflows in any framework.
- You can take a Keras model and train it in a training loop written from scratch in native TF, JAX, or PyTorch.
- You can take a Keras model and use it as part of a PyTorch-native `Module` or as part of a JAX-native model function.
- Make your ML code future-proof by avoiding framework lock-in.
- As a PyTorch user: get access to power and usability of Keras, at last!
- As a JAX user: get access to a fully-featured, battle-tested, well-documented modeling and training library.
Read more in the [Keras 3 release announcement](https://keras.io/keras_3/).
================================================
FILE: SECURITY.md
================================================
# Security Policy
- [**Using Keras Securely**](#using-keras-securely)
- [Untrusted inputs](#untrusted-inputs)
- [Data privacy](#data-privacy)
- [Untrusted environments or networks](#untrusted-environments-or-networks)
- [Multi-Tenant environments](#multi-tenant-environments)
- [**Reporting a Vulnerability**](#reporting-a-vulnerability)
## Using Keras Securely
### Untrusted inputs
Some models accept various input formats (text, images, audio, etc.). The libraries converting these inputs have varying security levels, so it's crucial to isolate the model and carefully pre-process inputs to mitigate script injection risks.
For maximum security when handling untrusted inputs, you may need to employ the following:
* Sandboxing: Isolate the model process.
* Pre-analysis: check how the model performs by default when exposed to prompt injection (e.g. using [fuzzing for prompt injection](https://github.com/FonduAI/awesome-prompt-injection?tab=readme-ov-file#tools)). This will give you leads on how hard you will have to work on the next topics.
* Updates: Keep your model and libraries updated with the latest security patches.
* Input Sanitation: Before feeding data to the model, sanitize inputs rigorously. This involves techniques such as:
* Validation: Enforce strict rules on allowed characters and data types.
* Filtering: Remove potentially malicious scripts or code fragments.
* Encoding: Convert special characters into safe representations.
* Verification: Run tooling that identifies potential script injections (e.g. [models that detect prompt injection attempts](https://python.langchain.com/docs/guides/safety/hugging_face_prompt_injection)).
### Data privacy
To protect sensitive data from potential leaks or unauthorized access, it is essential to sandbox the model execution. This means running the model in a secure, isolated environment, which helps mitigate many attack vectors.
When training the model with sensitive data, expose your newly-trained model to tests to identify potential sensitive data leaks.
### Untrusted environments or networks
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
* Confirm the hash of any downloaded artifact (i.e. pre-trained model weights) matches a known-good value
* Encrypt your data while sending it over the network.
### Multi-Tenant environments
If you intend to run multiple models in parallel with shared memory, it is your responsibility to ensure the models do not interact or access each other's data. The primary areas of concern are tenant isolation, resource allocation, model sharing and hardware attacks.
#### Tenant Isolation
You must make sure that models run separately. Since models can run code, it's important to use strong isolation methods to prevent unwanted access to the data from other tenants.
Separating networks is also a big part of isolation. If you keep model network traffic separate, you not only prevent unauthorized access to data or models, but also prevent malicious users or tenants sending graphs to execute under another tenant’s identity.
#### Resource Allocation
A denial of service caused by one model can impact the overall system health. Implement safeguards like rate limits, access controls, and health monitoring.
#### Model Sharing
In a multitenant design that allows sharing models, make sure that tenants and users fully understand the potential security risks involved. They must be aware that they will essentially be running code provided by other users. Unfortunately, there are no reliable methods available to detect malicious models, graphs, or checkpoints. To mitigate this risk, the recommended approach is to sandbox the model execution, effectively isolating it from the rest of the system.
#### Hardware Attacks
Besides the virtual environment, the hardware (GPUs or TPUs) can also be attacked. [Research](https://scholar.google.com/scholar?q=gpu+side+channel) has shown that side channel attacks on GPUs are possible, which can make data leak from other models or processes running on the same system at the same time.
## Reporting a Vulnerability
Beware that none of the topics under [Using Keras Securely](#using-keras-securely) are considered vulnerabilities of Keras.
If you have discovered a security vulnerability in this project, please report it
privately. **Do not disclose it as a public issue.** This gives us time to work with you
to fix the issue before public exposure, reducing the chance that the exploit will be
used before a patch is released.
You may submit the report in the following ways:
- send an email to francois.chollet@gmail.com and/or
- send a [private vulnerability report](https://github.com/keras-team/keras/security/advisories/new)
Please provide the following information in your report:
- A description of the vulnerability and its impact
- How to reproduce the issue
This project is maintained by volunteers on a reasonable-effort basis. As such,
please give us 90 days to work on a fix before public exposure.
================================================
FILE: api_gen.py
================================================
"""Script to generate keras public API in `keras/api` directory.
Usage:
Run via `./shell/api_gen.sh`.
It generates API and formats user and generated APIs.
"""
import os
import re
import shutil
import namex
PACKAGE = "keras"
BUILD_DIR_NAME = "tmp_build_dir"
def ignore_files(_, filenames):
return [f for f in filenames if f.endswith("_test.py")]
def copy_source_to_build_directory(root_path):
# Copy sources (`keras/` directory and setup files) to build dir
build_dir = os.path.join(root_path, BUILD_DIR_NAME)
build_package_dir = os.path.join(build_dir, PACKAGE)
build_src_dir = os.path.join(build_package_dir, "src")
root_src_dir = os.path.join(root_path, PACKAGE, "src")
if os.path.exists(build_dir):
shutil.rmtree(build_dir)
os.makedirs(build_package_dir)
shutil.copytree(root_src_dir, build_src_dir)
return build_dir
def create_legacy_directory(package_dir):
src_dir = os.path.join(package_dir, "src")
# Make keras/_tf_keras/ by copying keras/
tf_keras_dirpath_parent = os.path.join(package_dir, "_tf_keras")
tf_keras_dirpath = os.path.join(tf_keras_dirpath_parent, "keras")
os.makedirs(tf_keras_dirpath, exist_ok=True)
with open(os.path.join(tf_keras_dirpath_parent, "__init__.py"), "w") as f:
f.write("from keras._tf_keras import keras\n")
with open(os.path.join(package_dir, "__init__.py")) as f:
init_file = f.read()
init_file = init_file.replace(
"from keras import _legacy as _legacy",
"from keras import _tf_keras as _tf_keras",
)
with open(os.path.join(package_dir, "__init__.py"), "w") as f:
f.write(init_file)
# Remove the import of `_tf_keras` in `keras/_tf_keras/keras/__init__.py`
init_file = init_file.replace("from keras import _tf_keras\n", "\n")
with open(os.path.join(tf_keras_dirpath, "__init__.py"), "w") as f:
f.write(init_file)
for dirname in os.listdir(package_dir):
dirpath = os.path.join(package_dir, dirname)
if os.path.isdir(dirpath) and dirname not in (
"_legacy",
"_tf_keras",
"src",
):
destpath = os.path.join(tf_keras_dirpath, dirname)
if os.path.exists(destpath):
shutil.rmtree(destpath)
shutil.copytree(
dirpath,
destpath,
ignore=ignore_files,
)
# Copy keras/_legacy/ file contents to keras/_tf_keras/keras
legacy_submodules = [
path[:-3]
for path in os.listdir(os.path.join(src_dir, "legacy"))
if path.endswith(".py")
]
legacy_submodules += [
path
for path in os.listdir(os.path.join(src_dir, "legacy"))
if os.path.isdir(os.path.join(src_dir, "legacy", path))
]
for root, _, fnames in os.walk(os.path.join(package_dir, "_legacy")):
for fname in fnames:
if fname.endswith(".py"):
legacy_fpath = os.path.join(root, fname)
tf_keras_root = root.replace(
os.path.join(os.path.sep, "_legacy"),
os.path.join(os.path.sep, "_tf_keras", "keras"),
)
core_api_fpath = os.path.join(
root.replace(os.path.join(os.path.sep, "_legacy"), ""),
fname,
)
if not os.path.exists(tf_keras_root):
os.makedirs(tf_keras_root)
tf_keras_fpath = os.path.join(tf_keras_root, fname)
with open(legacy_fpath) as f:
legacy_contents = f.read()
legacy_contents = legacy_contents.replace(
"keras._legacy", "keras._tf_keras.keras"
)
if os.path.exists(core_api_fpath):
with open(core_api_fpath) as f:
core_api_contents = f.read()
core_api_contents = core_api_contents.replace(
"from keras import _tf_keras as _tf_keras\n", ""
)
for legacy_submodule in legacy_submodules:
core_api_contents = core_api_contents.replace(
f"from keras import {legacy_submodule} as {legacy_submodule}\n", # noqa: E501
"",
)
core_api_contents = core_api_contents.replace(
f"keras.{legacy_submodule}",
f"keras._tf_keras.keras.{legacy_submodule}",
)
# Remove duplicate generated comments string.
legacy_contents = re.sub(r"\n", r"\\n", legacy_contents)
legacy_contents = re.sub('""".*"""', "", legacy_contents)
legacy_contents = re.sub(r"\\n", r"\n", legacy_contents)
# If the same module is in legacy and core_api, use legacy
legacy_imports = re.findall(
r"import (\w+)", legacy_contents
)
for import_name in legacy_imports:
core_api_contents = re.sub(
f"\n.* import {import_name} as {import_name}\n",
r"\n",
core_api_contents,
)
legacy_contents = f"{core_api_contents}\n{legacy_contents}"
with open(tf_keras_fpath, "w") as f:
f.write(legacy_contents)
# Delete keras/api/_legacy/
shutil.rmtree(os.path.join(package_dir, "_legacy"))
def export_version_string(api_init_fname):
with open(api_init_fname) as f:
contents = f.read()
with open(api_init_fname, "w") as f:
contents += "from keras.src.version import __version__ as __version__\n"
f.write(contents)
def build():
root_path = os.path.dirname(os.path.abspath(__file__))
code_api_dir = os.path.join(root_path, PACKAGE, "api")
# Create temp build dir
build_dir = copy_source_to_build_directory(root_path)
build_api_dir = os.path.join(build_dir, PACKAGE)
build_src_dir = os.path.join(build_api_dir, "src")
build_api_init_fname = os.path.join(build_api_dir, "__init__.py")
try:
os.chdir(build_dir)
open(build_api_init_fname, "w").close()
namex.generate_api_files(
"keras",
code_directory="src",
exclude_directories=[
os.path.join("src", "backend", "jax"),
os.path.join("src", "backend", "openvino"),
os.path.join("src", "backend", "tensorflow"),
os.path.join("src", "backend", "torch"),
],
)
# Add __version__ to `api/`.
export_version_string(build_api_init_fname)
# Creates `_tf_keras` with full keras API
create_legacy_directory(package_dir=os.path.join(build_dir, PACKAGE))
# Copy back the keras/api and keras/__init__.py from build directory
if os.path.exists(build_src_dir):
shutil.rmtree(build_src_dir)
if os.path.exists(code_api_dir):
shutil.rmtree(code_api_dir)
shutil.copytree(
build_api_dir, code_api_dir, ignore=shutil.ignore_patterns("src/")
)
finally:
# Clean up: remove the build directory (no longer needed)
shutil.rmtree(build_dir)
if __name__ == "__main__":
build()
================================================
FILE: benchmarks/__init__.py
================================================
================================================
FILE: benchmarks/layer_benchmark/README.md
================================================
# Benchmark the layer performance
This directory contains benchmarks to compare the performance of
`keras.layers.XXX` and `tf.keras.layers.XXX`. We compare the performance of
both the forward pass and train step (forward & backward pass).
To run the benchmark, use the command below and change the flags according to
your target:
```shell
python3 -m benchmarks.layer_benchmark.conv_benchmark \
--benchmark_name=benchmark_conv2D \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
================================================
FILE: benchmarks/layer_benchmark/__init__.py
================================================
================================================
FILE: benchmarks/layer_benchmark/activation_benchmark.py
================================================
"""Benchmark activation layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.activation_benchmark \
--benchmark_name=benchmark_elu \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_elu(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ELU"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_prelu(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "PReLU"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_relu(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ReLU"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_leaky_relu(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "LeakyReLU"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_softmax(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Softmax"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_elu": benchmark_elu,
"benchmark_relu": benchmark_relu,
"benchmark_leaky_relu": benchmark_leaky_relu,
"benchmark_prelu": benchmark_prelu,
"benchmark_softmax": benchmark_softmax,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/attention_benchmark.py
================================================
"""Benchmark attention layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.attention_benchmark \
--benchmark_name=benchmark_attention \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_attention(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Attention"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 64], [256, 64]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_multi_head_attention(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "MultiHeadAttention"
init_args = {
"num_heads": 4,
"key_dim": 16,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 64], [256, 64], [256, 64]],
flat_call_inputs=True,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_additive_attention(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "AdditiveAttention"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 64], [256, 64], [256, 64]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_attention": benchmark_attention,
"benchmark_multi_head_attention": benchmark_multi_head_attention,
"benchmark_additive_attention": benchmark_additive_attention,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/base_benchmark.py
================================================
import time
import numpy as np
import tensorflow as tf
from absl import flags
import keras
FLAGS = flags.FLAGS
flags.DEFINE_string(
"benchmark_name",
None,
"The name of benchmark to run. If None, all benchmarks in the file will be "
"run.",
)
flags.DEFINE_integer(
"num_samples",
1000,
"Number of input data samples.",
)
flags.DEFINE_integer(
"batch_size",
20,
"Batch size of data.",
)
flags.DEFINE_bool(
"jit_compile",
True,
"If True, the benchmark will run with XLA compilation.",
)
class BenchmarkMetricsCallback:
def __init__(self, start_batch=1, stop_batch=None):
self.start_batch = start_batch
self.stop_batch = stop_batch
self.state = {}
def on_train_batch_begin(self, batch, logs=None):
if batch == self.start_batch:
self.state["benchmark_begin"] = time.time()
def on_train_batch_end(self, batch, logs=None):
if batch == self.stop_batch:
self.state["benchmark_end"] = time.time()
throughput = (self.stop_batch - self.start_batch + 1) / (
self.state["benchmark_end"] - self.state["benchmark_begin"]
)
self.state["throughput"] = throughput
def on_predict_batch_begin(self, batch, logs=None):
if batch == self.start_batch:
self.state["benchmark_begin"] = time.time()
def on_predict_batch_end(self, batch, logs=None):
if batch == self.stop_batch:
self.state["benchmark_end"] = time.time()
throughput = (self.stop_batch - self.start_batch + 1) / (
self.state["benchmark_end"] - self.state["benchmark_begin"]
)
self.state["throughput"] = throughput
class KerasCoreBenchmarkMetricsCallback(keras.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self._callback = BenchmarkMetricsCallback(start_batch, stop_batch)
def on_train_batch_begin(self, batch, logs=None):
self._callback.on_train_batch_begin(batch, logs)
def on_train_batch_end(self, batch, logs=None):
self._callback.on_train_batch_end(batch, logs)
def on_predict_batch_begin(self, batch, logs=None):
self._callback.on_predict_batch_begin(batch, logs)
def on_predict_batch_end(self, batch, logs=None):
self._callback.on_predict_batch_end(batch, logs)
class TFKerasBenchmarkMetricsCallback(tf.keras.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self._callback = BenchmarkMetricsCallback(start_batch, stop_batch)
def on_train_batch_begin(self, batch, logs=None):
self._callback.on_train_batch_begin(batch, logs)
def on_train_batch_end(self, batch, logs=None):
self._callback.on_train_batch_end(batch, logs)
def on_predict_batch_begin(self, batch, logs=None):
self._callback.on_predict_batch_begin(batch, logs)
def on_predict_batch_end(self, batch, logs=None):
self._callback.on_predict_batch_end(batch, logs)
class LayerBenchmark:
def __init__(
self,
layer_name,
init_args,
input_shape,
flat_call_inputs=True,
jit_compile=True,
keras_layer=None,
tf_keras_layer=None,
):
self.layer_name = layer_name
_keras_layer_class = getattr(keras.layers, layer_name)
_tf_keras_layer_class = getattr(tf.keras.layers, layer_name)
if keras_layer is None:
# Sometimes you want to initialize the keras layer and tf_keras
# layer in a different way. For example, `Bidirectional` layer,
# which takes in `keras.layers.Layer` and
# `tf.keras.layer.Layer` separately.
self._keras_layer = _keras_layer_class(**init_args)
else:
self._keras_layer = keras_layer
if tf_keras_layer is None:
self._tf_keras_layer = _tf_keras_layer_class(**init_args)
else:
self._tf_keras_layer = tf_keras_layer
self.input_shape = input_shape
self._keras_model = self._build_keras_model(
input_shape, flat_call_inputs
)
self._tf_keras_model = self._build_tf_keras_model(
input_shape, flat_call_inputs
)
self._keras_model.compile(
loss="mse", optimizer="sgd", jit_compile=jit_compile
)
self._tf_keras_model.compile(
loss="mse", optimizer="sgd", jit_compile=jit_compile
)
self.flat_call_inputs = flat_call_inputs
self.jit_compile = jit_compile
self.input_shape = input_shape
def _build_keras_model(self, input_shape, flat_call_inputs=True):
inputs = []
if not isinstance(input_shape[0], (tuple, list)):
input_shape = [input_shape]
for shape in input_shape:
inputs.append(keras.Input(shape=shape))
if flat_call_inputs:
outputs = self._keras_layer(*inputs)
else:
outputs = self._keras_layer(inputs)
return keras.Model(inputs=inputs, outputs=outputs)
def _build_tf_keras_model(self, input_shape, flat_call_inputs=True):
inputs = []
if not isinstance(input_shape[0], (tuple, list)):
input_shape = [input_shape]
for shape in input_shape:
inputs.append(tf.keras.Input(shape=shape))
if flat_call_inputs:
outputs = self._tf_keras_layer(*inputs)
else:
outputs = self._tf_keras_layer(inputs)
return tf.keras.Model(inputs=inputs, outputs=outputs)
def benchmark_predict(self, num_samples, batch_size, data=None):
if data is None:
# Generate default data if not provided.
if isinstance(self.input_shape[0], (tuple, list)):
# The layer has multiple inputs.
data = []
for data_shape in self.input_shape:
data_shape = [num_samples] + list(data_shape)
data.append(np.random.normal(size=data_shape))
else:
data_shape = [num_samples] + list(self.input_shape)
data = np.random.normal(size=data_shape)
num_iterations = num_samples // batch_size - 1
callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations)
tf_keras_callback = TFKerasBenchmarkMetricsCallback(
stop_batch=num_iterations
)
self._keras_model.predict(
data,
batch_size=batch_size,
callbacks=[callback],
)
self._tf_keras_model.predict(
data,
batch_size=batch_size,
callbacks=[tf_keras_callback],
)
keras_throughput = callback._callback.state["throughput"] * batch_size
tf_keras_throughput = (
tf_keras_callback._callback.state["throughput"] * batch_size
)
print(
f"Keras 3 throughput of forward pass of {self.layer_name}: "
f"{keras_throughput:.2f} samples/sec."
)
print(
f"TF Keras throughput of forward pass of {self.layer_name}: "
f"{tf_keras_throughput:.2f} samples/sec."
)
def benchmark_train(self, num_samples, batch_size, data=None, label=None):
if data is None:
# Generate default data if not provided.
if isinstance(self.input_shape[0], (tuple, list)):
# The layer has multiple inputs.
data = []
for data_shape in self.input_shape:
data_shape = [num_samples] + list(data_shape)
data.append(np.random.normal(size=data_shape))
else:
data_shape = [num_samples] + list(self.input_shape)
data = [np.random.normal(size=data_shape)]
if label is None:
# Generate default label if not provided.
if self.flat_call_inputs:
# Scale by a small factor to avoid zero gradients.
label = (
keras.backend.convert_to_numpy(self._keras_layer(*data))
* 1.001
)
else:
label = (
keras.backend.convert_to_numpy(self._keras_layer(data))
* 1.001
)
num_iterations = num_samples // batch_size - 1
callback = KerasCoreBenchmarkMetricsCallback(stop_batch=num_iterations)
tf_keras_callback = TFKerasBenchmarkMetricsCallback(
stop_batch=num_iterations
)
self._keras_model.fit(
data,
label,
batch_size=batch_size,
callbacks=[callback],
)
self._tf_keras_model.fit(
data,
label,
batch_size=batch_size,
callbacks=[tf_keras_callback],
)
keras_throughput = callback._callback.state["throughput"] * batch_size
tf_keras_throughput = (
tf_keras_callback._callback.state["throughput"] * batch_size
)
print(
f"Keras 3 throughput of forward & backward pass of "
f"{self.layer_name}: {keras_throughput:.2f} samples/sec."
)
print(
f"TF Keras throughput of forward & backward pass of "
f"{self.layer_name}: {tf_keras_throughput:.2f} samples/sec."
)
================================================
FILE: benchmarks/layer_benchmark/conv_benchmark.py
================================================
"""Benchmark conv layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.conv_benchmark \
--benchmark_name=benchmark_conv2D \
--num_samples=2046 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_conv1D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Conv1D"
init_args = {
"filters": 64,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[1024, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_conv2D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Conv2D"
init_args = {
"filters": 16,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[128, 128, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_conv3D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Conv3D"
init_args = {
"filters": 16,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 32, 32, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_depthwise_conv1D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "DepthwiseConv1D"
init_args = {
"kernel_size": 16,
"depth_multiplier": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 64],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_depthwise_conv2D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "DepthwiseConv2D"
init_args = {
"kernel_size": 16,
"depth_multiplier": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[128, 128, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_separable_conv1D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "SeparableConv1D"
init_args = {
"kernel_size": 16,
"depth_multiplier": 2,
"filters": 3,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 64],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_separable_conv2D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "SeparableConv2D"
init_args = {
"kernel_size": 16,
"depth_multiplier": 2,
"filters": 3,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[128, 128, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_conv1D_transpose(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Conv1DTranspose"
init_args = {
"filters": 32,
"kernel_size": 4,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_conv2D_transpose(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Conv2DTranspose"
init_args = {
"filters": 16,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[128, 128, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_conv3D_transpose(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Conv3DTranspose"
init_args = {
"filters": 16,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 32, 32, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_conv1D": benchmark_conv1D,
"benchmark_conv2D": benchmark_conv2D,
"benchmark_conv3D": benchmark_conv3D,
"benchmark_depthwise_conv1D": benchmark_depthwise_conv1D,
"benchmark_depthwise_conv2D": benchmark_depthwise_conv2D,
"benchmark_separable_conv1D": benchmark_separable_conv1D,
"benchmark_separable_conv2D": benchmark_separable_conv2D,
"benchmark_conv1D_transpose": benchmark_conv1D_transpose,
"benchmark_conv2D_transpose": benchmark_conv2D_transpose,
"benchmark_conv3D_transpose": benchmark_conv3D_transpose,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES:
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/core_benchmark.py
================================================
"""Benchmark core layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.core_benchmark \
--benchmark_name=benchmark_dense \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
import numpy as np
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_dense(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Dense"
init_args = {"units": 256}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_einsum_dense(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "EinsumDense"
init_args = {
"equation": "abc,cd->abd",
"output_shape": (None, 256),
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_embedding(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Embedding"
init_args = {
"input_dim": 128,
"output_dim": 256,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[
256,
],
jit_compile=jit_compile,
)
data = [np.random.randint(30, size=(num_samples, 256))]
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
data=data,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
data=data,
)
BENCHMARK_NAMES = {
"benchmark_dense": benchmark_dense,
"benchmark_einsum_dense": benchmark_einsum_dense,
"benchmark_embedding": benchmark_embedding,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/merge_benchmark.py
================================================
"""Benchmark merge layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.merge_benchmark \
--benchmark_name=benchmark_add \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_add(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Add"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 256], [256, 256]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_average(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Average"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 256], [256, 256]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_concatenate(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Concatenate"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 256], [256, 256]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_dot(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Dot"
init_args = {"axes": [2, 1]}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 32], [32, 64]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_maximum(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Maximum"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 256], [256, 256]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_minimum(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Minimum"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 256], [256, 256]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_multiply(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Multiply"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 64], [256, 64]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_subtract(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Subtract"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[[256, 256], [256, 256]],
flat_call_inputs=False,
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_add": benchmark_add,
"benchmark_average": benchmark_average,
"benchmark_concatenate": benchmark_concatenate,
"benchmark_dot": benchmark_dot,
"benchmark_maximum": benchmark_maximum,
"benchmark_minimum": benchmark_minimum,
"benchmark_multiply": benchmark_multiply,
"benchmark_subtract": benchmark_subtract,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/normalization_benchmark.py
================================================
"""Benchmark normalization layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.normalization_benchmark \
--benchmark_name=benchmark_batch_normalization \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_batch_normalization(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "BatchNormalization"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_group_normalization(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GroupNormalization"
init_args = {
"groups": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_layer_normalization(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "LayerNormalization"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 128, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_unit_normalization(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "UnitNormalization"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 128, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_batch_normalization": benchmark_batch_normalization,
"benchmark_group_normalization": benchmark_group_normalization,
"benchmark_layer_normalization": benchmark_layer_normalization,
"benchmark_unit_normalization": benchmark_unit_normalization,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/pooling_benchmark.py
================================================
"""Benchmark pooling layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.pooling_benchmark \
--benchmark_name=benchmark_max_pooling1d \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_average_pooling1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "AveragePooling1D"
init_args = {
"pool_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[1024, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_average_pooling2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "AveragePooling2D"
init_args = {
"pool_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_average_pooling3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "AveragePooling3D"
init_args = {
"pool_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_max_pooling1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "MaxPooling1D"
init_args = {
"pool_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[1024, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_max_pooling2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "MaxPooling2D"
init_args = {
"pool_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_max_pooling3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "MaxPooling3D"
init_args = {
"pool_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_global_average_pooling1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GlobalAveragePooling1D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[1024, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_global_average_pooling2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GlobalAveragePooling2D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_global_average_pooling3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GlobalAveragePooling3D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_global_max_pooling1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GlobalMaxPooling1D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[1024, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_global_max_pooling2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GlobalMaxPooling2D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_global_max_pooling3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GlobalMaxPooling3D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[64, 64, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_average_pooling1d": benchmark_average_pooling1d,
"benchmark_average_pooling2d": benchmark_average_pooling2d,
"benchmark_average_pooling3d": benchmark_average_pooling3d,
"benchmark_max_pooling1d": benchmark_max_pooling1d,
"benchmark_max_pooling2d": benchmark_max_pooling2d,
"benchmark_max_pooling3d": benchmark_max_pooling3d,
"benchmark_global_average_pooling1d": benchmark_global_average_pooling1d,
"benchmark_global_average_pooling2d": benchmark_global_average_pooling2d,
"benchmark_global_average_pooling3d": benchmark_global_average_pooling3d,
"benchmark_global_max_pooling1d": benchmark_global_max_pooling1d,
"benchmark_global_max_pooling2d": benchmark_global_max_pooling2d,
"benchmark_global_max_pooling3d": benchmark_global_max_pooling3d,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/random_rotation_benchmark.py
================================================
"""Benchmark RandomRotation layer."""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_random_rotation(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "RandomRotation"
init_args = {"factor": 0.1}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[224, 224, 3],
jit_compile=jit_compile,
)
# Predict is effectively a no-op for preprocessing layers,
# but we still call it to follow the standard benchmark structure.
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_random_rotation": benchmark_random_rotation,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for benchmark_fn in BENCHMARK_NAMES.values():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, "
f"`benchmark_name` must be one of {BENCHMARK_NAMES.keys()}"
)
BENCHMARK_NAMES[benchmark_name](num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/regularization_benchmark.py
================================================
"""Benchmark regularization layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.regularization_benchmark \
--benchmark_name=benchmark_dropout\
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_dropout(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Dropout"
init_args = {
"rate": 0.5,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_gaussian_dropout(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GaussianDropout"
init_args = {
"rate": 0.5,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_gaussian_noise(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GaussianNoise"
init_args = {
"stddev": 0.5,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 4],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_spatial_dropout1D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "SpatialDropout1D"
init_args = {
"rate": 0.5,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_spatial_dropout2D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "SpatialDropout2D"
init_args = {
"rate": 0.5,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_spatial_dropout3D(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "SpatialDropout3D"
init_args = {
"rate": 0.5,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 32, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_dropout": benchmark_dropout,
"benchmark_gaussian_dropout": benchmark_gaussian_dropout,
"benchmark_gaussian_noise": benchmark_gaussian_noise,
"benchmark_spatial_dropout1D": benchmark_spatial_dropout1D,
"benchmark_spatial_dropout2D": benchmark_spatial_dropout2D,
"benchmark_spatial_dropout3D": benchmark_spatial_dropout3D,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/reshaping_benchmark.py
================================================
"""Benchmark reshaping layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.reshaping_benchmark \
--benchmark_name=benchmark_cropping2d \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
from absl import app
from absl import flags
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_cropping1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Cropping1D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[1024, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_cropping2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Cropping2D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_cropping3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Cropping3D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 32, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_flatten(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Flatten"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_permute(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Permute"
init_args = {
"dims": (2, 1),
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_up_sampling1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "UpSampling1D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_up_sampling2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "UpSampling2D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[128, 128, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_up_sampling3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "UpSampling3D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 16, 16, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_zero_padding1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ZeroPadding1D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_zero_padding2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ZeroPadding2D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_zero_padding3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ZeroPadding3D"
init_args = {}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 32, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_cropping1d": benchmark_cropping1d,
"benchmark_cropping2d": benchmark_cropping2d,
"benchmark_cropping3d": benchmark_cropping3d,
"benchmark_flatten": benchmark_flatten,
"benchmark_permute": benchmark_permute,
"benchmark_up_sampling1d": benchmark_up_sampling1d,
"benchmark_up_sampling2d": benchmark_up_sampling2d,
"benchmark_up_sampling3d": benchmark_up_sampling3d,
"benchmark_zero_padding1d": benchmark_zero_padding1d,
"benchmark_zero_padding2d": benchmark_zero_padding2d,
"benchmark_zero_padding3d": benchmark_zero_padding3d,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/layer_benchmark/rnn_benchmark.py
================================================
"""Benchmark rnn layers.
To run benchmarks, see the following command for an example, please change the
flag to your custom value:
```
python3 -m benchmarks.layer_benchmark.rnn_benchmark \
--benchmark_name=benchmark_lstm \
--num_samples=2048 \
--batch_size=256 \
--jit_compile=True
```
"""
import tensorflow as tf
from absl import app
from absl import flags
import keras
from benchmarks.layer_benchmark.base_benchmark import LayerBenchmark
FLAGS = flags.FLAGS
def benchmark_conv_lstm1d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ConvLSTM1D"
init_args = {
"filters": 16,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 256, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_conv_lstm2d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ConvLSTM2D"
init_args = {
"filters": 16,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[32, 32, 32, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_conv_lstm3d(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "ConvLSTM3D"
init_args = {
"filters": 8,
"kernel_size": 2,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[8, 16, 16, 16, 3],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_gru(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "GRU"
init_args = {
"units": 32,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_lstm(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "LSTM"
init_args = {
"units": 32,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_simple_rnn(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "SimpleRNN"
init_args = {
"units": 32,
}
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_bidirectional(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "Bidirectional"
init_args = {}
keras_layer = keras.layers.Bidirectional(keras.layers.LSTM(32))
tf_keras_layer = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32))
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[256, 256],
jit_compile=jit_compile,
keras_layer=keras_layer,
tf_keras_layer=tf_keras_layer,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
def benchmark_time_distributed(
num_samples,
batch_size,
jit_compile=True,
):
layer_name = "TimeDistributed"
init_args = {}
keras_layer = keras.layers.TimeDistributed(keras.layers.Conv2D(16, (3, 3)))
tf_keras_layer = tf.keras.layers.TimeDistributed(
tf.keras.layers.Conv2D(16, (3, 3))
)
benchmark = LayerBenchmark(
layer_name,
init_args,
input_shape=[10, 32, 32, 3],
jit_compile=jit_compile,
keras_layer=keras_layer,
tf_keras_layer=tf_keras_layer,
)
benchmark.benchmark_predict(
num_samples=num_samples,
batch_size=batch_size,
)
benchmark.benchmark_train(
num_samples=num_samples,
batch_size=batch_size,
)
BENCHMARK_NAMES = {
"benchmark_conv_lstm1d": benchmark_conv_lstm1d,
"benchmark_conv_lstm2d": benchmark_conv_lstm2d,
"benchmark_conv_lstm3d": benchmark_conv_lstm3d,
"benchmark_gru": benchmark_gru,
"benchmark_lstm": benchmark_lstm,
"benchmark_simple_rnn": benchmark_simple_rnn,
"benchmark_bidirectional": benchmark_bidirectional,
"benchmark_time_distributed": benchmark_time_distributed,
}
def main(_):
benchmark_name = FLAGS.benchmark_name
num_samples = FLAGS.num_samples
batch_size = FLAGS.batch_size
jit_compile = FLAGS.jit_compile
if benchmark_name is None:
for name, benchmark_fn in BENCHMARK_NAMES.items():
benchmark_fn(num_samples, batch_size, jit_compile)
return
if benchmark_name not in BENCHMARK_NAMES:
raise ValueError(
f"Invalid benchmark name: {benchmark_name}, `benchmark_name` must "
f"be one of {BENCHMARK_NAMES.keys()}"
)
benchmark_fn = BENCHMARK_NAMES[benchmark_name]
benchmark_fn(num_samples, batch_size, jit_compile)
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/model_benchmark/__init__.py
================================================
================================================
FILE: benchmarks/model_benchmark/benchmark_utils.py
================================================
import time
import keras
class BenchmarkMetricsCallback(keras.callbacks.Callback):
def __init__(self, start_batch=1, stop_batch=None):
self.start_batch = start_batch
self.stop_batch = stop_batch
# Store the throughput of each epoch.
self.state = {"throughput": []}
def on_train_batch_begin(self, batch, logs=None):
if batch == self.start_batch:
self.state["epoch_begin_time"] = time.time()
def on_train_batch_end(self, batch, logs=None):
if batch == self.stop_batch:
epoch_end_time = time.time()
throughput = (self.stop_batch - self.start_batch + 1) / (
epoch_end_time - self.state["epoch_begin_time"]
)
self.state["throughput"].append(throughput)
================================================
FILE: benchmarks/model_benchmark/bert_benchmark.py
================================================
"""Benchmark BERT model on GLUE/MRPC task.
To run the script, make sure you are in benchmarks/ directory, abd run the
command below:
```
python3 -m model_benchmark.bert_benchmark \
--epochs 2 \
--batch_size 32
```
"""
import time
import keras_nlp
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import app
from absl import flags
from absl import logging
from model_benchmark.benchmark_utils import BenchmarkMetricsCallback
import keras
flags.DEFINE_string("model_size", "small", "The size of model to benchmark.")
flags.DEFINE_string(
"mixed_precision_policy",
"mixed_float16",
"The global precision policy to use, e.g., 'mixed_float16' or 'float32'.",
)
flags.DEFINE_integer("epochs", 2, "The number of epochs.")
flags.DEFINE_integer("batch_size", 8, "Batch Size.")
FLAGS = flags.FLAGS
MODEL_SIZE_MAP = {
"tiny": "bert_tiny_en_uncased",
"small": "bert_small_en_uncased",
"base": "bert_base_en_uncased",
"large": "bert_large_en_uncased",
}
def load_data():
"""Load data.
Load GLUE/MRPC dataset, and convert the dictionary format to
(features, label), where `features` is a tuple of all input sentences.
"""
feature_names = ("sentence1", "sentence2")
def split_features(x):
# GLUE comes with dictionary data, we convert it to a uniform format
# (features, label), where features is a tuple consisting of all
# features. This format is necessary for using KerasNLP preprocessors.
features = tuple([x[name] for name in feature_names])
label = x["label"]
return (features, label)
train_ds, test_ds, validation_ds = tfds.load(
"glue/mrpc",
split=["train", "test", "validation"],
)
train_ds = (
train_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE)
.batch(FLAGS.batch_size)
.prefetch(tf.data.AUTOTUNE)
)
test_ds = (
test_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE)
.batch(FLAGS.batch_size)
.prefetch(tf.data.AUTOTUNE)
)
validation_ds = (
validation_ds.map(split_features, num_parallel_calls=tf.data.AUTOTUNE)
.batch(FLAGS.batch_size)
.prefetch(tf.data.AUTOTUNE)
)
return train_ds, test_ds, validation_ds
def load_model():
if FLAGS.model_size not in MODEL_SIZE_MAP.keys():
raise KeyError(
f"`model_size` must be one of {MODEL_SIZE_MAP.keys()}, but "
f"received {FLAGS.model_size}."
)
return keras_nlp.models.BertClassifier.from_preset(
MODEL_SIZE_MAP[FLAGS.model_size], num_classes=2
)
def main(_):
keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy)
logging.info(
"Benchmarking configs...\n"
"=========================\n"
f"MODEL: BERT {FLAGS.model_size}\n"
f"TASK: glue/mrpc \n"
f"BATCH_SIZE: {FLAGS.batch_size}\n"
f"EPOCHS: {FLAGS.epochs}\n"
"=========================\n"
)
# Load datasets.
train_ds, test_ds, validation_ds = load_data()
# Load the model.
model = load_model()
# Set loss and metrics.
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = [keras.metrics.SparseCategoricalAccuracy()]
# Configure optimizer.
lr = keras.optimizers.schedules.PolynomialDecay(
5e-4,
decay_steps=train_ds.cardinality() * FLAGS.epochs,
end_learning_rate=0.0,
)
optimizer = keras.optimizers.AdamW(lr, weight_decay=0.01)
optimizer.exclude_from_weight_decay(
var_names=["LayerNorm", "layer_norm", "bias"]
)
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
benchmark_metrics_callback = BenchmarkMetricsCallback(
start_batch=1,
stop_batch=train_ds.cardinality().numpy() - 1,
)
# Start training.
logging.info("Starting Training...")
st = time.time()
history = model.fit(
train_ds,
validation_data=validation_ds,
epochs=FLAGS.epochs,
callbacks=[benchmark_metrics_callback],
)
wall_time = time.time() - st
validation_accuracy = history.history["val_sparse_categorical_accuracy"][-1]
examples_per_second = (
np.mean(np.array(benchmark_metrics_callback.state["throughput"]))
* FLAGS.batch_size
)
logging.info("Training Finished!")
logging.info(f"Wall Time: {wall_time:.4f} seconds.")
logging.info(f"Validation Accuracy: {validation_accuracy:.4f}")
logging.info(f"examples_per_second: {examples_per_second:.4f}")
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/model_benchmark/image_classification_benchmark.py
================================================
"""Image classification benchmark.
This script runs image classification benchmark with "dogs vs cats" datasets.
It supports the following 3 models:
- EfficientNetV2B0
- Xception
- ResNet50V2
To run the benchmark, make sure you are in model_benchmark/ directory, and run
the command below:
python3 -m model_benchmark.image_classification_benchmark \
--model="EfficientNetV2B0" \
--epochs=2 \
--batch_size=32 \
--mixed_precision_policy="mixed_float16"
"""
import time
import numpy as np
import tensorflow as tf
import tensorflow_datasets as tfds
from absl import app
from absl import flags
from absl import logging
from model_benchmark.benchmark_utils import BenchmarkMetricsCallback
import keras
flags.DEFINE_string("model", "EfficientNetV2B0", "The model to benchmark.")
flags.DEFINE_integer("epochs", 1, "The number of epochs.")
flags.DEFINE_integer("batch_size", 4, "Batch Size.")
flags.DEFINE_string(
"mixed_precision_policy",
"mixed_float16",
"The global precision policy to use, e.g., 'mixed_float16' or 'float32'.",
)
FLAGS = flags.FLAGS
BATCH_SIZE = 32
IMAGE_SIZE = (224, 224)
CHANNELS = 3
MODEL_MAP = {
"EfficientNetV2B0": keras.applications.EfficientNetV2B0,
"Xception": keras.applications.Xception,
"ResNet50V2": keras.applications.ResNet50V2,
}
def load_data():
# Load cats vs dogs dataset, and split into train and validation sets.
train_dataset, val_dataset = tfds.load(
"cats_vs_dogs", split=["train[:90%]", "train[90%:]"], as_supervised=True
)
resizing = keras.layers.Resizing(
IMAGE_SIZE[0], IMAGE_SIZE[1], crop_to_aspect_ratio=True
)
def preprocess_inputs(image, label):
image = tf.cast(image, "float32")
return resizing(image), label
train_dataset = (
train_dataset.map(
preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE
)
.batch(FLAGS.batch_size)
.prefetch(tf.data.AUTOTUNE)
)
val_dataset = (
val_dataset.map(preprocess_inputs, num_parallel_calls=tf.data.AUTOTUNE)
.batch(FLAGS.batch_size)
.cache()
.prefetch(tf.data.AUTOTUNE)
)
return train_dataset, val_dataset
def load_model():
model_class = MODEL_MAP[FLAGS.model]
# Load the EfficientNetV2B0 model and add a classification head.
model = model_class(include_top=False, weights="imagenet")
classifier = keras.models.Sequential(
[
keras.Input([IMAGE_SIZE[0], IMAGE_SIZE[1], CHANNELS]),
model,
keras.layers.GlobalAveragePooling2D(),
keras.layers.Dense(2),
]
)
return classifier
def main(_):
keras.mixed_precision.set_dtype_policy(FLAGS.mixed_precision_policy)
logging.info(
"Benchmarking configs...\n"
"=========================\n"
f"MODEL: {FLAGS.model}\n"
f"TASK: image classification/dogs-vs-cats \n"
f"BATCH_SIZE: {FLAGS.batch_size}\n"
f"EPOCHS: {FLAGS.epochs}\n"
"=========================\n"
)
# Load datasets.
train_ds, validation_ds = load_data()
# Load the model.
classifier = load_model()
lr = keras.optimizers.schedules.PolynomialDecay(
5e-4,
decay_steps=train_ds.cardinality() * FLAGS.epochs,
end_learning_rate=0.0,
)
optimizer = keras.optimizers.Adam(lr)
loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
benchmark_metrics_callback = BenchmarkMetricsCallback(
start_batch=1,
stop_batch=train_ds.cardinality().numpy() - 1,
)
classifier.compile(
optimizer=optimizer,
loss=loss,
metrics=["sparse_categorical_accuracy"],
)
# Start training.
logging.info("Starting Training...")
st = time.time()
history = classifier.fit(
train_ds,
validation_data=validation_ds,
epochs=FLAGS.epochs,
callbacks=[benchmark_metrics_callback],
)
wall_time = time.time() - st
validation_accuracy = history.history["val_sparse_categorical_accuracy"][-1]
examples_per_second = (
np.mean(np.array(benchmark_metrics_callback.state["throughput"]))
* FLAGS.batch_size
)
logging.info("Training Finished!")
logging.info(f"Wall Time: {wall_time:.4f} seconds.")
logging.info(f"Validation Accuracy: {validation_accuracy:.4f}")
logging.info(f"examples_per_second: {examples_per_second:.4f}")
if __name__ == "__main__":
app.run(main)
================================================
FILE: benchmarks/torch_ctl_benchmark/README.md
================================================
# Benchmark the performance of torch custom training loop
This directory contains benchmarks to compare the performance of a Keras model
and a equivalent Torch model while using the same Torch custom training loop.
The benchmark purpose is to understand the performance diff resulting from the
modeling API choice (Keras or Torch).
To run the benchmark, use the command below and change to your target:
```shell
python3 -m benchmarks.torch_ctl_benchmark.conv_model_benchmark
```
================================================
FILE: benchmarks/torch_ctl_benchmark/__init__.py
================================================
================================================
FILE: benchmarks/torch_ctl_benchmark/benchmark_utils.py
================================================
import time
import numpy as np
import torch
def train_loop(model, train_loader, num_epochs, optimizer, loss_fn, framework):
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
start = None
average_batch_time_per_epoch = []
for _ in range(num_epochs):
running_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(train_loader):
if batch_idx == 1:
start = time.time()
inputs = inputs.to(device)
targets = targets.to(device)
# Forward pass
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
end = time.time()
average_batch_time_per_epoch.append(
(end - start) / (len(train_loader) - 1)
)
average_time = np.mean(average_batch_time_per_epoch)
print(f"Time per batch in {framework}: {average_time:.2f}")
================================================
FILE: benchmarks/torch_ctl_benchmark/conv_model_benchmark.py
================================================
"""Benchmark Keras performance with torch custom training loop.
In this file we use a convolution model. Training loop is written in the
vanilla torch way, and we compare the performance between building model with
Keras and torch.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import keras
from benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop
from keras import layers
num_classes = 2
input_shape = (3, 256, 256)
batch_size = 128
num_batches = 20
num_epochs = 1
x_train = np.random.normal(
size=(num_batches * batch_size, *input_shape)
).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(num_batches * batch_size,))
# Create a TensorDataset
dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train)
)
# Create a DataLoader
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False
)
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.Conv2d(3, 32, kernel_size=(3, 3))
self.activation = torch.nn.ReLU()
self.max_pool = torch.nn.MaxPool2d((2, 2))
self.flatten = torch.nn.Flatten()
self.dense = torch.nn.LazyLinear(num_classes)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, x):
x = self.conv(x)
x = self.activation(x)
x = self.max_pool(x)
x = self.flatten(x)
x = self.dense(x)
x = self.softmax(x)
return x
def run_keras_custom_training_loop():
keras_model = keras.Sequential(
[
layers.Input(shape=input_shape),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dense(num_classes),
layers.Softmax(),
]
)
optimizer = optim.Adam(keras_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
train_loop(
keras_model,
train_loader,
num_epochs=num_epochs,
optimizer=optimizer,
loss_fn=loss_fn,
framework="keras",
)
def run_torch_custom_training_loop():
torch_model = TorchModel()
optimizer = optim.Adam(torch_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
train_loop(
torch_model,
train_loader,
num_epochs=num_epochs,
optimizer=optimizer,
loss_fn=loss_fn,
framework="torch",
)
if __name__ == "__main__":
run_keras_custom_training_loop()
run_torch_custom_training_loop()
================================================
FILE: benchmarks/torch_ctl_benchmark/dense_model_benchmark.py
================================================
"""Benchmark Keras performance with torch custom training loop.
In this file we use a model with 3 dense layers. Training loop is written in the
vanilla torch way, and we compare the performance between building model with
Keras and torch.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import keras
from benchmarks.torch_ctl_benchmark.benchmark_utils import train_loop
from keras import layers
num_classes = 2
input_shape = (8192,)
batch_size = 4096
num_batches = 20
num_epochs = 1
x_train = np.random.normal(
size=(num_batches * batch_size, *input_shape)
).astype(np.float32)
y_train = np.random.randint(0, num_classes, size=(num_batches * batch_size,))
# Create a TensorDataset
dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train)
)
# Create a DataLoader
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False
)
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.dense1 = torch.nn.Linear(8192, 64)
self.activation1 = torch.nn.ReLU()
self.dense2 = torch.nn.Linear(64, 8)
self.activation2 = torch.nn.ReLU()
self.dense3 = torch.nn.Linear(8, num_classes)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self, x):
x = self.dense1(x)
x = self.activation1(x)
x = self.dense2(x)
x = self.activation2(x)
x = self.dense3(x)
x = self.softmax(x)
return x
def run_keras_custom_training_loop():
keras_model = keras.Sequential(
[
layers.Input(shape=input_shape),
layers.Dense(64, activation="relu"),
layers.Dense(8, activation="relu"),
layers.Dense(num_classes),
layers.Softmax(),
]
)
optimizer = optim.Adam(keras_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
train_loop(
keras_model,
train_loader,
num_epochs=num_epochs,
optimizer=optimizer,
loss_fn=loss_fn,
framework="keras",
)
def run_torch_custom_training_loop():
torch_model = TorchModel()
optimizer = optim.Adam(torch_model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
train_loop(
torch_model,
train_loader,
num_epochs=num_epochs,
optimizer=optimizer,
loss_fn=loss_fn,
framework="torch",
)
if __name__ == "__main__":
run_keras_custom_training_loop()
run_torch_custom_training_loop()
================================================
FILE: codecov.yml
================================================
coverage:
status:
project:
default:
# `auto` compares coverage with the base-commit
target: auto
patch:
default:
target:auto
comment:
layout: "header, reach, diff, flags, files"
behavior: default
require_changes: no
require_base: no
require_head: yes
show_carryforward_flags: yes
flag_management:
default_rules:
carryforward: false
statuses:
- type: project
target: auto
- type: patch
target: auto
individual_flags:
- name: keras
paths:
- keras
- name: keras.applications
paths:
- keras/applications
carryforward: true
================================================
FILE: conftest.py
================================================
try:
# When using torch and tensorflow, torch needs to be imported first,
# otherwise it will segfault upon import. This should force the torch
# import to happen first for all tests.
import torch # noqa: F401
except ImportError:
torch = None
import pytest # noqa: E402
from keras.src.backend import backend # noqa: E402
def pytest_configure(config):
config.addinivalue_line(
"markers",
"requires_trainable_backend: mark test for trainable backend only",
)
def pytest_collection_modifyitems(config, items):
openvino_skipped_tests = []
if backend() == "openvino":
with open(
"keras/src/backend/openvino/excluded_concrete_tests.txt", "r"
) as file:
openvino_skipped_tests = file.readlines()
# it is necessary to check if stripped line is not empty
# and exclude such lines
openvino_skipped_tests = [
line.strip() for line in openvino_skipped_tests if line.strip()
]
tpu_skipped_tests = []
if backend() == "jax":
import jax
if jax.default_backend() == "tpu":
with open(
"keras/src/backend/jax/excluded_tpu_tests.txt", "r"
) as file:
tpu_skipped_tests = file.readlines()
# it is necessary to check if stripped line is not empty
# and exclude such lines
tpu_skipped_tests = [
line.strip() for line in tpu_skipped_tests if line.strip()
]
requires_trainable_backend = pytest.mark.skipif(
backend() in ["numpy", "openvino"],
reason="Trainer not implemented for NumPy and OpenVINO backend.",
)
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)
# also, skip concrete tests for openvino, listed in the special file
# this is more granular mechanism to exclude tests rather
# than using --ignore option
for skipped_test in openvino_skipped_tests:
if skipped_test in item.nodeid:
item.add_marker(
skip_if_backend(
"openvino",
"Not supported operation by openvino backend",
)
)
# also, skip concrete tests for TPU when using JAX backend
for skipped_test in tpu_skipped_tests:
if skipped_test in item.nodeid:
item.add_marker(
pytest.mark.skip(
reason="Known TPU test failure",
)
)
def skip_if_backend(given_backend, reason):
return pytest.mark.skipif(backend() == given_backend, reason=reason)
================================================
FILE: examples/demo_custom_jax_workflow.py
================================================
# flake8: noqa
import os
# Set backend env to JAX
os.environ["KERAS_BACKEND"] = "jax"
import jax
import numpy as np
from keras import Model
from keras import backend
from keras import initializers
from keras import layers
from keras import ops
from keras import optimizers
class MyDense(layers.Layer):
def __init__(self, units, name=None):
super().__init__(name=name)
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
w_shape = (input_dim, self.units)
w_value = initializers.GlorotUniform()(w_shape)
self.w = backend.Variable(w_value, name="kernel")
b_shape = (self.units,)
b_value = initializers.Zeros()(b_shape)
self.b = backend.Variable(b_value, name="bias")
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
class MyModel(Model):
def __init__(self, hidden_dim, output_dim):
super().__init__()
self.dense1 = MyDense(hidden_dim)
self.dense2 = MyDense(hidden_dim)
self.dense3 = MyDense(output_dim)
def call(self, x):
x = jax.nn.relu(self.dense1(x))
x = jax.nn.relu(self.dense2(x))
return self.dense3(x)
def Dataset():
for _ in range(20):
yield (np.random.random((32, 128)), np.random.random((32, 4)))
def loss_fn(y_true, y_pred):
return ops.sum((y_true - y_pred) ** 2)
model = MyModel(hidden_dim=256, output_dim=4)
optimizer = optimizers.SGD(learning_rate=0.001)
dataset = Dataset()
# Build model
x = np.random.random((1, 128))
model(x)
# Build optimizer
optimizer.build(model.trainable_variables)
######### Custom JAX workflow ###############
def compute_loss_and_updates(
trainable_variables, non_trainable_variables, x, y
):
y_pred, non_trainable_variables = model.stateless_call(
trainable_variables, non_trainable_variables, x
)
loss = loss_fn(y, y_pred)
return loss, non_trainable_variables
grad_fn = jax.value_and_grad(compute_loss_and_updates, has_aux=True)
@jax.jit
def train_step(state, data):
trainable_variables, non_trainable_variables, optimizer_variables = state
x, y = data
(loss, non_trainable_variables), grads = grad_fn(
trainable_variables, non_trainable_variables, x, y
)
trainable_variables, optimizer_variables = optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)
# Return updated state
return loss, (
trainable_variables,
non_trainable_variables,
optimizer_variables,
)
trainable_variables = model.trainable_variables
non_trainable_variables = model.non_trainable_variables
optimizer_variables = optimizer.variables
state = trainable_variables, non_trainable_variables, optimizer_variables
# Training loop
for data in dataset:
loss, state = train_step(state, data)
print("Loss:", loss)
# Post-processing model state update
trainable_variables, non_trainable_variables, optimizer_variables = state
for variable, value in zip(model.trainable_variables, trainable_variables):
variable.assign(value)
for variable, value in zip(
model.non_trainable_variables, non_trainable_variables
):
variable.assign(value)
================================================
FILE: examples/demo_custom_layer_backend_agnostic.py
================================================
import numpy as np
import keras
from keras import Model
from keras import initializers
from keras import layers
from keras import losses
from keras import metrics
from keras import ops
from keras import optimizers
class MyDense(layers.Layer):
def __init__(self, units, name=None):
super().__init__(name=name)
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
self.w = self.add_weight(
shape=(input_dim, self.units),
initializer=initializers.GlorotNormal(),
name="kernel",
trainable=True,
)
self.b = self.add_weight(
shape=(self.units,),
initializer=initializers.Zeros(),
name="bias",
trainable=True,
)
def call(self, inputs):
# Use Keras ops to create backend-agnostic layers/metrics/etc.
return ops.matmul(inputs, self.w) + self.b
class MyDropout(layers.Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
# Use seed_generator for managing RNG state.
# It is a state element and its seed variable is
# tracked as part of `layer.variables`.
self.seed_generator = keras.random.SeedGenerator(1337)
def call(self, inputs):
# Use `keras.random` for random ops.
return keras.random.dropout(inputs, self.rate, seed=self.seed_generator)
class MyModel(Model):
def __init__(self, hidden_dim, output_dim):
super().__init__()
self.dense1 = MyDense(hidden_dim)
self.dense2 = MyDense(hidden_dim)
self.dense3 = MyDense(output_dim)
self.dp = MyDropout(0.5)
def call(self, x):
x1 = self.dense1(x)
x2 = self.dense2(x)
# Why not use some ops here as well
x = ops.concatenate([x1, x2], axis=-1)
x = self.dp(x)
return self.dense3(x)
model = MyModel(hidden_dim=256, output_dim=16)
x = np.random.random((50000, 128))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 5
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
model.summary()
print("History:")
print(history.history)
================================================
FILE: examples/demo_custom_tf_workflow.py
================================================
# flake8: noqa
import os
# Set backend env to tensorflow
os.environ["KERAS_BACKEND"] = "tensorflow"
import numpy as np
import tensorflow as tf
from keras import Model
from keras import backend
from keras import initializers
from keras import layers
from keras import ops
from keras import optimizers
class MyDense(layers.Layer):
def __init__(self, units, name=None):
super().__init__(name=name)
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
w_shape = (input_dim, self.units)
w_value = initializers.GlorotUniform()(w_shape)
self.w = backend.Variable(w_value, name="kernel")
b_shape = (self.units,)
b_value = initializers.Zeros()(b_shape)
self.b = backend.Variable(b_value, name="bias")
def call(self, inputs):
return ops.matmul(inputs, self.w) + self.b
class MyModel(Model):
def __init__(self, hidden_dim, output_dim):
super().__init__()
self.dense1 = MyDense(hidden_dim)
self.dense2 = MyDense(hidden_dim)
self.dense3 = MyDense(output_dim)
def call(self, x):
x = tf.nn.relu(self.dense1(x))
x = tf.nn.relu(self.dense2(x))
return self.dense3(x)
def Dataset():
for _ in range(20):
yield (
np.random.random((32, 128)).astype("float32"),
np.random.random((32, 4)).astype("float32"),
)
def loss_fn(y_true, y_pred):
return ops.sum((y_true - y_pred) ** 2)
model = MyModel(hidden_dim=256, output_dim=4)
optimizer = optimizers.SGD(learning_rate=0.001)
dataset = Dataset()
######### Custom TF workflow ###############
@tf.function(jit_compile=True)
def train_step(data):
x, y = data
with tf.GradientTape() as tape:
y_pred = model(x)
loss = loss_fn(y, y_pred)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
return loss
for data in dataset:
loss = train_step(data)
print("Loss:", float(loss))
================================================
FILE: examples/demo_custom_torch_workflow.py
================================================
# flake8: noqa
import os
# Set backend env to torch
os.environ["KERAS_BACKEND"] = "torch"
import torch
import torch.nn as nn
import torch.optim as optim
from keras import layers
import keras
import numpy as np
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
learning_rate = 0.01
batch_size = 64
num_epochs = 1
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")
# Create the Keras model
model = keras.Sequential(
[
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes),
]
)
#################################################################
######## Writing a torch training loop for a Keras model ########
#################################################################
# Instantiate the torch optimizer
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Instantiate the torch loss function
loss_fn = nn.CrossEntropyLoss()
def train(model, train_loader, num_epochs, optimizer, loss_fn):
for epoch in range(num_epochs):
running_loss = 0.0
for batch_idx, (inputs, targets) in enumerate(train_loader):
# Forward pass
outputs = model(inputs)
loss = loss_fn(outputs, targets)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
# Print loss statistics
if (batch_idx + 1) % 10 == 0:
print(
f"Epoch [{epoch + 1}/{num_epochs}], "
f"Batch [{batch_idx + 1}/{len(train_loader)}], "
f"Loss: {running_loss / 10}"
)
running_loss = 0.0
# Create a TensorDataset
dataset = torch.utils.data.TensorDataset(
torch.from_numpy(x_train), torch.from_numpy(y_train)
)
# Create a DataLoader
train_loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=False
)
train(model, train_loader, num_epochs, optimizer, loss_fn)
################################################################
######## Using a Keras model or layer in a torch Module ########
################################################################
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.model = keras.Sequential(
[
layers.Input(shape=(28, 28, 1)),
layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
layers.MaxPooling2D(pool_size=(2, 2)),
layers.Flatten(),
layers.Dropout(0.5),
layers.Dense(num_classes),
]
)
def forward(self, x):
return self.model(x)
torch_module = MyModel()
# Instantiate the torch optimizer
optimizer = optim.Adam(torch_module.parameters(), lr=learning_rate)
# Instantiate the torch loss function
loss_fn = nn.CrossEntropyLoss()
train(torch_module, train_loader, num_epochs, optimizer, loss_fn)
================================================
FILE: examples/demo_functional.py
================================================
import numpy as np
from keras import Model
from keras import layers
from keras import losses
from keras import metrics
from keras import optimizers
import keras
keras.config.disable_traceback_filtering()
inputs = layers.Input((100,))
x = layers.Dense(512, activation="relu")(inputs)
residual = x
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
x = layers.Dense(512, activation="relu")(x)
residual = x
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
residual = x
x = layers.Dense(512, activation="relu")(x)
x = layers.Dense(512, activation="relu")(x)
x += residual
outputs = layers.Dense(16)(x)
model = Model(inputs, outputs)
model.summary()
x = np.random.random((50000, 100))
y = np.random.random((50000, 16))
batch_size = 32
epochs = 5
model.compile(
optimizer=optimizers.Adam(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[
metrics.CategoricalAccuracy(name="acc"),
metrics.MeanSquaredError(name="mse"),
],
)
print("\nTrain model")
history = model.fit(
x, y, batch_size=batch_size, epochs=epochs, validation_split=0.2
)
print("\nHistory:")
print(history.history)
print("\nEvaluate model")
scores = model.evaluate(x, y, return_dict=True)
print(scores)
print("\nRun inference")
pred = model.predict(x)
print(f"Inferred output shape {pred.shape}")
================================================
FILE: examples/demo_jax_distributed.py
================================================
# To run this demo, you will need to spin up a "TPU VM" on Google Cloud.
# Please follow instructions here: https://cloud.google.com/tpu/docs/run-calculation-jax
# Force a JAX backend
import os, pprint, collections
os.environ["KERAS_BACKEND"] = "jax"
pp = pprint.PrettyPrinter()
import jax
import jax.numpy as jnp
import tensorflow as tf # just for tf.data
import keras # Keras multi-backend
import numpy as np
from tqdm import tqdm
from jax.experimental import mesh_utils
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P
""" Dataset
Classic MNIST, loaded using tf.data
"""
BATCH_SIZE = 192
(
(x_train, train_labels),
(x_eval, eval_labels),
) = keras.datasets.mnist.load_data()
x_train = np.expand_dims(x_train, axis=-1).astype(
np.float32
) # from 28x28 to 28x28 x 1 color channel (B&W)
x_eval = np.expand_dims(x_eval, axis=-1).astype(np.float32)
train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)
train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)
train_data = train_data.repeat()
eval_data = tf.data.Dataset.from_tensor_slices((x_eval, eval_labels))
eval_data = eval_data.batch(10000) # everything as one batch
STEPS_PER_EPOCH = len(train_labels) // BATCH_SIZE
""" Keras model
Simple but non-trivial model with:
* Batch Normalization (non-trainable state updated during training, different training-time and inference behavior)
* Dropout (randomness, different training time and inference behavior)
"""
# Keras "sequential" model building style
def make_backbone():
return keras.Sequential(
[
keras.layers.Rescaling(
1.0 / 255.0
), # input images are in the range [0, 255]
keras.layers.Conv2D(
filters=12, kernel_size=3, padding="same", use_bias=False
),
keras.layers.BatchNormalization(scale=False, center=True),
keras.layers.Activation("relu"),
keras.layers.Conv2D(
filters=24,
kernel_size=6,
padding="same",
use_bias=False,
strides=2,
),
keras.layers.BatchNormalization(scale=False, center=True),
keras.layers.Activation("relu"),
keras.layers.Conv2D(
filters=32,
kernel_size=6,
padding="same",
use_bias=False,
strides=2,
name="large_k",
),
keras.layers.BatchNormalization(scale=False, center=True),
keras.layers.Activation("relu"),
],
name="backbone",
)
def make_model():
input = keras.Input(shape=[28, 28, 1])
y = make_backbone()(input)
y = keras.layers.Flatten()(y)
y = keras.layers.Dense(200, activation="relu")(y)
y = keras.layers.Dropout(0.4)(y)
y = keras.layers.Dense(10, activation="softmax")(y)
model = keras.Model(inputs=input, outputs=y)
return model
""" JAX-native distribution with a Keras model
For now, you have to write a custom training loop for this
Note: The features required by jax.sharding are not supported by the Colab TPU
runtime at this time, but are available on Cloud TPU VMs and Kaggle TPU VMs.
"""
if len(jax.local_devices()) < 8:
raise Exc
gitextract_qha1vuxj/
├── .devcontainer/
│ ├── README.md
│ ├── devcontainer.json
│ └── setup.sh
├── .gemini/
│ ├── config.yaml
│ └── styleguide.md
├── .github/
│ ├── dependabot.yml
│ └── workflows/
│ ├── actions.yml
│ ├── auto-assignment.yaml
│ ├── config/
│ │ ├── jax/
│ │ │ └── keras.json
│ │ ├── numpy/
│ │ │ └── keras.json
│ │ ├── openvino/
│ │ │ └── keras.json
│ │ ├── tensorflow/
│ │ │ └── keras.json
│ │ └── torch/
│ │ └── keras.json
│ ├── gpu_tests.yml
│ ├── labeler.yaml
│ ├── nightly.yml
│ ├── scorecard.yml
│ ├── scripts/
│ │ ├── auto-assignment.js
│ │ └── labeler.js
│ ├── stale-issue-pr.yaml
│ └── tpu_tests.yml
├── .gitignore
├── .kokoro/
│ ├── README.md
│ └── github/
│ └── ubuntu/
│ └── gpu/
│ ├── build.sh
│ ├── jax/
│ │ ├── continuous.cfg
│ │ └── presubmit.cfg
│ └── tensorflow/
│ ├── continuous.cfg
│ └── presubmit.cfg
├── .pre-commit-config.yaml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── SECURITY.md
├── api_gen.py
├── benchmarks/
│ ├── __init__.py
│ ├── layer_benchmark/
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── activation_benchmark.py
│ │ ├── attention_benchmark.py
│ │ ├── base_benchmark.py
│ │ ├── conv_benchmark.py
│ │ ├── core_benchmark.py
│ │ ├── merge_benchmark.py
│ │ ├── normalization_benchmark.py
│ │ ├── pooling_benchmark.py
│ │ ├── random_rotation_benchmark.py
│ │ ├── regularization_benchmark.py
│ │ ├── reshaping_benchmark.py
│ │ └── rnn_benchmark.py
│ ├── model_benchmark/
│ │ ├── __init__.py
│ │ ├── benchmark_utils.py
│ │ ├── bert_benchmark.py
│ │ └── image_classification_benchmark.py
│ └── torch_ctl_benchmark/
│ ├── README.md
│ ├── __init__.py
│ ├── benchmark_utils.py
│ ├── conv_model_benchmark.py
│ └── dense_model_benchmark.py
├── codecov.yml
├── conftest.py
├── examples/
│ ├── demo_custom_jax_workflow.py
│ ├── demo_custom_layer_backend_agnostic.py
│ ├── demo_custom_tf_workflow.py
│ ├── demo_custom_torch_workflow.py
│ ├── demo_functional.py
│ ├── demo_jax_distributed.py
│ ├── demo_mnist_convnet.py
│ ├── demo_subclass.py
│ └── demo_torch_multi_gpu.py
├── guides/
│ ├── custom_train_step_in_jax.py
│ ├── custom_train_step_in_tensorflow.py
│ ├── custom_train_step_in_torch.py
│ ├── distributed_training_with_jax.py
│ ├── distributed_training_with_tensorflow.py
│ ├── distributed_training_with_torch.py
│ ├── functional_api.py
│ ├── making_new_layers_and_models_via_subclassing.py
│ ├── sequential_model.py
│ ├── training_with_built_in_methods.py
│ ├── transfer_learning.py
│ ├── understanding_masking_and_padding.py
│ ├── writing_a_custom_training_loop_in_jax.py
│ ├── writing_a_custom_training_loop_in_tensorflow.py
│ ├── writing_a_custom_training_loop_in_torch.py
│ └── writing_your_own_callbacks.py
├── integration_tests/
│ ├── basic_full_flow.py
│ ├── dataset_tests/
│ │ ├── boston_housing_test.py
│ │ ├── california_housing_test.py
│ │ ├── cifar100_test.py
│ │ ├── cifar10_test.py
│ │ ├── fashion_mnist_test.py
│ │ ├── imdb_test.py
│ │ ├── mnist_test.py
│ │ └── reuters_test.py
│ ├── import_test.py
│ ├── jax_custom_fit_test.py
│ ├── model_visualization_test.py
│ ├── numerical_test.py
│ ├── pytorch_export_test.py
│ ├── tf_custom_fit_test.py
│ ├── tf_distribute_training_test.py
│ ├── torch_custom_fit_test.py
│ └── torch_workflow_test.py
├── keras/
│ ├── __init__.py
│ ├── api/
│ │ ├── __init__.py
│ │ ├── _tf_keras/
│ │ │ ├── __init__.py
│ │ │ └── keras/
│ │ │ ├── __init__.py
│ │ │ ├── activations/
│ │ │ │ └── __init__.py
│ │ │ ├── applications/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── convnext/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── densenet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── efficientnet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── efficientnet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── imagenet_utils/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── inception_resnet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── inception_v3/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mobilenet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mobilenet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mobilenet_v3/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── nasnet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── resnet/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── resnet50/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── resnet_v2/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── vgg16/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── vgg19/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── xception/
│ │ │ │ └── __init__.py
│ │ │ ├── backend/
│ │ │ │ └── __init__.py
│ │ │ ├── callbacks/
│ │ │ │ └── __init__.py
│ │ │ ├── config/
│ │ │ │ └── __init__.py
│ │ │ ├── constraints/
│ │ │ │ └── __init__.py
│ │ │ ├── datasets/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── boston_housing/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── california_housing/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── cifar10/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── cifar100/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── fashion_mnist/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── imdb/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── mnist/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── reuters/
│ │ │ │ └── __init__.py
│ │ │ ├── distillation/
│ │ │ │ └── __init__.py
│ │ │ ├── distribution/
│ │ │ │ └── __init__.py
│ │ │ ├── dtype_policies/
│ │ │ │ └── __init__.py
│ │ │ ├── export/
│ │ │ │ └── __init__.py
│ │ │ ├── initializers/
│ │ │ │ └── __init__.py
│ │ │ ├── layers/
│ │ │ │ └── __init__.py
│ │ │ ├── legacy/
│ │ │ │ ├── __init__.py
│ │ │ │ └── saving/
│ │ │ │ └── __init__.py
│ │ │ ├── losses/
│ │ │ │ └── __init__.py
│ │ │ ├── metrics/
│ │ │ │ └── __init__.py
│ │ │ ├── mixed_precision/
│ │ │ │ └── __init__.py
│ │ │ ├── models/
│ │ │ │ └── __init__.py
│ │ │ ├── ops/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── image/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── linalg/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── nn/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── numpy/
│ │ │ │ └── __init__.py
│ │ │ ├── optimizers/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── legacy/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── schedules/
│ │ │ │ └── __init__.py
│ │ │ ├── preprocessing/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── image/
│ │ │ │ │ └── __init__.py
│ │ │ │ ├── sequence/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── text/
│ │ │ │ └── __init__.py
│ │ │ ├── quantizers/
│ │ │ │ └── __init__.py
│ │ │ ├── random/
│ │ │ │ └── __init__.py
│ │ │ ├── regularizers/
│ │ │ │ └── __init__.py
│ │ │ ├── saving/
│ │ │ │ └── __init__.py
│ │ │ ├── tree/
│ │ │ │ └── __init__.py
│ │ │ ├── utils/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── bounding_boxes/
│ │ │ │ │ └── __init__.py
│ │ │ │ └── legacy/
│ │ │ │ └── __init__.py
│ │ │ ├── visualization/
│ │ │ │ └── __init__.py
│ │ │ └── wrappers/
│ │ │ └── __init__.py
│ │ ├── activations/
│ │ │ └── __init__.py
│ │ ├── applications/
│ │ │ ├── __init__.py
│ │ │ ├── convnext/
│ │ │ │ └── __init__.py
│ │ │ ├── densenet/
│ │ │ │ └── __init__.py
│ │ │ ├── efficientnet/
│ │ │ │ └── __init__.py
│ │ │ ├── efficientnet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── imagenet_utils/
│ │ │ │ └── __init__.py
│ │ │ ├── inception_resnet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── inception_v3/
│ │ │ │ └── __init__.py
│ │ │ ├── mobilenet/
│ │ │ │ └── __init__.py
│ │ │ ├── mobilenet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── mobilenet_v3/
│ │ │ │ └── __init__.py
│ │ │ ├── nasnet/
│ │ │ │ └── __init__.py
│ │ │ ├── resnet/
│ │ │ │ └── __init__.py
│ │ │ ├── resnet50/
│ │ │ │ └── __init__.py
│ │ │ ├── resnet_v2/
│ │ │ │ └── __init__.py
│ │ │ ├── vgg16/
│ │ │ │ └── __init__.py
│ │ │ ├── vgg19/
│ │ │ │ └── __init__.py
│ │ │ └── xception/
│ │ │ └── __init__.py
│ │ ├── backend/
│ │ │ └── __init__.py
│ │ ├── callbacks/
│ │ │ └── __init__.py
│ │ ├── config/
│ │ │ └── __init__.py
│ │ ├── constraints/
│ │ │ └── __init__.py
│ │ ├── datasets/
│ │ │ ├── __init__.py
│ │ │ ├── boston_housing/
│ │ │ │ └── __init__.py
│ │ │ ├── california_housing/
│ │ │ │ └── __init__.py
│ │ │ ├── cifar10/
│ │ │ │ └── __init__.py
│ │ │ ├── cifar100/
│ │ │ │ └── __init__.py
│ │ │ ├── fashion_mnist/
│ │ │ │ └── __init__.py
│ │ │ ├── imdb/
│ │ │ │ └── __init__.py
│ │ │ ├── mnist/
│ │ │ │ └── __init__.py
│ │ │ └── reuters/
│ │ │ └── __init__.py
│ │ ├── distillation/
│ │ │ └── __init__.py
│ │ ├── distribution/
│ │ │ └── __init__.py
│ │ ├── dtype_policies/
│ │ │ └── __init__.py
│ │ ├── export/
│ │ │ └── __init__.py
│ │ ├── initializers/
│ │ │ └── __init__.py
│ │ ├── layers/
│ │ │ └── __init__.py
│ │ ├── legacy/
│ │ │ ├── __init__.py
│ │ │ └── saving/
│ │ │ └── __init__.py
│ │ ├── losses/
│ │ │ └── __init__.py
│ │ ├── metrics/
│ │ │ └── __init__.py
│ │ ├── mixed_precision/
│ │ │ └── __init__.py
│ │ ├── models/
│ │ │ └── __init__.py
│ │ ├── ops/
│ │ │ ├── __init__.py
│ │ │ ├── image/
│ │ │ │ └── __init__.py
│ │ │ ├── linalg/
│ │ │ │ └── __init__.py
│ │ │ ├── nn/
│ │ │ │ └── __init__.py
│ │ │ └── numpy/
│ │ │ └── __init__.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── legacy/
│ │ │ │ └── __init__.py
│ │ │ └── schedules/
│ │ │ └── __init__.py
│ │ ├── preprocessing/
│ │ │ ├── __init__.py
│ │ │ ├── image/
│ │ │ │ └── __init__.py
│ │ │ └── sequence/
│ │ │ └── __init__.py
│ │ ├── quantizers/
│ │ │ └── __init__.py
│ │ ├── random/
│ │ │ └── __init__.py
│ │ ├── regularizers/
│ │ │ └── __init__.py
│ │ ├── saving/
│ │ │ └── __init__.py
│ │ ├── tree/
│ │ │ └── __init__.py
│ │ ├── utils/
│ │ │ ├── __init__.py
│ │ │ ├── bounding_boxes/
│ │ │ │ └── __init__.py
│ │ │ └── legacy/
│ │ │ └── __init__.py
│ │ ├── visualization/
│ │ │ └── __init__.py
│ │ └── wrappers/
│ │ └── __init__.py
│ └── src/
│ ├── __init__.py
│ ├── activations/
│ │ ├── __init__.py
│ │ ├── activations.py
│ │ └── activations_test.py
│ ├── api_export.py
│ ├── applications/
│ │ ├── __init__.py
│ │ ├── applications_test.py
│ │ ├── convnext.py
│ │ ├── densenet.py
│ │ ├── efficientnet.py
│ │ ├── efficientnet_v2.py
│ │ ├── imagenet_utils.py
│ │ ├── imagenet_utils_test.py
│ │ ├── inception_resnet_v2.py
│ │ ├── inception_v3.py
│ │ ├── mobilenet.py
│ │ ├── mobilenet_v2.py
│ │ ├── mobilenet_v3.py
│ │ ├── nasnet.py
│ │ ├── resnet.py
│ │ ├── resnet_v2.py
│ │ ├── vgg16.py
│ │ ├── vgg19.py
│ │ └── xception.py
│ ├── backend/
│ │ ├── __init__.py
│ │ ├── common/
│ │ │ ├── __init__.py
│ │ │ ├── backend_utils.py
│ │ │ ├── backend_utils_test.py
│ │ │ ├── compute_output_spec_test.py
│ │ │ ├── dtypes.py
│ │ │ ├── dtypes_test.py
│ │ │ ├── global_state.py
│ │ │ ├── global_state_test.py
│ │ │ ├── keras_tensor.py
│ │ │ ├── keras_tensor_test.py
│ │ │ ├── masking.py
│ │ │ ├── masking_test.py
│ │ │ ├── name_scope.py
│ │ │ ├── name_scope_test.py
│ │ │ ├── remat.py
│ │ │ ├── remat_test.py
│ │ │ ├── stateless_scope.py
│ │ │ ├── stateless_scope_test.py
│ │ │ ├── symbolic_scope.py
│ │ │ ├── symbolic_scope_test.py
│ │ │ ├── tensor_attributes.py
│ │ │ ├── thread_safe_test.py
│ │ │ ├── variables.py
│ │ │ └── variables_test.py
│ │ ├── config.py
│ │ ├── jax/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── core_test.py
│ │ │ ├── distribution_lib.py
│ │ │ ├── distribution_lib_test.py
│ │ │ ├── excluded_tpu_tests.txt
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── optimizer.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ ├── sparse.py
│ │ │ ├── tensorboard.py
│ │ │ ├── trainer.py
│ │ │ └── trainer_test.py
│ │ ├── numpy/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ └── trainer.py
│ │ ├── openvino/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── excluded_concrete_tests.txt
│ │ │ ├── excluded_tests.txt
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ └── trainer.py
│ │ ├── tensorflow/
│ │ │ ├── __init__.py
│ │ │ ├── core.py
│ │ │ ├── distribute_test.py
│ │ │ ├── distribution_lib.py
│ │ │ ├── export.py
│ │ │ ├── image.py
│ │ │ ├── layer.py
│ │ │ ├── linalg.py
│ │ │ ├── math.py
│ │ │ ├── name_scope_test.py
│ │ │ ├── nn.py
│ │ │ ├── numpy.py
│ │ │ ├── optimizer.py
│ │ │ ├── optimizer_distribute_test.py
│ │ │ ├── random.py
│ │ │ ├── rnn.py
│ │ │ ├── saved_model_test.py
│ │ │ ├── sparse.py
│ │ │ ├── tensorboard.py
│ │ │ ├── trackable.py
│ │ │ └── trainer.py
│ │ ├── tests/
│ │ │ ├── compute_output_spec_test.py
│ │ │ └── device_scope_test.py
│ │ └── torch/
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── export.py
│ │ ├── image.py
│ │ ├── layer.py
│ │ ├── linalg.py
│ │ ├── math.py
│ │ ├── nn.py
│ │ ├── numpy.py
│ │ ├── optimizers/
│ │ │ ├── __init__.py
│ │ │ ├── torch_adadelta.py
│ │ │ ├── torch_adagrad.py
│ │ │ ├── torch_adam.py
│ │ │ ├── torch_adamax.py
│ │ │ ├── torch_adamw.py
│ │ │ ├── torch_lion.py
│ │ │ ├── torch_nadam.py
│ │ │ ├── torch_optimizer.py
│ │ │ ├── torch_parallel_optimizer.py
│ │ │ ├── torch_rmsprop.py
│ │ │ └── torch_sgd.py
│ │ ├── random.py
│ │ ├── rnn.py
│ │ └── trainer.py
│ ├── callbacks/
│ │ ├── __init__.py
│ │ ├── backup_and_restore.py
│ │ ├── backup_and_restore_test.py
│ │ ├── callback.py
│ │ ├── callback_list.py
│ │ ├── callback_test.py
│ │ ├── csv_logger.py
│ │ ├── csv_logger_test.py
│ │ ├── early_stopping.py
│ │ ├── early_stopping_test.py
│ │ ├── history.py
│ │ ├── lambda_callback.py
│ │ ├── lambda_callback_test.py
│ │ ├── learning_rate_scheduler.py
│ │ ├── learning_rate_scheduler_test.py
│ │ ├── model_checkpoint.py
│ │ ├── model_checkpoint_test.py
│ │ ├── monitor_callback.py
│ │ ├── monitor_callback_test.py
│ │ ├── orbax_checkpoint.py
│ │ ├── orbax_checkpoint_test.py
│ │ ├── progbar_logger.py
│ │ ├── reduce_lr_on_plateau.py
│ │ ├── reduce_lr_on_plateau_test.py
│ │ ├── remote_monitor.py
│ │ ├── remote_monitor_test.py
│ │ ├── swap_ema_weights.py
│ │ ├── swap_ema_weights_test.py
│ │ ├── tensorboard.py
│ │ ├── tensorboard_test.py
│ │ ├── terminate_on_nan.py
│ │ └── terminate_on_nan_test.py
│ ├── constraints/
│ │ ├── __init__.py
│ │ ├── constraints.py
│ │ └── constraints_test.py
│ ├── datasets/
│ │ ├── __init__.py
│ │ ├── boston_housing.py
│ │ ├── california_housing.py
│ │ ├── cifar.py
│ │ ├── cifar10.py
│ │ ├── cifar100.py
│ │ ├── fashion_mnist.py
│ │ ├── imdb.py
│ │ ├── mnist.py
│ │ └── reuters.py
│ ├── distillation/
│ │ ├── __init__.py
│ │ ├── distillation_loss.py
│ │ ├── distillation_loss_test.py
│ │ ├── distiller.py
│ │ └── distiller_test.py
│ ├── distribution/
│ │ ├── __init__.py
│ │ ├── distribution_lib.py
│ │ └── distribution_lib_test.py
│ ├── dtype_policies/
│ │ ├── __init__.py
│ │ ├── dtype_policy.py
│ │ ├── dtype_policy_map.py
│ │ ├── dtype_policy_map_test.py
│ │ └── dtype_policy_test.py
│ ├── export/
│ │ ├── __init__.py
│ │ ├── export_utils.py
│ │ ├── litert.py
│ │ ├── litert_test.py
│ │ ├── neptune_model_export_archive.py
│ │ ├── onnx.py
│ │ ├── onnx_test.py
│ │ ├── openvino.py
│ │ ├── openvino_test.py
│ │ ├── saved_model.py
│ │ ├── saved_model_export_archive.py
│ │ ├── saved_model_test.py
│ │ ├── tf2onnx_lib.py
│ │ ├── tfsm_layer.py
│ │ └── tfsm_layer_test.py
│ ├── initializers/
│ │ ├── __init__.py
│ │ ├── constant_initializers.py
│ │ ├── constant_initializers_test.py
│ │ ├── initializer.py
│ │ ├── random_initializers.py
│ │ └── random_initializers_test.py
│ ├── layers/
│ │ ├── __init__.py
│ │ ├── activations/
│ │ │ ├── __init__.py
│ │ │ ├── activation.py
│ │ │ ├── activation_test.py
│ │ │ ├── elu.py
│ │ │ ├── elu_test.py
│ │ │ ├── leaky_relu.py
│ │ │ ├── leaky_relu_test.py
│ │ │ ├── prelu.py
│ │ │ ├── prelu_test.py
│ │ │ ├── relu.py
│ │ │ ├── relu_test.py
│ │ │ ├── softmax.py
│ │ │ └── softmax_test.py
│ │ ├── attention/
│ │ │ ├── __init__.py
│ │ │ ├── additive_attention.py
│ │ │ ├── additive_attention_test.py
│ │ │ ├── attention.py
│ │ │ ├── attention_test.py
│ │ │ ├── grouped_query_attention.py
│ │ │ ├── grouped_query_attention_test.py
│ │ │ ├── multi_head_attention.py
│ │ │ └── multi_head_attention_test.py
│ │ ├── convolutional/
│ │ │ ├── __init__.py
│ │ │ ├── base_conv.py
│ │ │ ├── base_conv_transpose.py
│ │ │ ├── base_depthwise_conv.py
│ │ │ ├── base_separable_conv.py
│ │ │ ├── conv1d.py
│ │ │ ├── conv1d_transpose.py
│ │ │ ├── conv2d.py
│ │ │ ├── conv2d_transpose.py
│ │ │ ├── conv3d.py
│ │ │ ├── conv3d_transpose.py
│ │ │ ├── conv_test.py
│ │ │ ├── conv_transpose_test.py
│ │ │ ├── depthwise_conv1d.py
│ │ │ ├── depthwise_conv2d.py
│ │ │ ├── depthwise_conv_test.py
│ │ │ ├── separable_conv1d.py
│ │ │ ├── separable_conv2d.py
│ │ │ └── separable_conv_test.py
│ │ ├── core/
│ │ │ ├── __init__.py
│ │ │ ├── dense.py
│ │ │ ├── dense_test.py
│ │ │ ├── einsum_dense.py
│ │ │ ├── einsum_dense_test.py
│ │ │ ├── embedding.py
│ │ │ ├── embedding_test.py
│ │ │ ├── identity.py
│ │ │ ├── identity_test.py
│ │ │ ├── input_layer.py
│ │ │ ├── input_layer_test.py
│ │ │ ├── lambda_layer.py
│ │ │ ├── lambda_layer_test.py
│ │ │ ├── masking.py
│ │ │ ├── masking_test.py
│ │ │ ├── reversible_embedding.py
│ │ │ ├── reversible_embedding_test.py
│ │ │ ├── wrapper.py
│ │ │ └── wrapper_test.py
│ │ ├── input_spec.py
│ │ ├── layer.py
│ │ ├── layer_test.py
│ │ ├── merging/
│ │ │ ├── __init__.py
│ │ │ ├── add.py
│ │ │ ├── average.py
│ │ │ ├── base_merge.py
│ │ │ ├── concatenate.py
│ │ │ ├── dot.py
│ │ │ ├── maximum.py
│ │ │ ├── merging_test.py
│ │ │ ├── minimum.py
│ │ │ ├── multiply.py
│ │ │ └── subtract.py
│ │ ├── normalization/
│ │ │ ├── __init__.py
│ │ │ ├── batch_normalization.py
│ │ │ ├── batch_normalization_test.py
│ │ │ ├── group_normalization.py
│ │ │ ├── group_normalization_test.py
│ │ │ ├── layer_normalization.py
│ │ │ ├── layer_normalization_test.py
│ │ │ ├── rms_normalization.py
│ │ │ ├── rms_normalization_test.py
│ │ │ ├── spectral_normalization.py
│ │ │ ├── spectral_normalization_test.py
│ │ │ ├── unit_normalization.py
│ │ │ └── unit_normalization_test.py
│ │ ├── pooling/
│ │ │ ├── __init__.py
│ │ │ ├── adaptive_average_pooling1d.py
│ │ │ ├── adaptive_average_pooling2d.py
│ │ │ ├── adaptive_average_pooling3d.py
│ │ │ ├── adaptive_max_pooling1d.py
│ │ │ ├── adaptive_max_pooling2d.py
│ │ │ ├── adaptive_max_pooling3d.py
│ │ │ ├── adaptive_pooling1d_test.py
│ │ │ ├── adaptive_pooling2d_test.py
│ │ │ ├── adaptive_pooling3d_test.py
│ │ │ ├── average_pooling1d.py
│ │ │ ├── average_pooling2d.py
│ │ │ ├── average_pooling3d.py
│ │ │ ├── average_pooling_test.py
│ │ │ ├── base_adaptive_pooling.py
│ │ │ ├── base_global_pooling.py
│ │ │ ├── base_pooling.py
│ │ │ ├── global_average_pooling1d.py
│ │ │ ├── global_average_pooling2d.py
│ │ │ ├── global_average_pooling3d.py
│ │ │ ├── global_average_pooling_test.py
│ │ │ ├── global_max_pooling1d.py
│ │ │ ├── global_max_pooling2d.py
│ │ │ ├── global_max_pooling3d.py
│ │ │ ├── global_max_pooling_test.py
│ │ │ ├── max_pooling1d.py
│ │ │ ├── max_pooling2d.py
│ │ │ ├── max_pooling3d.py
│ │ │ └── max_pooling_test.py
│ │ ├── preprocessing/
│ │ │ ├── __init__.py
│ │ │ ├── category_encoding.py
│ │ │ ├── category_encoding_test.py
│ │ │ ├── data_layer.py
│ │ │ ├── data_layer_test.py
│ │ │ ├── discretization.py
│ │ │ ├── discretization_test.py
│ │ │ ├── feature_space.py
│ │ │ ├── feature_space_test.py
│ │ │ ├── hashed_crossing.py
│ │ │ ├── hashed_crossing_test.py
│ │ │ ├── hashing.py
│ │ │ ├── hashing_test.py
│ │ │ ├── image_preprocessing/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── aug_mix.py
│ │ │ │ ├── aug_mix_test.py
│ │ │ │ ├── auto_contrast.py
│ │ │ │ ├── auto_contrast_test.py
│ │ │ │ ├── base_image_preprocessing_layer.py
│ │ │ │ ├── bounding_boxes/
│ │ │ │ │ ├── __init__.py
│ │ │ │ │ ├── bounding_box.py
│ │ │ │ │ ├── converters.py
│ │ │ │ │ ├── converters_test.py
│ │ │ │ │ ├── formats.py
│ │ │ │ │ ├── iou.py
│ │ │ │ │ ├── iou_test.py
│ │ │ │ │ ├── validation.py
│ │ │ │ │ └── validation_test.py
│ │ │ │ ├── center_crop.py
│ │ │ │ ├── center_crop_test.py
│ │ │ │ ├── clahe.py
│ │ │ │ ├── clahe_test.py
│ │ │ │ ├── cut_mix.py
│ │ │ │ ├── cut_mix_test.py
│ │ │ │ ├── equalization.py
│ │ │ │ ├── equalization_test.py
│ │ │ │ ├── max_num_bounding_box.py
│ │ │ │ ├── max_num_bounding_box_test.py
│ │ │ │ ├── mix_up.py
│ │ │ │ ├── mix_up_test.py
│ │ │ │ ├── rand_augment.py
│ │ │ │ ├── rand_augment_test.py
│ │ │ │ ├── random_brightness.py
│ │ │ │ ├── random_brightness_test.py
│ │ │ │ ├── random_color_degeneration.py
│ │ │ │ ├── random_color_degeneration_test.py
│ │ │ │ ├── random_color_jitter.py
│ │ │ │ ├── random_color_jitter_test.py
│ │ │ │ ├── random_contrast.py
│ │ │ │ ├── random_contrast_test.py
│ │ │ │ ├── random_crop.py
│ │ │ │ ├── random_crop_test.py
│ │ │ │ ├── random_elastic_transform.py
│ │ │ │ ├── random_elastic_transform_test.py
│ │ │ │ ├── random_erasing.py
│ │ │ │ ├── random_erasing_test.py
│ │ │ │ ├── random_flip.py
│ │ │ │ ├── random_flip_test.py
│ │ │ │ ├── random_gaussian_blur.py
│ │ │ │ ├── random_gaussian_blur_test.py
│ │ │ │ ├── random_grayscale.py
│ │ │ │ ├── random_grayscale_test.py
│ │ │ │ ├── random_hue.py
│ │ │ │ ├── random_hue_test.py
│ │ │ │ ├── random_invert.py
│ │ │ │ ├── random_invert_test.py
│ │ │ │ ├── random_perspective.py
│ │ │ │ ├── random_perspective_test.py
│ │ │ │ ├── random_posterization.py
│ │ │ │ ├── random_posterization_test.py
│ │ │ │ ├── random_rotation.py
│ │ │ │ ├── random_rotation_test.py
│ │ │ │ ├── random_saturation.py
│ │ │ │ ├── random_saturation_test.py
│ │ │ │ ├── random_sharpness.py
│ │ │ │ ├── random_sharpness_test.py
│ │ │ │ ├── random_shear.py
│ │ │ │ ├── random_shear_test.py
│ │ │ │ ├── random_translation.py
│ │ │ │ ├── random_translation_test.py
│ │ │ │ ├── random_zoom.py
│ │ │ │ ├── random_zoom_test.py
│ │ │ │ ├── resizing.py
│ │ │ │ ├── resizing_test.py
│ │ │ │ ├── solarization.py
│ │ │ │ └── solarization_test.py
│ │ │ ├── index_lookup.py
│ │ │ ├── index_lookup_test.py
│ │ │ ├── integer_lookup.py
│ │ │ ├── integer_lookup_test.py
│ │ │ ├── mel_spectrogram.py
│ │ │ ├── mel_spectrogram_test.py
│ │ │ ├── normalization.py
│ │ │ ├── normalization_test.py
│ │ │ ├── pipeline.py
│ │ │ ├── pipeline_test.py
│ │ │ ├── rescaling.py
│ │ │ ├── rescaling_test.py
│ │ │ ├── stft_spectrogram.py
│ │ │ ├── stft_spectrogram_test.py
│ │ │ ├── string_lookup.py
│ │ │ ├── string_lookup_test.py
│ │ │ ├── text_vectorization.py
│ │ │ └── text_vectorization_test.py
│ │ ├── regularization/
│ │ │ ├── __init__.py
│ │ │ ├── activity_regularization.py
│ │ │ ├── activity_regularization_test.py
│ │ │ ├── alpha_dropout.py
│ │ │ ├── alpha_dropout_test.py
│ │ │ ├── dropout.py
│ │ │ ├── dropout_test.py
│ │ │ ├── gaussian_dropout.py
│ │ │ ├── gaussian_dropout_test.py
│ │ │ ├── gaussian_noise.py
│ │ │ ├── gaussian_noise_test.py
│ │ │ ├── spatial_dropout.py
│ │ │ └── spatial_dropout_test.py
│ │ ├── reshaping/
│ │ │ ├── __init__.py
│ │ │ ├── cropping1d.py
│ │ │ ├── cropping1d_test.py
│ │ │ ├── cropping2d.py
│ │ │ ├── cropping2d_test.py
│ │ │ ├── cropping3d.py
│ │ │ ├── cropping3d_test.py
│ │ │ ├── flatten.py
│ │ │ ├── flatten_test.py
│ │ │ ├── permute.py
│ │ │ ├── permute_test.py
│ │ │ ├── repeat_vector.py
│ │ │ ├── repeat_vector_test.py
│ │ │ ├── reshape.py
│ │ │ ├── reshape_test.py
│ │ │ ├── up_sampling1d.py
│ │ │ ├── up_sampling1d_test.py
│ │ │ ├── up_sampling2d.py
│ │ │ ├── up_sampling2d_test.py
│ │ │ ├── up_sampling3d.py
│ │ │ ├── up_sampling3d_test.py
│ │ │ ├── zero_padding1d.py
│ │ │ ├── zero_padding1d_test.py
│ │ │ ├── zero_padding2d.py
│ │ │ ├── zero_padding2d_test.py
│ │ │ ├── zero_padding3d.py
│ │ │ └── zero_padding3d_test.py
│ │ └── rnn/
│ │ ├── __init__.py
│ │ ├── bidirectional.py
│ │ ├── bidirectional_test.py
│ │ ├── conv_lstm.py
│ │ ├── conv_lstm1d.py
│ │ ├── conv_lstm1d_test.py
│ │ ├── conv_lstm2d.py
│ │ ├── conv_lstm2d_test.py
│ │ ├── conv_lstm3d.py
│ │ ├── conv_lstm3d_test.py
│ │ ├── conv_lstm_test.py
│ │ ├── dropout_rnn_cell.py
│ │ ├── dropout_rnn_cell_test.py
│ │ ├── gru.py
│ │ ├── gru_test.py
│ │ ├── lstm.py
│ │ ├── lstm_test.py
│ │ ├── rnn.py
│ │ ├── rnn_test.py
│ │ ├── simple_rnn.py
│ │ ├── simple_rnn_test.py
│ │ ├── stacked_rnn_cells.py
│ │ ├── stacked_rnn_cells_test.py
│ │ ├── time_distributed.py
│ │ └── time_distributed_test.py
│ ├── legacy/
│ │ ├── __init__.py
│ │ ├── backend.py
│ │ ├── layers.py
│ │ ├── losses.py
│ │ ├── preprocessing/
│ │ │ ├── __init__.py
│ │ │ ├── image.py
│ │ │ ├── sequence.py
│ │ │ └── text.py
│ │ └── saving/
│ │ ├── __init__.py
│ │ ├── json_utils.py
│ │ ├── json_utils_test.py
│ │ ├── legacy_h5_format.py
│ │ ├── legacy_h5_format_test.py
│ │ ├── saving_options.py
│ │ ├── saving_utils.py
│ │ └── serialization.py
│ ├── losses/
│ │ ├── __init__.py
│ │ ├── loss.py
│ │ ├── loss_test.py
│ │ ├── losses.py
│ │ └── losses_test.py
│ ├── metrics/
│ │ ├── __init__.py
│ │ ├── accuracy_metrics.py
│ │ ├── accuracy_metrics_test.py
│ │ ├── confusion_metrics.py
│ │ ├── confusion_metrics_test.py
│ │ ├── correlation_metrics.py
│ │ ├── correlation_metrics_test.py
│ │ ├── f_score_metrics.py
│ │ ├── f_score_metrics_test.py
│ │ ├── hinge_metrics.py
│ │ ├── hinge_metrics_test.py
│ │ ├── iou_metrics.py
│ │ ├── iou_metrics_test.py
│ │ ├── metric.py
│ │ ├── metric_test.py
│ │ ├── metrics_utils.py
│ │ ├── probabilistic_metrics.py
│ │ ├── probabilistic_metrics_test.py
│ │ ├── reduction_metrics.py
│ │ ├── reduction_metrics_test.py
│ │ ├── regression_metrics.py
│ │ └── regression_metrics_test.py
│ ├── models/
│ │ ├── __init__.py
│ │ ├── cloning.py
│ │ ├── cloning_test.py
│ │ ├── functional.py
│ │ ├── functional_test.py
│ │ ├── model.py
│ │ ├── model_test.py
│ │ ├── sequential.py
│ │ ├── sequential_test.py
│ │ ├── variable_mapping.py
│ │ └── variable_mapping_test.py
│ ├── ops/
│ │ ├── __init__.py
│ │ ├── core.py
│ │ ├── core_test.py
│ │ ├── einops.py
│ │ ├── einops_test.py
│ │ ├── function.py
│ │ ├── function_test.py
│ │ ├── image.py
│ │ ├── image_test.py
│ │ ├── linalg.py
│ │ ├── linalg_test.py
│ │ ├── math.py
│ │ ├── math_test.py
│ │ ├── nn.py
│ │ ├── nn_test.py
│ │ ├── node.py
│ │ ├── node_test.py
│ │ ├── numpy.py
│ │ ├── numpy_test.py
│ │ ├── operation.py
│ │ ├── operation_test.py
│ │ ├── operation_utils.py
│ │ ├── operation_utils_test.py
│ │ ├── ops_test.py
│ │ ├── symbolic_arguments.py
│ │ └── symbolic_arguments_test.py
│ ├── optimizers/
│ │ ├── __init__.py
│ │ ├── adadelta.py
│ │ ├── adadelta_test.py
│ │ ├── adafactor.py
│ │ ├── adafactor_test.py
│ │ ├── adagrad.py
│ │ ├── adagrad_test.py
│ │ ├── adam.py
│ │ ├── adam_test.py
│ │ ├── adamax.py
│ │ ├── adamax_test.py
│ │ ├── adamw.py
│ │ ├── adamw_test.py
│ │ ├── base_optimizer.py
│ │ ├── ftrl.py
│ │ ├── ftrl_test.py
│ │ ├── lamb.py
│ │ ├── lamb_test.py
│ │ ├── lion.py
│ │ ├── lion_test.py
│ │ ├── loss_scale_optimizer.py
│ │ ├── loss_scale_optimizer_test.py
│ │ ├── muon.py
│ │ ├── muon_test.py
│ │ ├── nadam.py
│ │ ├── nadam_test.py
│ │ ├── optimizer.py
│ │ ├── optimizer_sparse_test.py
│ │ ├── optimizer_test.py
│ │ ├── rmsprop.py
│ │ ├── rmsprop_test.py
│ │ ├── schedule_free_adamw.py
│ │ ├── schedule_free_adamw_test.py
│ │ ├── schedules/
│ │ │ ├── __init__.py
│ │ │ ├── learning_rate_schedule.py
│ │ │ └── learning_rate_schedule_test.py
│ │ ├── sgd.py
│ │ └── sgd_test.py
│ ├── quantizers/
│ │ ├── __init__.py
│ │ ├── awq.py
│ │ ├── awq_config.py
│ │ ├── awq_config_test.py
│ │ ├── awq_core.py
│ │ ├── awq_test.py
│ │ ├── gptq.py
│ │ ├── gptq_config.py
│ │ ├── gptq_config_test.py
│ │ ├── gptq_core.py
│ │ ├── gptq_core_test.py
│ │ ├── gptq_test.py
│ │ ├── quantization_config.py
│ │ ├── quantization_config_test.py
│ │ ├── quantizers.py
│ │ ├── quantizers_test.py
│ │ ├── utils.py
│ │ └── utils_test.py
│ ├── random/
│ │ ├── __init__.py
│ │ ├── random.py
│ │ ├── random_test.py
│ │ ├── seed_generator.py
│ │ └── seed_generator_test.py
│ ├── regularizers/
│ │ ├── __init__.py
│ │ ├── regularizers.py
│ │ └── regularizers_test.py
│ ├── saving/
│ │ ├── __init__.py
│ │ ├── file_editor.py
│ │ ├── file_editor_test.py
│ │ ├── keras_saveable.py
│ │ ├── object_registration.py
│ │ ├── object_registration_test.py
│ │ ├── orbax_util.py
│ │ ├── saving_api.py
│ │ ├── saving_api_test.py
│ │ ├── saving_lib.py
│ │ ├── saving_lib_test.py
│ │ ├── serialization_lib.py
│ │ └── serialization_lib_test.py
│ ├── testing/
│ │ ├── __init__.py
│ │ ├── test_case.py
│ │ ├── test_utils.py
│ │ └── test_utils_test.py
│ ├── trainers/
│ │ ├── __init__.py
│ │ ├── compile_utils.py
│ │ ├── compile_utils_test.py
│ │ ├── data_adapters/
│ │ │ ├── __init__.py
│ │ │ ├── array_data_adapter.py
│ │ │ ├── array_data_adapter_test.py
│ │ │ ├── array_slicing.py
│ │ │ ├── data_adapter.py
│ │ │ ├── data_adapter_utils.py
│ │ │ ├── data_adapter_utils_test.py
│ │ │ ├── generator_data_adapter.py
│ │ │ ├── generator_data_adapter_test.py
│ │ │ ├── grain_dataset_adapter.py
│ │ │ ├── grain_dataset_adapter_test.py
│ │ │ ├── py_dataset_adapter.py
│ │ │ ├── py_dataset_adapter_test.py
│ │ │ ├── tf_dataset_adapter.py
│ │ │ ├── tf_dataset_adapter_test.py
│ │ │ ├── torch_data_loader_adapter.py
│ │ │ └── torch_data_loader_adapter_test.py
│ │ ├── epoch_iterator.py
│ │ ├── epoch_iterator_test.py
│ │ ├── trainer.py
│ │ └── trainer_test.py
│ ├── tree/
│ │ ├── __init__.py
│ │ ├── dmtree_impl.py
│ │ ├── optree_impl.py
│ │ ├── torchtree_impl.py
│ │ ├── tree_api.py
│ │ └── tree_test.py
│ ├── utils/
│ │ ├── __init__.py
│ │ ├── argument_validation.py
│ │ ├── audio_dataset_utils.py
│ │ ├── audio_dataset_utils_test.py
│ │ ├── backend_utils.py
│ │ ├── backend_utils_test.py
│ │ ├── code_stats.py
│ │ ├── code_stats_test.py
│ │ ├── config.py
│ │ ├── dataset_utils.py
│ │ ├── dataset_utils_test.py
│ │ ├── dtype_utils.py
│ │ ├── dtype_utils_test.py
│ │ ├── file_utils.py
│ │ ├── file_utils_test.py
│ │ ├── grain_utils.py
│ │ ├── image_dataset_utils.py
│ │ ├── image_dataset_utils_test.py
│ │ ├── image_utils.py
│ │ ├── image_utils_test.py
│ │ ├── io_utils.py
│ │ ├── io_utils_test.py
│ │ ├── jax_layer.py
│ │ ├── jax_layer_test.py
│ │ ├── jax_utils.py
│ │ ├── model_visualization.py
│ │ ├── module_utils.py
│ │ ├── naming.py
│ │ ├── naming_test.py
│ │ ├── numerical_utils.py
│ │ ├── numerical_utils_test.py
│ │ ├── progbar.py
│ │ ├── progbar_test.py
│ │ ├── python_utils.py
│ │ ├── python_utils_test.py
│ │ ├── rng_utils.py
│ │ ├── rng_utils_test.py
│ │ ├── sequence_utils.py
│ │ ├── sequence_utils_test.py
│ │ ├── summary_utils.py
│ │ ├── summary_utils_test.py
│ │ ├── text_dataset_utils.py
│ │ ├── text_dataset_utils_test.py
│ │ ├── tf_utils.py
│ │ ├── timeseries_dataset_utils.py
│ │ ├── timeseries_dataset_utils_test.py
│ │ ├── torch_utils.py
│ │ ├── torch_utils_test.py
│ │ ├── traceback_utils.py
│ │ ├── tracking.py
│ │ └── tracking_test.py
│ ├── version.py
│ ├── visualization/
│ │ ├── __init__.py
│ │ ├── draw_bounding_boxes.py
│ │ ├── draw_segmentation_masks.py
│ │ ├── plot_bounding_box_gallery.py
│ │ ├── plot_image_gallery.py
│ │ └── plot_segmentation_mask_gallery.py
│ └── wrappers/
│ ├── __init__.py
│ ├── fixes.py
│ ├── sklearn_test.py
│ ├── sklearn_wrapper.py
│ └── utils.py
├── pip_build.py
├── pyproject.toml
├── requirements-common.txt
├── requirements-jax-cuda.txt
├── requirements-jax-tpu.txt
├── requirements-tensorflow-cuda.txt
├── requirements-tensorflow-tpu.txt
├── requirements-torch-cuda.txt
├── requirements.txt
└── shell/
├── api_gen.sh
└── format.sh
Showing preview only (1,114K chars total). Download the full file or copy to clipboard to get everything.
SYMBOL INDEX (14089 symbols across 779 files)
FILE: api_gen.py
function ignore_files (line 19) | def ignore_files(_, filenames):
function copy_source_to_build_directory (line 23) | def copy_source_to_build_directory(root_path):
function create_legacy_directory (line 36) | def create_legacy_directory(package_dir):
function export_version_string (line 140) | def export_version_string(api_init_fname):
function build (line 148) | def build():
FILE: benchmarks/layer_benchmark/activation_benchmark.py
function benchmark_elu (line 23) | def benchmark_elu(
function benchmark_prelu (line 47) | def benchmark_prelu(
function benchmark_relu (line 71) | def benchmark_relu(
function benchmark_leaky_relu (line 95) | def benchmark_leaky_relu(
function benchmark_softmax (line 119) | def benchmark_softmax(
function main (line 152) | def main(_):
FILE: benchmarks/layer_benchmark/attention_benchmark.py
function benchmark_attention (line 23) | def benchmark_attention(
function benchmark_multi_head_attention (line 49) | def benchmark_multi_head_attention(
function benchmark_additive_attention (line 78) | def benchmark_additive_attention(
function main (line 111) | def main(_):
FILE: benchmarks/layer_benchmark/base_benchmark.py
class BenchmarkMetricsCallback (line 37) | class BenchmarkMetricsCallback:
method __init__ (line 38) | def __init__(self, start_batch=1, stop_batch=None):
method on_train_batch_begin (line 44) | def on_train_batch_begin(self, batch, logs=None):
method on_train_batch_end (line 48) | def on_train_batch_end(self, batch, logs=None):
method on_predict_batch_begin (line 56) | def on_predict_batch_begin(self, batch, logs=None):
method on_predict_batch_end (line 60) | def on_predict_batch_end(self, batch, logs=None):
class KerasCoreBenchmarkMetricsCallback (line 69) | class KerasCoreBenchmarkMetricsCallback(keras.callbacks.Callback):
method __init__ (line 70) | def __init__(self, start_batch=1, stop_batch=None):
method on_train_batch_begin (line 73) | def on_train_batch_begin(self, batch, logs=None):
method on_train_batch_end (line 76) | def on_train_batch_end(self, batch, logs=None):
method on_predict_batch_begin (line 79) | def on_predict_batch_begin(self, batch, logs=None):
method on_predict_batch_end (line 82) | def on_predict_batch_end(self, batch, logs=None):
class TFKerasBenchmarkMetricsCallback (line 86) | class TFKerasBenchmarkMetricsCallback(tf.keras.callbacks.Callback):
method __init__ (line 87) | def __init__(self, start_batch=1, stop_batch=None):
method on_train_batch_begin (line 90) | def on_train_batch_begin(self, batch, logs=None):
method on_train_batch_end (line 93) | def on_train_batch_end(self, batch, logs=None):
method on_predict_batch_begin (line 96) | def on_predict_batch_begin(self, batch, logs=None):
method on_predict_batch_end (line 99) | def on_predict_batch_end(self, batch, logs=None):
class LayerBenchmark (line 103) | class LayerBenchmark:
method __init__ (line 104) | def __init__(
method _build_keras_model (line 151) | def _build_keras_model(self, input_shape, flat_call_inputs=True):
method _build_tf_keras_model (line 165) | def _build_tf_keras_model(self, input_shape, flat_call_inputs=True):
method benchmark_predict (line 179) | def benchmark_predict(self, num_samples, batch_size, data=None):
method benchmark_train (line 223) | def benchmark_train(self, num_samples, batch_size, data=None, label=No...
FILE: benchmarks/layer_benchmark/conv_benchmark.py
function benchmark_conv1D (line 23) | def benchmark_conv1D(
function benchmark_conv2D (line 51) | def benchmark_conv2D(
function benchmark_conv3D (line 79) | def benchmark_conv3D(
function benchmark_depthwise_conv1D (line 107) | def benchmark_depthwise_conv1D(
function benchmark_depthwise_conv2D (line 135) | def benchmark_depthwise_conv2D(
function benchmark_separable_conv1D (line 163) | def benchmark_separable_conv1D(
function benchmark_separable_conv2D (line 192) | def benchmark_separable_conv2D(
function benchmark_conv1D_transpose (line 221) | def benchmark_conv1D_transpose(
function benchmark_conv2D_transpose (line 249) | def benchmark_conv2D_transpose(
function benchmark_conv3D_transpose (line 277) | def benchmark_conv3D_transpose(
function main (line 319) | def main(_):
FILE: benchmarks/layer_benchmark/core_benchmark.py
function benchmark_dense (line 24) | def benchmark_dense(
function benchmark_einsum_dense (line 49) | def benchmark_einsum_dense(
function benchmark_embedding (line 77) | def benchmark_embedding(
function main (line 117) | def main(_):
FILE: benchmarks/layer_benchmark/merge_benchmark.py
function benchmark_add (line 23) | def benchmark_add(
function benchmark_average (line 49) | def benchmark_average(
function benchmark_concatenate (line 75) | def benchmark_concatenate(
function benchmark_dot (line 101) | def benchmark_dot(
function benchmark_maximum (line 127) | def benchmark_maximum(
function benchmark_minimum (line 153) | def benchmark_minimum(
function benchmark_multiply (line 179) | def benchmark_multiply(
function benchmark_subtract (line 205) | def benchmark_subtract(
function main (line 243) | def main(_):
FILE: benchmarks/layer_benchmark/normalization_benchmark.py
function benchmark_batch_normalization (line 23) | def benchmark_batch_normalization(
function benchmark_group_normalization (line 48) | def benchmark_group_normalization(
function benchmark_layer_normalization (line 75) | def benchmark_layer_normalization(
function benchmark_unit_normalization (line 100) | def benchmark_unit_normalization(
function main (line 133) | def main(_):
FILE: benchmarks/layer_benchmark/pooling_benchmark.py
function benchmark_average_pooling1d (line 23) | def benchmark_average_pooling1d(
function benchmark_average_pooling2d (line 50) | def benchmark_average_pooling2d(
function benchmark_average_pooling3d (line 77) | def benchmark_average_pooling3d(
function benchmark_max_pooling1d (line 104) | def benchmark_max_pooling1d(
function benchmark_max_pooling2d (line 131) | def benchmark_max_pooling2d(
function benchmark_max_pooling3d (line 158) | def benchmark_max_pooling3d(
function benchmark_global_average_pooling1d (line 185) | def benchmark_global_average_pooling1d(
function benchmark_global_average_pooling2d (line 210) | def benchmark_global_average_pooling2d(
function benchmark_global_average_pooling3d (line 235) | def benchmark_global_average_pooling3d(
function benchmark_global_max_pooling1d (line 260) | def benchmark_global_max_pooling1d(
function benchmark_global_max_pooling2d (line 285) | def benchmark_global_max_pooling2d(
function benchmark_global_max_pooling3d (line 310) | def benchmark_global_max_pooling3d(
function main (line 351) | def main(_):
FILE: benchmarks/layer_benchmark/random_rotation_benchmark.py
function benchmark_random_rotation (line 11) | def benchmark_random_rotation(
function main (line 44) | def main(_):
FILE: benchmarks/layer_benchmark/regularization_benchmark.py
function benchmark_dropout (line 23) | def benchmark_dropout(
function benchmark_gaussian_dropout (line 50) | def benchmark_gaussian_dropout(
function benchmark_gaussian_noise (line 77) | def benchmark_gaussian_noise(
function benchmark_spatial_dropout1D (line 104) | def benchmark_spatial_dropout1D(
function benchmark_spatial_dropout2D (line 131) | def benchmark_spatial_dropout2D(
function benchmark_spatial_dropout3D (line 158) | def benchmark_spatial_dropout3D(
function main (line 195) | def main(_):
FILE: benchmarks/layer_benchmark/reshaping_benchmark.py
function benchmark_cropping1d (line 23) | def benchmark_cropping1d(
function benchmark_cropping2d (line 48) | def benchmark_cropping2d(
function benchmark_cropping3d (line 73) | def benchmark_cropping3d(
function benchmark_flatten (line 98) | def benchmark_flatten(
function benchmark_permute (line 123) | def benchmark_permute(
function benchmark_up_sampling1d (line 150) | def benchmark_up_sampling1d(
function benchmark_up_sampling2d (line 175) | def benchmark_up_sampling2d(
function benchmark_up_sampling3d (line 200) | def benchmark_up_sampling3d(
function benchmark_zero_padding1d (line 225) | def benchmark_zero_padding1d(
function benchmark_zero_padding2d (line 250) | def benchmark_zero_padding2d(
function benchmark_zero_padding3d (line 275) | def benchmark_zero_padding3d(
function main (line 315) | def main(_):
FILE: benchmarks/layer_benchmark/rnn_benchmark.py
function benchmark_conv_lstm1d (line 25) | def benchmark_conv_lstm1d(
function benchmark_conv_lstm2d (line 53) | def benchmark_conv_lstm2d(
function benchmark_conv_lstm3d (line 81) | def benchmark_conv_lstm3d(
function benchmark_gru (line 109) | def benchmark_gru(
function benchmark_lstm (line 136) | def benchmark_lstm(
function benchmark_simple_rnn (line 163) | def benchmark_simple_rnn(
function benchmark_bidirectional (line 190) | def benchmark_bidirectional(
function benchmark_time_distributed (line 219) | def benchmark_time_distributed(
function main (line 262) | def main(_):
FILE: benchmarks/model_benchmark/benchmark_utils.py
class BenchmarkMetricsCallback (line 6) | class BenchmarkMetricsCallback(keras.callbacks.Callback):
method __init__ (line 7) | def __init__(self, start_batch=1, stop_batch=None):
method on_train_batch_begin (line 14) | def on_train_batch_begin(self, batch, logs=None):
method on_train_batch_end (line 18) | def on_train_batch_end(self, batch, logs=None):
FILE: benchmarks/model_benchmark/bert_benchmark.py
function load_data (line 47) | def load_data():
function load_model (line 86) | def load_model():
function main (line 97) | def main(_):
FILE: benchmarks/model_benchmark/image_classification_benchmark.py
function load_data (line 54) | def load_data():
function load_model (line 84) | def load_model():
function main (line 99) | def main(_):
FILE: benchmarks/torch_ctl_benchmark/benchmark_utils.py
function train_loop (line 7) | def train_loop(model, train_loader, num_epochs, optimizer, loss_fn, fram...
FILE: benchmarks/torch_ctl_benchmark/conv_model_benchmark.py
class TorchModel (line 38) | class TorchModel(torch.nn.Module):
method __init__ (line 39) | def __init__(self):
method forward (line 49) | def forward(self, x):
function run_keras_custom_training_loop (line 59) | def run_keras_custom_training_loop():
function run_torch_custom_training_loop (line 82) | def run_torch_custom_training_loop():
FILE: benchmarks/torch_ctl_benchmark/dense_model_benchmark.py
class TorchModel (line 38) | class TorchModel(torch.nn.Module):
method __init__ (line 39) | def __init__(self):
method forward (line 49) | def forward(self, x):
function run_keras_custom_training_loop (line 59) | def run_keras_custom_training_loop():
function run_torch_custom_training_loop (line 81) | def run_torch_custom_training_loop():
FILE: conftest.py
function pytest_configure (line 14) | def pytest_configure(config):
function pytest_collection_modifyitems (line 21) | def pytest_collection_modifyitems(config, items):
function skip_if_backend (line 77) | def skip_if_backend(given_backend, reason):
FILE: examples/demo_custom_jax_workflow.py
class MyDense (line 18) | class MyDense(layers.Layer):
method __init__ (line 19) | def __init__(self, units, name=None):
method build (line 23) | def build(self, input_shape):
method call (line 33) | def call(self, inputs):
class MyModel (line 37) | class MyModel(Model):
method __init__ (line 38) | def __init__(self, hidden_dim, output_dim):
method call (line 44) | def call(self, x):
function Dataset (line 50) | def Dataset():
function loss_fn (line 55) | def loss_fn(y_true, y_pred):
function compute_loss_and_updates (line 74) | def compute_loss_and_updates(
function train_step (line 88) | def train_step(state, data):
FILE: examples/demo_custom_layer_backend_agnostic.py
class MyDense (line 13) | class MyDense(layers.Layer):
method __init__ (line 14) | def __init__(self, units, name=None):
method build (line 18) | def build(self, input_shape):
method call (line 34) | def call(self, inputs):
class MyDropout (line 39) | class MyDropout(layers.Layer):
method __init__ (line 40) | def __init__(self, rate, name=None):
method call (line 48) | def call(self, inputs):
class MyModel (line 53) | class MyModel(Model):
method __init__ (line 54) | def __init__(self, hidden_dim, output_dim):
method call (line 61) | def call(self, x):
FILE: examples/demo_custom_tf_workflow.py
class MyDense (line 18) | class MyDense(layers.Layer):
method __init__ (line 19) | def __init__(self, units, name=None):
method build (line 23) | def build(self, input_shape):
method call (line 33) | def call(self, inputs):
class MyModel (line 37) | class MyModel(Model):
method __init__ (line 38) | def __init__(self, hidden_dim, output_dim):
method call (line 44) | def call(self, x):
function Dataset (line 50) | def Dataset():
function loss_fn (line 58) | def loss_fn(y_true, y_pred):
function train_step (line 72) | def train_step(data):
FILE: examples/demo_custom_torch_workflow.py
function train (line 59) | def train(model, train_loader, num_epochs, optimizer, loss_fn):
class MyModel (line 102) | class MyModel(nn.Module):
method __init__ (line 103) | def __init__(self):
method forward (line 118) | def forward(self, x):
FILE: examples/demo_jax_distributed.py
function make_backbone (line 57) | def make_backbone():
function make_model (line 92) | def make_model():
function compute_loss (line 254) | def compute_loss(trainable_variables, non_trainable_variables, x, y):
function train_step (line 268) | def train_step(train_state, x, y):
function predict (line 306) | def predict(data):
FILE: examples/demo_subclass.py
class MyModel (line 10) | class MyModel(Model):
method __init__ (line 11) | def __init__(self, hidden_dim, output_dim):
method call (line 17) | def call(self, x):
FILE: examples/demo_torch_multi_gpu.py
function get_data (line 28) | def get_data():
function get_model (line 49) | def get_model():
class MyModel (line 66) | class MyModel(nn.Module):
method __init__ (line 67) | def __init__(self):
method forward (line 82) | def forward(self, x):
function train (line 86) | def train(model, train_loader, num_epochs, optimizer, loss_fn):
function setup (line 114) | def setup(current_gpu_index, num_gpu):
function prepare (line 128) | def prepare(dataset, current_gpu_index, num_gpu, batch_size):
function cleanup (line 147) | def cleanup():
function main (line 152) | def main(current_gpu_index, num_gpu):
FILE: guides/custom_train_step_in_jax.py
class CustomModel (line 77) | class CustomModel(keras.Model):
method compute_loss_and_updates (line 78) | def compute_loss_and_updates(
method train_step (line 95) | def train_step(self, state, data):
method __init__ (line 181) | def __init__(self, *args, **kwargs):
method compute_loss_and_updates (line 187) | def compute_loss_and_updates(
method train_step (line 204) | def train_step(self, state, data):
method metrics (line 268) | def metrics(self):
method test_step (line 298) | def test_step(self, state, data):
class CustomModel (line 180) | class CustomModel(keras.Model):
method compute_loss_and_updates (line 78) | def compute_loss_and_updates(
method train_step (line 95) | def train_step(self, state, data):
method __init__ (line 181) | def __init__(self, *args, **kwargs):
method compute_loss_and_updates (line 187) | def compute_loss_and_updates(
method train_step (line 204) | def train_step(self, state, data):
method metrics (line 268) | def metrics(self):
method test_step (line 298) | def test_step(self, state, data):
class CustomModel (line 297) | class CustomModel(keras.Model):
method compute_loss_and_updates (line 78) | def compute_loss_and_updates(
method train_step (line 95) | def train_step(self, state, data):
method __init__ (line 181) | def __init__(self, *args, **kwargs):
method compute_loss_and_updates (line 187) | def compute_loss_and_updates(
method train_step (line 204) | def train_step(self, state, data):
method metrics (line 268) | def metrics(self):
method test_step (line 298) | def test_step(self, state, data):
FILE: guides/custom_train_step_in_tensorflow.py
class CustomModel (line 83) | class CustomModel(keras.Model):
method train_step (line 84) | def train_step(self, data):
method __init__ (line 150) | def __init__(self, *args, **kwargs):
method train_step (line 156) | def train_step(self, data):
method metrics (line 180) | def metrics(self):
method train_step (line 216) | def train_step(self, data):
method test_step (line 276) | def test_step(self, data):
class CustomModel (line 149) | class CustomModel(keras.Model):
method train_step (line 84) | def train_step(self, data):
method __init__ (line 150) | def __init__(self, *args, **kwargs):
method train_step (line 156) | def train_step(self, data):
method metrics (line 180) | def metrics(self):
method train_step (line 216) | def train_step(self, data):
method test_step (line 276) | def test_step(self, data):
class CustomModel (line 215) | class CustomModel(keras.Model):
method train_step (line 84) | def train_step(self, data):
method __init__ (line 150) | def __init__(self, *args, **kwargs):
method train_step (line 156) | def train_step(self, data):
method metrics (line 180) | def metrics(self):
method train_step (line 216) | def train_step(self, data):
method test_step (line 276) | def test_step(self, data):
class CustomModel (line 275) | class CustomModel(keras.Model):
method train_step (line 84) | def train_step(self, data):
method __init__ (line 150) | def __init__(self, *args, **kwargs):
method train_step (line 156) | def train_step(self, data):
method metrics (line 180) | def metrics(self):
method train_step (line 216) | def train_step(self, data):
method test_step (line 276) | def test_step(self, data):
class GAN (line 357) | class GAN(keras.Model):
method __init__ (line 358) | def __init__(self, discriminator, generator, latent_dim):
method metrics (line 368) | def metrics(self):
method compile (line 371) | def compile(self, d_optimizer, g_optimizer, loss_fn):
method train_step (line 377) | def train_step(self, real_images):
FILE: guides/custom_train_step_in_torch.py
class CustomModel (line 84) | class CustomModel(keras.Model):
method train_step (line 85) | def train_step(self, data):
method __init__ (line 158) | def __init__(self, *args, **kwargs):
method train_step (line 164) | def train_step(self, data):
method metrics (line 195) | def metrics(self):
method train_step (line 231) | def train_step(self, data):
method test_step (line 296) | def test_step(self, data):
class CustomModel (line 157) | class CustomModel(keras.Model):
method train_step (line 85) | def train_step(self, data):
method __init__ (line 158) | def __init__(self, *args, **kwargs):
method train_step (line 164) | def train_step(self, data):
method metrics (line 195) | def metrics(self):
method train_step (line 231) | def train_step(self, data):
method test_step (line 296) | def test_step(self, data):
class CustomModel (line 230) | class CustomModel(keras.Model):
method train_step (line 85) | def train_step(self, data):
method __init__ (line 158) | def __init__(self, *args, **kwargs):
method train_step (line 164) | def train_step(self, data):
method metrics (line 195) | def metrics(self):
method train_step (line 231) | def train_step(self, data):
method test_step (line 296) | def test_step(self, data):
class CustomModel (line 295) | class CustomModel(keras.Model):
method train_step (line 85) | def train_step(self, data):
method __init__ (line 158) | def __init__(self, *args, **kwargs):
method train_step (line 164) | def train_step(self, data):
method metrics (line 195) | def metrics(self):
method train_step (line 231) | def train_step(self, data):
method test_step (line 296) | def test_step(self, data):
class GAN (line 377) | class GAN(keras.Model):
method __init__ (line 378) | def __init__(self, discriminator, generator, latent_dim):
method metrics (line 389) | def metrics(self):
method compile (line 392) | def compile(self, d_optimizer, g_optimizer, loss_fn):
method train_step (line 398) | def train_step(self, real_images):
FILE: guides/distributed_training_with_jax.py
function get_model (line 58) | def get_model():
function get_datasets (line 92) | def get_datasets():
function compute_loss (line 176) | def compute_loss(trainable_variables, non_trainable_variables, x, y):
function train_step (line 190) | def train_step(train_state, x, y):
function get_replicated_train_state (line 212) | def get_replicated_train_state(devices):
FILE: guides/distributed_training_with_tensorflow.py
function get_compiled_model (line 116) | def get_compiled_model():
function get_dataset (line 131) | def get_dataset():
function make_or_restore_model (line 193) | def make_or_restore_model():
function run_training (line 208) | def run_training(epochs=1):
FILE: guides/distributed_training_with_torch.py
function get_model (line 53) | def get_model():
function get_dataset (line 87) | def get_dataset():
function train_model (line 112) | def train_model(model, dataloader, num_epochs, optimizer, loss_fn):
function setup_device (line 193) | def setup_device(current_gpu_index, num_gpus):
function cleanup (line 207) | def cleanup():
function prepare_dataloader (line 211) | def prepare_dataloader(dataset, current_gpu_index, num_gpus, batch_size):
function per_device_launch_fn (line 227) | def per_device_launch_fn(current_gpu_index, num_gpu):
FILE: guides/functional_api.py
function get_model (line 290) | def get_model():
class CustomDense (line 596) | class CustomDense(layers.Layer):
method __init__ (line 597) | def __init__(self, units=32):
method build (line 601) | def build(self, input_shape):
method call (line 611) | def call(self, inputs):
method __init__ (line 627) | def __init__(self, units=32):
method build (line 631) | def build(self, input_shape):
method call (line 641) | def call(self, inputs):
method get_config (line 644) | def get_config(self):
class CustomDense (line 626) | class CustomDense(layers.Layer):
method __init__ (line 597) | def __init__(self, units=32):
method build (line 601) | def build(self, input_shape):
method call (line 611) | def call(self, inputs):
method __init__ (line 627) | def __init__(self, units=32):
method build (line 631) | def build(self, input_shape):
method call (line 641) | def call(self, inputs):
method get_config (line 644) | def get_config(self):
class CustomRNN (line 796) | class CustomRNN(layers.Layer):
method __init__ (line 797) | def __init__(self):
method call (line 805) | def call(self, inputs):
method __init__ (line 851) | def __init__(self):
method call (line 858) | def call(self, inputs):
class CustomRNN (line 850) | class CustomRNN(layers.Layer):
method __init__ (line 797) | def __init__(self):
method call (line 805) | def call(self, inputs):
method __init__ (line 851) | def __init__(self):
method call (line 858) | def call(self, inputs):
FILE: guides/making_new_layers_and_models_via_subclassing.py
class Linear (line 49) | class Linear(keras.layers.Layer):
method __init__ (line 50) | def __init__(self, units=32, input_dim=32):
method call (line 61) | def call(self, inputs):
method __init__ (line 131) | def __init__(self, units=32, input_dim=32):
method call (line 142) | def call(self, inputs):
method __init__ (line 157) | def __init__(self, units=32):
method build (line 161) | def build(self, input_shape):
method call (line 171) | def call(self, inputs):
method __init__ (line 385) | def __init__(self, units=32):
method build (line 389) | def build(self, input_shape):
method call (line 399) | def call(self, inputs):
method get_config (line 402) | def get_config(self):
method __init__ (line 421) | def __init__(self, units=32, **kwargs):
method build (line 425) | def build(self, input_shape):
method call (line 435) | def call(self, inputs):
method get_config (line 438) | def get_config(self):
class ComputeSum (line 93) | class ComputeSum(keras.layers.Layer):
method __init__ (line 94) | def __init__(self, input_dim):
method call (line 100) | def call(self, inputs):
class Linear (line 130) | class Linear(keras.layers.Layer):
method __init__ (line 50) | def __init__(self, units=32, input_dim=32):
method call (line 61) | def call(self, inputs):
method __init__ (line 131) | def __init__(self, units=32, input_dim=32):
method call (line 142) | def call(self, inputs):
method __init__ (line 157) | def __init__(self, units=32):
method build (line 161) | def build(self, input_shape):
method call (line 171) | def call(self, inputs):
method __init__ (line 385) | def __init__(self, units=32):
method build (line 389) | def build(self, input_shape):
method call (line 399) | def call(self, inputs):
method get_config (line 402) | def get_config(self):
method __init__ (line 421) | def __init__(self, units=32, **kwargs):
method build (line 425) | def build(self, input_shape):
method call (line 435) | def call(self, inputs):
method get_config (line 438) | def get_config(self):
class Linear (line 156) | class Linear(keras.layers.Layer):
method __init__ (line 50) | def __init__(self, units=32, input_dim=32):
method call (line 61) | def call(self, inputs):
method __init__ (line 131) | def __init__(self, units=32, input_dim=32):
method call (line 142) | def call(self, inputs):
method __init__ (line 157) | def __init__(self, units=32):
method build (line 161) | def build(self, input_shape):
method call (line 171) | def call(self, inputs):
method __init__ (line 385) | def __init__(self, units=32):
method build (line 389) | def build(self, input_shape):
method call (line 399) | def call(self, inputs):
method get_config (line 402) | def get_config(self):
method __init__ (line 421) | def __init__(self, units=32, **kwargs):
method build (line 425) | def build(self, input_shape):
method call (line 435) | def call(self, inputs):
method get_config (line 438) | def get_config(self):
class MLPBlock (line 202) | class MLPBlock(keras.layers.Layer):
method __init__ (line 203) | def __init__(self):
method call (line 209) | def call(self, inputs):
class ActivityRegularizationLayer (line 291) | class ActivityRegularizationLayer(keras.layers.Layer):
method __init__ (line 292) | def __init__(self, rate=1e-2):
method call (line 296) | def call(self, inputs):
class OuterLayer (line 309) | class OuterLayer(keras.layers.Layer):
method __init__ (line 310) | def __init__(self):
method call (line 314) | def call(self, inputs):
class OuterLayerWithKernelRegularizer (line 336) | class OuterLayerWithKernelRegularizer(keras.layers.Layer):
method __init__ (line 337) | def __init__(self):
method call (line 343) | def call(self, inputs):
class Linear (line 384) | class Linear(keras.layers.Layer):
method __init__ (line 50) | def __init__(self, units=32, input_dim=32):
method call (line 61) | def call(self, inputs):
method __init__ (line 131) | def __init__(self, units=32, input_dim=32):
method call (line 142) | def call(self, inputs):
method __init__ (line 157) | def __init__(self, units=32):
method build (line 161) | def build(self, input_shape):
method call (line 171) | def call(self, inputs):
method __init__ (line 385) | def __init__(self, units=32):
method build (line 389) | def build(self, input_shape):
method call (line 399) | def call(self, inputs):
method get_config (line 402) | def get_config(self):
method __init__ (line 421) | def __init__(self, units=32, **kwargs):
method build (line 425) | def build(self, input_shape):
method call (line 435) | def call(self, inputs):
method get_config (line 438) | def get_config(self):
class Linear (line 420) | class Linear(keras.layers.Layer):
method __init__ (line 50) | def __init__(self, units=32, input_dim=32):
method call (line 61) | def call(self, inputs):
method __init__ (line 131) | def __init__(self, units=32, input_dim=32):
method call (line 142) | def call(self, inputs):
method __init__ (line 157) | def __init__(self, units=32):
method build (line 161) | def build(self, input_shape):
method call (line 171) | def call(self, inputs):
method __init__ (line 385) | def __init__(self, units=32):
method build (line 389) | def build(self, input_shape):
method call (line 399) | def call(self, inputs):
method get_config (line 402) | def get_config(self):
method __init__ (line 421) | def __init__(self, units=32, **kwargs):
method build (line 425) | def build(self, input_shape):
method call (line 435) | def call(self, inputs):
method get_config (line 438) | def get_config(self):
class CustomDropout (line 477) | class CustomDropout(keras.layers.Layer):
method __init__ (line 478) | def __init__(self, rate, **kwargs):
method call (line 482) | def call(self, inputs, training=None):
class Sampling (line 591) | class Sampling(layers.Layer):
method call (line 594) | def call(self, inputs):
class Encoder (line 602) | class Encoder(layers.Layer):
method __init__ (line 605) | def __init__(
method call (line 614) | def call(self, inputs):
class Decoder (line 622) | class Decoder(layers.Layer):
method __init__ (line 625) | def __init__(
method call (line 632) | def call(self, inputs):
class VariationalAutoEncoder (line 637) | class VariationalAutoEncoder(keras.Model):
method __init__ (line 640) | def __init__(
method call (line 655) | def call(self, inputs):
FILE: guides/training_with_built_in_methods.py
function get_uncompiled_model (line 193) | def get_uncompiled_model():
function get_compiled_model (line 202) | def get_compiled_model():
function custom_mean_squared_error (line 251) | def custom_mean_squared_error(y_true, y_pred):
class CustomMSE (line 280) | class CustomMSE(keras.losses.Loss):
method __init__ (line 281) | def __init__(self, regularization_factor=0.1, name="custom_mse"):
method call (line 285) | def call(self, y_true, y_pred):
class CategoricalTruePositives (line 320) | class CategoricalTruePositives(keras.metrics.Metric):
method __init__ (line 321) | def __init__(self, name="categorical_true_positives", **kwargs):
method update_state (line 327) | def update_state(self, y_true, y_pred, sample_weight=None):
method result (line 336) | def result(self):
method reset_state (line 339) | def reset_state(self):
class ActivityRegularizationLayer (line 368) | class ActivityRegularizationLayer(layers.Layer):
method call (line 369) | def call(self, inputs):
class LogisticEndpoint (line 402) | class LogisticEndpoint(keras.layers.Layer):
method __init__ (line 403) | def __init__(self, name=None):
method call (line 407) | def call(self, targets, logits, sample_weights=None):
class ExamplePyDataset (line 601) | class ExamplePyDataset(keras.utils.PyDataset):
method __init__ (line 602) | def __init__(self, x, y, batch_size, **kwargs):
method __len__ (line 608) | def __len__(self):
method __getitem__ (line 611) | def __getitem__(self, idx):
class ExampleTorchDataset (line 686) | class ExampleTorchDataset(torch.utils.data.Dataset):
method __init__ (line 687) | def __init__(self, x, y):
method __len__ (line 691) | def __len__(self):
method __getitem__ (line 694) | def __getitem__(self, idx):
class LossHistory (line 1079) | class LossHistory(keras.callbacks.Callback):
method on_train_begin (line 1080) | def on_train_begin(self, logs):
method on_batch_end (line 1083) | def on_batch_end(self, batch, logs):
function make_or_restore_model (line 1132) | def make_or_restore_model():
FILE: guides/transfer_learning.py
function data_augmentation (line 425) | def data_augmentation(x):
FILE: guides/understanding_masking_and_padding.py
class MyLayer (line 173) | class MyLayer(layers.Layer):
method __init__ (line 174) | def __init__(self, **kwargs):
method call (line 181) | def call(self, inputs):
class TemporalSplit (line 218) | class TemporalSplit(keras.layers.Layer):
method call (line 221) | def call(self, inputs):
method compute_mask (line 226) | def compute_mask(self, inputs, mask=None):
class CustomEmbedding (line 243) | class CustomEmbedding(keras.layers.Layer):
method __init__ (line 244) | def __init__(self, input_dim, output_dim, mask_zero=False, **kwargs):
method build (line 250) | def build(self, input_shape):
method call (line 257) | def call(self, inputs):
method compute_mask (line 261) | def compute_mask(self, inputs, mask=None):
class MyActivation (line 300) | class MyActivation(keras.layers.Layer):
method __init__ (line 301) | def __init__(self, **kwargs):
method call (line 306) | def call(self, inputs):
class TemporalSoftmax (line 340) | class TemporalSoftmax(keras.layers.Layer):
method __init__ (line 341) | def __init__(self, **kwargs):
method call (line 345) | def call(self, inputs, mask=None):
FILE: guides/writing_a_custom_training_loop_in_jax.py
function get_model (line 61) | def get_model():
function compute_loss_and_updates (line 178) | def compute_loss_and_updates(
function train_step (line 214) | def train_step(state, data):
function train_step (line 248) | def train_step(state, data):
function compute_loss_and_updates (line 349) | def compute_loss_and_updates(
function train_step (line 366) | def train_step(state, data):
function eval_step (line 395) | def eval_step(state, data):
class ActivityRegularizationLayer (line 484) | class ActivityRegularizationLayer(keras.layers.Layer):
method call (line 485) | def call(self, inputs):
function compute_loss_and_updates (line 511) | def compute_loss_and_updates(
FILE: guides/writing_a_custom_training_loop_in_tensorflow.py
function get_model (line 48) | def get_model():
function train_step (line 235) | def train_step(x, y):
function test_step (line 251) | def test_step(x, y):
class ActivityRegularizationLayer (line 312) | class ActivityRegularizationLayer(keras.layers.Layer):
method call (line 313) | def call(self, inputs):
function train_step (line 337) | def train_step(x, y):
function train_step (line 449) | def train_step(real_images):
FILE: guides/writing_a_custom_training_loop_in_torch.py
function get_model (line 61) | def get_model():
class ActivityRegularizationLayer (line 301) | class ActivityRegularizationLayer(keras.layers.Layer):
method call (line 302) | def call(self, inputs):
FILE: guides/writing_your_own_callbacks.py
function get_model (line 90) | def get_model():
class CustomCallback (line 127) | class CustomCallback(keras.callbacks.Callback):
method on_train_begin (line 128) | def on_train_begin(self, logs=None):
method on_train_end (line 132) | def on_train_end(self, logs=None):
method on_epoch_begin (line 136) | def on_epoch_begin(self, epoch, logs=None):
method on_epoch_end (line 142) | def on_epoch_end(self, epoch, logs=None):
method on_test_begin (line 146) | def on_test_begin(self, logs=None):
method on_test_end (line 150) | def on_test_end(self, logs=None):
method on_predict_begin (line 154) | def on_predict_begin(self, logs=None):
method on_predict_end (line 158) | def on_predict_end(self, logs=None):
method on_train_batch_begin (line 162) | def on_train_batch_begin(self, batch, logs=None):
method on_train_batch_end (line 170) | def on_train_batch_end(self, batch, logs=None):
method on_test_batch_begin (line 176) | def on_test_batch_begin(self, batch, logs=None):
method on_test_batch_end (line 184) | def on_test_batch_end(self, batch, logs=None):
method on_predict_batch_begin (line 192) | def on_predict_batch_begin(self, batch, logs=None):
method on_predict_batch_end (line 200) | def on_predict_batch_end(self, batch, logs=None):
class LossAndErrorPrintingCallback (line 238) | class LossAndErrorPrintingCallback(keras.callbacks.Callback):
method on_train_batch_end (line 239) | def on_train_batch_end(self, batch, logs=None):
method on_test_batch_end (line 246) | def on_test_batch_end(self, batch, logs=None):
method on_epoch_end (line 253) | def on_epoch_end(self, epoch, logs=None):
class EarlyStoppingAtMinLoss (line 316) | class EarlyStoppingAtMinLoss(keras.callbacks.Callback):
method __init__ (line 324) | def __init__(self, patience=0):
method on_train_begin (line 330) | def on_train_begin(self, logs=None):
method on_epoch_end (line 338) | def on_epoch_end(self, epoch, logs=None):
method on_train_end (line 353) | def on_train_end(self, logs=None):
class CustomLearningRateScheduler (line 378) | class CustomLearningRateScheduler(keras.callbacks.Callback):
method __init__ (line 387) | def __init__(self, schedule):
method on_epoch_begin (line 391) | def on_epoch_begin(self, epoch, logs=None):
function lr_schedule (line 414) | def lr_schedule(epoch, lr):
FILE: integration_tests/basic_full_flow.py
class MyModel (line 12) | class MyModel(keras.Model):
method __init__ (line 13) | def __init__(self, hidden_dim, output_dim, **kwargs):
method call (line 21) | def call(self, x):
class BasicFlowTest (line 27) | class BasicFlowTest(testing.TestCase):
method test_basic_fit (line 29) | def test_basic_fit(self):
method test_basic_fit_no_training (line 50) | def test_basic_fit_no_training(self):
FILE: integration_tests/dataset_tests/boston_housing_test.py
class BostonHousingTest (line 5) | class BostonHousingTest(testing.TestCase):
method test_load_data (line 6) | def test_load_data(self):
method test_seed_reproducibility (line 11) | def test_seed_reproducibility(self):
method test_invalid_test_split (line 18) | def test_invalid_test_split(self):
FILE: integration_tests/dataset_tests/california_housing_test.py
class CaliforniaHousingTest (line 5) | class CaliforniaHousingTest(testing.TestCase):
method test_load_data_large (line 6) | def test_load_data_large(self):
method test_load_data_small (line 14) | def test_load_data_small(self):
method test_invalid_version (line 22) | def test_invalid_version(self):
method test_seed_reproducibility (line 26) | def test_seed_reproducibility(self):
FILE: integration_tests/dataset_tests/cifar100_test.py
class Cifar100LoadDataTest (line 7) | class Cifar100LoadDataTest(testing.TestCase):
method test_shapes_fine_label_mode (line 8) | def test_shapes_fine_label_mode(self):
method test_shapes_coarse_label_mode (line 17) | def test_shapes_coarse_label_mode(self):
method test_dtypes (line 26) | def test_dtypes(self):
method test_invalid_label_mode (line 33) | def test_invalid_label_mode(self):
FILE: integration_tests/dataset_tests/cifar10_test.py
class Cifar10LoadDataTest (line 7) | class Cifar10LoadDataTest(testing.TestCase):
method test_x_train_shape (line 8) | def test_x_train_shape(self):
method test_y_train_shape (line 12) | def test_y_train_shape(self):
method test_x_test_shape (line 16) | def test_x_test_shape(self):
method test_y_test_shape (line 20) | def test_y_test_shape(self):
method test_x_train_dtype (line 24) | def test_x_train_dtype(self):
method test_y_train_dtype (line 28) | def test_y_train_dtype(self):
method test_x_test_dtype (line 32) | def test_x_test_dtype(self):
method test_y_test_dtype (line 36) | def test_y_test_dtype(self):
FILE: integration_tests/dataset_tests/fashion_mnist_test.py
class FashionMnistLoadDataTest (line 7) | class FashionMnistLoadDataTest(testing.TestCase):
method test_x_train_shape (line 8) | def test_x_train_shape(self):
method test_y_train_shape (line 12) | def test_y_train_shape(self):
method test_x_test_shape (line 16) | def test_x_test_shape(self):
method test_y_test_shape (line 20) | def test_y_test_shape(self):
method test_x_train_dtype (line 24) | def test_x_train_dtype(self):
method test_y_train_dtype (line 28) | def test_y_train_dtype(self):
method test_x_test_dtype (line 32) | def test_x_test_dtype(self):
method test_y_test_dtype (line 36) | def test_y_test_dtype(self):
FILE: integration_tests/dataset_tests/imdb_test.py
class ImdbLoadDataTest (line 7) | class ImdbLoadDataTest(testing.TestCase):
method test_load_data_default (line 8) | def test_load_data_default(self):
method test_num_words (line 25) | def test_num_words(self):
method test_skip_top (line 32) | def test_skip_top(self):
method test_maxlen (line 38) | def test_maxlen(self):
method test_get_word_index (line 43) | def test_get_word_index(self):
FILE: integration_tests/dataset_tests/mnist_test.py
class MnistLoadDataTest (line 7) | class MnistLoadDataTest(testing.TestCase):
method test_x_train_shape (line 8) | def test_x_train_shape(self):
method test_y_train_shape (line 12) | def test_y_train_shape(self):
method test_x_test_shape (line 16) | def test_x_test_shape(self):
method test_y_test_shape (line 20) | def test_y_test_shape(self):
method test_x_train_dtype (line 24) | def test_x_train_dtype(self):
method test_y_train_dtype (line 28) | def test_y_train_dtype(self):
method test_x_test_dtype (line 32) | def test_x_test_dtype(self):
method test_y_test_dtype (line 36) | def test_y_test_dtype(self):
FILE: integration_tests/dataset_tests/reuters_test.py
class ReutersLoadDataTest (line 7) | class ReutersLoadDataTest(testing.TestCase):
method test_load_data_default (line 8) | def test_load_data_default(self):
method test_num_words (line 22) | def test_num_words(self):
method test_skip_top (line 29) | def test_skip_top(self):
method test_maxlen (line 35) | def test_maxlen(self):
method test_get_word_index (line 40) | def test_get_word_index(self):
method test_get_label_names (line 46) | def test_get_label_names(self):
FILE: integration_tests/import_test.py
function setup_package (line 20) | def setup_package():
function create_virtualenv (line 40) | def create_virtualenv():
function manage_venv_installs (line 61) | def manage_venv_installs(whl_path):
function run_keras_flow (line 84) | def run_keras_flow():
function cleanup (line 92) | def cleanup():
function run_commands_local (line 104) | def run_commands_local(commands):
function run_commands_venv (line 110) | def run_commands_venv(commands):
function test_keras_imports (line 123) | def test_keras_imports():
FILE: integration_tests/jax_custom_fit_test.py
function test_custom_fit (line 7) | def test_custom_fit():
FILE: integration_tests/model_visualization_test.py
class SubclassModel (line 9) | class SubclassModel(keras.models.Model):
method __init__ (line 10) | def __init__(self, name):
method call (line 13) | def call(self, x):
function parse_text_from_html (line 17) | def parse_text_from_html(html):
function get_node_text (line 27) | def get_node_text(node):
function get_edge_dict (line 37) | def get_edge_dict(dot):
class ModelVisualizationTest (line 85) | class ModelVisualizationTest(testing.TestCase):
method multi_plot_model (line 86) | def multi_plot_model(self, model, name, expand_nested=False):
method test_plot_sequential_model (line 139) | def test_plot_sequential_model(self):
method test_plot_functional_model (line 157) | def test_plot_functional_model(self):
method test_plot_subclassed_model (line 196) | def test_plot_subclassed_model(self):
method test_plot_nested_functional_model (line 202) | def test_plot_nested_functional_model(self):
method test_plot_functional_model_with_splits_and_merges (line 265) | def test_plot_functional_model_with_splits_and_merges(self):
method test_plot_sequential_in_sequential (line 298) | def test_plot_sequential_in_sequential(self):
method test_plot_functional_in_functional (line 362) | def test_plot_functional_in_functional(self):
method test_plot_sequential_in_sequential_in_sequential (line 425) | def test_plot_sequential_in_sequential_in_sequential(self):
method test_plot_functional_in_sequential_in_sequential (line 510) | def test_plot_functional_in_sequential_in_sequential(self):
method test_plot_functional_in_functional_in_functional (line 592) | def test_plot_functional_in_functional_in_functional(self):
method test_plot_complex (line 673) | def test_plot_complex(self):
FILE: integration_tests/numerical_test.py
function build_mnist_data (line 16) | def build_mnist_data(num_classes):
function build_keras_model (line 33) | def build_keras_model(keras_module, num_classes):
function compile_model (line 56) | def compile_model(model):
function train_model (line 66) | def train_model(model, x, y):
function eval_model (line 77) | def eval_model(model, x, y):
function check_history (line 83) | def check_history(h1, h2):
function predict_model (line 95) | def predict_model(model, x):
function numerical_test (line 99) | def numerical_test():
FILE: integration_tests/pytorch_export_test.py
class TestPyTorchExportWithDynamicShapes (line 31) | class TestPyTorchExportWithDynamicShapes(testing.TestCase):
method test_issue_22102_model_inference (line 39) | def test_issue_22102_model_inference(self, input_shape, expected_shape):
method test_issue_22102_export_methods (line 61) | def test_issue_22102_export_methods(self, export_method):
method test_fixed_layers_export (line 187) | def test_fixed_layers_export(self, layer_type):
FILE: integration_tests/tf_custom_fit_test.py
function test_custom_fit (line 7) | def test_custom_fit():
FILE: integration_tests/tf_distribute_training_test.py
function test_model_fit (line 13) | def test_model_fit():
FILE: integration_tests/torch_custom_fit_test.py
function test_custom_fit (line 7) | def test_custom_fit():
FILE: integration_tests/torch_workflow_test.py
class Net (line 8) | class Net(torch.nn.Module):
method __init__ (line 9) | def __init__(self):
method forward (line 13) | def forward(self, x):
class TorchWorkflowTest (line 18) | class TorchWorkflowTest(testing.TestCase):
method test_keras_layer_in_nn_module (line 19) | def test_keras_layer_in_nn_module(self):
FILE: keras/src/activations/__init__.py
function serialize (line 77) | def serialize(activation):
function deserialize (line 106) | def deserialize(config, custom_objects=None):
function get (line 116) | def get(identifier):
FILE: keras/src/activations/activations.py
function relu (line 7) | def relu(x, negative_slope=0.0, max_value=None, threshold=0.0):
class ReLU (line 55) | class ReLU(ops.Operation):
method __init__ (line 56) | def __init__(
method call (line 64) | def call(self, x):
method compute_output_spec (line 72) | def compute_output_spec(self, x):
method static_call (line 76) | def static_call(x, negative_slope=0.0, max_value=None, threshold=0.0):
function leaky_relu (line 114) | def leaky_relu(x, negative_slope=0.2):
function relu6 (line 126) | def relu6(x):
function softmax (line 138) | def softmax(x, axis=-1):
function elu (line 171) | def elu(x, alpha=1.0):
function selu (line 200) | def selu(x):
function softplus (line 240) | def softplus(x):
function softsign (line 252) | def softsign(x):
function soft_shrink (line 264) | def soft_shrink(x, threshold=0.5):
function sparse_plus (line 282) | def sparse_plus(x):
function silu (line 299) | def silu(x):
function squareplus (line 319) | def squareplus(x, b=4):
function gelu (line 340) | def gelu(x, approximate=False):
function celu (line 363) | def celu(x, alpha=1.0):
function glu (line 384) | def glu(x, axis=-1):
function tanh (line 405) | def tanh(x):
function tanh_shrink (line 419) | def tanh_shrink(x):
function hard_tanh (line 433) | def hard_tanh(x):
function hard_shrink (line 448) | def hard_shrink(x, threshold=0.5):
function threshold (line 465) | def threshold(x, threshold, default_value):
function sigmoid (line 483) | def sigmoid(x):
function exponential (line 510) | def exponential(x):
function hard_sigmoid (line 520) | def hard_sigmoid(x):
function log_sigmoid (line 543) | def log_sigmoid(x):
function sparse_sigmoid (line 556) | def sparse_sigmoid(x):
function hard_silu (line 577) | def hard_silu(x):
function linear (line 600) | def linear(x):
class Mish (line 612) | class Mish(ops.Operation):
method call (line 613) | def call(self, x):
method compute_output_spec (line 616) | def compute_output_spec(self, x):
method static_call (line 620) | def static_call(x):
function mish (line 625) | def mish(x):
function log_softmax (line 648) | def log_softmax(x, axis=-1):
function sparsemax (line 663) | def sparsemax(x, axis=-1):
FILE: keras/src/activations/activations_test.py
function _ref_softmax (line 8) | def _ref_softmax(values):
function _ref_softplus (line 14) | def _ref_softplus(x):
function _ref_log_softmax (line 18) | def _ref_log_softmax(values):
function _ref_leaky_relu (line 25) | def _ref_leaky_relu(x, alpha=0.2):
function _ref_relu6 (line 29) | def _ref_relu6(x):
function _ref_silu (line 33) | def _ref_silu(x):
function _ref_hard_sigmoid (line 37) | def _ref_hard_sigmoid(x):
function _ref_sparse_sigmoid (line 43) | def _ref_sparse_sigmoid(x):
function _ref_log_sigmoid (line 47) | def _ref_log_sigmoid(x):
function _ref_hard_silu (line 51) | def _ref_hard_silu(x):
function _ref_sigmoid (line 55) | def _ref_sigmoid(x):
function _ref_softsign (line 63) | def _ref_softsign(x):
class ActivationsTest (line 67) | class ActivationsTest(testing.TestCase):
method test_softmax (line 68) | def test_softmax(self):
method test_softmax_2d_axis_0 (line 75) | def test_softmax_2d_axis_0(self):
method test_softmax_3d_axis_tuple (line 83) | def test_softmax_3d_axis_tuple(self):
method test_softmax_1d (line 91) | def test_softmax_1d(self):
method test_softmax_higher_dim (line 97) | def test_softmax_higher_dim(self):
method test_softmax_higher_dim_multiple_axes (line 106) | def test_softmax_higher_dim_multiple_axes(self):
method test_softmax_negative_axis (line 115) | def test_softmax_negative_axis(self):
method test_temporal_softmax (line 123) | def test_temporal_softmax(self):
method test_log_softmax_2d_axis_0 (line 129) | def test_log_softmax_2d_axis_0(self):
method test_log_softmax_3d_axis_tuple (line 137) | def test_log_softmax_3d_axis_tuple(self):
method test_log_softmax_1d (line 145) | def test_log_softmax_1d(self):
method test_log_softmax_higher_dim (line 151) | def test_log_softmax_higher_dim(self):
method test_log_softmax_higher_dim_multiple_axes (line 160) | def test_log_softmax_higher_dim_multiple_axes(self):
method test_log_softmax_negative_axis (line 169) | def test_log_softmax_negative_axis(self):
method test_temporal_log_softmax (line 177) | def test_temporal_log_softmax(self):
method test_selu (line 183) | def test_selu(self):
method test_softplus (line 196) | def test_softplus(self):
method test_softsign (line 237) | def test_softsign(self):
method test_sigmoid (line 278) | def test_sigmoid(self):
method test_hard_sigmoid (line 319) | def test_hard_sigmoid(self):
method test_sparse_sigmoid (line 348) | def test_sparse_sigmoid(self):
method test_log_sigmoid (line 387) | def test_log_sigmoid(self):
method test_hard_silu (line 426) | def test_hard_silu(self):
method test_relu_negative_slope (line 459) | def test_relu_negative_slope(self):
method test_relu_max_value (line 470) | def test_relu_max_value(self):
method test_relu_threshold (line 479) | def test_relu_threshold(self):
method test_relu_combined_threshold_and_max_value (line 488) | def test_relu_combined_threshold_and_max_value(self):
method test_relu_combined_all_parameters (line 497) | def test_relu_combined_all_parameters(self):
method test_relu_to_trigger_relu6 (line 508) | def test_relu_to_trigger_relu6(self):
method test_relu_to_trigger_leaky (line 514) | def test_relu_to_trigger_leaky(self):
method test_relu (line 520) | def test_relu(self):
method test_leaky_relu (line 563) | def test_leaky_relu(self):
method test_relu6 (line 600) | def test_relu6(self):
method test_silu (line 621) | def test_silu(self):
method test_gelu (line 642) | def test_gelu(self):
method test_celu (line 671) | def test_celu(self):
method test_glu (line 687) | def test_glu(self):
method test_tanh_shrink (line 702) | def test_tanh_shrink(self):
method test_hard_tanh (line 711) | def test_hard_tanh(self):
method test_hard_shrink (line 720) | def test_hard_shrink(self):
method test_threshold (line 729) | def test_threshold(self):
method test_squareplus (line 740) | def test_squareplus(self):
method test_soft_shrink (line 750) | def test_soft_shrink(self):
method test_sparse_plus (line 763) | def test_sparse_plus(self):
method test_elu (line 776) | def test_elu(self):
method test_tanh (line 785) | def test_tanh(self):
method test_exponential (line 834) | def test_exponential(self):
method test_mish (line 882) | def test_mish(self):
method test_linear (line 930) | def test_linear(self):
method test_sparsemax (line 953) | def test_sparsemax(self):
method test_get_method (line 1002) | def test_get_method(self):
FILE: keras/src/api_export.py
function register_internal_serializable (line 13) | def register_internal_serializable(path, symbol):
function get_symbol_from_name (line 23) | def get_symbol_from_name(name):
function get_name_from_symbol (line 27) | def get_name_from_symbol(symbol):
class keras_export (line 33) | class keras_export(namex.export):
method __init__ (line 34) | def __init__(self, path):
method __call__ (line 37) | def __call__(self, symbol):
method __init__ (line 44) | def __init__(self, path):
method __call__ (line 47) | def __call__(self, symbol):
class keras_export (line 43) | class keras_export:
method __init__ (line 34) | def __init__(self, path):
method __call__ (line 37) | def __call__(self, symbol):
method __init__ (line 44) | def __init__(self, path):
method __call__ (line 47) | def __call__(self, symbol):
FILE: keras/src/applications/applications_test.py
function _get_elephant (line 100) | def _get_elephant(target_size):
class ApplicationsTest (line 123) | class ApplicationsTest(testing.TestCase):
method setUpClass (line 125) | def setUpClass(cls):
method tearDownClass (line 129) | def tearDownClass(cls):
method skip_if_invalid_image_data_format_for_model (line 132) | def skip_if_invalid_image_data_format_for_model(
method test_application_notop_variable_input_channels (line 150) | def test_application_notop_variable_input_channels(
method test_application_base (line 183) | def test_application_base(self, app, _, app_module, image_data_format):
method test_application_notop_custom_input_shape (line 224) | def test_application_notop_custom_input_shape(
method test_application_notop_custom_input_tensor (line 245) | def test_application_notop_custom_input_tensor(
method test_application_pooling (line 271) | def test_application_pooling(self, app, last_dim, _, image_data_format):
method test_application_classifier_activation (line 284) | def test_application_classifier_activation(self, app, *_):
FILE: keras/src/applications/convnext.py
class StochasticDepth (line 141) | class StochasticDepth(Layer):
method __init__ (line 159) | def __init__(self, drop_path_rate, **kwargs):
method call (line 163) | def call(self, x, training=None):
method get_config (line 172) | def get_config(self):
class LayerScale (line 178) | class LayerScale(Layer):
method __init__ (line 194) | def __init__(self, init_values, projection_dim, **kwargs):
method build (line 199) | def build(self, _):
method call (line 206) | def call(self, x):
method get_config (line 209) | def get_config(self):
function ConvNeXtBlock (line 220) | def ConvNeXtBlock(
function PreStem (line 282) | def PreStem(name=None):
function Head (line 302) | def Head(num_classes=1000, classifier_activation=None, name=None):
function ConvNeXt (line 331) | def ConvNeXt(
function ConvNeXtTiny (line 581) | def ConvNeXtTiny(
function ConvNeXtSmall (line 617) | def ConvNeXtSmall(
function ConvNeXtBase (line 653) | def ConvNeXtBase(
function ConvNeXtLarge (line 689) | def ConvNeXtLarge(
function ConvNeXtXLarge (line 725) | def ConvNeXtXLarge(
function preprocess_input (line 763) | def preprocess_input(x, data_format=None):
function decode_predictions (line 785) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/densenet.py
function dense_block (line 35) | def dense_block(x, blocks, name):
function transition_block (line 51) | def transition_block(x, reduction, name):
function conv_block (line 77) | def conv_block(x, growth_rate, name):
function DenseNet (line 107) | def DenseNet(
function DenseNet121 (line 331) | def DenseNet121(
function DenseNet169 (line 361) | def DenseNet169(
function DenseNet201 (line 391) | def DenseNet201(
function preprocess_input (line 416) | def preprocess_input(x, data_format=None):
function decode_predictions (line 423) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/efficientnet.py
function EfficientNet (line 208) | def EfficientNet(
function block (line 444) | def block(
function EfficientNetB0 (line 565) | def EfficientNetB0(
function EfficientNetB1 (line 598) | def EfficientNetB1(
function EfficientNetB2 (line 631) | def EfficientNetB2(
function EfficientNetB3 (line 664) | def EfficientNetB3(
function EfficientNetB4 (line 697) | def EfficientNetB4(
function EfficientNetB5 (line 730) | def EfficientNetB5(
function EfficientNetB6 (line 763) | def EfficientNetB6(
function EfficientNetB7 (line 796) | def EfficientNetB7(
function preprocess_input (line 834) | def preprocess_input(x, data_format=None):
function decode_predictions (line 856) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/efficientnet_v2.py
function round_filters (line 589) | def round_filters(filters, width_coefficient, min_depth, depth_divisor):
function round_repeats (line 600) | def round_repeats(repeats, depth_coefficient):
function MBConvBlock (line 605) | def MBConvBlock(
function FusedMBConvBlock (line 721) | def FusedMBConvBlock(
function EfficientNetV2 (line 823) | def EfficientNetV2(
function EfficientNetV2B0 (line 1098) | def EfficientNetV2B0(
function EfficientNetV2B1 (line 1132) | def EfficientNetV2B1(
function EfficientNetV2B2 (line 1166) | def EfficientNetV2B2(
function EfficientNetV2B3 (line 1200) | def EfficientNetV2B3(
function EfficientNetV2S (line 1234) | def EfficientNetV2S(
function EfficientNetV2M (line 1268) | def EfficientNetV2M(
function EfficientNetV2L (line 1302) | def EfficientNetV2L(
function preprocess_input (line 1340) | def preprocess_input(x, data_format=None):
function decode_predictions (line 1362) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/imagenet_utils.py
function preprocess_input (line 87) | def preprocess_input(x, data_format=None, mode="caffe"):
function decode_predictions (line 117) | def decode_predictions(preds, top=5):
function _preprocess_numpy_input (line 161) | def _preprocess_numpy_input(x, data_format, mode):
function _preprocess_tensor_input (line 234) | def _preprocess_tensor_input(x, data_format, mode):
function obtain_input_shape (line 296) | def obtain_input_shape(
function correct_pad (line 416) | def correct_pad(inputs, kernel_size):
function validate_activation (line 441) | def validate_activation(classifier_activation, weights):
FILE: keras/src/applications/imagenet_utils_test.py
class TestImageNetUtils (line 12) | class TestImageNetUtils(testing.TestCase):
method test_preprocess_input (line 13) | def test_preprocess_input(self):
method test_preprocess_input_symbolic (line 80) | def test_preprocess_input_symbolic(self, mode):
method test_preprocess_input_symbolic_mixed_precision (line 150) | def test_preprocess_input_symbolic_mixed_precision(self, mode):
method test_obtain_input_shape (line 174) | def test_obtain_input_shape(self, data_format):
FILE: keras/src/applications/inception_resnet_v2.py
function InceptionResNetV2 (line 22) | def InceptionResNetV2(
function conv2d_bn (line 248) | def conv2d_bn(
class CustomScaleLayer (line 294) | class CustomScaleLayer(Layer):
method __init__ (line 295) | def __init__(self, scale, **kwargs):
method get_config (line 299) | def get_config(self):
method call (line 304) | def call(self, inputs):
function inception_resnet_block (line 308) | def inception_resnet_block(x, scale, block_type, block_idx, activation="...
function preprocess_input (line 380) | def preprocess_input(x, data_format=None):
function decode_predictions (line 387) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/inception_v3.py
function InceptionV3 (line 25) | def InceptionV3(
function conv2d_bn (line 383) | def conv2d_bn(
function preprocess_input (line 426) | def preprocess_input(x, data_format=None):
function decode_predictions (line 433) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/mobilenet.py
function MobileNet (line 22) | def MobileNet(
function _conv_block (line 276) | def _conv_block(inputs, filters, alpha, kernel=(3, 3), strides=(1, 1)):
function _depthwise_conv_block (line 332) | def _depthwise_conv_block(
function preprocess_input (line 419) | def preprocess_input(x, data_format=None):
function decode_predictions (line 426) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/mobilenet_v2.py
function MobileNetV2 (line 22) | def MobileNetV2(
function _inverted_res_block (line 395) | def _inverted_res_block(inputs, expansion, stride, alpha, filters, block...
function _make_divisible (line 470) | def _make_divisible(v, divisor, min_value=None):
function preprocess_input (line 481) | def preprocess_input(x, data_format=None):
function decode_predictions (line 488) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/mobilenet_v3.py
function MobileNetV3 (line 153) | def MobileNetV3(
function MobileNetV3Small (line 407) | def MobileNetV3Small(
function MobileNetV3Large (line 474) | def MobileNetV3Large(
function relu (line 542) | def relu(x):
function hard_sigmoid (line 546) | def hard_sigmoid(x):
function hard_swish (line 550) | def hard_swish(x):
function _depth (line 561) | def _depth(v, divisor=8, min_value=None):
function _se_block (line 571) | def _se_block(inputs, filters, se_ratio, prefix):
function _inverted_res_block (line 593) | def _inverted_res_block(
function preprocess_input (line 661) | def preprocess_input(x, data_format=None):
function decode_predictions (line 684) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/nasnet.py
function NASNet (line 20) | def NASNet(
function NASNetMobile (line 320) | def NASNetMobile(
function NASNetLarge (line 416) | def NASNetLarge(
function _separable_conv_block (line 498) | def _separable_conv_block(
function _adjust_block (line 556) | def _adjust_block(p, ip, filters, block_id=None):
function _normal_a_cell (line 643) | def _normal_a_cell(ip, p, filters, block_id=None):
function _reduction_a_cell (line 736) | def _reduction_a_cell(ip, p, filters, block_id=None):
function preprocess_input (line 853) | def preprocess_input(x, data_format=None):
function decode_predictions (line 860) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/resnet.py
function ResNet (line 48) | def ResNet(
function residual_block_v1 (line 219) | def residual_block_v1(
function stack_residual_blocks_v1 (line 276) | def stack_residual_blocks_v1(x, filters, blocks, stride1=2, name=None):
function residual_block_v2 (line 298) | def residual_block_v2(
function stack_residual_blocks_v2 (line 361) | def stack_residual_blocks_v2(x, filters, blocks, stride1=2, name=None):
function ResNet50 (line 391) | def ResNet50(
function ResNet101 (line 431) | def ResNet101(
function ResNet152 (line 471) | def ResNet152(
function preprocess_input (line 511) | def preprocess_input(x, data_format=None):
function decode_predictions (line 523) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/resnet_v2.py
function ResNet50V2 (line 12) | def ResNet50V2(
function ResNet101V2 (line 54) | def ResNet101V2(
function ResNet152V2 (line 96) | def ResNet152V2(
function preprocess_input (line 133) | def preprocess_input(x, data_format=None):
function decode_predictions (line 140) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/vgg16.py
function VGG16 (line 21) | def VGG16(
function preprocess_input (line 232) | def preprocess_input(x, data_format=None):
function decode_predictions (line 239) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/vgg19.py
function VGG19 (line 21) | def VGG19(
function preprocess_input (line 240) | def preprocess_input(x, data_format=None):
function decode_predictions (line 247) | def decode_predictions(preds, top=5):
FILE: keras/src/applications/xception.py
function Xception (line 25) | def Xception(
function preprocess_input (line 339) | def preprocess_input(x, data_format=None):
function decode_predictions (line 346) | def decode_predictions(preds, top=5):
FILE: keras/src/backend/__init__.py
class Variable (line 63) | class Variable(BackendVariable): # noqa: F811
class name_scope (line 71) | class name_scope(backend_name_scope):
function device (line 76) | def device(device_name):
FILE: keras/src/backend/common/backend_utils.py
function _convert_conv_transpose_padding_args_from_keras_to_jax (line 8) | def _convert_conv_transpose_padding_args_from_keras_to_jax(
function _convert_conv_transpose_padding_args_from_keras_to_torch (line 53) | def _convert_conv_transpose_padding_args_from_keras_to_torch(
function compute_conv_transpose_padding_args_for_jax (line 119) | def compute_conv_transpose_padding_args_for_jax(
function compute_conv_transpose_padding_args_for_torch (line 158) | def compute_conv_transpose_padding_args_for_torch(
function _get_output_shape_given_tf_padding (line 215) | def _get_output_shape_given_tf_padding(
function compute_conv_transpose_output_shape (line 244) | def compute_conv_transpose_output_shape(
function canonicalize_axis (line 292) | def canonicalize_axis(axis, num_dims):
function standardize_axis_for_numpy (line 305) | def standardize_axis_for_numpy(axis):
function to_tuple_or_list (line 310) | def to_tuple_or_list(value):
function _vectorize_parse_gufunc_signature (line 334) | def _vectorize_parse_gufunc_signature(
function _vectorize_update_dim_sizes (line 349) | def _vectorize_update_dim_sizes(dim_sizes, shape, core_dims, is_input=Tr...
function _vectorize_parse_input_dimensions (line 376) | def _vectorize_parse_input_dimensions(
function _vectorize_check_output_dims (line 401) | def _vectorize_check_output_dims(
function _vectorize_apply_excluded (line 439) | def _vectorize_apply_excluded(func, excluded, args, kwargs):
function vectorize_impl (line 463) | def vectorize_impl(pyfunc, vmap_fn, *, excluded=None, signature=None):
function slice_along_axis (line 545) | def slice_along_axis(x, start=0, stop=None, step=1, axis=0):
function compute_adaptive_pooling_window_sizes (line 557) | def compute_adaptive_pooling_window_sizes(input_dim, output_dim):
FILE: keras/src/backend/common/backend_utils_test.py
class ConvertConvTransposePaddingArgsJAXTest (line 19) | class ConvertConvTransposePaddingArgsJAXTest(test_case.TestCase):
method test_valid_padding_without_output_padding (line 20) | def test_valid_padding_without_output_padding(self):
method test_same_padding_without_output_padding (line 35) | def test_same_padding_without_output_padding(self):
class ConvertConvTransposePaddingArgsTorchTest (line 51) | class ConvertConvTransposePaddingArgsTorchTest(test_case.TestCase):
method test_valid_padding_without_output_padding (line 52) | def test_valid_padding_without_output_padding(self):
method test_same_padding_without_output_padding (line 67) | def test_same_padding_without_output_padding(self):
class ComputeConvTransposePaddingArgsForJAXTest (line 83) | class ComputeConvTransposePaddingArgsForJAXTest(test_case.TestCase):
method test_valid_padding_without_output_padding (line 84) | def test_valid_padding_without_output_padding(self):
method test_same_padding_without_output_padding (line 96) | def test_same_padding_without_output_padding(self):
class ComputeConvTransposePaddingArgsForTorchTest (line 110) | class ComputeConvTransposePaddingArgsForTorchTest(test_case.TestCase):
method test_valid_padding_without_output_padding (line 111) | def test_valid_padding_without_output_padding(self):
method test_same_padding_without_output_padding (line 127) | def test_same_padding_without_output_padding(self):
method test_valid_padding_with_none_output_padding (line 143) | def test_valid_padding_with_none_output_padding(self):
method test_valid_padding_with_output_padding (line 158) | def test_valid_padding_with_output_padding(self):
method test_output_padding_clamped_for_torch_constraint (line 173) | def test_output_padding_clamped_for_torch_constraint(self):
class GetOutputShapeGivenTFPaddingTest (line 193) | class GetOutputShapeGivenTFPaddingTest(test_case.TestCase):
method test_valid_padding_without_output_padding (line 194) | def test_valid_padding_without_output_padding(self):
method test_same_padding_without_output_padding (line 206) | def test_same_padding_without_output_padding(self):
method test_valid_padding_with_output_padding (line 218) | def test_valid_padding_with_output_padding(self):
method test_warning_for_inconsistencies (line 230) | def test_warning_for_inconsistencies(self):
method test_same_padding_without_output_padding_for_torch_ (line 241) | def test_same_padding_without_output_padding_for_torch_(self):
FILE: keras/src/backend/common/compute_output_spec_test.py
function example_fn (line 7) | def example_fn(x):
class ComputeOutputSpecTest (line 13) | class ComputeOutputSpecTest(testing.TestCase):
method test_basics (line 14) | def test_basics(self):
method test_torch_meta_device_incompatible_ops (line 36) | def test_torch_meta_device_incompatible_ops(self):
FILE: keras/src/backend/common/dtypes.py
function _type_promotion_lattice (line 61) | def _type_promotion_lattice():
function _make_lattice_upper_bounds (line 93) | def _make_lattice_upper_bounds():
function _least_upper_bound (line 115) | def _least_upper_bound(*nodes):
function _dtype_and_weaktype (line 175) | def _dtype_and_weaktype(value):
function _respect_weak_type (line 187) | def _respect_weak_type(dtype, weak_type):
function _resolve_weak_type (line 207) | def _resolve_weak_type(dtype, precision="32"):
function _lattice_result_type (line 245) | def _lattice_result_type(*args):
function result_type (line 281) | def result_type(*dtypes):
FILE: keras/src/backend/common/dtypes_test.py
class DtypesTest (line 13) | class DtypesTest(test_case.TestCase):
method test_result_type_with_python_scalar_types (line 44) | def test_result_type_with_python_scalar_types(self, dtype1, dtype2):
method test_result_type_with_tensor (line 54) | def test_result_type_with_tensor(self, dtype1, dtype2):
method test_result_type_with_int64 (line 82) | def test_result_type_with_int64(self, dtype):
method test_result_type_with_float64 (line 108) | def test_result_type_with_float64(self, dtype):
method test_result_type_with_none (line 116) | def test_result_type_with_none(self):
method test_result_type_empty_list (line 121) | def test_result_type_empty_list(self):
method test_respect_weak_type_for_bool (line 124) | def test_respect_weak_type_for_bool(self):
method test_respect_weak_type_for_int (line 127) | def test_respect_weak_type_for_int(self):
method test_respect_weak_type_for_float (line 130) | def test_respect_weak_type_for_float(self):
method test_resolve_weak_type_for_bfloat16 (line 133) | def test_resolve_weak_type_for_bfloat16(self):
method test_resolve_weak_type_for_bfloat16_with_precision (line 136) | def test_resolve_weak_type_for_bfloat16_with_precision(self):
method test_respect_weak_type_for_complex64 (line 141) | def test_respect_weak_type_for_complex64(self):
method test_respect_weak_type_for_complex128 (line 146) | def test_respect_weak_type_for_complex128(self):
method test_invalid_dtype_for_keras_promotion (line 151) | def test_invalid_dtype_for_keras_promotion(self):
method test_resolve_weak_type_for_invalid_dtype (line 157) | def test_resolve_weak_type_for_invalid_dtype(self):
method test_resolve_weak_type_for_invalid_precision (line 163) | def test_resolve_weak_type_for_invalid_precision(self):
method test_cycle_detection_in_make_lattice_upper_bounds (line 170) | def test_cycle_detection_in_make_lattice_upper_bounds(self):
method test_respect_weak_type_for_invalid_dtype (line 188) | def test_respect_weak_type_for_invalid_dtype(self):
method test_invalid_dtype_in_least_upper_bound (line 194) | def test_invalid_dtype_in_least_upper_bound(self):
method test_empty_lub_in_least_upper_bound (line 201) | def test_empty_lub_in_least_upper_bound(self):
method test_valid_dtype_leading_to_single_lub_element (line 214) | def test_valid_dtype_leading_to_single_lub_element(self):
method test_valid_dtype_leading_to_keyerror_and_valueerror (line 219) | def test_valid_dtype_leading_to_keyerror_and_valueerror(self):
method test_resolve_weak_type_bool (line 226) | def test_resolve_weak_type_bool(self):
method test_resolve_weak_type_int (line 229) | def test_resolve_weak_type_int(self):
method test_resolve_weak_type_uint (line 237) | def test_resolve_weak_type_uint(self):
method test_resolve_weak_type_float (line 245) | def test_resolve_weak_type_float(self):
method test_least_upper_bound_ensure_order_independence (line 253) | def test_least_upper_bound_ensure_order_independence(self):
method test_least_upper_bound_single_element (line 259) | def test_least_upper_bound_single_element(self):
method test_least_upper_bound_no_element (line 263) | def test_least_upper_bound_no_element(self):
method test_least_upper_bound_with_no_common_upper_bound (line 270) | def test_least_upper_bound_with_no_common_upper_bound(self):
method test_invalid_float8_dtype (line 281) | def test_invalid_float8_dtype(self):
FILE: keras/src/backend/common/global_state.py
function set_global_attribute (line 11) | def set_global_attribute(name, value):
function get_global_attribute (line 15) | def get_global_attribute(name, default=None, set_to_default=False):
function clear_session (line 25) | def clear_session(free_memory=True):
FILE: keras/src/backend/common/global_state_test.py
class GlobalStateTest (line 6) | class GlobalStateTest(test_case.TestCase):
method test_clear_session (line 7) | def test_clear_session(self):
FILE: keras/src/backend/common/keras_tensor.py
class KerasTensor (line 7) | class KerasTensor:
method __init__ (line 30) | def __init__(
method shape (line 70) | def shape(self):
method shape (line 74) | def shape(self, value):
method dtype (line 81) | def dtype(self):
method dtype (line 85) | def dtype(self, value):
method sparse (line 92) | def sparse(self):
method sparse (line 96) | def sparse(self, value):
method ragged_rank (line 103) | def ragged_rank(self):
method ragged_rank (line 107) | def ragged_rank(self, value):
method row_splits_dtype (line 114) | def row_splits_dtype(self):
method row_splits_dtype (line 118) | def row_splits_dtype(self, value):
method ragged (line 125) | def ragged(self):
method ragged (line 129) | def ragged(self, value):
method ndim (line 136) | def ndim(self):
method reshape (line 139) | def reshape(self, newshape):
method squeeze (line 144) | def squeeze(self, axis=None):
method __int__ (line 149) | def __int__(self):
method __float__ (line 156) | def __float__(self):
method __array__ (line 163) | def __array__(self):
method __jax_array__ (line 170) | def __jax_array__(self):
method __tf_tensor__ (line 193) | def __tf_tensor__(self, dtype=None, name=None):
method __repr__ (line 216) | def __repr__(self):
method __iter__ (line 222) | def __iter__(self):
method __bool__ (line 227) | def __bool__(self):
method __add__ (line 230) | def __add__(self, other):
method __radd__ (line 235) | def __radd__(self, other):
method __sub__ (line 240) | def __sub__(self, other):
method __rsub__ (line 245) | def __rsub__(self, other):
method __mul__ (line 250) | def __mul__(self, other):
method __rmul__ (line 255) | def __rmul__(self, other):
method __matmul__ (line 260) | def __matmul__(self, other):
method __rmatmul__ (line 265) | def __rmatmul__(self, other):
method __div__ (line 270) | def __div__(self, other):
method __rdiv__ (line 275) | def __rdiv__(self, other):
method __truediv__ (line 280) | def __truediv__(self, other):
method __rtruediv__ (line 285) | def __rtruediv__(self, other):
method __neg__ (line 290) | def __neg__(self):
method __abs__ (line 295) | def __abs__(self):
method __pow__ (line 300) | def __pow__(self, other):
method __rpow__ (line 305) | def __rpow__(self, other):
method __floordiv__ (line 310) | def __floordiv__(self, other):
method __rfloordiv__ (line 315) | def __rfloordiv__(self, other):
method __mod__ (line 320) | def __mod__(self, other):
method __rmod__ (line 325) | def __rmod__(self, other):
method __lt__ (line 330) | def __lt__(self, other):
method __le__ (line 335) | def __le__(self, other):
method __gt__ (line 340) | def __gt__(self, other):
method __ge__ (line 345) | def __ge__(self, other):
method __ne__ (line 350) | def __ne__(self, other):
method __and__ (line 355) | def __and__(self, other):
method __rand__ (line 360) | def __rand__(self, other):
method __or__ (line 365) | def __or__(self, other):
method __ror__ (line 370) | def __ror__(self, other):
method __invert__ (line 375) | def __invert__(self):
method __xor__ (line 380) | def __xor__(self, other):
method __rxor__ (line 385) | def __rxor__(self, other):
method __getitem__ (line 390) | def __getitem__(self, key):
method __round__ (line 395) | def __round__(self, ndigits=None):
function any_symbolic_tensors (line 402) | def any_symbolic_tensors(args=None, kwargs=None):
function is_keras_tensor (line 412) | def is_keras_tensor(x):
FILE: keras/src/backend/common/keras_tensor_test.py
class KerasTensorTest (line 13) | class KerasTensorTest(testing.TestCase):
method test_attributes (line 14) | def test_attributes(self):
method test_attributes_sparse (line 30) | def test_attributes_sparse(self):
method test_attributes_ragged (line 41) | def test_attributes_ragged(self):
method test_init_sparse_ragged_raises (line 52) | def test_init_sparse_ragged_raises(self):
method test_numpy_methods (line 58) | def test_numpy_methods(self):
method test_invalid_usage (line 75) | def test_invalid_usage(self):
method test_bool (line 95) | def test_bool(self):
method test_representation (line 100) | def test_representation(self):
method test_iterating (line 104) | def test_iterating(self):
method test_any_symbolic_tensors (line 109) | def test_any_symbolic_tensors(self):
method test_is_keras_tensor (line 115) | def test_is_keras_tensor(self):
method test_abs_method (line 122) | def test_abs_method(self, mock_symbolic_call):
method test_neg_method (line 131) | def test_neg_method(self, mock_method):
method test_sub_method (line 135) | def test_sub_method(self, mock_method):
method test_mul_method (line 140) | def test_mul_method(self, mock_method):
method test_matmul_method (line 145) | def test_matmul_method(self, mock_method):
method test_pow_method (line 150) | def test_pow_method(self, mock_method):
method test_mod_method (line 155) | def test_mod_method(self, mock_method):
method test_lt_method (line 160) | def test_lt_method(self, mock_method):
method test_and_method (line 165) | def test_and_method(self, mock_method):
method test_or_method (line 170) | def test_or_method(self, mock_method):
method test_getitem_method (line 175) | def test_getitem_method(self, mock_method):
method _test_unary_op_method (line 179) | def _test_unary_op_method(self, mock_method, operator):
method _test_binary_op_method (line 187) | def _test_binary_op_method(self, mock_method, other, operator):
method test_radd_method (line 196) | def test_radd_method(self, mock_symbolic_call):
method test_rsub_method (line 207) | def test_rsub_method(self, mock_symbolic_call):
method test_rmul_method (line 218) | def test_rmul_method(self, mock_symbolic_call):
method test_rmatmul_method (line 229) | def test_rmatmul_method(self, mock_symbolic_call):
method test_rpow_method (line 240) | def test_rpow_method(self, mock_symbolic_call):
method test_floordiv_method (line 251) | def test_floordiv_method(self, mock_symbolic_call):
method test_rfloordiv_method (line 262) | def test_rfloordiv_method(self, mock_symbolic_call):
method test_rmod_method (line 273) | def test_rmod_method(self, mock_symbolic_call):
method test_le_method (line 284) | def test_le_method(self, mock_symbolic_call):
method test_gt_method (line 295) | def test_gt_method(self, mock_symbolic_call):
method test_ge_method (line 306) | def test_ge_method(self, mock_symbolic_call):
method test_ne_method (line 317) | def test_ne_method(self, mock_symbolic_call):
method test_rand_method (line 328) | def test_rand_method(self, mock_symbolic_call):
method test_ror_method (line 339) | def test_ror_method(self, mock_symbolic_call):
method test_invert_method (line 350) | def test_invert_method(self, mock_symbolic_call):
method test_xor_method (line 360) | def test_xor_method(self, mock_symbolic_call):
method test_rxor_method (line 371) | def test_rxor_method(self, mock_symbolic_call):
method test_truediv_method (line 382) | def test_truediv_method(self, mock_symbolic_call):
method test_rtruediv_method (line 393) | def test_rtruediv_method(self, mock_symbolic_call):
method test_div_method (line 404) | def test_div_method(self, mock_symbolic_call):
method test_rdiv_method (line 416) | def test_rdiv_method(self, mock_symbolic_call):
FILE: keras/src/backend/common/masking.py
function set_keras_mask (line 5) | def set_keras_mask(x, mask):
function get_keras_mask (line 16) | def get_keras_mask(x):
FILE: keras/src/backend/common/masking_test.py
class MaskingTest (line 8) | class MaskingTest(testing.TestCase):
method test_mask_on_eager_tensor (line 9) | def test_mask_on_eager_tensor(self):
method test_mask_on_tracer_tensor (line 26) | def test_mask_on_tracer_tensor(self):
FILE: keras/src/backend/common/name_scope.py
class name_scope (line 4) | class name_scope:
method __init__ (line 18) | def __init__(
method __enter__ (line 39) | def __enter__(self):
method __exit__ (line 56) | def __exit__(self, *args, **kwargs):
function current_path (line 65) | def current_path():
FILE: keras/src/backend/common/name_scope_test.py
class NameScopeTest (line 9) | class NameScopeTest(testing.TestCase):
method test_stacking (line 10) | def test_stacking(self):
method test_deduplication (line 25) | def test_deduplication(self):
method test_errors (line 35) | def test_errors(self):
method test_override_parent (line 41) | def test_override_parent(self):
method test_exit_with_none_stack (line 53) | def test_exit_with_none_stack(self):
method test_exit_with_empty_stack (line 70) | def test_exit_with_empty_stack(self):
method test_multithreaded_name_scope (line 90) | def test_multithreaded_name_scope(self):
method test_exit_without_pop_on_exit (line 116) | def test_exit_without_pop_on_exit(self):
FILE: keras/src/backend/common/remat.py
class RematScope (line 9) | class RematScope:
method __init__ (line 78) | def __init__(
method __enter__ (line 98) | def __enter__(self):
method __exit__ (line 106) | def __exit__(self, *args, **kwargs):
function get_current_remat_mode (line 119) | def get_current_remat_mode():
function remat (line 138) | def remat(f):
FILE: keras/src/backend/common/remat_test.py
class TestRematScope (line 14) | class TestRematScope(testing.TestCase):
method test_remat_scope_activation (line 15) | def test_remat_scope_activation(self):
method test_remat_scope_nested (line 29) | def test_remat_scope_nested(self):
method test_remat_scope_stack_management (line 49) | def test_remat_scope_stack_management(self):
method test_invalid_mode (line 75) | def test_invalid_mode(self):
class RematTest (line 85) | class RematTest(testing.TestCase):
method test_remat_basic_call (line 86) | def test_remat_basic_call(self):
method test_remat_with_kwargs (line 117) | def test_remat_with_kwargs(self):
FILE: keras/src/backend/common/stateless_scope.py
class StatelessScope (line 6) | class StatelessScope:
method __init__ (line 35) | def __init__(
method __enter__ (line 70) | def __enter__(self):
method add_loss (line 75) | def add_loss(self, loss):
method add_update (line 78) | def add_update(self, update):
method get_current_value (line 82) | def get_current_value(self, variable):
method __exit__ (line 85) | def __exit__(self, *args, **kwargs):
function in_stateless_scope (line 100) | def in_stateless_scope():
function get_stateless_scope (line 104) | def get_stateless_scope():
FILE: keras/src/backend/common/stateless_scope_test.py
class TestStatelessScope (line 9) | class TestStatelessScope(testing.TestCase):
method test_basic_flow (line 10) | def test_basic_flow(self):
method test_invalid_key_in_state_mapping (line 38) | def test_invalid_key_in_state_mapping(self):
method test_invalid_value_shape_in_state_mapping (line 48) | def test_invalid_value_shape_in_state_mapping(self):
FILE: keras/src/backend/common/symbolic_scope.py
class SymbolicScope (line 6) | class SymbolicScope:
method __enter__ (line 9) | def __enter__(self):
method __exit__ (line 14) | def __exit__(self, *args, **kwargs):
function in_symbolic_scope (line 18) | def in_symbolic_scope():
function get_symbolic_scope (line 22) | def get_symbolic_scope():
FILE: keras/src/backend/common/symbolic_scope_test.py
class TestSymbolicScope (line 9) | class TestSymbolicScope(testing.TestCase):
method test_basic_flow (line 10) | def test_basic_flow(self):
FILE: keras/src/backend/common/tensor_attributes.py
function _clear_tensor_attr (line 6) | def _clear_tensor_attr(tensor_id, attr):
function set_tensor_attr (line 12) | def set_tensor_attr(tensor, attr, value):
function get_tensor_attr (line 29) | def get_tensor_attr(tensor, attr):
FILE: keras/src/backend/common/thread_safe_test.py
class TestThreadSafe (line 10) | class TestThreadSafe(testing.TestCase):
method test_is_thread_safe (line 11) | def test_is_thread_safe(self):
FILE: keras/src/backend/common/variables.py
class Variable (line 15) | class Variable:
method __init__ (line 92) | def __init__(
method _deferred_initialize (line 212) | def _deferred_initialize(self):
method _validate_shape (line 231) | def _validate_shape(self, shape):
method _maybe_autocast (line 241) | def _maybe_autocast(self, value):
method numpy (line 247) | def numpy(self):
method aggregation (line 251) | def aggregation(self):
method synchronization (line 256) | def synchronization(self):
method value (line 261) | def value(self):
method assign (line 278) | def assign(self, value):
method assign_add (line 296) | def assign_add(self, value):
method assign_sub (line 299) | def assign_sub(self, value):
method dtype (line 303) | def dtype(self):
method shape (line 317) | def shape(self):
method ndim (line 322) | def ndim(self):
method trainable (line 327) | def trainable(self):
method trainable (line 332) | def trainable(self, value):
method name (line 336) | def name(self):
method path (line 341) | def path(self):
method overwrite_with_gradient (line 346) | def overwrite_with_gradient(self):
method overwrite_with_gradient (line 359) | def overwrite_with_gradient(self, value):
method regularizer (line 368) | def regularizer(self):
method regularizer (line 372) | def regularizer(self, value):
method constraint (line 384) | def constraint(self):
method constraint (line 388) | def constraint(self, value):
method __repr__ (line 399) | def __repr__(self):
method _initialize (line 413) | def _initialize(self, value):
method _initialize_with_initializer (line 416) | def _initialize_with_initializer(self, initializer):
method _convert_to_tensor (line 422) | def _convert_to_tensor(self, value, dtype=None):
method __getitem__ (line 425) | def __getitem__(self, idx):
method __int__ (line 428) | def __int__(self):
method __float__ (line 436) | def __float__(self):
method __array__ (line 444) | def __array__(self, dtype=None):
method __bool__ (line 451) | def __bool__(self):
method __neg__ (line 454) | def __neg__(self):
method __pos__ (line 457) | def __pos__(self):
method __abs__ (line 460) | def __abs__(self):
method __invert__ (line 463) | def __invert__(self):
method __eq__ (line 466) | def __eq__(self, other):
method __ne__ (line 469) | def __ne__(self, other):
method __lt__ (line 472) | def __lt__(self, other):
method __le__ (line 475) | def __le__(self, other):
method __gt__ (line 478) | def __gt__(self, other):
method __ge__ (line 481) | def __ge__(self, other):
method __add__ (line 484) | def __add__(self, other):
method __radd__ (line 487) | def __radd__(self, other):
method __sub__ (line 490) | def __sub__(self, other):
method __rsub__ (line 493) | def __rsub__(self, other):
method __mul__ (line 496) | def __mul__(self, other):
method __rmul__ (line 499) | def __rmul__(self, other):
method __truediv__ (line 502) | def __truediv__(self, other):
method __rtruediv__ (line 505) | def __rtruediv__(self, other):
method __floordiv__ (line 508) | def __floordiv__(self, other):
method __rfloordiv__ (line 511) | def __rfloordiv__(self, other):
method __mod__ (line 514) | def __mod__(self, other):
method __rmod__ (line 517) | def __rmod__(self, other):
method __pow__ (line 520) | def __pow__(self, other):
method __rpow__ (line 523) | def __rpow__(self, other):
method __matmul__ (line 526) | def __matmul__(self, other):
method __rmatmul__ (line 529) | def __rmatmul__(self, other):
method __and__ (line 532) | def __and__(self, other):
method __rand__ (line 535) | def __rand__(self, other):
method __or__ (line 538) | def __or__(self, other):
method __ror__ (line 541) | def __ror__(self, other):
method __xor__ (line 544) | def __xor__(self, other):
method __rxor__ (line 547) | def __rxor__(self, other):
method __round__ (line 550) | def __round__(self, ndigits=None):
function register_uninitialized_variable (line 555) | def register_uninitialized_variable(variable):
function initialize_all_variables (line 562) | def initialize_all_variables():
function standardize_dtype (line 573) | def standardize_dtype(dtype):
function standardize_shape (line 591) | def standardize_shape(shape):
function shape_equal (line 651) | def shape_equal(a_shape, b_shape):
function is_float_dtype (line 662) | def is_float_dtype(dtype):
function is_int_dtype (line 668) | def is_int_dtype(dtype):
function get_autocast_scope (line 673) | def get_autocast_scope():
class AutocastScope (line 677) | class AutocastScope:
method __init__ (line 684) | def __init__(self, dtype):
method maybe_cast (line 696) | def maybe_cast(self, value):
method __enter__ (line 703) | def __enter__(self):
method __exit__ (line 707) | def __exit__(self, *args, **kwargs):
FILE: keras/src/backend/common/variables_test.py
class VariableInitializationTest (line 21) | class VariableInitializationTest(test_case.TestCase):
method test_deferred_initialization (line 24) | def test_deferred_initialization(self):
method test_variable_initialization_with_numpy_array (line 39) | def test_variable_initialization_with_numpy_array(self):
method test_variable_initialization_with_native_array (line 47) | def test_variable_initialization_with_native_array(self):
method test_variable_initialization_with_python_array (line 55) | def test_variable_initialization_with_python_array(self):
method test_variable_initialization_with_lambda_expression (line 66) | def test_variable_initialization_with_lambda_expression(self):
method test_variable_initialization_with_strings (line 103) | def test_variable_initialization_with_strings(self):
method test_variable_initialization_with_non_trainable (line 108) | def test_variable_initialization_with_non_trainable(self):
method test_variable_initialization_without_shape (line 113) | def test_variable_initialization_without_shape(self):
method test_deferred_initialize_already_initialized (line 121) | def test_deferred_initialize_already_initialized(self):
method test_variable_initialize (line 129) | def test_variable_initialize(self):
method test_variable_without_shape_from_callable_initializer (line 136) | def test_variable_without_shape_from_callable_initializer(self):
class VariablePropertiesTest (line 145) | class VariablePropertiesTest(test_case.TestCase):
method test_deferred_assignment (line 151) | def test_deferred_assignment(self):
method test_trainable_setter (line 164) | def test_trainable_setter(self):
method test_autocasting_float (line 180) | def test_autocasting_float(self):
method test_autocasting_float_assign (line 195) | def test_autocasting_float_assign(self):
method test_autocasting_int (line 221) | def test_autocasting_int(self):
method test_autocasting_float_with_autocast_off (line 235) | def test_autocasting_float_with_autocast_off(self):
method test_standardize_dtype (line 259) | def test_standardize_dtype(self, dtype):
method test_standardize_dtype_with_torch_dtype (line 290) | def test_standardize_dtype_with_torch_dtype(self):
method test_name_validation (line 297) | def test_name_validation(self):
method test_standardize_shape_with_none (line 311) | def test_standardize_shape_with_none(self):
method test_standardize_shape_with_non_iterable (line 317) | def test_standardize_shape_with_non_iterable(self):
method test_standardize_shape_with_valid_input (line 323) | def test_standardize_shape_with_valid_input(self):
method test_standardize_shape_with_valid_input_with_none (line 328) | def test_standardize_shape_with_valid_input_with_none(self):
method test_standardize_shape_with_valid_not_tuple_input (line 333) | def test_standardize_shape_with_valid_not_tuple_input(self):
method test_standardize_shape_with_numpy (line 338) | def test_standardize_shape_with_numpy(self):
method test_standardize_shape_with_string (line 345) | def test_standardize_shape_with_string(self):
method test_standardize_shape_with_float (line 353) | def test_standardize_shape_with_float(self):
method test_standardize_shape_with_object (line 361) | def test_standardize_shape_with_object(self):
method test_standardize_shape_with_negative_dimension (line 369) | def test_standardize_shape_with_negative_dimension(self):
method test_standardize_shape_preserves_none (line 381) | def test_standardize_shape_preserves_none(self, input_shape, expected):
method test_shape_equal_length_mismatch (line 386) | def test_shape_equal_length_mismatch(self):
method test_autocast_scope_with_non_float_dtype (line 392) | def test_autocast_scope_with_non_float_dtype(self):
method test_variable_path_creation (line 400) | def test_variable_path_creation(self):
method test_overwrite_with_gradient_setter (line 409) | def test_overwrite_with_gradient_setter(self):
class VariableNumpyValueAndAssignmentTest (line 422) | class VariableNumpyValueAndAssignmentTest(test_case.TestCase):
method test_variable_numpy (line 425) | def test_variable_numpy(self):
method test_variable_numpy_scalar (line 435) | def test_variable_numpy_scalar(self):
method test_variable_value (line 446) | def test_variable_value(self):
method test_variable_assign (line 451) | def test_variable_assign(self):
method test_variable_assign_return (line 457) | def test_variable_assign_return(self):
method test_variable_assign_add (line 463) | def test_variable_assign_add(self):
method test_variable_assign_add_return (line 469) | def test_variable_assign_add_return(self):
method test_variable_assign_sub (line 475) | def test_variable_assign_sub(self):
method test_variable_assign_sub_return (line 481) | def test_variable_assign_sub_return(self):
method test_deferred_initialize_within_stateless_scope (line 487) | def test_deferred_initialize_within_stateless_scope(self):
class VariableDtypeShapeNdimRepr (line 501) | class VariableDtypeShapeNdimRepr(test_case.TestCase):
method test_variable_dtype (line 504) | def test_variable_dtype(self):
method test_variable_shape (line 511) | def test_variable_shape(self):
method test_variable_ndim (line 516) | def test_variable_ndim(self):
method test_variable_repr (line 521) | def test_variable_repr(self):
method test_variable_getitem (line 543) | def test_variable_getitem(self):
method test_variable_initialize (line 548) | def test_variable_initialize(self):
method test_variable_convert_to_tensor (line 555) | def test_variable_convert_to_tensor(self):
method test_variable_convert_to_tensor_with_dtype (line 562) | def test_variable_convert_to_tensor_with_dtype(self):
method test_variable_array (line 570) | def test_variable_array(self):
class VariableOpsCorrectnessTest (line 576) | class VariableOpsCorrectnessTest(test_case.TestCase):
method test_int (line 579) | def test_int(self):
method test_float (line 583) | def test_float(self):
method test__neg__ (line 587) | def test__neg__(self):
method test__abs__ (line 592) | def test__abs__(self):
method test__invert__ (line 597) | def test__invert__(self):
method test__eq__ (line 604) | def test__eq__(self):
method test__ne__ (line 611) | def test__ne__(self):
method test__lt__ (line 618) | def test__lt__(self):
method test__le__ (line 625) | def test__le__(self):
method test__gt__ (line 632) | def test__gt__(self):
method test__ge__ (line 639) | def test__ge__(self):
method test__add__ (line 646) | def test__add__(self):
method test__radd__ (line 652) | def test__radd__(self):
method test__sub__ (line 658) | def test__sub__(self):
method test__rsub__ (line 664) | def test__rsub__(self):
method test__mul__ (line 670) | def test__mul__(self):
method test__rmul__ (line 676) | def test__rmul__(self):
method test__truediv__ (line 682) | def test__truediv__(self):
method test__rtruediv__ (line 688) | def test__rtruediv__(self):
method test__floordiv__ (line 697) | def test__floordiv__(self):
method test__rfloordiv__ (line 706) | def test__rfloordiv__(self):
method test__mod__ (line 712) | def test__mod__(self):
method test__rmod__ (line 718) | def test__rmod__(self):
method test__pow__ (line 724) | def test__pow__(self):
method test__rpow__ (line 730) | def test__rpow__(self):
method test__matmul__ (line 736) | def test__matmul__(self):
method test__rmatmul__ (line 744) | def test__rmatmul__(self):
method test__and__ (line 752) | def test__and__(self):
method test__rand__ (line 762) | def test__rand__(self):
method test__or__ (line 772) | def test__or__(self):
method test__ror__ (line 782) | def test__ror__(self):
method test__xor__ (line 792) | def test__xor__(self):
method test__rxor__ (line 802) | def test__rxor__(self):
method test__pos__ (line 812) | def test__pos__(self):
method test_variable_pow (line 817) | def test_variable_pow(self):
method test_variable_rpow (line 824) | def test_variable_rpow(self):
method test_round (line 831) | def test_round(self):
class VariableOpsBehaviorTest (line 836) | class VariableOpsBehaviorTest(test_case.TestCase):
method test_invalid_bool (line 837) | def test_invalid_bool(self):
method test_invalid_int (line 845) | def test_invalid_int(self):
method test_invalid_float (line 852) | def test_invalid_float(self):
class VariableOpsDTypeTest (line 860) | class VariableOpsDTypeTest(test_case.TestCase):
method test_eq (line 902) | def test_eq(self, dtypes):
method test_ne (line 917) | def test_ne(self, dtypes):
method test_lt (line 932) | def test_lt(self, dtypes):
method test_le (line 947) | def test_le(self, dtypes):
method test_gt (line 962) | def test_gt(self, dtypes):
method test_ge (line 977) | def test_ge(self, dtypes):
method test_add (line 994) | def test_add(self, dtypes):
method test_sub (line 1010) | def test_sub(self, dtypes):
method test_mul (line 1026) | def test_mul(self, dtypes):
method test_truediv (line 1042) | def test_truediv(self, dtypes):
method test_floordiv (line 1063) | def test_floordiv(self, dtypes):
method test_mod (line 1081) | def test_mod(self, dtypes):
method test_pow (line 1097) | def test_pow(self, dtypes):
method test_matmul (line 1113) | def test_matmul(self, dtypes):
method test_and (line 1129) | def test_and(self, dtypes):
method test_or (line 1147) | def test_or(self, dtypes):
method test_xor (line 1163) | def test_xor(self, dtypes):
class TestStandardizeShapeWithTorch (line 1183) | class TestStandardizeShapeWithTorch(test_case.TestCase):
method test_standardize_shape_with_torch_size (line 1184) | def test_standardize_shape_with_torch_size(self):
method test_standardize_shape_with_torch_symint (line 1195) | def test_standardize_shape_with_torch_symint(self):
class TestStandardizeShapeWithTensorflow (line 1222) | class TestStandardizeShapeWithTensorflow(test_case.TestCase):
method test_standardize_shape_with_tensor_size (line 1223) | def test_standardize_shape_with_tensor_size(self):
FILE: keras/src/backend/config.py
function floatx (line 27) | def floatx():
function set_floatx (line 45) | def set_floatx(value):
function epsilon (line 82) | def epsilon():
function set_epsilon (line 98) | def set_epsilon(value):
function image_data_format (line 126) | def image_data_format():
function set_image_data_format (line 147) | def set_image_data_format(data_format):
function enable_flash_attention (line 178) | def enable_flash_attention():
function disable_flash_attention (line 199) | def disable_flash_attention():
function is_flash_attention_enabled (line 215) | def is_flash_attention_enabled():
function is_nnx_enabled (line 237) | def is_nnx_enabled():
function set_nnx_enabled (line 247) | def set_nnx_enabled(value):
function standardize_data_format (line 262) | def standardize_data_format(data_format):
function keras_home (line 286) | def keras_home():
function backend (line 377) | def backend():
function set_max_epochs (line 394) | def set_max_epochs(max_epochs):
function set_max_steps_per_epoch (line 410) | def set_max_steps_per_epoch(max_steps_per_epoch):
function max_epochs (line 427) | def max_epochs():
function max_steps_per_epoch (line 442) | def max_steps_per_epoch():
FILE: keras/src/backend/jax/core.py
class JaxVariable (line 28) | class JaxVariable(KerasVariable):
method __init__ (line 29) | def __init__(self, *args, layout=None, **kwargs):
method _initialize_layout (line 35) | def _initialize_layout(self):
method _initialize (line 48) | def _initialize(self, value):
method _initialize_with_initializer (line 54) | def _initialize_with_initializer(self, initializer):
method _direct_assign (line 69) | def _direct_assign(self, value):
method _convert_to_tensor (line 74) | def _convert_to_tensor(self, value, dtype=None):
method __jax_array__ (line 78) | def __jax_array__(self):
class NnxVariable (line 86) | class NnxVariable(JaxVariable, nnx.Variable):
method __init__ (line 87) | def __init__(
method _initialize_with_initializer (line 135) | def _initialize_with_initializer(self, initializer):
method _value (line 142) | def _value(self):
method _value (line 148) | def _value(self, new_keras_value):
method __getstate__ (line 151) | def __getstate__(self):
method __setstate__ (line 182) | def __setstate__(self, state):
method _direct_assign (line 216) | def _direct_assign(self, value):
method value (line 233) | def value(self):
function _flatten_nnx_variable (line 261) | def _flatten_nnx_variable(variable):
function _unflatten_nnx_variable (line 275) | def _unflatten_nnx_variable(aux_data, children):
function __setattr__ (line 300) | def __setattr__(self, name, value):
function should_shard_at_init (line 315) | def should_shard_at_init(init_layout, shape):
function convert_to_tensor (line 327) | def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
function convert_to_numpy (line 360) | def convert_to_numpy(x):
function is_tensor (line 368) | def is_tensor(x):
function shape (line 374) | def shape(x):
function cast (line 378) | def cast(x, dtype):
function compute_output_spec (line 383) | def compute_output_spec(fn, *args, **kwargs):
function cond (line 487) | def cond(pred, true_fn, false_fn):
function vectorized_map (line 491) | def vectorized_map(function, elements):
function map (line 495) | def map(f, xs):
function scan (line 499) | def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
function associative_scan (line 511) | def associative_scan(f, elems, reverse=False, axis=0):
function scatter (line 515) | def scatter(indices, values, shape):
function scatter_update (line 521) | def scatter_update(inputs, indices, updates, reduction=None):
function slice (line 540) | def slice(inputs, start_indices, shape):
function slice_update (line 550) | def slice_update(inputs, start_indices, updates):
function switch (line 554) | def switch(index, branches, *operands):
function while_loop (line 558) | def while_loop(
function fori_loop (line 594) | def fori_loop(lower, upper, body_fun, init_val):
function stop_gradient (line 598) | def stop_gradient(variable):
function unstack (line 604) | def unstack(x, num=None, axis=0):
function random_seed_dtype (line 611) | def random_seed_dtype():
function custom_gradient (line 616) | def custom_gradient(fun):
function remat (line 630) | def remat(f):
class name_scope (line 642) | class name_scope(base_name_scope):
method __init__ (line 643) | def __init__(self, name, **kwargs):
method __enter__ (line 647) | def __enter__(self):
method __exit__ (line 665) | def __exit__(self, *args, **kwargs):
function device_scope (line 671) | def device_scope(device_name):
FILE: keras/src/backend/jax/core_test.py
class NnxVariableTest (line 27) | class NnxVariableTest(testing.TestCase):
method setup (line 28) | def setup(self):
method test_variable_in_nnx_module (line 47) | def test_variable_in_nnx_module(self):
method test_model_saving (line 55) | def test_model_saving(self):
method test_keras_variable_nnx_split_merge_sync (line 63) | def test_keras_variable_nnx_split_merge_sync(self):
FILE: keras/src/backend/jax/distribution_lib.py
function list_devices (line 11) | def list_devices(device_type=None):
function get_device_count (line 29) | def get_device_count(device_type=None):
function distribute_tensor (line 43) | def distribute_tensor(tensor, layout):
function distribute_data_input (line 86) | def distribute_data_input(per_process_batch, layout, batch_dim_name):
function initialize_rng (line 109) | def initialize_rng():
function initialize (line 138) | def initialize(job_addresses, num_processes, process_id):
function num_processes (line 165) | def num_processes():
function process_id (line 170) | def process_id():
function _to_backend_device (line 175) | def _to_backend_device(device_name):
function _to_backend_mesh (line 191) | def _to_backend_mesh(device_mesh):
function _to_backend_layout (line 206) | def _to_backend_layout(tensor_layout):
FILE: keras/src/backend/jax/distribution_lib_test.py
class JaxDistributionLibTest (line 35) | class JaxDistributionLibTest(testing.TestCase):
method _create_jax_layout (line 36) | def _create_jax_layout(self, sharding):
method test_get_device_count (line 45) | def test_get_device_count(self):
method test_list_devices (line 49) | def test_list_devices(self):
method test_device_conversion (line 54) | def test_device_conversion(self):
method test_initialize_with_all_job_addresses (line 64) | def test_initialize_with_all_job_addresses(self, mock_jax_initialize):
method test_initialize_validate_job_and_process (line 70) | def test_initialize_validate_job_and_process(self):
method test_initialize_with_coordinator_address (line 77) | def test_initialize_with_coordinator_address(self, mock_jax_initialize):
method test_distribute_tensor (line 83) | def test_distribute_tensor(self):
method test_distribute_tensor_with_jax_layout (line 107) | def test_distribute_tensor_with_jax_layout(self):
method test_processes (line 137) | def test_processes(self):
method test_to_backend_mesh (line 141) | def test_to_backend_mesh(self):
method test_to_backend_layout (line 153) | def test_to_backend_layout(self):
method test_validation_for_device_mesh (line 168) | def test_validation_for_device_mesh(self):
method test_variable_assignment_reuse_layout (line 177) | def test_variable_assignment_reuse_layout(self):
method test_e2e_data_parallel_model (line 215) | def test_e2e_data_parallel_model(self):
method test_e2e_model_parallel_model (line 239) | def test_e2e_model_parallel_model(self):
method test_e2e_model_parallel_with_output_sharding (line 278) | def test_e2e_model_parallel_with_output_sharding(self):
method test_distribute_data_input (line 334) | def test_distribute_data_input(self):
class ShardingCaptureLayer (line 368) | class ShardingCaptureLayer(layers.Layer):
method __init__ (line 369) | def __init__(self, **kwargs):
method call (line 374) | def call(self, inputs):
method capture_input_sharding (line 380) | def capture_input_sharding(self, sharding):
FILE: keras/src/backend/jax/export.py
class JaxExportArchive (line 13) | class JaxExportArchive(SavedModelExportArchive):
method _backend_init (line 16) | def _backend_init(self):
method _backend_track_layer (line 22) | def _backend_track_layer(self, layer):
method _backend_add_endpoint (line 46) | def _backend_add_endpoint(self, name, fn, input_signature, **kwargs):
method _convert_jax2tf_function (line 131) | def _convert_jax2tf_function(self, fn, input_signature, jax2tf_kwargs=...
method _to_polymorphic_shape (line 141) | def _to_polymorphic_shape(self, struct, allow_none=True):
method _check_device_compatible (line 170) | def _check_device_compatible(self):
FILE: keras/src/backend/jax/image.py
function rgb_to_grayscale (line 47) | def rgb_to_grayscale(images, data_format=None):
function rgb_to_hsv (line 71) | def rgb_to_hsv(images, data_format=None):
function hsv_to_rgb (line 121) | def hsv_to_rgb(images, data_format=None):
function resize (line 161) | def resize(
function affine_transform (line 403) | def affine_transform(
function perspective_transform (line 501) | def perspective_transform(
function compute_homography_matrix (line 602) | def compute_homography_matrix(start_points, end_points):
function map_coordinates (line 647) | def map_coordinates(
function gaussian_blur (line 680) | def gaussian_blur(
function elastic_transform (line 743) | def elastic_transform(
function scale_and_translate (line 872) | def scale_and_translate(
FILE: keras/src/backend/jax/layer.py
class BaseLayer (line 6) | class BaseLayer(nnx.Module):
method __init_subclass__ (line 7) | def __init_subclass__(cls, **kwargs):
class JaxLayer (line 13) | class JaxLayer(BaseLayer):
FILE: keras/src/backend/jax/linalg.py
function cholesky (line 12) | def cholesky(a, upper=False):
function cholesky_inverse (line 29) | def cholesky_inverse(a, upper=False):
function det (line 39) | def det(a):
function eig (line 43) | def eig(x):
function eigh (line 47) | def eigh(x):
function inv (line 51) | def inv(a):
function lu_factor (line 55) | def lu_factor(x):
function norm (line 64) | def norm(x, ord=None, axis=None, keepdims=False):
function qr (line 74) | def qr(x, mode="reduced"):
function solve (line 84) | def solve(a, b):
function solve_triangular (line 88) | def solve_triangular(a, b, lower=False):
function svd (line 92) | def svd(x, full_matrices=True, compute_uv=True):
function lstsq (line 96) | def lstsq(a, b, rcond=None):
function jvp (line 102) | def jvp(fun, primals, tangents, has_aux=False):
FILE: keras/src/backend/jax/math.py
function segment_sum (line 11) | def segment_sum(data, segment_ids, num_segments=None, sorted=False):
function segment_max (line 22) | def segment_max(data, segment_ids, num_segments=None, sorted=False):
function top_k (line 33) | def top_k(x, k, sorted=True):
function in_top_k (line 39) | def in_top_k(targets, predictions, k):
function logsumexp (line 51) | def logsumexp(x, axis=None, keepdims=False):
function qr (line 55) | def qr(x, mode="reduced"):
function extract_sequences (line 65) | def extract_sequences(x, sequence_length, sequence_stride):
function _get_complex_tensor_from_tuple (line 79) | def _get_complex_tensor_from_tuple(x):
function fft (line 107) | def fft(x):
function fft2 (line 113) | def fft2(x):
function ifft2 (line 119) | def ifft2(x):
function rfft (line 125) | def rfft(x, fft_length=None):
function irfft (line 130) | def irfft(x, fft_length=None):
function stft (line 135) | def stft(
function istft (line 198) | def istft(
function rsqrt (line 261) | def rsqrt(x):
function erf (line 265) | def erf(x):
function erfinv (line 269) | def erfinv(x):
function logdet (line 273) | def logdet(x):
FILE: keras/src/backend/jax/nn.py
function relu (line 29) | def relu(x):
function relu6 (line 34) | def relu6(x):
function sigmoid (line 39) | def sigmoid(x):
function sparse_sigmoid (line 44) | def sparse_sigmoid(x):
function tanh (line 49) | def tanh(x):
function tanh_shrink (line 54) | def tanh_shrink(x):
function softplus (line 59) | def softplus(x):
function softsign (line 64) | def softsign(x):
function soft_shrink (line 69) | def soft_shrink(x, threshold=0.5):
function sparse_plus (line 78) | def sparse_plus(x):
function silu (line 83) | def silu(x):
function squareplus (line 88) | def squareplus(x, b=4):
function log_sigmoid (line 93) | def log_sigmoid(x):
function leaky_relu (line 98) | def leaky_relu(x, negative_slope=0.2):
function hard_sigmoid (line 103) | def hard_sigmoid(x):
function hard_silu (line 108) | def hard_silu(x):
function elu (line 113) | def elu(x, alpha=1.0):
function selu (line 118) | def selu(x):
function gelu (line 123) | def gelu(x, approximate=True):
function celu (line 128) | def celu(x, alpha=1.0):
function glu (line 133) | def glu(x, axis=-1):
function hard_tanh (line 138) | def hard_tanh(x):
function hard_shrink (line 143) | def hard_shrink(x, threshold=0.5):
function threshold (line 148) | def threshold(x, threshold, default_value):
function softmax (line 153) | def softmax(x, axis=-1):
function log_softmax (line 158) | def log_softmax(x, axis=-1):
function sparsemax (line 163) | def sparsemax(x, axis=-1):
function _convert_to_spatial_operand (line 181) | def _convert_to_spatial_operand(
function _pool (line 198) | def _pool(
function max_pool (line 236) | def max_pool(
function average_pool (line 255) | def average_pool(
function _compute_adaptive_pooling_gather_indices (line 295) | def _compute_adaptive_pooling_gather_indices(
function _adaptive_average_pool1d (line 320) | def _adaptive_average_pool1d(inputs, output_size, data_format="channels_...
function _adaptive_max_pool1d (line 354) | def _adaptive_max_pool1d(inputs, output_size, data_format="channels_firs...
function _adaptive_average_pool2d (line 384) | def _adaptive_average_pool2d(inputs, output_size, data_format="channels_...
function _adaptive_max_pool2d (line 440) | def _adaptive_max_pool2d(inputs, output_size, data_format="channels_firs...
function _adaptive_average_pool3d (line 484) | def _adaptive_average_pool3d(inputs, output_size, data_format="channels_...
function _adaptive_max_pool3d (line 585) | def _adaptive_max_pool3d(inputs, output_size, data_format="channels_firs...
function adaptive_average_pool (line 668) | def adaptive_average_pool(inputs, output_size, data_format=None):
function adaptive_max_pool (line 680) | def adaptive_max_pool(inputs, output_size, data_format=None):
function _convert_to_lax_conv_dimension_numbers (line 692) | def _convert_to_lax_conv_dimension_numbers(
function conv (line 717) | def conv(
function depthwise_conv (line 777) | def depthwise_conv(
function separable_conv (line 824) | def separable_conv(
function conv_transpose (line 852) | def conv_transpose(
function one_hot (line 900) | def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
function multi_hot (line 928) | def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
function categorical_crossentropy (line 951) | def categorical_crossentropy(target, output, from_logits=False, axis=-1):
function sparse_categorical_crossentropy (line 977) | def sparse_categorical_crossentropy(target, output, from_logits=False, a...
function binary_crossentropy (line 1005) | def binary_crossentropy(target, output, from_logits=False):
function moments (line 1027) | def moments(x, axes, keepdims=False, synchronized=False):
function batch_normalization (line 1060) | def batch_normalization(
function ctc_loss (line 1081) | def ctc_loss(target, output, target_length, output_length, mask_index=0):
function _ctc_greedy_decode (line 1189) | def _ctc_greedy_decode(
function _ctc_beam_search_decode (line 1232) | def _ctc_beam_search_decode(
function ctc_decode (line 1395) | def ctc_decode(
function psnr (line 1430) | def psnr(x1, x2, max_val):
function _can_use_flash_attention (line 1443) | def _can_use_flash_attention(query, key, value, bias, raise_error=False):
function _apply_masks (line 1517) | def _apply_masks(logits, mask, is_causal):
function _dot_product_attention_core (line 1538) | def _dot_product_attention_core(
function wrap_flash_attention (line 1558) | def wrap_flash_attention(
function dot_product_attention (line 1626) | def dot_product_attention(
function unfold (line 1860) | def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
function fold (line 1903) | def fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):
function depth_to_space (line 1973) | def depth_to_space(x, block_size, data_format="channels_last"):
function space_to_depth (line 2012) | def space_to_depth(x, block_size, data_format="channels_last"):
FILE: keras/src/backend/jax/numpy.py
function _uses_cpu (line 20) | def _uses_cpu(x):
function rot90 (line 32) | def rot90(array, k=1, axes=(0, 1)):
function add (line 48) | def add(x1, x2):
function bartlett (line 54) | def bartlett(x):
function hamming (line 59) | def hamming(x):
function hanning (line 64) | def hanning(x):
function heaviside (line 69) | def heaviside(x1, x2):
function hypot (line 75) | def hypot(x1, x2):
function kaiser (line 81) | def kaiser(x, beta):
function bincount (line 86) | def bincount(x, weights=None, minlength=0, sparse=False):
function einsum (line 134) | def einsum(subscripts, *operands, **kwargs):
function subtract (line 148) | def subtract(x1, x2):
function matmul (line 154) | def matmul(x1, x2):
function multiply (line 185) | def multiply(x1, x2):
function mean (line 226) | def mean(x, axis=None, keepdims=False):
function max (line 266) | def max(x, axis=None, keepdims=False, initial=None):
function ones (line 271) | def ones(shape, dtype=None):
function zeros (line 276) | def zeros(shape, dtype=None):
function absolute (line 282) | def absolute(x):
function abs (line 287) | def abs(x):
function all (line 291) | def all(x, axis=None, keepdims=False):
function angle (line 295) | def angle(x):
function any (line 305) | def any(x, axis=None, keepdims=False):
function amax (line 309) | def amax(x, axis=None, keepdims=False):
function amin (line 313) | def amin(x, axis=None, keepdims=False):
function append (line 317) | def append(x1, x2, axis=None):
function arange (line 323) | def arange(start, stop=None, step=None, dtype=None):
function arccos (line 343) | def arccos(x):
function arccosh (line 354) | def arccosh(x):
function arcsin (line 365) | def arcsin(x):
function arcsinh (line 376) | def arcsinh(x):
function arctan (line 387) | def arctan(x):
function arctan2 (line 397) | def arctan2(x1, x2):
function arctanh (line 407) | def arctanh(x):
function argmax (line 417) | def argmax(x, axis=None, keepdims=False):
function argmin (line 432) | def argmin(x, axis=None, keepdims=False):
function argsort (line 447) | def argsort(x, axis=-1):
function array (line 454) | def array(x, dtype=None):
function view (line 458) | def view(x, dtype=None):
function average (line 463) | def average(x, axis=None, weights=None):
function bitwise_and (line 476) | def bitwise_and(x, y):
function bitwise_invert (line 482) | def bitwise_invert(x):
function bitwise_not (line 487) | def bitwise_not(x):
function bitwise_or (line 491) | def bitwise_or(x, y):
function bitwise_xor (line 497) | def bitwise_xor(x, y):
function bitwise_left_shift (line 503) | def bitwise_left_shift(x, y):
function left_shift (line 510) | def left_shift(x, y):
function bitwise_right_shift (line 514) | def bitwise_right_shift(x, y):
function right_shift (line 521) | def right_shift(x, y):
function blackman (line 525) | def blackman(x):
function broadcast_to (line 530) | def broadcast_to(x, shape):
function cbrt (line 535) | def cbrt(x):
function ceil (line 541) | def ceil(x):
function clip (line 551) | def clip(x, x_min, x_max):
function concatenate (line 558) | def concatenate(xs, axis=0):
function conjugate (line 576) | def conjugate(x):
function conj (line 582) | def conj(x):
function copy (line 588) | def copy(x):
function cos (line 594) | def cos(x):
function cosh (line 605) | def cosh(x):
function count_nonzero (line 615) | def count_nonzero(x, axis=None):
function cross (line 619) | def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
function cumprod (line 632) | def cumprod(x, axis=None, dtype=None):
function cumsum (line 637) | def cumsum(x, axis=None, dtype=None):
function deg2rad (line 642) | def deg2rad(x):
function diag (line 647) | def diag(x, k=0):
function diagflat (line 652) | def diagflat(x, k=0):
function diagonal (line 657) | def diagonal(x, offset=0, axis1=0, axis2=1):
function diff (line 667) | def diff(a, n=1, axis=-1):
function digitize (line 673) | def digitize(x, bins):
function dot (line 679) | def dot(x1, x2):
function dstack (line 685) | def dstack(xs):
function empty (line 689) | def empty(shape, dtype=None):
function empty_like (line 694) | def empty_like(x, dtype=None):
function equal (line 698) | def equal(x1, x2):
function exp (line 705) | def exp(x):
function exp2 (line 714) | def exp2(x):
function expand_dims (line 722) | def expand_dims(x, axis):
function expm1 (line 739) | def expm1(x):
function flip (line 747) | def flip(x, axis=None):
function floor (line 752) | def floor(x):
function full (line 762) | def full(shape, fill_value, dtype=None):
function full_like (line 767) | def full_like(x, fill_value, dtype=None):
function gcd (line 771) | def gcd(x1, x2):
function geomspace (line 777) | def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
function greater (line 783) | def greater(x1, x2):
function greater_equal (line 789) | def greater_equal(x1, x2):
function hstack (line 795) | def hstack(xs):
function hsplit (line 799) | def hsplit(x, indices_or_sections):
function identity (line 804) | def identity(n, dtype=None):
function imag (line 810) | def imag(x):
function isclose (line 815) | def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):
function allclose (line 821) | def allclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):
function isfinite (line 828) | def isfinite(x):
function isin (line 833) | def isin(x1, x2, assume_unique=False, invert=False):
function isinf (line 840) | def isinf(x):
function isnan (line 846) | def isnan(x):
function isneginf (line 851) | def isneginf(x):
function isposinf (line 856) | def isposinf(x):
function isreal (line 861) | def isreal(x):
function kron (line 866) | def kron(x1, x2):
function lcm (line 872) | def lcm(x1, x2):
function ldexp (line 878) | def ldexp(x1, x2):
function less (line 891) | def less(x1, x2):
function less_equal (line 897) | def less_equal(x1, x2):
function linspace (line 903) | def linspace(
function log (line 918) | def log(x):
function log10 (line 926) | def log10(x):
function log1p (line 934) | def log1p(x):
function log2 (line 942) | def log2(x):
function logaddexp (line 949) | def logaddexp(x1, x2):
function logaddexp2 (line 958) | def logaddexp2(x1, x2):
function logical_and (line 967) | def logical_and(x1, x2):
function logical_not (line 973) | def logical_not(x):
function logical_or (line 978) | def logical_or(x1, x2):
function logspace (line 984) | def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, ax...
function maximum (line 997) | def maximum(x1, x2):
function median (line 1003) | def median(x, axis=None, keepdims=False):
function meshgrid (line 1020) | def meshgrid(*x, indexing="xy"):
function min (line 1024) | def min(x, axis=None, keepdims=False, initial=None):
function minimum (line 1030) | def minimum(x1, x2):
function mod (line 1036) | def mod(x1, x2):
function fmod (line 1042) | def fmod(x1, x2):
function moveaxis (line 1048) | def moveaxis(x, source, destination):
function nanargmax (line 1052) | def nanargmax(x, axis=None, keepdims=False):
function nanargmin (line 1057) | def nanargmin(x, axis=None, keepdims=False):
function nancumsum (line 1062) | def nancumsum(x, axis=None, dtype=None):
function nancumprod (line 1067) | def nancumprod(x, axis=None, dtype=None):
function nanmax (line 1072) | def nanmax(x, axis=None, keepdims=False):
function nanmean (line 1077) | def nanmean(x, axis=None, keepdims=False):
function nanmin (line 1082) | def nanmin(x, axis=None, keepdims=False):
function nanprod (line 1087) | def nanprod(x, axis=None, keepdims=False):
function nanstd (line 1092) | def nanstd(x, axis=None, keepdims=False):
function nansum (line 1097) | def nansum(x, axis=None, keepdims=False):
function nanvar (line 1102) | def nanvar(x, axis=None, keepdims=False):
function nan_to_num (line 1107) | def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
function ndim (line 1112) | def ndim(x):
function nonzero (line 1116) | def nonzero(x):
function not_equal (line 1120) | def not_equal(x1, x2):
function ones_like (line 1126) | def ones_like(x, dtype=None):
function zeros_like (line 1130) | def zeros_like(x, dtype=None):
function outer (line 1134) | def outer(x1, x2):
function pad (line 1138) | def pad(x, pad_width, mode="constant", constant_values=None):
function prod (line 1152) | def prod(x, axis=None, keepdims=False, dtype=None):
function ptp (line 1157) | def ptp(x, axis=None, keepdims=False):
function quantile (line 1162) | def quantile(x, q, axis=None, method="linear", keepdims=False):
function ravel (line 1178) | def ravel(x):
function unravel_index (line 1183) | def unravel_index(indices, shape):
function real (line 1189) | def real(x):
function reciprocal (line 1195) | def reciprocal(x):
function repeat (line 1200) | def repeat(x, repeats, axis=None):
function reshape (line 1205) | def reshape(x, newshape):
function roll (line 1220) | def roll(x, shift, axis=None):
function searchsorted (line 1224) | def searchsorted(sorted_sequence, values, side="left"):
function sign (line 1236) | def sign(x):
function signbit (line 1242) | def signbit(x):
function sin (line 1248) | def sin(x):
function sinc (line 1258) | def sinc(x):
function sinh (line 1269) | def sinh(x):
function size (line 1279) | def size(x):
function sort (line 1283) | def sort(x, axis=-1):
function split (line 1288) | def split(x, indices_or_sections, axis=0):
function array_split (line 1293) | def array_split(x, indices_or_sections, axis=0):
function stack (line 1298) | def stack(x, axis=0):
function std (line 1303) | def std(x, axis=None, keepdims=False):
function swapaxes (line 1310) | def swapaxes(x, axis1, axis2):
function take (line 1315) | def take(x, indices, axis=None):
function take_along_axis (line 1321) | def take_along_axis(x, indices, axis=None):
function tan (line 1328) | def tan(x):
function tanh (line 1339) | def tanh(x):
function tensordot (line 1349) | def tensordot(x1, x2, axes=2):
function round (line 1356) | def round(x, decimals=0):
function tile (line 1372) | def tile(x, repeats):
function trace (line 1376) | def trace(x, offset=0, axis1=0, axis2=1):
function tri (line 1381) | def tri(N, M=None, k=0, dtype=None):
function tril (line 1386) | def tril(x, k=0):
function triu (line 1391) | def triu(x, k=0):
function trunc (line 1396) | def trunc(x):
function vdot (line 1404) | def vdot(x1, x2):
function inner (line 1410) | def inner(x1, x2):
function vstack (line 1416) | def vstack(xs):
function vsplit (line 1420) | def vsplit(x, indices_or_sections):
function vectorize (line 1425) | def vectorize(pyfunc, *, excluded=None, signature=None):
function where (line 1431) | def where(condition, x1=None, x2=None):
function divide (line 1436) | def divide(x1, x2):
function divide_no_nan (line 1442) | def divide_no_nan(x1, x2):
function true_divide (line 1449) | def true_divide(x1, x2):
function power (line 1453) | def power(x1, x2):
function negative (line 1460) | def negative(x):
function nextafter (line 1465) | def nextafter(x1, x2):
function square (line 1472) | def square(x):
function sqrt (line 1478) | def sqrt(x):
function squeeze (line 1485) | def squeeze(x, axis=None):
function transpose (line 1495) | def transpose(x, axes=None):
function trapezoid (line 1510) | def trapezoid(y, x=None, dx=1.0, axis=-1):
function vander (line 1518) | def vander(x, N=None, increasing=False):
function var (line 1523) | def var(x, axis=None, keepdims=False):
function sum (line 1535) | def sum(x, axis=None, keepdims=False):
function eye (line 1560) | def eye(N, M=None, k=0, dtype=None):
function floor_divide (line 1565) | def floor_divide(x1, x2):
function logical_xor (line 1571) | def logical_xor(x1, x2):
function corrcoef (line 1577) | def corrcoef(x):
function correlate (line 1582) | def correlate(x1, x2, mode="valid"):
function select (line 1588) | def select(condlist, choicelist, default=0):
function slogdet (line 1592) | def slogdet(x):
function argpartition (line 1597) | def argpartition(x, kth, axis=-1):
function histogram (line 1601) | def histogram(x, bins=10, range=None):
FILE: keras/src/backend/jax/optimizer.py
class JaxOptimizer (line 15) | class JaxOptimizer(base_optimizer.BaseOptimizer):
method _backend_apply_gradients (line 16) | def _backend_apply_gradients(self, grads, trainable_variables):
FILE: keras/src/backend/jax/random.py
function jax_draw_seed (line 9) | def jax_draw_seed(seed):
function normal (line 16) | def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
function uniform (line 23) | def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
function categorical (line 31) | def categorical(logits, num_samples, dtype="int32", seed=None):
function randint (line 42) | def randint(shape, minval, maxval, dtype="int32", seed=None):
function truncated_normal (line 49) | def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
function _get_concrete_noise_shape (line 58) | def _get_concrete_noise_shape(inputs, noise_shape):
function dropout (line 71) | def dropout(inputs, rate, noise_shape=None, seed=None):
function shuffle (line 88) | def shuffle(x, axis=0, seed=None):
function gamma (line 93) | def gamma(shape, alpha, dtype=None, seed=None):
function binomial (line 99) | def binomial(shape, counts, probabilities, dtype=None, seed=None):
function beta (line 111) | def beta(shape, alpha, beta, dtype=None, seed=None):
FILE: keras/src/backend/jax/rnn.py
function rnn (line 10) | def rnn(
function cudnn_ok (line 214) | def cudnn_ok(*args, **kwargs):
function lstm (line 218) | def lstm(*args, **kwargs):
function gru (line 222) | def gru(*args, **kwargs):
function unstack (line 226) | def unstack(x, axis=0):
FILE: keras/src/backend/jax/sparse.py
function axis_shape_dims_for_broadcast_in_dim (line 9) | def axis_shape_dims_for_broadcast_in_dim(axis, input_shape, insert_dims):
function bcoo_add_indices (line 67) | def bcoo_add_indices(x1, x2, sum_duplicates):
function densifying_unary (line 88) | def densifying_unary(func):
function elementwise_unary (line 116) | def elementwise_unary(linear):
function elementwise_binary_union (line 163) | def elementwise_binary_union(linear, use_sparsify):
function elementwise_division (line 255) | def elementwise_division(func):
FILE: keras/src/backend/jax/tensorboard.py
function start_trace (line 4) | def start_trace(logdir):
function stop_trace (line 9) | def stop_trace(save):
function start_batch_trace (line 14) | def start_batch_trace(batch):
function stop_batch_trace (line 22) | def stop_batch_trace(batch_trace_context):
FILE: keras/src/backend/jax/trainer.py
class JAXTrainer (line 31) | class JAXTrainer(base_trainer.Trainer):
method __init__ (line 32) | def __init__(self):
method compute_loss_and_updates (line 39) | def compute_loss_and_updates(
method _update_metrics_variables (line 98) | def _update_metrics_variables(
method train_step (line 123) | def train_step(self, state, data):
method test_step (line 167) | def test_step(self, state, data):
method predict_step (line 198) | def predict_step(self, state, data):
method _make_function (line 210) | def _make_function(self, step_function, concatenate_outputs=False):
method make_train_function (line 261) | def make_train_function(self, force=False):
method make_test_function (line 287) | def make_test_function(self, force=False):
method make_predict_function (line 323) | def make_predict_function(self, force=False):
method fit (line 362) | def fit(
method evaluate (line 550) | def evaluate(
method predict (line 653) | def predict(
method train_on_batch (line 749) | def train_on_batch(
method test_on_batch (line 809) | def test_on_batch(
method predict_on_batch (line 850) | def predict_on_batch(self, x):
method jax_state_sync (line 878) | def jax_state_sync(self):
method _get_state_sharding_spec (line 904) | def _get_state_sharding_spec(self):
method _check_sharding_consistency (line 933) | def _check_sharding_consistency(
method _purge_model_variables (line 1007) | def _purge_model_variables(
method _get_jax_state (line 1037) | def _get_jax_state(
function _distribute_data (line 1064) | def _distribute_data(data, layouts=None):
class JAXEpochIterator (line 1082) | class JAXEpochIterator(EpochIterator):
method __next__ (line 1083) | def __next__(self):
method _get_iterator (line 1086) | def _get_iterator(self):
method _get_distributed_iterator (line 1095) | def _get_distributed_iterator(self, distribution):
method _one_batch_ahead_iterator (line 1108) | def _one_batch_ahead_iterator(self, numpy_iterator):
FILE: keras/src/backend/jax/trainer_test.py
class JAXTrainerTest (line 14) | class JAXTrainerTest(testing.TestCase, parameterized.TestCase):
method _skip_if_not_distributed (line 15) | def _skip_if_not_distributed(self):
method _make_distribution (line 21) | def _make_distribution(self, dist_type):
method test_warns_when_model_built_outside_scope (line 41) | def test_warns_when_model_built_outside_scope(self, dist_type):
method test_no_warning_when_model_built_inside_scope (line 82) | def test_no_warning_when_model_built_inside_scope(self, dist_type):
FILE: keras/src/backend/numpy/core.py
class Variable (line 22) | class Variable(KerasVariable):
method _initialize (line 23) | def _initialize(self, value):
method _direct_assign (line 26) | def _direct_assign(self, value):
method _convert_to_tensor (line 29) | def _convert_to_tensor(self, value, dtype=None):
method __array__ (line 33) | def __array__(self):
function convert_to_tensor (line 37) | def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
function convert_to_numpy (line 59) | def convert_to_numpy(x):
function is_tensor (line 63) | def is_tensor(x):
function shape (line 69) | def shape(x):
function cast (line 73) | def cast(x, dtype):
function cond (line 77) | def cond(pred, true_fn, false_fn):
function vectorized_map (line 83) | def vectorized_map(function, elements):
function compute_output_spec (line 95) | def compute_output_spec(fn, *args, **kwargs):
function map (line 156) | def map(f, xs):
function scan (line 164) | def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
function associative_scan (line 213) | def associative_scan(f, elems, reverse=False, axis=0):
function scatter (line 328) | def scatter(indices, values, shape):
function scatter_update (line 343) | def scatter_update(inputs, indices, updates, reduction=None):
function slice (line 362) | def slice(inputs, start_indices, shape):
function slice_update (line 382) | def slice_update(inputs, start_indices, updates):
function switch (line 395) | def switch(index, branches, *operands):
function while_loop (line 401) | def while_loop(
function fori_loop (line 423) | def fori_loop(lower, upper, body_fun, init_val):
function stop_gradient (line 430) | def stop_gradient(variable):
function unstack (line 434) | def unstack(x, num=None, axis=0):
function random_seed_dtype (line 439) | def random_seed_dtype():
class custom_gradient (line 443) | class custom_gradient:
method __init__ (line 450) | def __init__(self, fun):
method __call__ (line 458) | def __call__(self, *args, **kwargs):
function device_scope (line 464) | def device_scope(device_name):
function remat (line 468) | def remat(f):
FILE: keras/src/backend/numpy/export.py
class NumpyExportArchive (line 1) | class NumpyExportArchive:
method track (line 2) | def track(self, resource):
method add_endpoint (line 7) | def add_endpoint(self, name, fn, input_signature=None, **kwargs):
FILE: keras/src/backend/numpy/image.py
function rgb_to_grayscale (line 46) | def rgb_to_grayscale(images, data_format=None):
function rgb_to_hsv (line 68) | def rgb_to_hsv(images, data_format=None):
function hsv_to_rgb (line 118) | def hsv_to_rgb(images, data_format=None):
function resize (line 158) | def resize(
function _compute_weight_mat (line 398) | def _compute_weight_mat(
function _resize (line 440) | def _resize(image, shape, method, antialias):
function _resize_nearest (line 467) | def _resize_nearest(x, output_shape):
function _fill_triangle_kernel (line 483) | def _fill_triangle_kernel(x):
function _fill_keys_cubic_kernel (line 487) | def _fill_keys_cubic_kernel(x):
function _fill_lanczos_kernel (line 493) | def _fill_lanczos_kernel(radius, x):
function _scale_and_translate (line 511) | def _scale_and_translate(
function affine_transform (line 542) | def affine_transform(
function perspective_transform (line 651) | def perspective_transform(
function compute_homography_matrix (line 760) | def compute_homography_matrix(start_points, end_points):
function map_coordinates (line 910) | def map_coordinates(
function gaussian_blur (line 971) | def gaussian_blur(
function elastic_transform (line 1047) | def elastic_transform(
function scale_and_translate (line 1171) | def scale_and_translate(
FILE: keras/src/backend/numpy/layer.py
class NumpyLayer (line 1) | class NumpyLayer:
FILE: keras/src/backend/numpy/linalg.py
function cholesky (line 9) | def cholesky(a, upper=False):
function cholesky_inverse (line 13) | def cholesky_inverse(a, upper=False):
function det (line 23) | def det(a):
function eig (line 27) | def eig(a):
function eigh (line 31) | def eigh(a):
function inv (line 35) | def inv(a):
function lu_factor (line 39) | def lu_factor(a):
function norm (line 53) | def norm(x, ord=None, axis=None, keepdims=False):
function qr (line 63) | def qr(x, mode="reduced"):
function solve (line 73) | def solve(a, b):
function solve_triangular (line 77) | def solve_triangular(a, b, lower=False):
function svd (line 91) | def svd(x, full_matrices=True, compute_uv=True):
function lstsq (line 95) | def lstsq(a, b, rcond=None):
function jvp (line 101) | def jvp(fun, primals, tangents, has_aux=False):
FILE: keras/src/backend/numpy/math.py
function _segment_reduction_fn (line 10) | def _segment_reduction_fn(
function segment_sum (line 42) | def segment_sum(data, segment_ids, num_segments=None, sorted=False):
function segment_max (line 48) | def segment_max(data, segment_ids, num_segments=None, sorted=False):
function top_k (line 54) | def top_k(x, k, sorted=True):
function in_top_k (line 69) | def in_top_k(targets, predictions, k):
function logsumexp (line 77) | def logsumexp(x, axis=None, keepdims=False):
function qr (line 81) | def qr(x, mode="reduced"):
function extract_sequences (line 91) | def extract_sequences(x, sequence_length, sequence_stride):
function _get_complex_tensor_from_tuple (line 106) | def _get_complex_tensor_from_tuple(x):
function fft (line 134) | def fft(x):
function fft2 (line 139) | def fft2(x):
function ifft2 (line 144) | def ifft2(x):
function rfft (line 150) | def rfft(x, fft_length=None):
function irfft (line 159) | def irfft(x, fft_length=None):
function stft (line 167) | def stft(
function istft (line 231) | def istft(
function rsqrt (line 288) | def rsqrt(x):
function erf (line 292) | def erf(x):
function erfinv (line 296) | def erfinv(x):
function logdet (line 300) | def logdet(x):
FILE: keras/src/backend/numpy/nn.py
function relu (line 18) | def relu(x):
function relu6 (line 23) | def relu6(x):
function sigmoid (line 32) | def sigmoid(x):
function sparse_sigmoid (line 37) | def sparse_sigmoid(x):
function tanh (line 48) | def tanh(x):
function tanh_shrink (line 52) | def tanh_shrink(x):
function softplus (line 57) | def softplus(x):
function softsign (line 62) | def softsign(x):
function soft_shrink (line 67) | def soft_shrink(x, threshold=0.5):
function sparse_plus (line 79) | def sparse_plus(x):
function silu (line 87) | def silu(x):
function squareplus (line 92) | def squareplus(x, b=4):
function log_sigmoid (line 99) | def log_sigmoid(x):
function leaky_relu (line 104) | def leaky_relu(x, negative_slope=0.2):
function hard_sigmoid (line 109) | def hard_sigmoid(x):
function hard_silu (line 120) | def hard_silu(x):
function elu (line 124) | def elu(x, alpha=1.0):
function selu (line 131) | def selu(x):
function gelu (line 138) | def gelu(x, approximate=True):
function celu (line 160) | def celu(x, alpha=1.0):
function glu (line 168) | def glu(x, axis=-1):
function hard_tanh (line 180) | def hard_tanh(x):
function hard_shrink (line 187) | def hard_shrink(x, threshold=0.5):
function threshold (line 196) | def threshold(x, threshold, default_value):
function softmax (line 201) | def softmax(x, axis=-1):
function log_softmax (line 206) | def log_softmax(x, axis=-1):
function sparsemax (line 212) | def sparsemax(x, axis=-1):
function _convert_to_spatial_operand (line 230) | def _convert_to_spatial_operand(
function _pool (line 247) | def _pool(
function max_pool (line 287) | def max_pool(
function average_pool (line 306) | def average_pool(
function _compute_adaptive_pooling_gather_indices (line 346) | def _compute_adaptive_pooling_gather_indices(
function _strided_view_1d (line 370) | def _strided_view_1d(x, window_size):
function _adaptive_pool1d_impl (line 381) | def _adaptive_pool1d_impl(inputs, output_size, mode, data_format):
function _adaptive_pool2d_impl (line 415) | def _adaptive_pool2d_impl(inputs, output_size, mode, data_format):
function _adaptive_pool3d_impl (line 478) | def _adaptive_pool3d_impl(inputs, output_size, mode, data_format):
function adaptive_average_pool (line 562) | def adaptive_average_pool(inputs, output_size, data_format=None):
function adaptive_max_pool (line 580) | def adaptive_max_pool(inputs, output_size, data_format=None):
function _convert_to_lax_conv_dimension_numbers (line 592) | def _convert_to_lax_conv_dimension_numbers(
function conv (line 617) | def conv(
function depthwise_conv (line 677) | def depthwise_conv(
function separable_conv (line 724) | def separable_conv(
function conv_transpose (line 752) | def conv_transpose(
function one_hot (line 802) | def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
function multi_hot (line 830) | def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
function categorical_crossentropy (line 842) | def categorical_crossentropy(target, output, from_logits=False, axis=-1):
function sparse_categorical_crossentropy (line 868) | def sparse_categorical_crossentropy(target, output, from_logits=False, a...
function binary_crossentropy (line 896) | def binary_crossentropy(target, output, from_logits=False):
function moments (line 916) | def moments(x, axes, keepdims=False, synchronized=False):
function batch_normalization (line 951) | def batch_normalization(
function ctc_loss (line 972) | def ctc_loss(target, output, target_length, output_length, mask_index=0):
function _ctc_greedy_decode (line 1085) | def _ctc_greedy_decode(
function _ctc_beam_search_decode (line 1128) | def _ctc_beam_search_decode(
function ctc_decode (line 1293) | def ctc_decode(
function psnr (line 1328) | def psnr(x1, x2, max_val):
function _get_large_negative (line 1341) | def _get_large_negative(dtype):
function _apply_masks (line 1347) | def _apply_masks(logits, mask, is_causal):
function _dot_product_attention_xla (line 1367) | def _dot_product_attention_xla(query, key, value, bias, mask, is_causal,...
function dot_product_attention (line 1396) | def dot_product_attention(
function unfold (line 1438) | def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
function fold (line 1491) | def fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):
function depth_to_space (line 1548) | def depth_to_space(x, block_size, data_format="channels_last"):
function space_to_depth (line 1587) | def space_to_depth(x, block_size, data_format="channels_last"):
FILE: keras/src/backend/numpy/numpy.py
function rot90 (line 11) | def rot90(array, k=1, axes=(0, 1)):
function add (line 26) | def add(x1, x2):
function einsum (line 40) | def einsum(subscripts, *operands, **kwargs):
function subtract (line 58) | def subtract(x1, x2):
function matmul (line 72) | def matmul(x1, x2):
function multiply (line 88) | def multiply(x1, x2):
function mean (line 102) | def mean(x, axis=None, keepdims=False):
function max (line 113) | def max(x, axis=None, keepdims=False, initial=None):
function ones (line 118) | def ones(shape, dtype=None):
function zeros (line 123) | def zeros(shape, dtype=None):
function absolute (line 128) | def absolute(x):
function abs (line 132) | def abs(x):
function all (line 136) | def all(x, axis=None, keepdims=False):
function allclose (line 141) | def allclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
function angle (line 145) | def angle(x):
function any (line 155) | def any(x, axis=None, keepdims=False):
function amax (line 160) | def amax(x, axis=None, keepdims=False):
function amin (line 165) | def amin(x, axis=None, keepdims=False):
function append (line 170) | def append(x1, x2, axis=None):
function arange (line 180) | def arange(start, stop=None, step=None, dtype=None):
function arccos (line 195) | def arccos(x):
function arccosh (line 205) | def arccosh(x):
function arcsin (line 215) | def arcsin(x):
function arcsinh (line 225) | def arcsinh(x):
function arctan (line 235) | def arctan(x):
function arctan2 (line 245) | def arctan2(x1, x2):
function arctanh (line 254) | def arctanh(x):
function argmax (line 264) | def argmax(x, axis=None, keepdims=False):
function argmin (line 278) | def argmin(x, axis=None, keepdims=False):
function argsort (line 292) | def argsort(x, axis=-1):
function array (line 297) | def array(x, dtype=None):
function view (line 301) | def view(x, dtype=None):
function average (line 306) | def average(x, axis=None, weights=None):
function bartlett (line 320) | def bartlett(x):
function hamming (line 325) | def hamming(x):
function hanning (line 330) | def hanning(x):
function heaviside (line 335) | def heaviside(x1, x2):
function kaiser (line 348) | def kaiser(x, beta):
function bincount (line 353) | def bincount(x, weights=None, minlength=0, sparse=False):
function bitwise_and (line 384) | def bitwise_and(x, y):
function bitwise_invert (line 393) | def bitwise_invert(x):
function bitwise_not (line 398) | def bitwise_not(x):
function bitwise_or (line 402) | def bitwise_or(x, y):
function bitwise_xor (line 411) | def bitwise_xor(x, y):
function bitwise_left_shift (line 420) | def bitwise_left_shift(x, y):
function left_shift (line 430) | def left_shift(x, y):
function bitwise_right_shift (line 434) | def bitwise_right_shift(x, y):
function right_shift (line 444) | def right_shift(x, y):
function blackman (line 448) | def blackman(x):
function broadcast_to (line 453) | def broadcast_to(x, shape):
function cbrt (line 457) | def cbrt(x):
function ceil (line 469) | def ceil(x):
function clip (line 479) | def clip(x, x_min, x_max):
function concatenate (line 487) | def concatenate(xs, axis=0):
function conjugate (line 498) | def conjugate(x):
function conj (line 502) | def conj(x):
function copy (line 506) | def copy(x):
function cos (line 510) | def cos(x):
function cosh (line 520) | def cosh(x):
function count_nonzero (line 530) | def count_nonzero(x, axis=None):
function cross (line 537) | def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
function cumprod (line 554) | def cumprod(x, axis=None, dtype=None):
function cumsum (line 562) | def cumsum(x, axis=None, dtype=None):
function deg2rad (line 570) | def deg2rad(x):
function diag (line 583) | def diag(x, k=0):
function diagflat (line 587) | def diagflat(x, k=0):
function diagonal (line 591) | def diagonal(x, offset=0, axis1=0, axis2=1):
function diff (line 597) | def diff(a, n=1, axis=-1):
function digitize (line 601) | def digitize(x, bins):
function dot (line 605) | def dot(x1, x2):
function dstack (line 614) | def dstack(xs):
function empty (line 624) | def empty(shape, dtype=None):
function empty_like (line 629) | def empty_like(x, dtype=None):
function equal (line 633) | def equal(x1, x2):
function exp (line 637) | def exp(x):
function exp2 (line 645) | def exp2(x):
function expand_dims (line 653) | def expand_dims(x, axis):
function expm1 (line 658) | def expm1(x):
function flip (line 666) | def flip(x, axis=None):
function floor (line 671) | def floor(x):
function full (line 682) | def full(shape, fill_value, dtype=None):
function full_like (line 687) | def full_like(x, fill_value, dtype=None):
function gcd (line 691) | def gcd(x1, x2):
function geomspace (line 699) | def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
function greater (line 706) | def greater(x1, x2):
function greater_equal (line 710) | def greater_equal(x1, x2):
function hstack (line 714) | def hstack(xs):
function hsplit (line 724) | def hsplit(x, indices_or_sections):
function hypot (line 729) | def hypot(x1, x2):
function identity (line 742) | def identity(n, dtype=None):
function imag (line 747) | def imag(x):
function isclose (line 751) | def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):
function isfinite (line 755) | def isfinite(x):
function isin (line 759) | def isin(x1, x2, assume_unique=False, invert=False):
function isinf (line 765) | def isinf(x):
function isnan (line 769) | def isnan(x):
function isneginf (line 773) | def isneginf(x):
function isposinf (line 778) | def isposinf(x):
function isreal (line 783) | def isreal(x):
function kron (line 788) | def kron(x1, x2):
function lcm (line 795) | def lcm(x1, x2):
function ldexp (line 802) | def ldexp(x1, x2):
function less (line 815) | def less(x1, x2):
function less_equal (line 819) | def less_equal(x1, x2):
function linspace (line 823) | def linspace(
function log (line 845) | def log(x):
function log10 (line 855) | def log10(x):
function log1p (line 865) | def log1p(x):
function log2 (line 875) | def log2(x):
function logaddexp (line 885) | def logaddexp(x1, x2):
function logaddexp2 (line 894) | def logaddexp2(x1, x2):
function logical_and (line 901) | def logical_and(x1, x2):
function logical_not (line 905) | def logical_not(x):
function logical_or (line 909) | def logical_or(x1, x2):
function logspace (line 913) | def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, ax...
function maximum (line 932) | def maximum(x1, x2):
function median (line 946) | def median(x, axis=None, keepdims=False):
function meshgrid (line 951) | def meshgrid(*x, indexing="xy"):
function min (line 955) | def min(x, axis=None, keepdims=False, initial=None):
function minimum (line 960) | def minimum(x1, x2):
function mod (line 974) | def mod(x1, x2):
function fmod (line 985) | def fmod(x1, x2):
function moveaxis (line 996) | def moveaxis(x, source, destination):
function nanargmax (line 1000) | def nanargmax(x, axis=None, keepdims=False):
function nanargmin (line 1015) | def nanargmin(x, axis=None, keepdims=False):
function nancumsum (line 1030) | def nancumsum(x, axis=None, dtype=None):
function nancumprod (line 1038) | def nancumprod(x, axis=None, dtype=None):
function nanmax (line 1046) | def nanmax(x, axis=None, keepdims=False):
function nanmean (line 1050) | def nanmean(x, axis=None, keepdims=False):
function nanmin (line 1055) | def nanmin(x, axis=None, keepdims=False):
function nanprod (line 1059) | def nanprod(x, axis=None, keepdims=False):
function nanstd (line 1072) | def nanstd(x, axis=None, keepdims=False):
function nansum (line 1082) | def nansum(x, axis=None, keepdims=False):
function nanvar (line 1093) | def nanvar(x, axis=None, keepdims=False):
function nan_to_num (line 1103) | def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
function ndim (line 1107) | def ndim(x):
function nonzero (line 1111) | def nonzero(x):
function not_equal (line 1115) | def not_equal(x1, x2):
function zeros_like (line 1119) | def zeros_like(x, dtype=None):
function ones_like (line 1123) | def ones_like(x, dtype=None):
function outer (line 1127) | def outer(x1, x2):
function pad (line 1136) | def pad(x, pad_width, mode="constant", constant_values=None):
function prod (line 1149) | def prod(x, axis=None, keepdims=False, dtype=None):
function ptp (line 1161) | def ptp(x, axis=None, keepdims=False):
function quantile (line 1165) | def quantile(x, q, axis=None, method="linear", keepdims=False):
function ravel (line 1182) | def ravel(x):
function unravel_index (line 1186) | def unravel_index(indices, shape):
function real (line 1193) | def real(x):
function reciprocal (line 1197) | def reciprocal(x):
function repeat (line 1201) | def repeat(x, repeats, axis=None):
function reshape (line 1205) | def reshape(x, newshape):
function roll (line 1209) | def roll(x, shift, axis=None):
function searchsorted (line 1213) | def searchsorted(sorted_sequence, values, side="left"):
function sign (line 1229) | def sign(x):
function signbit (line 1233) | def signbit(x):
function sin (line 1237) | def sin(x):
function sinc (line 1247) | def sinc(x):
function sinh (line 1257) | def sinh(x):
function size (line 1267) | def size(x):
function sort (line 1271) | def sort(x, axis=-1):
function split (line 1276) | def split(x, indices_or_sections, axis=0):
function array_split (line 1281) | def array_split(x, indices_or_sections, axis=0):
function stack (line 1286) | def stack(x, axis=0):
function std (line 1295) | def std(x, axis=None, keepdims=False):
function swapaxes (line 1304) | def swapaxes(x, axis1, axis2):
function take (line 1308) | def take(x, indices, axis=None):
function take_along_axis (line 1313) | def take_along_axis(x, indices, axis=None):
function tan (line 1318) | def tan(x):
function tanh (line 1328) | def tanh(x):
function tensordot (line 1338) | def tensordot(x1, x2, axes=2):
function round (line 1348) | def round(x, decimals=0):
function tile (line 1352) | def tile(x, repeats):
function trace (line 1356) | def trace(x, offset=0, axis1=0, axis2=1):
function tri (line 1368) | def tri(N, M=None, k=0, dtype=None):
function tril (line 1373) | def tril(x, k=0):
function triu (line 1377) | def triu(x, k=0):
function trunc (line 1381) | def trunc(x):
function vdot (line 1389) | def vdot(x1, x2):
function inner (line 1398) | def inner(x1, x2):
function vstack (line 1407) | def vstack(xs):
function vsplit (line 1417) | def vsplit(x, indices_or_sections):
function vectorize (line 1422) | def vectorize(pyfunc, *, excluded=None, signature=None):
function where (line 1426) | def where(condition, x1=None, x2=None):
function divide (line 1443) | def divide(x1, x2):
function divide_no_nan (line 1458) | def divide_no_nan(x1, x2):
function true_divide (line 1475) | def true_divide(x1, x2):
function power (line 1479) | def power(x1, x2):
function negative (line 1493) | def negative(x):
function nextafter (line 1497) | def nextafter(x1, x2):
function square (line 1505) | def square(x):
function sqrt (line 1512) | def sqrt(x):
function squeeze (line 1523) | def squeeze(x, axis=None):
function transpose (line 1528) | def transpose(x, axes=None):
function trapezoid (line 1533) | def trapezoid(y, x=None, dx=1.0, axis=-1):
function vander (line 1542) | def vander(x, N=None, increasing=False):
function var (line 1550) | def var(x, axis=None, keepdims=False):
function sum (line 1560) | def sum(x, axis=None, keepdims=False):
function eye (line 1571) | def eye(N, M=None, k=0, dtype=None):
function floor_divide (line 1576) | def floor_divide(x1, x2):
function logical_xor (line 1589) | def logical_xor(x1, x2):
function corrcoef (line 1593) | def corrcoef(x):
function correlate (line 1606) | def correlate(x1, x2, mode="valid"):
function select (line 1621) | def select(condlist, choicelist, default=0):
function slogdet (line 1625) | def slogdet(x):
function argpartition (line 1629) | def argpartition(x, kth, axis=-1):
function histogram (line 1633) | def histogram(x, bins=10, range=None):
FILE: keras/src/backend/numpy/random.py
function normal (line 10) | def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
function uniform (line 17) | def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
function categorical (line 24) | def categorical(logits, num_samples, dtype="int64", seed=None):
function randint (line 36) | def randint(shape, minval, maxval, dtype="int32", seed=None):
function truncated_normal (line 43) | def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
function dropout (line 69) | def dropout(inputs, rate, noise_shape=None, seed=None):
function shuffle (line 97) | def shuffle(x, axis=0, seed=None):
function gamma (line 103) | def gamma(shape, alpha, dtype=None, seed=None):
function binomial (line 110) | def binomial(shape, counts, probabilities, dtype=None, seed=None):
function beta (line 118) | def beta(shape, alpha, beta, dtype=None, seed=None):
FILE: keras/src/backend/numpy/rnn.py
function rnn (line 6) | def rnn(
function lstm (line 203) | def lstm(*args, **kwargs):
function gru (line 207) | def gru(*args, **kwargs):
function unstack (line 211) | def unstack(x, axis=0):
function numpy_scan (line 215) | def numpy_scan(f, init, xs, reverse=False, mask=None):
function cudnn_ok (line 242) | def cudnn_ok(*args, **kwargs):
FILE: keras/src/backend/numpy/trainer.py
class NumpyTrainer (line 16) | class NumpyTrainer(base_trainer.Trainer):
method __init__ (line 17) | def __init__(self):
method test_step (line 22) | def test_step(self, data):
method predict_step (line 40) | def predict_step(self, data):
method make_test_function (line 48) | def make_test_function(self, force=False):
method make_predict_function (line 68) | def make_predict_function(self, force=False):
method _symbolic_build (line 95) | def _symbolic_build(self, data_batch):
method fit (line 151) | def fit(
method predict (line 173) | def predict(
method evaluate (line 226) | def evaluate(
method train_on_batch (line 293) | def train_on_batch(
method test_on_batch (line 305) | def test_on_batch(
method predict_on_batch (line 326) | def predict_on_batch(self, x):
FILE: keras/src/backend/openvino/core.py
function align_operand_types (line 78) | def align_operand_types(x1, x2, op_name):
function get_ov_output (line 99) | def get_ov_output(x, ov_type=None):
class OpenVINOKerasTensor (line 142) | class OpenVINOKerasTensor:
method __init__ (line 143) | def __init__(self, x, data=None):
method __add__ (line 165) | def __add__(self, other):
method __radd__ (line 173) | def __radd__(self, other):
method __sub__ (line 181) | def __sub__(self, other):
method __rsub__ (line 193) | def __rsub__(self, other):
method __mul__ (line 201) | def __mul__(self, other):
method __rmul__ (line 213) | def __rmul__(self, other):
method __truediv__ (line 225) | def __truediv__(self, other):
method __rtruediv__ (line 233) | def __rtruediv__(self, other):
method __floordiv__ (line 241) | def __floordiv__(self, other):
method __rfloordiv__ (line 249) | def __rfloordiv__(self, other):
method __neg__ (line 257) | def __neg__(self):
method __abs__ (line 261) | def __abs__(self):
method __invert__ (line 265) | def __invert__(self):
method __pow__ (line 269) | def __pow__(self, other):
method __rpow__ (line 277) | def __rpow__(self, other):
method __lt__ (line 285) | def __lt__(self, other):
method __gt__ (line 293) | def __gt__(self, other):
method __le__ (line 301) | def __le__(self, other):
method __ge__ (line 309) | def __ge__(self, other):
method __eq__ (line 319) | def __eq__(self, other):
method __ne__ (line 327) | def __ne__(self, other):
method __getitem__ (line 335) | def __getitem__(self, indices):
method __len__ (line 494) | def __len__(self):
method __iter__ (line 509) | def __iter__(self):
method __bool__ (line 515) | def __bool__(self):
method __mod__ (line 518) | def __mod__(self, other):
method __array__ (line 526) | def __array__(self, dtype=None):
method numpy (line 538) | def numpy(self):
method __rmod__ (line 541) | def __rmod__(self, other):
method __matmul__ (line 549) | def __matmul__(self, other):
method __rmatmul__ (line 559) | def __rmatmul__(self, other):
method __div__ (line 569) | def __div__(self, other):
method __rdiv__ (line 572) | def __rdiv__(self, other):
method __and__ (line 575) | def __and__(self, other):
method __rand__ (line 583) | def __rand__(self, other):
method __or__ (line 591) | def __or__(self, other):
method __ror__ (line 599) | def __ror__(self, other):
method __xor__ (line 607) | def __xor__(self, other):
method __rxor__ (line 615) | def __rxor__(self, other):
method __int__ (line 623) | def __int__(self):
method __float__ (line 632) | def __float__(self):
method __repr__ (line 641) | def __repr__(self):
method __round__ (line 644) | def __round__(self, ndigits=None):
method reshape (line 656) | def reshape(self, new_shape):
method squeeze (line 663) | def squeeze(self, axis=None):
function ov_to_keras_type (line 674) | def ov_to_keras_type(ov_type):
function device_scope (line 684) | def device_scope(device_name):
function get_device (line 688) | def get_device():
class Variable (line 692) | class Variable(KerasVariable):
method _initialize (line 693) | def _initialize(self, value):
method _direct_assign (line 707) | def _direct_assign(self, value):
method _convert_to_tensor (line 710) | def _convert_to_tensor(self, value, dtype=None):
method __array__ (line 713) | def __array__(self):
method __getitem__ (line 716) | def __getitem__(self, idx):
method __int__ (line 720) | def __int__(self):
method __float__ (line 729) | def __float__(self):
function _is_scalar (line 739) | def _is_scalar(elem):
function convert_to_tensor (line 743) | def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
function convert_to_numpy (line 799) | def convert_to_numpy(x):
function is_tensor (line 843) | def is_tensor(x):
function shape (line 851) | def shape(x):
function cast (line 855) | def cast(x, dtype):
function cond (line 862) | def cond(pred, true_fn, false_fn):
function vectorized_map (line 891) | def vectorized_map(function, elements):
function compute_output_spec (line 896) | def compute_output_spec(fn, *args, **kwargs):
function map (line 933) | def map(f, xs):
function scan (line 941) | def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
function associative_scan (line 1009) | def associative_scan(f, elems, reverse=False, axis=0):
function scatter (line 1150) | def scatter(indices, values, shape):
function scatter_update (line 1162) | def scatter_update(inputs, indices, updates, reduction=None):
function slice (line 1188) | def slice(inputs, start_indices, shape):
function slice_update (line 1254) | def slice_update(inputs, start_indices, updates):
function switch (line 1419) | def switch(index, branches, *operands):
function while_loop (line 1498) | def while_loop(
function fori_loop (line 1612) | def fori_loop(lower, upper, body_fun, init_val):
function stop_gradient (line 1620) | def stop_gradient(variable):
function unstack (line 1624) | def unstack(x, num=None, axis=0):
function random_seed_dtype (line 1640) | def random_seed_dtype():
function custom_gradient (line 1648) | def custom_gradient(fun):
function remat (line 1670) | def remat(f):
FILE: keras/src/backend/openvino/export.py
class OpenvinoExportArchive (line 1) | class OpenvinoExportArchive:
method track (line 2) | def track(self, resource):
method add_endpoint (line 7) | def add_endpoint(self, name, fn, input_signature=None, **kwargs):
FILE: keras/src/backend/openvino/image.py
function rgb_to_grayscale (line 8) | def rgb_to_grayscale(images, data_format=None):
function rgb_to_hsv (line 48) | def rgb_to_hsv(images, data_format=None):
function hsv_to_rgb (line 142) | def hsv_to_rgb(images, data_format=None):
function resize (line 211) | def resize(
function affine_transform (line 225) | def affine_transform(
function perspective_transform (line 238) | def perspective_transform(
function map_coordinates (line 251) | def map_coordinates(
function gaussian_blur (line 259) | def gaussian_blur(
function elastic_transform (line 267) | def elastic_transform(
function scale_and_translate (line 282) | def scale_and_translate(
FILE: keras/src/backend/openvino/layer.py
class OpenvinoLayer (line 1) | class OpenvinoLayer:
FILE: keras/src/backend/openvino/linalg.py
function cholesky (line 15) | def cholesky(a, upper=False):
function cholesky_inverse (line 21) | def cholesky_inverse(a, upper=False):
function det (line 38) | def det(a):
function eig (line 155) | def eig(a):
function eigh (line 159) | def eigh(a):
function inv (line 429) | def inv(a):
function lu_factor (line 436) | def lu_factor(a):
function norm (line 442) | def norm(x, ord=None, axis=None, keepdims=False):
function qr (line 614) | def qr(x, mode="reduced"):
function solve (line 618) | def solve(a, b):
function solve_triangular (line 635) | def solve_triangular(a, b, lower=False):
function svd (line 641) | def svd(x, full_matrices=True, compute_uv=True):
function lstsq (line 645) | def lstsq(a, b, rcond=None):
function jvp (line 649) | def jvp(fun, primals, tangents, has_aux=False):
FILE: keras/src/backend/openvino/math.py
function _segment_reduction_fn (line 15) | def _segment_reduction_fn(
function segment_sum (line 103) | def segment_sum(data, segment_ids, num_segments=None, sorted=False):
function segment_max (line 107) | def segment_max(data, segment_ids, num_segments=None, sorted=False):
function top_k (line 111) | def top_k(x, k, sorted=True):
function in_top_k (line 122) | def in_top_k(targets, predictions, k):
function logsumexp (line 147) | def logsumexp(x, axis=None, keepdims=False):
function qr (line 172) | def qr(x, mode="reduced"):
function extract_sequences (line 176) | def extract_sequences(x, sequence_length, sequence_stride):
function _dft (line 241) | def _dft(x, axes_offsets, inverse=False):
function fft (line 284) | def fft(x):
function fft2 (line 289) | def fft2(x):
function ifft2 (line 294) | def ifft2(x):
function rfft (line 299) | def rfft(x, fft_length=None):
function irfft (line 339) | def irfft(x, fft_length=None):
function stft (line 360) | def stft(
function _overlap_sequences_ov (line 462) | def _overlap_sequences_ov(x, sequence_stride, fft_length):
function istft (line 616) | def istft(
function rsqrt (line 764) | def rsqrt(x):
function erf (line 771) | def erf(x):
function erfinv (line 777) | def erfinv(x):
FILE: keras/src/backend/openvino/nn.py
function relu (line 15) | def relu(x):
function relu6 (line 20) | def relu6(x):
function celu (line 25) | def celu(x, alpha=1.0):
function sigmoid (line 42) | def sigmoid(x):
function tanh (line 47) | def tanh(x):
function tanh_shrink (line 52) | def tanh_shrink(x):
function hard_tanh (line 57) | def hard_tanh(x):
function soft_shrink (line 62) | def soft_shrink(x, threshold=0.5):
function hard_shrink (line 75) | def hard_shrink(x, threshold=0.5):
function softplus (line 85) | def softplus(x):
function softsign (line 90) | def softsign(x):
function silu (line 95) | def silu(x):
function log_sigmoid (line 101) | def log_sigmoid(x):
function leaky_relu (line 109) | def leaky_relu(x, negative_slope=0.2):
function sparse_sigmoid (line 118) | def sparse_sigmoid(x):
function hard_sigmoid (line 129) | def hard_sigmoid(x):
function hard_silu (line 136) | def hard_silu(x):
function elu (line 144) | def elu(x, alpha=1.0):
function selu (line 149) | def selu(x):
function gelu (line 158) | def gelu(x, approximate=True):
function softmax (line 166) | def softmax(x, axis=-1):
function log_softmax (line 187) | def log_softmax(x, axis=-1):
function squareplus (line 214) | def squareplus(x, b=4):
function sparse_plus (line 227) | def sparse_plus(x):
function threshold (line 246) | def threshold(x, threshold, default_value):
function max_pool (line 256) | def max_pool(
function average_pool (line 278) | def average_pool(
function _compute_adaptive_gather_indices (line 296) | def _compute_adaptive_gather_indices(
function _adaptive_pool_ov (line 319) | def _adaptive_pool_ov(
function adaptive_average_pool (line 418) | def adaptive_average_pool(inputs, output_size, data_format=None):
function adaptive_max_pool (line 430) | def adaptive_max_pool(inputs, output_size, data_format=None):
function _pool (line 442) | def _pool(
function _adjust_strides_dilation (line 477) | def _adjust_strides_dilation(
function _adjust_padding (line 489) | def _adjust_padding(
function _adjust_input (line 507) | def _adjust_input(inputs, num_spatial_dims, data_format):
function _adjust_kernel (line 520) | def _adjust_kernel(kernel, num_spatial_dims):
function _adjust_depthwise_kernel (line 531) | def _adjust_depthwise_kernel(kernel, num_spatial_dims):
function _adjust_outputs (line 546) | def _adjust_outputs(outputs, num_spatial_dims, data_format):
function conv (line 560) | def conv(
function depthwise_conv (line 641) | def depthwise_conv(
function separable_conv (line 679) | def separable_conv(
function conv_transpose (line 707) | def conv_transpose(
function one_hot (line 782) | def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
function multi_hot (line 801) | def multi_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
function categorical_crossentropy (line 812) | def categorical_crossentropy(target, output, from_logits=False, axis=-1):
function sparse_categorical_crossentropy (line 844) | def sparse_categorical_crossentropy(target, output, from_logits=False, a...
function binary_crossentropy (line 894) | def binary_crossentropy(target, output, from_logits=False):
function moments (line 938) | def moments(x, axes, keepdims=False, synchronized=False):
function batch_normalization (line 964) | def batch_normalization(
function ctc_loss (line 1007) | def ctc_loss(target, output, target_length, output_length, mask_index=0):
function ctc_decode (line 1019) | def ctc_decode(
function psnr (line 1033) | def psnr(x1, x2, max_val):
function dot_product_attention (line 1057) | def dot_product_attention(
function unfold (line 1110) | def unfold(input, kernel_size, dilation=1, padding=0, stride=1):
function fold (line 1142) | def fold(x, output_size, kernel_size, dilation=1, padding=0, stride=1):
function depth_to_space (line 1213) | def depth_to_space(x, block_size, data_format="channels_last"):
function space_to_depth (line 1249) | def space_to_depth(x, block_size, data_format="channels_last"):
FILE: keras/src/backend/openvino/numpy.py
function _promote_binary_op_types (line 23) | def _promote_binary_op_types(x1, x2):
function add (line 40) | def add(x1, x2):
function einsum (line 48) | def einsum(subscripts, *operands, **kwargs):
function subtract (line 80) | def subtract(x1, x2):
function matmul (line 88) | def matmul(x1, x2):
function multiply (line 110) | def multiply(x1, x2):
function mean (line 118) | def mean(x, axis=None, keepdims=False):
function max (line 146) | def max(x, axis=None, keepdims=False, initial=None):
function _compute_extrema (line 150) | def _compute_extrema(x, operation, axis=None, keepdims=False, initial=No...
function ones (line 196) | def ones(shape, dtype=None):
function zeros (line 209) | def zeros(shape, dtype=None):
function absolute (line 222) | def absolute(x):
function abs (line 230) | def abs(x):
function all (line 235) | def all(x, axis=None, keepdims=False):
function allclose (line 246) | def allclose(x1, x2, rtol=1e-05, atol=1e-08, equal_nan=False):
function angle (line 268) | def angle(x):
function any (line 286) | def any(x, axis=None, keepdims=False):
function amax (line 297) | def amax(x, axis=None, keepdims=False):
function amin (line 310) | def amin(x, axis=None, keepdims=False):
function _resolve_axis (line 323) | def _resolve_axis(x, axis):
function _upcast_type_if_needed (line 336) | def _upcast_type_if_needed(x):
function append (line 347) | def append(x1, x2, axis=None):
function arange (line 358) | def arange(start, stop=None, step=None, dtype=None):
function arccos (line 388) | def arccos(x):
function arccosh (line 397) | def arccosh(x):
function arcsin (line 406) | def arcsin(x):
function arcsinh (line 415) | def arcsinh(x):
function arctan (line 424) | def arctan(x):
function arctan2 (line 433) | def arctan2(x1, x2):
function arctanh (line 483) | def arctanh(x):
function argmax (line 492) | def argmax(x, axis=None, keepdims=False):
function argmin (line 524) | def argmin(x, axis=None, keepdims=False):
function argsort (line 556) | def argsort(x, axis=-1):
function array (line 589) | def array(x, dtype=None):
function view (line 593) | def view(x, dtype=None):
function average (line 597) | def average(x, axis=None, weights=None):
function bartlett (line 629) | def bartlett(x):
function hamming (line 654) | def hamming(x):
function hanning (line 686) | def hanning(x):
function heaviside (line 729) | def heaviside(x1, x2):
function _i0_node (line 767) | def _i0_node(x):
function kaiser (line 813) | def kaiser(x, beta):
function bitwise_left_shift (line 853) | def bitwise_left_shift(x, y):
function left_shift (line 865) | def left_shift(x, y):
function bitwise_right_shift (line 869) | def bitwise_right_shift(x, y):
function right_shift (line 881) | def right_shift(x, y):
function bincount (line 885) | def bincount(x, weights=None, minlength=0, sparse=False):
function _bitwise_op_i8u8 (line 930) | def _bitwise_op_i8u8(ov_op, x, y):
function bitwise_and (line 949) | def bitwise_and(x, y):
function bitwise_xor (line 956) | def bitwise_xor(x, y):
function bitwise_invert (line 963) | def bitwise_invert(x):
function bitwise_not (line 968) | def bitwise_not(x):
function bitwise_or (line 972) | def bitwise_or(x, y):
function blackman (line 979) | def blackman(x):
function broadcast_to (line 1006) | def broadcast_to(x, shape):
function cbrt (line 1017) | def cbrt(x):
function ceil (line 1030) | def ceil(x):
function clip (line 1039) | def clip(x, x_min, x_max):
function concatenate (line 1051) | def concatenate(xs, axis=0):
function conjugate (line 1073) | def conjugate(x):
function conj (line 1079) | def conj(x):
function copy (line 1083) | def copy(x):
function cos (line 1087) | def cos(x):
function cosh (line 1096) | def cosh(x):
function count_nonzero (line 1105) | def count_nonzero(x, axis=None):
function cross (line 1117) | def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
function cumprod (line 1192) | def cumprod(x, axis=None, dtype=None):
function cumsum (line 1280) | def cumsum(x, axis=None, dtype=None):
function deg2rad (line 1291) | def deg2rad(x):
function diag (line 1312) | def diag(x, k=0):
function diagflat (line 1379) | def diagflat(x, k=0):
function diagonal (line 1441) | def diagonal(x, offset=0, axis1=0, axis2=1):
function diff (line 1491) | def diff(a, n=1, axis=-1):
function digitize (line 1564) | def digitize(x, bins):
function dot (line 1591) | def dot(x1, x2):
function dstack (line 1605) | def dstack(xs):
function empty (line 1647) | def empty(shape, dtype=None):
function empty_like (line 1660) | def empty_like(x, dtype=None):
function equal (line 1664) | def equal(x1, x2):
function exp (line 1676) | def exp(x):
function exp2 (line 1685) | def exp2(x):
function expand_dims (line 1696) | def expand_dims(x, axis):
function expm1 (line 1704) | def expm1(x):
function flip (line 1716) | def flip(x, axis=None):
function rot90 (line 1752) | def rot90(array, k=1, axes=(0, 1)):
function floor (line 1807) | def floor(x):
function full (line 1815) | def full(shape, fill_value, dtype=None):
function full_like (line 1827) | def full_like(x, fill_value, dtype=None):
function gcd (line 1839) | def gcd(x1, x2):
function geomspace (line 1897) | def geomspace(start, stop, num=50, endpoint=True, dtype=None, axis=0):
function greater (line 1941) | def greater(x1, x2):
function greater_equal (line 1953) | def greater_equal(x1, x2):
function hstack (line 1965) | def hstack(xs):
function hsplit (line 1980) | def hsplit(x, indices_or_sections):
function hypot (line 1987) | def hypot(x1, x2):
function identity (line 2018) | def identity(n, dtype=None):
function imag (line 2032) | def imag(x):
function inner (line 2038) | def inner(x1, x2):
function isclose (line 2060) | def isclose(x1, x2, rtol=1e-5, atol=1e-8, equal_nan=False):
function isfinite (line 2079) | def isfinite(x):
function isin (line 2092) | def isin(x1, x2, assume_unique=False, invert=False):
function isinf (line 2114) | def isinf(x):
function isnan (line 2121) | def isnan(x):
function isneginf (line 2129) | def isneginf(x):
function isposinf (line 2133) | def isposinf(x):
function _is_inf (line 2137) | def _is_inf(x, pos=True):
function isreal (line 2171) | def isreal(x):
function kron (line 2177) | def kron(x1, x2):
function lcm (line 2240) | def lcm(x1, x2):
function ldexp (line 2264) | def ldexp(x1, x2):
function less (line 2286) | def less(x1, x2):
function less_equal (line 2298) | def less_equal(x1, x2):
function linspace (line 2310) | def linspace(
function log (line 2440) | def log(x):
function log10 (line 2449) | def log10(x):
function log1p (line 2462) | def log1p(x):
function log2 (line 2476) | def log2(x):
function logaddexp (line 2489) | def logaddexp(x1, x2):
function logaddexp2 (line 2539) | def logaddexp2(x1, x2):
function logical_and (line 2578) | def logical_and(x1, x2):
function logical_not (line 2586) | def logical_not(x):
function logical_or (line 2592) | def logical_or(x1, x2):
function logspace (line 2600) | def logspace(start, stop, num=50, endpoint=True, base=10, dtype=None, ax...
function maximum (line 2626) | def maximum(x1, x2):
function median (line 2632) | def median(x, axis=None, keepdims=False):
function meshgrid (line 2767) | def meshgrid(*x, indexing="xy"):
function min (line 2811) | def min(x, axis=None, keepdims=False, initial=None):
function minimum (line 2815) | def minimum(x1, x2):
function mod (line 2821) | def mod(x1, x2):
function fmod (line 2828) | def fmod(x1, x2):
function moveaxis (line 2835) | def moveaxis(x, source, destination):
function nanargmax (line 2855) | def nanargmax(x, axis=None, keepdims=False):
function nanargmin (line 2899) | def nanargmin(x, axis=None, keepdims=False):
function nancumsum (line 2943) | def nancumsum(x, axis=None, dtype=None):
function nancumprod (line 2947) | def nancumprod(x, axis=None, dtype=None):
function nanmax (line 2951) | def nanmax(x, axis=None, keepdims=False):
function nanmean (line 2986) | def nanmean(x, axis=None, keepdims=False):
function nanmin (line 3017) | def nanmin(x, axis=None, keepdims=False):
function nanprod (line 3052) | def nanprod(x, axis=None, keepdims=False):
function nanstd (line 3077) | def nanstd(x, axis=None, keepdims=False):
function nansum (line 3081) | def nansum(x, axis=None, keepdims=False):
function nanvar (line 3100) | def nanvar(x, axis=None, keepdims=False):
function nan_to_num (line 3161) | def nan_to_num(x, nan=0.0, posinf=None, neginf=None):
function ndim (line 3194) | def ndim(x):
function nonzero (line 3201) | def nonzero(x):
function not_equal (line 3207) | def not_equal(x1, x2):
function zeros_like (line 3219) | def zeros_like(x, dtype=None):
function ones_like (line 3231) | def ones_like(x, dtype=None):
function outer (line 3243) | def outer(x1, x2):
function pad (line 3261) | def pad(x, pad_width, mode="constant", constant_values=None):
function prod (line 3294) | def prod(x, axis=None, keepdims=False, dtype=None):
function ptp (line 3313) | def ptp(x, axis=None, keepdims=False):
function quantile (line 3326) | def quantile(x, q, axis=None, method="linear", keepdims=False):
function ravel (line 3477) | def ravel(x):
function real (line 3485) | def real(x):
function reciprocal (line 3491) | def reciprocal(x):
function repeat (line 3498) | def repeat(x, repeats, axis=None):
function reshape (line 3557) | def reshape(x, newshape):
function roll (line 3572) | def roll(x, shift, axis=None):
function searchsorted (line 3586) | def searchsorted(sorted_sequence, values, side="left"):
function sign (line 3620) | def sign(x):
function signbit (line 3625) | def signbit(x):
function sin (line 3640) | def sin(x):
function sinc (line 3649) | def sinc(x):
function sinh (line 3667) | def sinh(x):
function size (line 3676) | def size(x):
function sort (line 3687) | def sort(x, axis=-1):
function split (line 3727) | def split(x, indices_or_sections, axis=0):
function array_split (line 3778) | def array_split(x, indices_or_sections, axis=0):
function stack (line 3809) | def stack(x, axis=0):
function std (line 3827) | def std(x, axis=None, keepdims=False):
function swapaxes (line 3833) | def swapaxes(x, axis1, axis2):
function take (line 3850) | def take(x, indices, axis=None):
function take_along_axis (line 3862) | def take_along_axis(x, indices, axis=None):
function tan (line 3933) | def tan(x):
function tanh (line 3942) | def tanh(x):
function tensordot (line 3951) | def tensordot(x1, x2, axes=2):
function round (line 4030) | def round(x, decimals=0):
function trunc (line 4050) | def trunc(x):
function tile (line 4062) | def tile(x, repeats):
function trace (line 4096) | def trace(x, offset=0, axis1=0, axis2=1):
function tri (line 4101) | def tri(N, M=None, k=0, dtype=None):
function tril (line 4151) | def tril(x, k=0):
function triu (line 4169) | def triu(x, k=0):
function vdot (line 4192) | def vdot(x1, x2):
function vstack (line 4209) | def vstack(xs):
function vsplit (line 4223) | def vsplit(x, indices_or_sections):
function vectorize (line 4227) | def vectorize(pyfunc, *, excluded=None, signature=None):
function where (line 4235) | def where(condition, x1=None, x2=None):
function divide (line 4264) | def divide(x1, x2):
function divide_no_nan (line 4281) | def divide_no_nan(x1, x2):
function true_divide (line 4298) | def true_divide(x1, x2):
function power (line 4302) | def power(x1, x2):
function negative (line 4320) | def negative(x):
function nextafter (line 4325) | def nextafter(x1, x2):
function square (line 4384) | def square(x):
function sqrt (line 4393) | def sqrt(x):
function squeeze (line 4402) | def squeeze(x, axis=None):
function transpose (line 4415) | def transpose(x, axes=None):
function _helper_trapezoid (line 4435) | def _helper_trapezoid(y, axis):
function trapezoid (line 4467) | def trapezoid(y, x=None, dx=1.0, axis=-1):
function unravel_index (line 4500) | def unravel_index(indices, shape):
function vander (line 4535) | def vander(x, N=None, increasing=False):
function var (line 4577) | def var(x, axis=None, keepdims=False):
function sum (line 4608) | def sum(x, axis=None, keepdims=False):
function eye (line 4618) | def eye(N, M=None, k=0, dtype=None):
function floor_divide (line 4633) | def floor_divide(x1, x2):
function logical_xor (line 4652) | def logical_xor(x1, x2):
function corrcoef (line 4660) | def corrcoef(x):
function correlate (line 4687) | def correlate(x1, x2, mode="valid"):
function select (line 4744) | def select(condlist, choicelist, default=0):
function slogdet (line 4761) | def slogdet(x):
function argpartition (line 4957) | def argpartition(x, kth, axis=-1):
function histogram (line 5023) | def histogram(x, bins=10, range=None):
FILE: keras/src/backend/openvino/random.py
function _rng_from_seed_data (line 16) | def _rng_from_seed_data(seed_data):
function _random_uniform (line 29) | def _random_uniform(shape, minval, maxval, dtype, seed1, seed2):
function normal (line 55) | def normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
function uniform (line 63) | def uniform(shape, minval=0.0, maxval=1.0, dtype=None, seed=None):
function categorical (line 75) | def categorical(logits, num_samples, dtype="int64", seed=None):
function randint (line 133) | def randint(shape, minval, maxval, dtype="int32", seed=None):
function truncated_normal (line 167) | def truncated_normal(shape, mean=0.0, stddev=1.0, dtype=None, seed=None):
function dropout (line 194) | def dropout(inputs, rate, noise_shape=None, seed=None):
function shuffle (line 266) | def shuffle(x, axis=0, seed=None):
function _const (line 295) | def _const(val, dtype):
function _random_normal (line 303) | def _random_normal(shape, dtype, seed1, seed2):
function gamma (line 320) | def gamma(shape, alpha, dtype=None, seed=None):
function binomial (line 396) | def binomial(shape, counts, probabilities, dtype=None, seed=None):
function beta (line 445) | def beta(shape, alpha, beta, dtype=None, seed=None):
FILE: keras/src/backend/openvino/rnn.py
function rnn (line 10) | def rnn(
function _reorder_gates (line 296) | def _reorder_gates(x_ov, from_order, to_order, axis):
function _seq_lengths (line 311) | def _seq_lengths(inputs_ov):
function lstm (line 327) | def lstm(
function gru (line 432) | def gru(
function cudnn_ok (line 558) | def cudnn_ok(*args, **kwargs):
FILE: keras/src/backend/openvino/trainer.py
class OpenVINOTrainer (line 18) | class OpenVINOTrainer(base_trainer.Trainer):
method __init__ (line 19) | def __init__(self):
method _unpack_singleton (line 28) | def _unpack_singleton(self, x):
method test_step (line 33) | def test_step(self, data):
method predict_step (line 50) | def predict_step(self, data):
method make_test_function (line 61) | def make_test_function(self, force=False):
method _parameterize_data (line 81) | def _parameterize_data(self, data):
method _get_data_shapes (line 107) | def _get_data_shapes(self, data):
method _get_compiled_model (line 116) | def _get_compiled_model(self, data):
method make_predict_function (line 147) | def make_predict_function(self, force=False):
method fit (line 174) | def fit(
method predict (line 198) | def predict(
method evaluate (line 252) | def evaluate(
method train_on_batch (line 309) | def train_on_batch(
method test_on_batch (line 321) | def test_on_batch(
method predict_on_batch (line 337) | def predict_on_batch(self, x):
FILE: keras/src/backend/tensorflow/core.py
class Variable (line 27) | class Variable(
method handle (line 35) | def handle(self):
method _initialize (line 38) | def _initialize(self, value):
method _initialize_with_initializer (line 51) | def _initialize_with_initializer(self, initializer):
method _deferred_initialize (line 54) | def _deferred_initialize(self):
method _direct_assign (line 69) | def _direct_assign(self, value):
method _convert_to_tensor (line 72) | def _convert_to_tensor(self, value, dtype=None):
method numpy (line 75) | def numpy(self): # noqa: F811
method shape (line 79) | def shape(self):
method __tf_tensor__ (line 83) | def __tf_tensor__(self, dtype=None, name=None):
method _shared_name (line 88) | def _shared_name(self):
method _serialize_to_tensors (line 91) | def _serialize_to_tensors(self):
method _restore_from_tensors (line 97) | def _restore_from_tensors(self, restored_tensors):
method _copy_trackable_to_cpu (line 104) | def _copy_trackable_to_cpu(self, object_map):
method _export_to_saved_model_graph (line 108) | def _export_to_saved_model_graph(
method _write_object_proto (line 117) | def _write_object_proto(self, proto, options):
method _map_aggregation (line 120) | def _map_aggregation(self, aggregation):
method _map_synchronization (line 129) | def _map_synchronization(self, synchronization):
function convert_to_tensor (line 139) | def convert_to_tensor(x, dtype=None, sparse=None, ragged=None):
function convert_to_numpy (line 164) | def convert_to_numpy(x):
function is_tensor (line 174) | def is_tensor(x):
function shape (line 178) | def shape(x):
function cast (line 209) | def cast(x, dtype):
function compute_output_spec (line 220) | def compute_output_spec(fn, *args, **kwargs):
function cond (line 253) | def cond(pred, true_fn, false_fn):
function vectorized_map (line 261) | def vectorized_map(function, elements):
function map (line 265) | def map(f, xs):
function scan (line 281) | def scan(f, init, xs=None, length=None, reverse=False, unroll=1):
function associative_scan (line 402) | def associative_scan(f, elems, reverse=False, axis=0):
function scatter (line 580) | def scatter(indices, values, shape):
function scatter_update (line 584) | def scatter_update(inputs, indices, updates, reduction=None):
function slice (line 616) | def slice(inputs, start_indices, shape):
function slice_update (line 620) | def slice_update(inputs, start_indices, updates):
function switch (line 624) | def switch(index, branches, *operands):
function while_loop (line 637) | def while_loop(
function fori_loop (line 659) | def fori_loop(lower, upper, body_fun, init_val):
function stop_gradient (line 667) | def stop_gradient(variable):
function unstack (line 671) | def unstack(x, num=None, axis=0):
function random_seed_dtype (line 675) | def random_seed_dtype():
function custom_gradient (line 680) | def custom_gradient(fun):
function remat (line 684) | def remat(f):
class name_scope (line 696) | class name_scope(base_name_scope):
method __init__ (line 697) | def __init__(self, name, **kwargs):
method __enter__ (line 701) | def __enter__(self):
method __exit__ (line 719) | def __exit__(self, *args, **kwargs):
function device_scope (line 725) | def device_scope(device_name):
FILE: keras/src/backend/tensorflow/distribute_test.py
class DistributeTest (line 20) | class DistributeTest(testing.TestCase):
method setUp (line 21) | def setUp(self):
method test_variable_creation (line 34) | def test_variable_creation(self):
method test_strategy_run (line 50) | def test_strategy_run(self):
method test_epoch_iterator (line 89) | def test_epoch_iterator(self):
method test_variable_aggregation (line 127) | def test_variable_aggregation(self):
method test_variable_synchronization (line 140) | def test_variable_synchronization(self):
method test_seed_generator (line 158) | def test_seed_generator(self):
method test_correctness_with_fit_and_regularizer (line 166) | def test_correctness_with_fit_and_regularizer(self):
FILE: keras/src/backend/tensorflow/distribution_lib.py
function list_devices (line 13) | def list_devices(device_type=None):
function distribute_value (line 48) | def distribute_value(value, tensor_layout):
function _to_backend_mesh (line 53) | def _to_backend_mesh(device_mesh):
function _to_backend_layout (line 68) | def _to_backend_layout(tensor_layout):
FILE: keras/src/backend/tensorflow/export.py
class TFExportArchive (line 6) | class TFExportArchive(SavedModelExportArchive):
method _backend_track_layer (line 9) | def _backend_track_layer(self, layer):
method _backend_add_endpoint (line 19) | def _backend_add_endpoint(self, name, fn, input_signature, **kwargs):
FILE: keras/src/backend/tensorflow/image.py
function rgb_to_grayscale (line 51) | def rgb_to_grayscale(images, data_format=None):
function rgb_to_hsv (line 75) | def rgb_to_hsv(images, data_format=None):
function hsv_to_rgb (line 104) | def hsv_to_rgb(images, data_format=None):
function resize (line 133) | def resize(
function affine_transform (line 335) | def affine_transform(
function perspective_transform (line 395) | def perspective_transform(
function compute_homography_matrix (line 473) | def compute_homography_matrix(start_points, end_points):
function _mirror_index_fixer (line 615) | def _mirror_index_fixer(index, size):
function _reflect_index_fixer (line 621) | def _reflect_index_fixer(index, size):
function _nearest_indices_and_weights (line 627) | def _nearest_indices_and_weights(coordinate):
function _linear_indices_and_weights (line 636) | def _linear_indices_and_weights(coordinate):
function map_coordinates (line 644) | def map_coordinates(
function gaussian_blur (line 742) | def gaussian_blur(
function elastic_transform (line 798) | def elastic_transform(
function _fill_triangle_kernel (line 947) | def _fill_triangle_kernel(x):
function _fill_keys_cubic_kernel (line 951) | def _fill_keys_cubic_kernel(x):
function _fill_lanczos_kernel (line 957) | def _fill_lanczos_kernel(radius, x):
function _compute_weight_mat (line 973) | def _compute_weight_mat(
function _scale_and_translate (line 1010) | def _scale_and_translate(
function scale_and_translate (line 1042) | def scale_and_translate(
FILE: keras/src/backend/tensorflow/layer.py
class TFLayer (line 11) | class TFLayer(KerasAutoTrackable):
method __init__ (line 12) | def __init__(self, *args, **kwargs):
method _set_save_spec (line 18) | def _set_save_spec(self, inputs, args=None, kwargs=None):
method _trackable_children (line 50) | def _trackable_children(self, save_type="checkpoint", **kwargs):
method _convert_tracked_collections (line 73) | def _convert_tracked_collections(self, children):
method _get_save_spec (line 86) | def _get_save_spec(self, dynamic_batch=True):
method _default_save_signature (line 115) | def _default_save_signature(self):
FILE: keras/src/backend/tensorflow/linalg.py
function cholesky (line 10) | def cholesky(a, upper=False):
function cholesky_inverse (line 19) | def cholesky_inverse(a, upper=False):
function det (line 29) | def det(a):
function eig (line 33) | def eig(a):
function eigh (line 37) | def eigh(a):
function inv (line 41) |
Copy disabled (too large)
Download .json
Condensed preview — 1011 files, each showing path, character count, and a content snippet. Download the .json file for the full structured content (11,302K chars).
[
{
"path": ".devcontainer/README.md",
"chars": 1254,
"preview": "# Dev container configurations\n\nThis directory contains the configuration for dev containers, which is used to\ninitializ"
},
{
"path": ".devcontainer/devcontainer.json",
"chars": 869,
"preview": "{\n \"image\": \"mcr.microsoft.com/vscode/devcontainers/python:3.10\",\n \"postCreateCommand\": \"sh ./.devcontainer/setup."
},
{
"path": ".devcontainer/setup.sh",
"chars": 150,
"preview": "sudo pip install --upgrade pip\nsudo pip install -r requirements.txt\necho \"bash shell/lint.sh\" > .git/hooks/pre-commit\nch"
},
{
"path": ".gemini/config.yaml",
"chars": 253,
"preview": "have_fun: false\nmemory_config:\n disabled: false\ncode_review:\n disable: false\n comment_severity_threshold: MEDIUM\n ma"
},
{
"path": ".gemini/styleguide.md",
"chars": 19932,
"preview": "# Keras API design guidelines\n\nThese guidelines are meant to help focus design discussions and help us create delightful"
},
{
"path": ".github/dependabot.yml",
"chars": 1002,
"preview": "# To get started with Dependabot version updates, you'll need to specify which\n# package ecosystems to update and where "
},
{
"path": ".github/workflows/actions.yml",
"chars": 5931,
"preview": "name: Tests\n\n# TODO: Consider enabling all tests (pytest, applications, etc.) with NNX in the future\n# Currently only ba"
},
{
"path": ".github/workflows/auto-assignment.yaml",
"chars": 422,
"preview": "name: auto-assignment\non:\n issues:\n types:\n - opened\n\npermissions:\n contents: read\n issues: write\n pull-requ"
},
{
"path": ".github/workflows/config/jax/keras.json",
"chars": 140,
"preview": "{\n \"floatx\": \"float32\",\n \"epsilon\": 1e-07,\n \"backend\": \"jax\",\n \"image_data_format\": \"channels_last\",\n \"nn"
},
{
"path": ".github/workflows/config/numpy/keras.json",
"chars": 116,
"preview": "{\n \"floatx\": \"float32\",\n \"epsilon\": 1e-07,\n \"backend\": \"numpy\",\n \"image_data_format\": \"channels_last\"\n}\n"
},
{
"path": ".github/workflows/config/openvino/keras.json",
"chars": 119,
"preview": "{\n \"floatx\": \"float32\",\n \"epsilon\": 1e-07,\n \"backend\": \"openvino\",\n \"image_data_format\": \"channels_last\"\n}\n"
},
{
"path": ".github/workflows/config/tensorflow/keras.json",
"chars": 121,
"preview": "{\n \"floatx\": \"float32\",\n \"epsilon\": 1e-07,\n \"backend\": \"tensorflow\",\n \"image_data_format\": \"channels_last\"\n}"
},
{
"path": ".github/workflows/config/torch/keras.json",
"chars": 117,
"preview": "{\n \"floatx\": \"float32\",\n \"epsilon\": 1e-07,\n \"backend\": \"torch\",\n \"image_data_format\": \"channels_first\"\n}\n"
},
{
"path": ".github/workflows/gpu_tests.yml",
"chars": 2087,
"preview": "name: Keras GPU Tests\n\non:\n push:\n branches: [master]\n pull_request:\n types: [unlabeled]\n release:\n types: ["
},
{
"path": ".github/workflows/labeler.yaml",
"chars": 1398,
"preview": "# Copyright 2024 Google LLC. All Rights Reserved.\n#\n# Licensed under the Apache License, Version 2.0 (the \"License\");\n# "
},
{
"path": ".github/workflows/nightly.yml",
"chars": 3733,
"preview": "name: Nightly\n\non:\n workflow_dispatch: # To Generate wheels on demand outside of schedule.\n schedule:\n - cron: \"0 3"
},
{
"path": ".github/workflows/scorecard.yml",
"chars": 2461,
"preview": "name: Scorecard supply-chain security\non:\n # For Branch-Protection check. Only the default branch is supported. See\n #"
},
{
"path": ".github/workflows/scripts/auto-assignment.js",
"chars": 2071,
"preview": "/**\n * @license\n * Copyright 2023 Google LLC. All Rights Reserved.\n * Licensed under the Apache License, Version 2.0 (th"
},
{
"path": ".github/workflows/scripts/labeler.js",
"chars": 2030,
"preview": "/*\nCopyright 2024 Google LLC. All Rights Reserved.\nLicensed under the Apache License, Version 2.0 (the \"License\");\nyou m"
},
{
"path": ".github/workflows/stale-issue-pr.yaml",
"chars": 2559,
"preview": "name: Close inactive issues\non:\n schedule:\n - cron: \"30 1 * * *\"\njobs:\n close-issues:\n # Don't do this in forks\n"
},
{
"path": ".github/workflows/tpu_tests.yml",
"chars": 1289,
"preview": "name: Keras TPU Tests\n\non:\n push:\n branches: [master]\n pull_request:\n types: [unlabeled]\n release:\n types: ["
},
{
"path": ".gitignore",
"chars": 283,
"preview": ".DS_Store\n*.pyc\n.vscode-test\n__pycache__\n**/.vscode-test/**\n**/.vscode test/**\n**/.vscode-smoke/**\n**/.venv*/\nvenv\nbin/*"
},
{
"path": ".kokoro/README.md",
"chars": 36,
"preview": "CI to run on PR and merge to Master."
},
{
"path": ".kokoro/github/ubuntu/gpu/build.sh",
"chars": 3161,
"preview": "set -e\nset -x\n\ncd \"${KOKORO_ROOT}/\"\n\nsudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1\n\n"
},
{
"path": ".kokoro/github/ubuntu/gpu/jax/continuous.cfg",
"chars": 275,
"preview": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n define_artifacts {\n regex: \"**/sponge_log.log\"\n "
},
{
"path": ".kokoro/github/ubuntu/gpu/jax/presubmit.cfg",
"chars": 275,
"preview": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n define_artifacts {\n regex: \"**/sponge_log.log\"\n "
},
{
"path": ".kokoro/github/ubuntu/gpu/tensorflow/continuous.cfg",
"chars": 280,
"preview": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n define_artifacts {\n regex: \"**/sponge_log.log\"\n "
},
{
"path": ".kokoro/github/ubuntu/gpu/tensorflow/presubmit.cfg",
"chars": 280,
"preview": "build_file: \"keras/.kokoro/github/ubuntu/gpu/build.sh\"\n\naction {\n define_artifacts {\n regex: \"**/sponge_log.log\"\n "
},
{
"path": ".pre-commit-config.yaml",
"chars": 890,
"preview": "repos:\n - repo: local\n hooks:\n - id: api-gen\n name: api_gen\n entry: |\n bash shell/api_ge"
},
{
"path": "CONTRIBUTING.md",
"chars": 7147,
"preview": "Keras 3 is a high-velocity open-source project. We welcome contributions!\n\nContributions can be made in a variety of way"
},
{
"path": "LICENSE",
"chars": 11356,
"preview": " Apache License\n Version 2.0, January 2004\n "
},
{
"path": "README.md",
"chars": 5354,
"preview": "# Keras 3: Deep Learning for Humans\n\nKeras 3 is a multi-backend deep learning framework, with support for JAX, TensorFlo"
},
{
"path": "SECURITY.md",
"chars": 5158,
"preview": "# Security Policy\n\n - [**Using Keras Securely**](#using-keras-securely)\n - [Untrusted inputs](#untrusted-inputs)\n - "
},
{
"path": "api_gen.py",
"chars": 7543,
"preview": "\"\"\"Script to generate keras public API in `keras/api` directory.\n\nUsage:\n\nRun via `./shell/api_gen.sh`.\nIt generates API"
},
{
"path": "benchmarks/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "benchmarks/layer_benchmark/README.md",
"chars": 512,
"preview": "# Benchmark the layer performance\n\nThis directory contains benchmarks to compare the performance of\n`keras.layers.XXX` a"
},
{
"path": "benchmarks/layer_benchmark/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "benchmarks/layer_benchmark/activation_benchmark.py",
"chars": 3754,
"preview": "\"\"\"Benchmark activation layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to "
},
{
"path": "benchmarks/layer_benchmark/attention_benchmark.py",
"chars": 3052,
"preview": "\"\"\"Benchmark attention layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to y"
},
{
"path": "benchmarks/layer_benchmark/base_benchmark.py",
"chars": 9434,
"preview": "import time\n\nimport numpy as np\nimport tensorflow as tf\nfrom absl import flags\n\nimport keras\n\nFLAGS = flags.FLAGS\n\nflags"
},
{
"path": "benchmarks/layer_benchmark/conv_benchmark.py",
"chars": 7246,
"preview": "\"\"\"Benchmark conv layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your c"
},
{
"path": "benchmarks/layer_benchmark/core_benchmark.py",
"chars": 3033,
"preview": "\"\"\"Benchmark core layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your c"
},
{
"path": "benchmarks/layer_benchmark/merge_benchmark.py",
"chars": 5706,
"preview": "\"\"\"Benchmark merge layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your "
},
{
"path": "benchmarks/layer_benchmark/normalization_benchmark.py",
"chars": 3507,
"preview": "\"\"\"Benchmark normalization layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag "
},
{
"path": "benchmarks/layer_benchmark/pooling_benchmark.py",
"chars": 8179,
"preview": "\"\"\"Benchmark pooling layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to you"
},
{
"path": "benchmarks/layer_benchmark/random_rotation_benchmark.py",
"chars": 1563,
"preview": "\"\"\"Benchmark RandomRotation layer.\"\"\"\n\nfrom absl import app\nfrom absl import flags\n\nfrom benchmarks.layer_benchmark.base"
},
{
"path": "benchmarks/layer_benchmark/regularization_benchmark.py",
"chars": 4668,
"preview": "\"\"\"Benchmark regularization layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag"
},
{
"path": "benchmarks/layer_benchmark/reshaping_benchmark.py",
"chars": 7154,
"preview": "\"\"\"Benchmark reshaping layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to y"
},
{
"path": "benchmarks/layer_benchmark/rnn_benchmark.py",
"chars": 6189,
"preview": "\"\"\"Benchmark rnn layers.\n\nTo run benchmarks, see the following command for an example, please change the\nflag to your cu"
},
{
"path": "benchmarks/model_benchmark/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "benchmarks/model_benchmark/benchmark_utils.py",
"chars": 790,
"preview": "import time\n\nimport keras\n\n\nclass BenchmarkMetricsCallback(keras.callbacks.Callback):\n def __init__(self, start_batch"
},
{
"path": "benchmarks/model_benchmark/bert_benchmark.py",
"chars": 4661,
"preview": "\"\"\"Benchmark BERT model on GLUE/MRPC task.\n\nTo run the script, make sure you are in benchmarks/ directory, abd run the\nc"
},
{
"path": "benchmarks/model_benchmark/image_classification_benchmark.py",
"chars": 4498,
"preview": "\"\"\"Image classification benchmark.\n\nThis script runs image classification benchmark with \"dogs vs cats\" datasets.\nIt sup"
},
{
"path": "benchmarks/torch_ctl_benchmark/README.md",
"chars": 482,
"preview": "# Benchmark the performance of torch custom training loop\n\nThis directory contains benchmarks to compare the performance"
},
{
"path": "benchmarks/torch_ctl_benchmark/__init__.py",
"chars": 0,
"preview": ""
},
{
"path": "benchmarks/torch_ctl_benchmark/benchmark_utils.py",
"chars": 1087,
"preview": "import time\n\nimport numpy as np\nimport torch\n\n\ndef train_loop(model, train_loader, num_epochs, optimizer, loss_fn, frame"
},
{
"path": "benchmarks/torch_ctl_benchmark/conv_model_benchmark.py",
"chars": 2638,
"preview": "\"\"\"Benchmark Keras performance with torch custom training loop.\n\nIn this file we use a convolution model. Training loop "
},
{
"path": "benchmarks/torch_ctl_benchmark/dense_model_benchmark.py",
"chars": 2574,
"preview": "\"\"\"Benchmark Keras performance with torch custom training loop.\n\nIn this file we use a model with 3 dense layers. Traini"
},
{
"path": "codecov.yml",
"chars": 662,
"preview": "coverage:\n status:\n project:\n default:\n # `auto` compares coverage with the base-commit\n target: "
},
{
"path": "conftest.py",
"chars": 2816,
"preview": "try:\n # When using torch and tensorflow, torch needs to be imported first,\n # otherwise it will segfault upon impo"
},
{
"path": "examples/demo_custom_jax_workflow.py",
"chars": 3226,
"preview": "# flake8: noqa\nimport os\n\n# Set backend env to JAX\nos.environ[\"KERAS_BACKEND\"] = \"jax\"\n\nimport jax\nimport numpy as np\n\nf"
},
{
"path": "examples/demo_custom_layer_backend_agnostic.py",
"chars": 2354,
"preview": "import numpy as np\n\nimport keras\nfrom keras import Model\nfrom keras import initializers\nfrom keras import layers\nfrom ke"
},
{
"path": "examples/demo_custom_tf_workflow.py",
"chars": 2063,
"preview": "# flake8: noqa\nimport os\n\n# Set backend env to tensorflow\nos.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n\nimport numpy as np"
},
{
"path": "examples/demo_custom_torch_workflow.py",
"chars": 3817,
"preview": "# flake8: noqa\nimport os\n\n# Set backend env to torch\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport torch\nimport torch.nn"
},
{
"path": "examples/demo_functional.py",
"chars": 1415,
"preview": "import numpy as np\n\nfrom keras import Model\nfrom keras import layers\nfrom keras import losses\nfrom keras import metrics\n"
},
{
"path": "examples/demo_jax_distributed.py",
"chars": 11243,
"preview": "# To run this demo, you will need to spin up a \"TPU VM\" on Google Cloud.\n# Please follow instructions here: https://clou"
},
{
"path": "examples/demo_mnist_convnet.py",
"chars": 1562,
"preview": "import numpy as np\nimport keras\nfrom keras import layers\nfrom keras.utils import to_categorical\n\n# Model / data paramete"
},
{
"path": "examples/demo_subclass.py",
"chars": 979,
"preview": "import numpy as np\n\nfrom keras import Model\nfrom keras import layers\nfrom keras import losses\nfrom keras import metrics\n"
},
{
"path": "examples/demo_torch_multi_gpu.py",
"chars": 6005,
"preview": "# flake8: noqa\nimport os\n\n# Set backend env to torch\nos.environ[\"KERAS_BACKEND\"] = \"torch\"\n\nimport torch\nimport torch.nn"
},
{
"path": "guides/custom_train_step_in_jax.py",
"chars": 10713,
"preview": "\"\"\"\nTitle: Customizing what happens in `fit()` with JAX\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2"
},
{
"path": "guides/custom_train_step_in_tensorflow.py",
"chars": 16172,
"preview": "\"\"\"\nTitle: Customizing what happens in `fit()` with TensorFlow\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate cre"
},
{
"path": "guides/custom_train_step_in_torch.py",
"chars": 16958,
"preview": "\"\"\"\nTitle: Customizing what happens in `fit()` with PyTorch\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate create"
},
{
"path": "guides/distributed_training_with_jax.py",
"chars": 9861,
"preview": "\"\"\"\nTitle: Multi-GPU distributed training with JAX\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2023/0"
},
{
"path": "guides/distributed_training_with_tensorflow.py",
"chars": 10064,
"preview": "\"\"\"\nTitle: Multi-GPU distributed training with TensorFlow\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created:"
},
{
"path": "guides/distributed_training_with_torch.py",
"chars": 8878,
"preview": "\"\"\"\nTitle: Multi-GPU distributed training with PyTorch\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 20"
},
{
"path": "guides/functional_api.py",
"chars": 28215,
"preview": "\"\"\"\nTitle: The Functional API\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2019/03/01\nLast modified: 2"
},
{
"path": "guides/making_new_layers_and_models_via_subclassing.py",
"chars": 20321,
"preview": "\"\"\"\nTitle: Making new layers and models via subclassing\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2"
},
{
"path": "guides/sequential_model.py",
"chars": 10186,
"preview": "\"\"\"\nTitle: The Sequential model\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2020/04/12\nLast modified:"
},
{
"path": "guides/training_with_built_in_methods.py",
"chars": 40459,
"preview": "\"\"\"\nTitle: Training & evaluation with the built-in methods\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created"
},
{
"path": "guides/transfer_learning.py",
"chars": 20523,
"preview": "\"\"\"\nTitle: Transfer learning & fine-tuning\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 2020/04/15\nLas"
},
{
"path": "guides/understanding_masking_and_padding.py",
"chars": 12229,
"preview": "\"\"\"\nTitle: Understanding masking & padding\nAuthors: Scott Zhu, Francois Chollet\nDate created: 2019/07/16\nLast modified: "
},
{
"path": "guides/writing_a_custom_training_loop_in_jax.py",
"chars": 17063,
"preview": "\"\"\"\nTitle: Writing a training loop from scratch in JAX\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created: 20"
},
{
"path": "guides/writing_a_custom_training_loop_in_tensorflow.py",
"chars": 18097,
"preview": "\"\"\"\nTitle: Writing a training loop from scratch in TensorFlow\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate crea"
},
{
"path": "guides/writing_a_custom_training_loop_in_torch.py",
"chars": 12610,
"preview": "\"\"\"\nTitle: Writing a training loop from scratch in PyTorch\nAuthor: [fchollet](https://twitter.com/fchollet)\nDate created"
},
{
"path": "guides/writing_your_own_callbacks.py",
"chars": 12876,
"preview": "\"\"\"\nTitle: Writing your own callbacks\nAuthors: Rick Chao, Francois Chollet\nDate created: 2019/03/20\nLast modified: 2023/"
},
{
"path": "integration_tests/basic_full_flow.py",
"chars": 1582,
"preview": "import numpy as np\nimport pytest\n\nimport keras\nfrom keras.src import layers\nfrom keras.src import losses\nfrom keras.src "
},
{
"path": "integration_tests/dataset_tests/boston_housing_test.py",
"chars": 887,
"preview": "from keras.src import testing\nfrom keras.src.datasets import boston_housing\n\n\nclass BostonHousingTest(testing.TestCase):"
},
{
"path": "integration_tests/dataset_tests/california_housing_test.py",
"chars": 1353,
"preview": "from keras.src import testing\nfrom keras.src.datasets import california_housing\n\n\nclass CaliforniaHousingTest(testing.Te"
},
{
"path": "integration_tests/dataset_tests/cifar100_test.py",
"chars": 1318,
"preview": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import cifar100\n\n\nclass Cifar100LoadDataTest(t"
},
{
"path": "integration_tests/dataset_tests/cifar10_test.py",
"chars": 1195,
"preview": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import cifar10\n\n\nclass Cifar10LoadDataTest(tes"
},
{
"path": "integration_tests/dataset_tests/fashion_mnist_test.py",
"chars": 1244,
"preview": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import fashion_mnist\n\n\nclass FashionMnistLoadD"
},
{
"path": "integration_tests/dataset_tests/imdb_test.py",
"chars": 1831,
"preview": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import imdb\n\n\nclass ImdbLoadDataTest(testing.T"
},
{
"path": "integration_tests/dataset_tests/mnist_test.py",
"chars": 1165,
"preview": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import mnist\n\n\nclass MnistLoadDataTest(testing"
},
{
"path": "integration_tests/dataset_tests/reuters_test.py",
"chars": 1953,
"preview": "import numpy as np\n\nfrom keras.src import testing\nfrom keras.src.datasets import reuters\n\n\nclass ReutersLoadDataTest(tes"
},
{
"path": "integration_tests/import_test.py",
"chars": 4109,
"preview": "import os\nimport re\nimport subprocess\n\nfrom keras.src import backend\nfrom keras.src.backend import config\n\n# For torch, "
},
{
"path": "integration_tests/jax_custom_fit_test.py",
"chars": 3305,
"preview": "import jax\nimport numpy as np\n\nimport keras\n\n\ndef test_custom_fit():\n class CustomModel(keras.Model):\n def __i"
},
{
"path": "integration_tests/model_visualization_test.py",
"chars": 30746,
"preview": "import re\n\nimport keras\nfrom keras.src import testing\nfrom keras.src.utils import model_to_dot\nfrom keras.src.utils impo"
},
{
"path": "integration_tests/numerical_test.py",
"chars": 4352,
"preview": "import keras # isort: skip, keep it on top for torch test\n\nimport sys\n\nimport numpy as np\nimport tf_keras\n\nkeras.backen"
},
{
"path": "integration_tests/pytorch_export_test.py",
"chars": 11419,
"preview": "\"\"\"\nIntegration tests for PyTorch model export with dynamic shapes.\n\nTests the complete fix for GitHub issue #22102 wher"
},
{
"path": "integration_tests/tf_custom_fit_test.py",
"chars": 1563,
"preview": "import numpy as np\nimport tensorflow as tf\n\nimport keras\n\n\ndef test_custom_fit():\n class CustomModel(keras.Model):\n "
},
{
"path": "integration_tests/tf_distribute_training_test.py",
"chars": 2150,
"preview": "import numpy as np\nimport tensorflow as tf\n\nimport keras\nfrom keras.src import layers\nfrom keras.src import losses\nfrom "
},
{
"path": "integration_tests/torch_custom_fit_test.py",
"chars": 1618,
"preview": "import numpy as np\nimport torch\n\nimport keras\n\n\ndef test_custom_fit():\n class CustomModel(keras.Model):\n def _"
},
{
"path": "integration_tests/torch_workflow_test.py",
"chars": 996,
"preview": "import torch\n\nfrom keras.src import layers\nfrom keras.src import testing\nfrom keras.src.backend.common import KerasVaria"
},
{
"path": "keras/__init__.py",
"chars": 446,
"preview": "# This file should NEVER be packaged! This is a hack to make \"import keras\" from\n# the base of the repo just import the "
},
{
"path": "keras/api/__init__.py",
"chars": 2951,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/__init__.py",
"chars": 34,
"preview": "from keras._tf_keras import keras\n"
},
{
"path": "keras/api/_tf_keras/keras/__init__.py",
"chars": 2955,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/activations/__init__.py",
"chars": 2413,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/__init__.py",
"chars": 4210,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/convnext/__init__.py",
"chars": 670,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/densenet/__init__.py",
"chars": 510,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/efficientnet/__init__.py",
"chars": 962,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/efficientnet_v2/__init__.py",
"chars": 993,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/imagenet_utils/__init__.py",
"chars": 318,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/inception_resnet_v2/__init__.py",
"chars": 431,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/inception_v3/__init__.py",
"chars": 389,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/mobilenet/__init__.py",
"chars": 376,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/mobilenet_v2/__init__.py",
"chars": 389,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/mobilenet_v3/__init__.py",
"chars": 314,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/nasnet/__init__.py",
"chars": 433,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/resnet/__init__.py",
"chars": 486,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/resnet50/__init__.py",
"chars": 356,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/resnet_v2/__init__.py",
"chars": 522,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/vgg16/__init__.py",
"chars": 347,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/vgg19/__init__.py",
"chars": 347,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/applications/xception/__init__.py",
"chars": 362,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/backend/__init__.py",
"chars": 8626,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/callbacks/__init__.py",
"chars": 1448,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/config/__init__.py",
"chars": 2279,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/constraints/__init__.py",
"chars": 893,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/__init__.py",
"chars": 530,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/boston_housing/__init__.py",
"chars": 191,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/california_housing/__init__.py",
"chars": 195,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/cifar10/__init__.py",
"chars": 184,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/cifar100/__init__.py",
"chars": 185,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/fashion_mnist/__init__.py",
"chars": 190,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/imdb/__init__.py",
"chars": 250,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/mnist/__init__.py",
"chars": 182,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/datasets/reuters/__init__.py",
"chars": 330,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/distillation/__init__.py",
"chars": 497,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/distribution/__init__.py",
"chars": 1063,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/dtype_policies/__init__.py",
"chars": 1063,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/export/__init__.py",
"chars": 194,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/initializers/__init__.py",
"chars": 3371,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/layers/__init__.py",
"chars": 16408,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/legacy/__init__.py",
"chars": 164,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/legacy/saving/__init__.py",
"chars": 342,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/losses/__init__.py",
"chars": 4027,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/metrics/__init__.py",
"chars": 6546,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/mixed_precision/__init__.py",
"chars": 727,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/models/__init__.py",
"chars": 501,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/ops/__init__.py",
"chars": 16649,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/ops/image/__init__.py",
"chars": 1034,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/ops/linalg/__init__.py",
"chars": 822,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/ops/nn/__init__.py",
"chars": 3158,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/ops/numpy/__init__.py",
"chars": 10545,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/optimizers/__init__.py",
"chars": 1368,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/optimizers/legacy/__init__.py",
"chars": 516,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/optimizers/schedules/__init__.py",
"chars": 1120,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/preprocessing/__init__.py",
"chars": 673,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/preprocessing/image/__init__.py",
"chars": 1656,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/preprocessing/sequence/__init__.py",
"chars": 479,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/preprocessing/text/__init__.py",
"chars": 543,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/quantizers/__init__.py",
"chars": 1790,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/random/__init__.py",
"chars": 763,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/regularizers/__init__.py",
"chars": 923,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/saving/__init__.py",
"chars": 1298,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/tree/__init__.py",
"chars": 967,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/utils/__init__.py",
"chars": 3673,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/utils/bounding_boxes/__init__.py",
"chars": 1275,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/utils/legacy/__init__.py",
"chars": 342,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/visualization/__init__.py",
"chars": 722,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/_tf_keras/keras/wrappers/__init__.py",
"chars": 407,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/activations/__init__.py",
"chars": 2413,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/__init__.py",
"chars": 4210,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/convnext/__init__.py",
"chars": 670,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/densenet/__init__.py",
"chars": 510,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/efficientnet/__init__.py",
"chars": 962,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/efficientnet_v2/__init__.py",
"chars": 993,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/imagenet_utils/__init__.py",
"chars": 318,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/inception_resnet_v2/__init__.py",
"chars": 431,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/inception_v3/__init__.py",
"chars": 389,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/mobilenet/__init__.py",
"chars": 376,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/mobilenet_v2/__init__.py",
"chars": 389,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/mobilenet_v3/__init__.py",
"chars": 314,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/nasnet/__init__.py",
"chars": 433,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/resnet/__init__.py",
"chars": 486,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/resnet50/__init__.py",
"chars": 356,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/resnet_v2/__init__.py",
"chars": 522,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/vgg16/__init__.py",
"chars": 347,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/vgg19/__init__.py",
"chars": 347,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/applications/xception/__init__.py",
"chars": 362,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/backend/__init__.py",
"chars": 1134,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/callbacks/__init__.py",
"chars": 1448,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/config/__init__.py",
"chars": 2279,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/constraints/__init__.py",
"chars": 893,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/datasets/__init__.py",
"chars": 530,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/datasets/boston_housing/__init__.py",
"chars": 191,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/datasets/california_housing/__init__.py",
"chars": 195,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
},
{
"path": "keras/api/datasets/cifar10/__init__.py",
"chars": 184,
"preview": "\"\"\"DO NOT EDIT.\n\nThis file was autogenerated. Do not edit it by hand,\nsince your modifications would be overwritten.\n\"\"\""
}
]
// ... and 811 more files (download for full content)
About this extraction
This page contains the full source code of the keras-team/keras GitHub repository, extracted and formatted as plain text for AI agents and large language models (LLMs). The extraction includes 1011 files (10.3 MB), approximately 2.8M tokens, and a symbol index with 14089 extracted functions, classes, methods, constants, and types. Use this with OpenClaw, Claude, ChatGPT, Cursor, Windsurf, or any other AI tool that accepts text input. You can copy the full output to your clipboard or download it as a .txt file.
Extracted by GitExtract — free GitHub repo to text converter for AI. Built by Nikandr Surkov.