Repository: google-deepmind/android_env Branch: main Commit: 0cdf2711c4e9 Files: 148 Total size: 908.6 KB Directory structure: gitextract_70x6n_qo/ ├── .github/ │ └── workflows/ │ └── tests.yml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── android_env/ │ ├── __init__.py │ ├── apps/ │ │ ├── MODULE.bazel │ │ ├── java/ │ │ │ └── com/ │ │ │ └── google/ │ │ │ └── androidenv/ │ │ │ ├── accessibilityforwarder/ │ │ │ │ ├── AccessibilityForwarder.kt │ │ │ │ ├── AccessibilityForwarderTest.kt │ │ │ │ ├── AccessibilityTreeCreator.kt │ │ │ │ ├── AccessibilityTreeCreatorTest.kt │ │ │ │ ├── AndroidManifest.xml │ │ │ │ ├── AndroidManifest_lite.xml │ │ │ │ ├── FlagsBroadcastReceiver.kt │ │ │ │ ├── FlagsBroadcastReceiverTest.kt │ │ │ │ ├── LogFlags.kt │ │ │ │ ├── ParentChildNodePair.kt │ │ │ │ ├── UniqueIdsGenerator.kt │ │ │ │ └── res/ │ │ │ │ └── xml/ │ │ │ │ └── accessibility_forwarder_service.xml │ │ │ └── catch/ │ │ │ ├── AndroidManifest.xml │ │ │ ├── BUILD.bazel │ │ │ ├── GameLogic.kt │ │ │ ├── GameLogicThread.kt │ │ │ ├── MainActivity.kt │ │ │ ├── RenderThread.kt │ │ │ ├── res/ │ │ │ │ ├── layout/ │ │ │ │ │ └── main.xml │ │ │ │ └── values/ │ │ │ │ └── strings.xml │ │ │ └── sprite/ │ │ │ ├── BUILD.bazel │ │ │ ├── Background.kt │ │ │ ├── Ball.kt │ │ │ ├── LineSegment.kt │ │ │ ├── Paddle.kt │ │ │ ├── Point.kt │ │ │ └── Sprite.kt │ │ └── javatests/ │ │ └── com/ │ │ └── google/ │ │ └── androidenv/ │ │ └── catch/ │ │ ├── AndroidManifest.xml │ │ ├── BUILD.bazel │ │ ├── GameLogicTest.kt │ │ ├── GameLogicThreadTest.kt │ │ ├── MainActivityTest.kt │ │ ├── RenderThreadTest.kt │ │ └── sprite/ │ │ ├── BUILD.bazel │ │ ├── BackgroundTest.kt │ │ ├── BallTest.kt │ │ ├── PaddleTest.kt │ │ └── SpriteTest.kt │ ├── components/ │ │ ├── __init__.py │ │ ├── action_fns.py │ │ ├── action_fns_test.py │ │ ├── action_type.py │ │ ├── adb_call_parser.py │ │ ├── adb_call_parser_test.py │ │ ├── adb_controller.py │ │ ├── adb_controller_test.py │ │ ├── adb_log_stream.py │ │ ├── adb_log_stream_test.py │ │ ├── app_screen_checker.py │ │ ├── app_screen_checker_test.py │ │ ├── config_classes.py │ │ ├── coordinator.py │ │ ├── coordinator_test.py │ │ ├── device_settings.py │ │ ├── device_settings_test.py │ │ ├── dumpsys_thread.py │ │ ├── dumpsys_thread_test.py │ │ ├── errors.py │ │ ├── errors_test.py │ │ ├── log_stream.py │ │ ├── log_stream_test.py │ │ ├── logcat_thread.py │ │ ├── logcat_thread_test.py │ │ ├── pixel_fns.py │ │ ├── pixel_fns_test.py │ │ ├── setup_step_interpreter.py │ │ ├── setup_step_interpreter_test.py │ │ ├── simulators/ │ │ │ ├── __init__.py │ │ │ ├── base_simulator.py │ │ │ ├── base_simulator_test.py │ │ │ ├── emulator/ │ │ │ │ ├── __init__.py │ │ │ │ ├── emulator_launcher.py │ │ │ │ ├── emulator_launcher_test.py │ │ │ │ ├── emulator_simulator.py │ │ │ │ └── emulator_simulator_test.py │ │ │ └── fake/ │ │ │ ├── __init__.py │ │ │ ├── fake_simulator.py │ │ │ └── fake_simulator_test.py │ │ ├── specs.py │ │ ├── specs_test.py │ │ ├── task_manager.py │ │ └── task_manager_test.py │ ├── env_interface.py │ ├── environment.py │ ├── environment_test.py │ ├── loader.py │ ├── loader_test.py │ ├── proto/ │ │ ├── __init__.py │ │ ├── a11y/ │ │ │ ├── __init__.py │ │ │ ├── a11y.proto │ │ │ ├── android_accessibility_action.proto │ │ │ ├── android_accessibility_forest.proto │ │ │ ├── android_accessibility_node_info.proto │ │ │ ├── android_accessibility_node_info_clickable_span.proto │ │ │ ├── android_accessibility_tree.proto │ │ │ ├── android_accessibility_window_info.proto │ │ │ └── rect.proto │ │ ├── adb.proto │ │ ├── emulator_controller.proto │ │ ├── snapshot.proto │ │ ├── snapshot_service.proto │ │ ├── state.proto │ │ └── task.proto │ └── wrappers/ │ ├── __init__.py │ ├── a11y/ │ │ ├── __init__.py │ │ ├── a11y_events.py │ │ ├── a11y_events_test.py │ │ ├── a11y_forests.py │ │ ├── a11y_forests_test.py │ │ ├── a11y_servicer.py │ │ └── a11y_servicer_test.py │ ├── a11y_grpc_wrapper.py │ ├── a11y_grpc_wrapper_test.py │ ├── base_wrapper.py │ ├── base_wrapper_test.py │ ├── discrete_action_wrapper.py │ ├── discrete_action_wrapper_test.py │ ├── flat_interface_wrapper.py │ ├── flat_interface_wrapper_test.py │ ├── float_pixels_wrapper.py │ ├── float_pixels_wrapper_test.py │ ├── gym_wrapper.py │ ├── gym_wrapper_test.py │ ├── image_rescale_wrapper.py │ ├── image_rescale_wrapper_test.py │ ├── last_action_wrapper.py │ ├── last_action_wrapper_test.py │ ├── rate_limit_wrapper.py │ ├── rate_limit_wrapper_test.py │ ├── tap_action_wrapper.py │ └── tap_action_wrapper_test.py ├── docs/ │ ├── emulator_guide.md │ ├── environment.md │ ├── example_tasks.md │ ├── instructions.md │ └── tasks_guide.md ├── examples/ │ ├── __init__.py │ ├── run_acme_agent.py │ ├── run_human_agent.py │ └── run_random_agent.py ├── pyproject.toml └── setup.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/workflows/tests.yml ================================================ name: tests on: push: branches: [ main ] pull_request: branches: [ main ] workflow_dispatch: inputs: git-ref: description: Git Ref (Optional) required: false jobs: build: runs-on: ubuntu-latest env: TEST_TMPDIR: '/tmp' strategy: matrix: python-version: ["3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v6 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v6 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | pip install --upgrade pip setuptools python setup.py install pip install .[testing] - name: Run tests run: | # Find all test files, print their names and execute them in parallel # with a maximum of 20 proccesses. find . -type f -name "*_test.py" -print0 | xargs -t -0 -n1 -P 20 python3 ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute # Pull Requests Please send in fixes or feature additions through Pull Requests. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement. You (or your employer) retain the copyright to your contribution, this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # AndroidEnv - The Android Learning Environment [AndroidEnv](https://github.com/deepmind/android_env) is a Python library that exposes an [Android](https://www.android.com/) device as a Reinforcement Learning (RL) environment. The library provides a flexible platform for defining custom tasks on top of the Android Operating System, including any Android application. Agents interact with the device through a universal action interface - the touchscreen - by sending localized touch and lift events to the system. The library processes these events and returns pixel observations and rewards as provided by specific [task definitions](docs/tasks_guide.md). For example, rewards might be given for events such as successfully scrolling down a page, sending an email, or achieving some score in a game, depending on the research purpose and how the user configures the task. [![tests](https://github.com/deepmind/android_env/actions/workflows/tests.yml/badge.svg?branch=main)](https://github.com/deepmind/android_env/actions/workflows/tests.yml) [![PyPI version](https://badge.fury.io/py/android-env.svg)](https://badge.fury.io/py/android-env) [![Downloads](https://pepy.tech/badge/android-env)](https://pepy.tech/project/android-env) ## Index * [Environment details](docs/environment.md) * [Running AndroidEnv](docs/instructions.md) * [Setting up a virtual Android device](docs/emulator_guide.md) * [Defining a task in AndroidEnv](docs/tasks_guide.md) * [Example tasks available for download](docs/example_tasks.md) ## Environment features There are a number of aspects that make AndroidEnv a challenging yet suitable environment for Reinforcement Learning research: * Allowing agents to interact with a system used daily by billions of users around the world, AndroidEnv offers a platform for RL agents to navigate, learn tasks and have direct impact in **real-world contexts**. The environment wraps a simulated Android device, which runs independently from the environment, completely unaltered, and works in exactly the same way as the devices that humans use, exposing exactly the same features and services. * The platform offers a virtually infinite **range of possible tasks**, all sharing a common action interface. The library facilitates the design of Reinforcement Learning tasks for any existing or custom built Android application. For example, it exposes the broad world of Android games, ranging from card games, puzzle games, time reactive games, all requiring a diverse set of action combinations and interaction types. * The environment runs on top of a **real-time simulation** of an Android device. In other words, the environment dynamics does not wait for the agent to deliberate, and the speed of the simulation cannot be increased. * The observation is a collection of **RGB values** corresponding to the displayed pixels on the screen. The exact screen resolution depends on the simulated device, but in general it will be considered relatively large in an RL context. However, users have the option of downsampling each observation. * The learning environment has an interesting, **complex action space** unique to the touchscreen interface of Android. * The raw, **hybrid action space** consists of a continuous tuple signifying the action location, and a discrete signal determining whether the agent wants to touch the screen or lift its virtual finger. * Raw actions are highly **composable**: the Android UI and most applications were designed so that they could be intuitively navigated via common [touchscreen gestures](https://developer.android.com/training/gestures/detector) such as tapping, scrolling, swiping, pinching, drag & drop etc. This is still the case in AndroidEnv: to trigger meaningful changes in the environment, the agent often has to perform carefully timed and positioned sequences of raw actions. For example, in order to navigate to the next image in a photo gallery, the agent would have to perform a *swipe*, touching the screen multiple times, gradually shifting the actions' positions to the right. Thus, in most contexts raw actions do not trigger changes in the state of the environment unless correctly chained together to make up a human gesture. * The action interface is **closely related to the observation space**, as meaningful touch and lift events are often either co-localized or strongly correlated to the location or movement of salient objects in the observation. For example, the position of a button on the screen aligns with the location of the actions that trigger the button press. * The library provides tools for flexibly **altering the action interface** if needed for particular studies, such as discretization or hard-coding gesture skills. Still, we believe that the real challenge remains in devising agents that are capable of dealing with a large suite of diverse tasks, through acting and learning in the complex unifying action interface. # Getting started ### Installation The easiest way to get AndroidEnv is with pip: ```shell $ python3 -m pip install android-env ``` Please note that `/examples` are not included in this package. Alternatively, you can clone the repository from git's `main` branch: ```shell $ git clone https://github.com/deepmind/android_env/ $ cd android_env $ python3 setup.py install ``` Update: the environment now runs on Windows, but please keep in mind that this option is not well-maintained or widely supported, as Unix-based systems are the primary target platforms of this project. ### Create a simulator Before running the environment, you will need access to an emulated Android device. For instructions on creating a virtual Android device, see the [Emulator guide](docs/emulator_guide.md). ### Define a task Then, you will want to define what the agent's *task* is. At this point, the agent will be able to communicate with the emulated device, but it will not yet have an objective, or access to signals such as rewards or RL episode ends. Learn [how to define an RL task](docs/tasks_guide.md) of your own, or use one of the [existing task definitions](docs/example_tasks.md) for training. ### Load and run To find out how to run and train agents on AndroidEnv, see these [detailed instructions](docs/instructions.md). Here you can also find example scripts demonstrating how to run a random agent, an [acme](https://github.com/deepmind/acme) agent, or a human agent on AndroidEnv. ## About This library is developed and maintained by [DeepMind](http://deepmind.com). \ You can find the [technical report](https://arxiv.org/abs/2105.13231) on Arxiv, as well as an introductory [blog post](https://www.deepmind.com/publications/androidenv-the-android-learning-environment) on DeepMind's website. If you use AndroidEnv in your research, you can cite the paper using the following BibTeX: ``` @article{ToyamaEtAl2021AndroidEnv, title = {{AndroidEnv}: A Reinforcement Learning Platform for Android}, author = {Daniel Toyama and Philippe Hamel and Anita Gergely and Gheorghe Comanici and Amelia Glaese and Zafarali Ahmed and Tyler Jackson and Shibl Mourad and Doina Precup}, year = {2021}, eprint = {2105.13231}, archivePrefix = {arXiv}, primaryClass = {cs.LG}, volume = {abs/2105.13231}, url = {http://arxiv.org/abs/2105.13231}, } ``` Disclaimer: This is not an official Google product. ================================================ FILE: android_env/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/apps/MODULE.bazel ================================================ # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Bazel dependencies for building Catch. module( name = "catch", version = "1.0", ) bazel_dep(name = "rules_android", version = "0.6.6") bazel_dep(name = "rules_kotlin", version = "2.1.8") bazel_dep(name = "rules_jvm_external", version = "6.7") bazel_dep(name = "rules_robolectric", version = "4.16", repo_name = "robolectric") bazel_dep(name = "rules_java", version = "9.0.3") bazel_dep(name = "protobuf", version = "30.0") # To avoid conflict with different protobuf versions. single_version_override( module_name = "protobuf", version = "30.0", ) maven = use_extension("@rules_jvm_external//:extensions.bzl", "maven") # Need to set testonly = True because the package depends on testonly targets. maven.artifact( testonly = True, artifact = "runner", group = "androidx.test", version = "1.7.0", ) maven.artifact( testonly = True, artifact = "junit", group = "androidx.test.ext", version = "1.3.0", ) maven.artifact( testonly = True, artifact = "mockito-kotlin", group = "org.mockito.kotlin", version = "6.1.0", ) maven.install( artifacts = [ "androidx.test.ext:junit:1.3.0", "androidx.test:runner:1.7.0", "com.google.guava:guava:32.0.1-jre", "com.google.truth:truth:1.4.0", "org.mockito.kotlin:mockito-kotlin:6.1.0", "org.mockito:mockito-core:5.20.0", "org.robolectric:robolectric:4.16", "org.yaml:snakeyaml:2.5", ], repositories = [ "https://maven.google.com", "https://repo1.maven.org/maven2", ], ) use_repo(maven, "maven") remote_android_extensions = use_extension( "@rules_android//bzlmod_extensions:android_extensions.bzl", "remote_android_tools_extensions", ) use_repo(remote_android_extensions, "android_tools") android_sdk_repository_extension = use_extension("@rules_android//rules/android_sdk_repository:rule.bzl", "android_sdk_repository_extension") use_repo(android_sdk_repository_extension, "androidsdk") ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import android.accessibilityservice.AccessibilityService import android.util.Log import android.view.accessibility.AccessibilityEvent import android.view.accessibility.AccessibilityNodeInfo import android.view.accessibility.AccessibilityWindowInfo import com.google.androidenv.accessibilityforwarder.A11yServiceGrpcKt.A11yServiceCoroutineStub import io.grpc.ManagedChannel import io.grpc.ManagedChannelBuilder import io.grpc.ProxyDetector import io.grpc.StatusException import kotlinx.coroutines.TimeoutCancellationException import kotlinx.coroutines.runBlocking import kotlinx.coroutines.withTimeout /** * An Android service that listens to accessibility events and forwards them via gRPC. * * This service also logs the accessibility tree if [LogFlags.logAccessibilityTree] is set and if * [LogFlags.grpcPort] is positive. * * Please see * https://developer.android.com/reference/android/view/accessibility/AccessibilityEvent#getEventType() * for a comprehensive list of events emitted by Android. */ class AccessibilityForwarder( private val channelFactory: (host: String, port: Int) -> ManagedChannel = { host, port -> ManagedChannelBuilder.forAddress(host, port) .proxyDetector(ProxyDetector { _ -> null }) .usePlaintext() .build() } ) : AccessibilityService() { init { // Spawn long-running thread for periodically logging the tree. Thread( Runnable { while (LogFlags.a11yTreePeriodMs > 0) { try { logAccessibilityTree() } catch (e: ConcurrentModificationException) { continue } Thread.sleep(/* millis= */ LogFlags.a11yTreePeriodMs) } } ) .start() } // grpcStub has a backing property that can be reset to null. private var _grpcStub: A11yServiceCoroutineStub? = null val grpcStub: A11yServiceCoroutineStub get() { if (_grpcStub == null) { Log.i(TAG, "Building channel on ${LogFlags.grpcHost}:${LogFlags.grpcPort}.") _grpcStub = A11yServiceCoroutineStub(channelFactory(LogFlags.grpcHost, LogFlags.grpcPort)) } return _grpcStub!! } private fun resetGrpcStub() { _grpcStub = null } override fun onInterrupt() { LogFlags.a11yTreePeriodMs = 0 // Turn off periodic tree forwarding. } override fun onAccessibilityEvent(event: AccessibilityEvent?) { if (event == null) { Log.i(TAG, "`event` is null.") return } logExtrasForEvent(event) val eventType = event.eventType val eventTypeStr: String = AccessibilityEvent.eventTypeToString(eventType) if (eventTypeStr.isNotEmpty()) { Log.i(TAG, eventTypeStr) } } private fun logAccessibilityTree() { if (!LogFlags.logAccessibilityTree) { Log.i(TAG, "Not logging accessibility tree") return } val windows = getWindowsOrNull() if (windows == null) { Log.i(TAG, "windows is null.") return } // Check gRPC port before actually building the forest. if (LogFlags.grpcPort <= 0) { Log.w(TAG, "Can't log accessibility tree because gRPC port has not been set.") return } val forest = creator.buildForest(windows) try { val grpcTimeoutMillis = 1000L val response: ForestResponse = with(grpcStub) { Log.i(TAG, "sending (blocking) gRPC request for tree.") runBlocking { withTimeout(grpcTimeoutMillis) { sendForest(forest) } } } if (response.error.isNotEmpty()) { Log.w(TAG, "gRPC response.error: ${response.error}") } else { Log.i(TAG, "gRPC request for tree succeeded.") } } catch (e: StatusException) { Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?") Log.i(TAG, "extra: exception ['$e']") resetGrpcStub() } catch (e: TimeoutCancellationException) { Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?") Log.i(TAG, "extra: exception ['$e']") resetGrpcStub() } } private fun getWindowsOrNull(): List? = try { windows } catch (e: NullPointerException) { null } /** Logs extras for all event types. */ private fun logExtrasForEvent(event: AccessibilityEvent) { val events: MutableMap = mutableMapOf() val sourceDescription = event.source?.contentDescription() if (!sourceDescription.isNullOrEmpty()) { events.put("source_content_description", sourceDescription) } // Output the event text. val eventText = event.text.joinToString(", ") if (eventText.isNotEmpty()) { events.put("event_text", eventText) } // Output the source text. val sourceText = event.source?.text?.toString() if (!sourceText.isNullOrEmpty()) { events.put("source_text", sourceText) } val eventTypeStr: String = AccessibilityEvent.eventTypeToString(event.eventType) if (eventTypeStr.isNotEmpty()) { events.put("event_type", eventTypeStr) } val className = event.source?.className?.toString() if (!className.isNullOrEmpty()) { events.put("source_class_name", className) } val packageName = event.packageName?.toString() if (!packageName.isNullOrEmpty()) { events.put("event_package_name", packageName) } // Text editing properties. val beforeText = event.beforeText?.toString() if (!beforeText.isNullOrEmpty()) { events.put("before_text", beforeText) } val fromIndex = event.fromIndex if (fromIndex != -1) { events.put("from_index", fromIndex.toString()) } val toIndex = event.toIndex if (toIndex != -1) { events.put("to_index", toIndex.toString()) } val addedCount = event.addedCount if (addedCount != -1) { events.put("added_count", addedCount.toString()) } val removedCount = event.removedCount if (removedCount != -1) { events.put("removed_count", removedCount.toString()) } // Text traversal properties val movementGranularity = event.movementGranularity if (movementGranularity != 0) { events.put("movement_granularity", movementGranularity.toString()) } val action = event.action if (action != 0) { events.put("action", action.toString()) } // Scrolling properties. if (eventTypeStr == "TYPE_VIEW_SCROLLED") { events.put("scroll_delta_x", event.scrollDeltaX.toString()) events.put("scroll_delta_y", event.scrollDeltaY.toString()) } // Report viewID so we know exactly where the event came from. val viewId = event.source?.viewIdResourceName?.toString() if (!viewId.isNullOrEmpty()) { events.put("view_id", viewId) } // Format [events] as a Python dict. if (events.isNotEmpty()) { events.put("event_timestamp_ms", event.eventTime.toString(10)) // Check if we want to use gRPC. if (LogFlags.grpcPort > 0) { try { val grpcTimeoutMillis = 1000L val request = eventRequest { this.event.putAll(events) } val response: EventResponse = with(grpcStub) { Log.i(TAG, "sending (blocking) gRPC request for event.") runBlocking { withTimeout(grpcTimeoutMillis) { sendEvent(request) } } } if (response.error.isNotEmpty()) { Log.w(TAG, "gRPC response.error: ${response.error}") } else { Log.i(TAG, "gRPC request for event succeeded.") } } catch (e: StatusException) { Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?") Log.i(TAG, "extra: exception ['$e']") resetGrpcStub() } catch (e: TimeoutCancellationException) { Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?") Log.i(TAG, "extra: exception ['$e']") resetGrpcStub() } } else { Log.w(TAG, "Can't log accessibility event because gRPC port has not been set.") } } } /** Recursively climbs the accessibility tree until the root, collecting descriptions. */ private fun AccessibilityNodeInfo?.contentDescription(): String { if (this == null) { return "" } val descriptions = mutableListOf() var current: AccessibilityNodeInfo? = this while (current != null) { val description = current.contentDescription if (description != null) { descriptions.add(description.toString()) } current = current.parent } return descriptions.joinToString(", ") } companion object { private const val TAG = "AndroidRLTask" private val creator = AccessibilityTreeCreator() } } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import android.view.accessibility.AccessibilityEvent import android.view.accessibility.AccessibilityNodeInfo import android.view.accessibility.AccessibilityWindowInfo import com.google.common.truth.Truth.assertThat import io.grpc.Status import io.grpc.StatusException import io.grpc.inprocess.InProcessChannelBuilder import io.grpc.inprocess.InProcessServerBuilder import io.grpc.testing.GrpcCleanupRule import org.junit.Assert.assertFalse import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.robolectric.RobolectricTestParameterInjector import org.robolectric.Shadows.shadowOf @RunWith(RobolectricTestParameterInjector::class) class AccessibilityForwarderTest { @get:Rule(order = 1) val cleanupRule = GrpcCleanupRule() class FakeAccessibilityService : A11yServiceGrpcKt.A11yServiceCoroutineImplBase() { var sendForestChecker: (AndroidAccessibilityForest) -> String = { _ -> "" } var sendEventChecker: (EventRequest) -> String = { _ -> "" } override suspend fun sendForest(request: AndroidAccessibilityForest) = forestResponse { error = sendForestChecker(request) } override suspend fun sendEvent(request: EventRequest) = eventResponse { error = sendEventChecker(request) } } protected lateinit var forwarder: AccessibilityForwarder protected val fakeA11yService = FakeAccessibilityService() protected val channel by lazy { val serverName: String = InProcessServerBuilder.generateName() cleanupRule.register( InProcessServerBuilder.forName(serverName) .directExecutor() .addService(fakeA11yService) .build() .start() ) cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build()) } /** Initializes [forwarder] and [LogFlags] from the given args. */ fun createForwarder( logAccessibilityTree: Boolean = false, a11yTreePeriodMs: Long = 0, grpcHost: String = "10.0.2.2", grpcPort: Int = 0, a11yWindows: MutableList? = null, ) { LogFlags.logAccessibilityTree = logAccessibilityTree LogFlags.a11yTreePeriodMs = a11yTreePeriodMs LogFlags.grpcHost = grpcHost LogFlags.grpcPort = grpcPort forwarder = AccessibilityForwarder({ _, _ -> channel }) if (a11yWindows == null) { shadowOf(forwarder).setWindows(mutableListOf(AccessibilityWindowInfo.obtain())) } else { shadowOf(forwarder).setWindows(a11yWindows) } } @Test fun onInterrupt_doesNotCrash() { // Arrange. createForwarder(logAccessibilityTree = false) fakeA11yService.sendEventChecker = { _: EventRequest -> assertFalse(true) // This should not be called. "" // This should be unreachable } // Act. forwarder.onInterrupt() // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_nullEventShouldBeIgnored() { // Arrange. createForwarder(logAccessibilityTree = false) fakeA11yService.sendEventChecker = { _: EventRequest -> assertFalse(true) // This should not be called. "" // This should be unreachable } // Act. forwarder.onAccessibilityEvent(null) // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_knownEventWithNoInformationShouldNotBeEmitted() { // Arrange. createForwarder(logAccessibilityTree = false) var nodeInfo = AccessibilityNodeInfo() nodeInfo.setContentDescription("") var event = AccessibilityEvent() shadowOf(event).setSourceNode(nodeInfo) fakeA11yService.sendEventChecker = { _: EventRequest -> assertFalse(true) // This should not be called. "" // This should be unreachable } // Act. forwarder.onAccessibilityEvent(event) // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_typeViewClicked_sendEventViaGrpc() { // Arrange. createForwarder(logAccessibilityTree = false, grpcPort = 1234) forwarder = AccessibilityForwarder({ _, _ -> channel }) var nodeInfo = AccessibilityNodeInfo() nodeInfo.setContentDescription("My Content Description") nodeInfo.setText("My Source Text") nodeInfo.setClassName("AwesomeClass") var event = AccessibilityEvent() event.setEventTime(1357924680) event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED) event.getText().add("Some text!") event.setPackageName("some.loooong.package.name") shadowOf(event).setSourceNode(nodeInfo) fakeA11yService.sendEventChecker = { request: EventRequest -> // Check that all fields are consistent with how they were set above. assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_CLICKED") assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") assertThat(request.eventMap.get("source_content_description")) .isEqualTo("My Content Description") assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") // No error message "" } // Act. forwarder.onAccessibilityEvent(event) // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_typeViewTextChanged_ensureAllFieldsForwarded() { // Arrange. createForwarder(logAccessibilityTree = false, grpcPort = 1234) var nodeInfo = AccessibilityNodeInfo() nodeInfo.setContentDescription("My Content Description") nodeInfo.setText("My Source Text") nodeInfo.setClassName("AwesomeClass") var event = AccessibilityEvent() event.setEventTime(1357924680) event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_CHANGED) event.getText().add("Some text!") event.fromIndex = 7 event.beforeText = "Old words" event.addedCount = 12 event.removedCount = 9 event.setPackageName("some.loooong.package.name") shadowOf(event).setSourceNode(nodeInfo) fakeA11yService.sendEventChecker = { request: EventRequest -> // Check that all fields are consistent with how they were set above. assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_TEXT_CHANGED") assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") assertThat(request.eventMap.get("source_content_description")) .isEqualTo("My Content Description") assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") assertThat(request.eventMap.get("from_index")).isEqualTo("7") assertThat(request.eventMap.get("before_text")).isEqualTo("Old words") assertThat(request.eventMap.get("added_count")).isEqualTo("12") assertThat(request.eventMap.get("removed_count")).isEqualTo("9") assertFalse(request.eventMap.containsKey("to_index")) assertFalse(request.eventMap.containsKey("view_id")) assertFalse(request.eventMap.containsKey("action")) assertFalse(request.eventMap.containsKey("movement_granularity")) assertFalse(request.eventMap.containsKey("scroll_delta_x")) assertFalse(request.eventMap.containsKey("scroll_delta_y")) // No error message "" } // Act. forwarder.onAccessibilityEvent(event) // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_typeViewScrolled_ensureAllFieldsForwarded() { // Arrange. createForwarder(logAccessibilityTree = false, grpcPort = 1234) var nodeInfo = AccessibilityNodeInfo() nodeInfo.setContentDescription("My Content Description") nodeInfo.setText("My Source Text") nodeInfo.setClassName("AwesomeClass") var event = AccessibilityEvent() event.setEventTime(1357924680) event.setEventType(AccessibilityEvent.TYPE_VIEW_SCROLLED) event.getText().add("Some text!") event.scrollDeltaX = 13 event.scrollDeltaY = 27 event.setPackageName("some.loooong.package.name") shadowOf(event).setSourceNode(nodeInfo) fakeA11yService.sendEventChecker = { request: EventRequest -> // Check that all fields are consistent with how they were set above. assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_SCROLLED") assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") assertThat(request.eventMap.get("source_content_description")) .isEqualTo("My Content Description") assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") assertThat(request.eventMap.get("scroll_delta_x")).isEqualTo("13") assertThat(request.eventMap.get("scroll_delta_y")).isEqualTo("27") assertFalse(request.eventMap.containsKey("from_index")) assertFalse(request.eventMap.containsKey("to_index")) assertFalse(request.eventMap.containsKey("before_text")) assertFalse(request.eventMap.containsKey("added_count")) assertFalse(request.eventMap.containsKey("removed_count")) // No error message "" } // Act. forwarder.onAccessibilityEvent(event) // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_typeViewTextTraversedAtMovementGranularity_ensureAllFieldsForwarded() { // Arrange. createForwarder(logAccessibilityTree = false, grpcPort = 1234) var nodeInfo = AccessibilityNodeInfo() nodeInfo.setContentDescription("My Content Description") nodeInfo.setText("My Source Text") nodeInfo.setClassName("AwesomeClass") nodeInfo.viewIdResourceName = "this.big.old.view.id" var event = AccessibilityEvent() event.setEventTime(1357924680) event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY) event.getText().add("Some text!") event.setPackageName("some.loooong.package.name") event.movementGranularity = 5 event.fromIndex = 6 event.toIndex = 8 event.action = 23 shadowOf(event).setSourceNode(nodeInfo) fakeA11yService.sendEventChecker = { request: EventRequest -> // Check that all fields are consistent with how they were set above. assertThat(request.eventMap.get("event_type")) .isEqualTo("TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY") assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name") assertThat(request.eventMap.get("source_content_description")) .isEqualTo("My Content Description") assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text") assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass") assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!") assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680") assertThat(request.eventMap.get("movement_granularity")).isEqualTo("5") assertThat(request.eventMap.get("from_index")).isEqualTo("6") assertThat(request.eventMap.get("to_index")).isEqualTo("8") assertThat(request.eventMap.get("view_id")).isEqualTo("this.big.old.view.id") assertThat(request.eventMap.get("action")).isEqualTo("23") // No error message "" } // Act. forwarder.onAccessibilityEvent(event) // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_sendingevent_grpcTimeout() { // Arrange. createForwarder( logAccessibilityTree = false, a11yTreePeriodMs = 0, grpcHost = "amazing.host", grpcPort = 4321, ) var nodeInfo = AccessibilityNodeInfo() nodeInfo.setContentDescription("My Content Description") nodeInfo.setText("My Source Text") nodeInfo.setClassName("AwesomeClass") var event = AccessibilityEvent() event.setEventTime(1357924680) event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED) event.getText().add("Some text!") event.setPackageName("some.loooong.package.name") shadowOf(event).setSourceNode(nodeInfo) fakeA11yService.sendEventChecker = { _ -> // Delay the request to prompt a timeout Thread.sleep(1500L) "" // Return no error. } // Act. forwarder.onAccessibilityEvent(event) // Run a second request to ensure that the channel gets rebuilt. fakeA11yService.sendEventChecker = { _ -> "" } forwarder.onAccessibilityEvent(event) // Assert. // See `sendEventChecker` above. } @Test fun onAccessibilityEvent_sendingevent_grpcStatusException() { // Arrange. createForwarder(logAccessibilityTree = false, grpcHost = "amazing.host", grpcPort = 4321) var nodeInfo = AccessibilityNodeInfo() nodeInfo.setContentDescription("My Content Description") nodeInfo.setText("My Source Text") nodeInfo.setClassName("AwesomeClass") var event = AccessibilityEvent() event.setEventTime(1357924680) event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED) event.getText().add("Some text!") event.setPackageName("some.loooong.package.name") shadowOf(event).setSourceNode(nodeInfo) fakeA11yService.sendEventChecker = { _ -> throw StatusException(Status.UNAVAILABLE) } // Act. forwarder.onAccessibilityEvent(event) // Run a second request to ensure that the channel gets rebuilt. fakeA11yService.sendEventChecker = { _ -> "" } forwarder.onAccessibilityEvent(event) // Assert. // See `sendEventChecker` above. } @Test fun logAccessibilityTreeFalse_doesNotLogAccessibilityTree() { // Arrange. createForwarder(logAccessibilityTree = false, a11yTreePeriodMs = 10, grpcPort = 13579) fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest -> assertFalse(true) // This should not be called. "" // This should be unreachable } // Act. Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. // Assert. // See `sendForestChecker` above. } @Test fun grpcPortZero_doesNotSendTree() { // Arrange. createForwarder(logAccessibilityTree = true, a11yTreePeriodMs = 10, grpcPort = 0) fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest -> assertFalse(true) // This should not be called. "" // This should be unreachable } // Act. Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. // Assert. // See `sendForestChecker` above. } @Test fun grpcPortPositive_shouldSendTreeViaGrpc() { // Arrange. val window = AccessibilityWindowInfo() shadowOf(window).setType(AccessibilityWindowInfo.TYPE_SYSTEM) createForwarder( logAccessibilityTree = true, a11yTreePeriodMs = 10, grpcPort = 1234, a11yWindows = mutableListOf(window), ) fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest -> // Check that we get only a single window. assertThat(request.windowsList.size).isEqualTo(1) // And that its type is what we set above. assertThat(request.windowsList[0].windowType) .isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_SYSTEM) // The error message "Something went wrong!" } // Act. Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. // Assert. // See `sendForestChecker` above. } @Test fun grpcPortPositiveAndHost_shouldSendTreeViaGrpc() { // Arrange. fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest -> // Check that we get only a single window. assertThat(request.windowsList.size).isEqualTo(1) // And that its type is what we set above. assertThat(request.windowsList[0].windowType) .isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_ACCESSIBILITY_OVERLAY) "" // Return no error. } val window = AccessibilityWindowInfo() shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) createForwarder( logAccessibilityTree = true, a11yTreePeriodMs = 500, grpcHost = "amazing.host", grpcPort = 4321, a11yWindows = mutableListOf(window), ) // Act. Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. // Assert. // See `sendForestChecker` above. } @Test fun sendingForest_grpcTimeout() { // Arrange. fakeA11yService.sendForestChecker = { _ -> // Delay the request to prompt a timeout Thread.sleep(1500L) "" // Return no error. } val window = AccessibilityWindowInfo() shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) createForwarder( logAccessibilityTree = true, a11yTreePeriodMs = 10, grpcHost = "amazing.host", grpcPort = 4321, a11yWindows = mutableListOf(window), ) // Act. Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function. // Run a second request to ensure that the channel gets rebuilt. fakeA11yService.sendForestChecker = { _ -> "" } Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function. // Assert. // See `sendForestChecker` above. } @Test fun sendingForest_grpcStatusException() { // Arrange. val window = AccessibilityWindowInfo() shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) createForwarder( logAccessibilityTree = true, a11yTreePeriodMs = 10, grpcHost = "amazing.host", grpcPort = 4321, a11yWindows = mutableListOf(window), ) fakeA11yService.sendForestChecker = { _ -> throw StatusException(Status.UNAVAILABLE) } // Act. Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. // Run a second request to ensure that the channel gets rebuilt. fakeA11yService.sendForestChecker = { _ -> "" } Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function. // Assert. // See `sendForestChecker` above. } } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import android.graphics.Rect import android.util.Log import android.view.accessibility.AccessibilityNodeInfo import android.view.accessibility.AccessibilityWindowInfo import com.google.androidenv.accessibilityforwarder.AndroidAccessibilityWindowInfo.WindowType import java.util.concurrent.ConcurrentHashMap import java.util.stream.Collectors import kotlin.collections.mutableListOf import kotlinx.coroutines.Deferred import kotlinx.coroutines.async import kotlinx.coroutines.awaitAll import kotlinx.coroutines.runBlocking /** Helper methods for creating the android accessibility info extra. */ class AccessibilityTreeCreator() { /** Creates an accessibility forest proto. */ fun buildForest(windowInfos: List): AndroidAccessibilityForest { val sourcesMap: ConcurrentHashMap = ConcurrentHashMap() val windows: List = processWindowsAndBlock(windowInfos, sourcesMap) return androidAccessibilityForest { this.windows += windows } } private fun processWindowsAndBlock( windowInfos: List, sourcesMap: ConcurrentHashMap, ): List { val windows: List runBlocking { windows = processWindows(windowInfos, sourcesMap) } return windows } private suspend fun processWindows( windowInfos: List, sourcesMap: ConcurrentHashMap, ): List { var windowInfoProtos = mutableListOf() for (i in windowInfos.size - 1 downTo 0) { val windowInfoProto = processWindow(windowInfos.get(i), sourcesMap) windowInfoProto?.let { windowInfoProtos.add(windowInfoProto) } } return windowInfoProtos.toList() } private suspend fun processWindow( windowInfo: AccessibilityWindowInfo, sources: ConcurrentHashMap, ): AndroidAccessibilityWindowInfo? { val bounds = Rect() windowInfo.getBoundsInScreen(bounds) val root: AccessibilityNodeInfo? = windowInfo.root if (root == null) { Log.i(TAG, "window root is null") return androidAccessibilityWindowInfo { this.tree = androidAccessibilityTree {} this.isActive = windowInfo.isActive this.id = windowInfo.id this.layer = windowInfo.layer this.isAccessibilityFocused = windowInfo.isAccessibilityFocused this.isFocused = windowInfo.isFocused this.boundsInScreen = convertToRectProto(bounds) this.windowType = toWindowType(windowInfo.type) } } val treeDeferred: Deferred runBlocking { treeDeferred = async { processNodesInWindow(root, sources) } } return androidAccessibilityWindowInfo { this.tree = treeDeferred.await() this.isActive = windowInfo.isActive this.id = windowInfo.id this.layer = windowInfo.layer this.isAccessibilityFocused = windowInfo.isAccessibilityFocused this.isFocused = windowInfo.isFocused this.boundsInScreen = convertToRectProto(bounds) this.windowType = toWindowType(windowInfo.type) } } private suspend fun processNodesInWindow( root: AccessibilityNodeInfo, sources: ConcurrentHashMap, ): AndroidAccessibilityTree { Log.d(TAG, "processNodesInWindow()") val traversalQueue = ArrayDeque() traversalQueue.add(ParentChildNodePair.builder().child(root).build()) val uniqueIdsCache: UniqueIdsGenerator = UniqueIdsGenerator() var currentDepth = 0 val nodesDeferred = mutableListOf>() val seenNodes: HashSet = HashSet() seenNodes.add(root) runBlocking { while (!traversalQueue.isEmpty()) { // Traverse the tree layer-by-layer. // The first layer has only the root and depth 0. // The second layer has all the root's children and depth 1. for (nodesAtCurrentDepth in traversalQueue.size downTo 1) { val nodePair: ParentChildNodePair = traversalQueue.removeFirst() for (i in 0 until nodePair.child().childCount) { val childNode: AccessibilityNodeInfo? = nodePair.child().getChild(i) if (childNode != null && !seenNodes.contains(childNode)) { traversalQueue.add( ParentChildNodePair.builder().child(childNode).parent(nodePair.child()).build() ) seenNodes.add(childNode) } } val thisDepth = currentDepth var deferred = async { processNode(nodePair, sources, uniqueIdsCache, thisDepth) } nodesDeferred.add(deferred) } currentDepth++ } } return androidAccessibilityTree { this.nodes += nodesDeferred.awaitAll() } } companion object { private const val TAG = "AndroidRLTask" } } private fun processNode( nodePair: ParentChildNodePair, sourceBuilder: ConcurrentHashMap, uniqueIdsCache: UniqueIdsGenerator, nodeDepth: Int, ): AndroidAccessibilityNodeInfo { val node: AccessibilityNodeInfo = nodePair.child() val immutableNode: AndroidAccessibilityNodeInfo = createAndroidAccessibilityNode( node, uniqueIdsCache.getUniqueId(node), nodeDepth, getChildUniqueIds(node, uniqueIdsCache), ) sourceBuilder.put(immutableNode, node) return immutableNode } private fun createAndroidAccessibilityNode( node: AccessibilityNodeInfo, nodeId: Int, depth: Int, childIds: List, ): AndroidAccessibilityNodeInfo { val bounds = Rect() node.getBoundsInScreen(bounds) val actions = node.getActionList().stream().map(::createAction).collect(Collectors.toList()) return androidAccessibilityNodeInfo { this.actions += actions this.boundsInScreen = convertToRectProto(bounds) this.isCheckable = node.isCheckable this.isChecked = node.isChecked this.className = stringFromNullableCharSequence(node.getClassName()) this.isClickable = node.isClickable this.contentDescription = stringFromNullableCharSequence(node.getContentDescription()) this.isEditable = node.isEditable this.isEnabled = node.isEnabled this.isFocusable = node.isFocusable this.hintText = stringFromNullableCharSequence(node.getHintText()) this.isLongClickable = node.isLongClickable this.packageName = stringFromNullableCharSequence(node.getPackageName()) this.isPassword = node.isPassword this.isScrollable = node.isScrollable this.isSelected = node.isSelected this.text = stringFromNullableCharSequence(node.getText()) this.textSelectionEnd = node.getTextSelectionEnd().toLong() this.textSelectionStart = node.getTextSelectionStart().toLong() this.viewIdResourceName = node.getViewIdResourceName() ?: "" this.isVisibleToUser = node.isVisibleToUser this.windowId = node.windowId this.uniqueId = nodeId this.childIds += childIds this.drawingOrder = node.drawingOrder this.tooltipText = stringFromNullableCharSequence(node.getTooltipText()) this.depth = depth } } private fun createAction( action: AccessibilityNodeInfo.AccessibilityAction ): AndroidAccessibilityAction = AndroidAccessibilityAction.newBuilder() .setId(action.id) .setLabel(stringFromNullableCharSequence(action.label)) .build() private fun getChildUniqueIds( node: AccessibilityNodeInfo, uniqueIdsCache: UniqueIdsGenerator, ): List { val ids = mutableListOf() for (childId in 0 until node.getChildCount()) { val child: AccessibilityNodeInfo = node.getChild(childId) ?: continue ids.add(uniqueIdsCache.getUniqueId(child)) } return ids.toList() } fun stringFromNullableCharSequence(cs: CharSequence?): String = cs?.toString() ?: "" fun convertToRectProto(rect: Rect) = protoRect { left = rect.left top = rect.top right = rect.right bottom = rect.bottom } private fun toWindowType(type: Int): WindowType = when (type) { AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY -> WindowType.TYPE_ACCESSIBILITY_OVERLAY AccessibilityWindowInfo.TYPE_APPLICATION -> WindowType.TYPE_APPLICATION AccessibilityWindowInfo.TYPE_INPUT_METHOD -> WindowType.TYPE_INPUT_METHOD AccessibilityWindowInfo.TYPE_SYSTEM -> WindowType.TYPE_SYSTEM AccessibilityWindowInfo.TYPE_SPLIT_SCREEN_DIVIDER -> WindowType.TYPE_SPLIT_SCREEN_DIVIDER else -> WindowType.UNKNOWN_TYPE } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import android.view.accessibility.AccessibilityNodeInfo import android.view.accessibility.AccessibilityWindowInfo import kotlin.test.assertEquals import org.junit.Test import org.junit.runner.RunWith import org.robolectric.RobolectricTestRunner import org.robolectric.Shadows.shadowOf @RunWith(RobolectricTestRunner::class) class AccessibilityTreeCreatorTest { @Test fun buildForest_buildsAccessibilityForestCorrectly() { val creator = AccessibilityTreeCreator() val forest = creator.buildForest(mutableListOf(createWindowInfo())) assertEquals(forest.windowsCount, 1) assertEquals(forest.getWindows(0).tree.nodesCount, 3) var rootNode: AndroidAccessibilityNodeInfo? = null var checkableNode: AndroidAccessibilityNodeInfo? = null val nodes = forest.getWindows(0).tree.nodesList for (i in nodes.size - 1 downTo 0) { if (nodes[i].text == "root node") { rootNode = nodes[i] } if (nodes[i].isCheckable == true) { checkableNode = nodes[i] } } assertEquals(rootNode?.childIdsCount, 2) assertEquals(checkableNode?.text, "Check box") } @Test fun buildForest_noRootInWindow_returnsEmptyTree() { val creator = AccessibilityTreeCreator() val windowInfo = AccessibilityWindowInfo.obtain() shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) val forest = creator.buildForest(mutableListOf(windowInfo)) assertEquals(0, forest.getWindows(0).tree.nodesList.size) } private fun createAccessibilityNodeInfo(): AccessibilityNodeInfo { val root = AccessibilityNodeInfo.obtain() root.text = "root node" root.isClickable = true val accessibilityNodeInfo = AccessibilityNodeInfo.obtain() accessibilityNodeInfo.viewIdResourceName = "test" accessibilityNodeInfo.isClickable = true accessibilityNodeInfo.isEditable = true accessibilityNodeInfo.hintText = "Please enter your address" shadowOf(root).addChild(accessibilityNodeInfo) val anotherChildNode = AccessibilityNodeInfo.obtain() anotherChildNode.isCheckable = true anotherChildNode.text = "Check box" shadowOf(root).addChild(anotherChildNode) return root } private fun createWindowInfo(): AccessibilityWindowInfo { val windowInfo = AccessibilityWindowInfo.obtain() shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY) shadowOf(windowInfo).setRoot(createAccessibilityNodeInfo()) return windowInfo } } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml ================================================ ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml ================================================ ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import android.content.BroadcastReceiver import android.content.Context import android.content.Intent import android.util.Log /** Broadcast receiver responsible for enabling or disabling flags. */ class FlagsBroadcastReceiver() : BroadcastReceiver() { override fun onReceive(context: Context?, intent: Intent?) { val action = intent?.action Log.i(TAG, "Received broadcast intent with action: " + action) when (action) { ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS -> { Log.i(TAG, "Enabling Accessibility Tree logging.") LogFlags.logAccessibilityTree = true } ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS -> { Log.i(TAG, "Disabling Accessibility Tree logging.") LogFlags.logAccessibilityTree = false } ACTION_SET_GRPC -> { // The Android Emulator uses 10.0.2.2 as a redirect to the workstation's IP. Most often the // gRPC server will be running locally so it makes sense to use this as the default value. // See https://developer.android.com/studio/run/emulator-networking#networkaddresses. val host = intent.getStringExtra("host") ?: "10.0.2.2" // The TCP port to connect. If <=0 gRPC is disabled. val port = intent.getIntExtra("port", 0) Log.i(TAG, "Setting gRPC endpoint to ${host}:${port}.") LogFlags.grpcHost = host LogFlags.grpcPort = port } else -> Log.w(TAG, "Unknown action: ${action}") } } companion object { private const val TAG = "FlagsBroadcastReceiver" private const val ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS = "accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS" private const val ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS = "accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS" private const val ACTION_SET_GRPC = "accessibility_forwarder.intent.action.SET_GRPC" } } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import android.content.Intent import com.google.common.truth.Truth.assertThat import org.junit.Test import org.junit.runner.RunWith import org.robolectric.RobolectricTestRunner @RunWith(RobolectricTestRunner::class) class FlagsBroadcastReceiverTest { @Test fun onReceive_nullIntent_shouldNotLogAnything() { // Arrange. LogFlags.logAccessibilityTree = false val receiver = FlagsBroadcastReceiver() // Act. receiver.onReceive(context = null, intent = null) // Assert. assertThat(LogFlags.logAccessibilityTree).isFalse() } @Test fun onReceive_nullIntent_actionShouldNotLogAnything() { // Arrange. LogFlags.logAccessibilityTree = false val receiver = FlagsBroadcastReceiver() val intent = Intent() // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.logAccessibilityTree).isFalse() } @Test fun onReceive_unknownIntent_actionShouldIssueWarning() { // Arrange. LogFlags.logAccessibilityTree = false val receiver = FlagsBroadcastReceiver() val intent = Intent("SOME_WEIRD_ACTION") // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.logAccessibilityTree).isFalse() } @Test fun onReceive_intentWithDisableAction_shouldDisableTreeLogging() { // Arrange. LogFlags.logAccessibilityTree = true val receiver = FlagsBroadcastReceiver() val intent = Intent("accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS") // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.logAccessibilityTree).isFalse() } @Test fun onReceive_intentWithEnableAction_shouldEnableTreeLogging() { // Arrange. LogFlags.logAccessibilityTree = false val receiver = FlagsBroadcastReceiver() val intent = Intent("accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS") // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.logAccessibilityTree).isTrue() } @Test fun onReceive_intentWithSetGrpcActionNoArgs_shouldDefaultToEmuIpAndPortZero() { // Arrange. LogFlags.grpcHost = "some_host" LogFlags.grpcPort = 9999 val receiver = FlagsBroadcastReceiver() val intent = Intent("accessibility_forwarder.intent.action.SET_GRPC") // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2") assertThat(LogFlags.grpcPort).isEqualTo(0) } @Test fun onReceive_intentWithSetGrpcActionWithHostNoPort_shouldDefaultPortToZero() { // Arrange. LogFlags.grpcHost = "some_host" LogFlags.grpcPort = 9999 val receiver = FlagsBroadcastReceiver() val intent = Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { putExtra("host", "awesome.server.ca") } // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.grpcHost).isEqualTo("awesome.server.ca") assertThat(LogFlags.grpcPort).isEqualTo(0) } @Test fun onReceive_intentWithSetGrpcActionWithPortNoHost_shouldDefaultHostToEmuIp() { // Arrange. LogFlags.grpcHost = "some_host" LogFlags.grpcPort = 9999 val receiver = FlagsBroadcastReceiver() val intent = Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { putExtra("port", 54321) } // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2") assertThat(LogFlags.grpcPort).isEqualTo(54321) } @Test fun onReceive_intentWithSetGrpcActionWithHostAndPort_shouldSetBoth() { // Arrange. LogFlags.grpcHost = "some_host" LogFlags.grpcPort = 9999 val receiver = FlagsBroadcastReceiver() val intent = Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { putExtra("host", "grpc.ca") putExtra("port", 54321) } // Act. receiver.onReceive(context = null, intent = intent) // Assert. assertThat(LogFlags.grpcHost).isEqualTo("grpc.ca") assertThat(LogFlags.grpcPort).isEqualTo(54321) } } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder /** * Controls global settings in AccessibilityForwarder. * * Please note that this class is not thread safe. */ object LogFlags { // Whether to log the accessibility tree. var logAccessibilityTree: Boolean = false // How frequent to emit a11y trees (in milliseconds). var a11yTreePeriodMs: Long = 100 // The gRPC server to connect to. (Only available if grpcPort>0). var grpcHost: String = "" // If >0 this represents the gRPC port number to connect to. var grpcPort: Int = 0 } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import android.view.accessibility.AccessibilityNodeInfo import com.google.auto.value.AutoValue /** Parent and child [AccessibilityNodeInfo] relationship. */ @AutoValue internal abstract class ParentChildNodePair { abstract fun parent(): AccessibilityNodeInfo? abstract fun child(): AccessibilityNodeInfo /** [ParentChildNodePair] builder. */ @AutoValue.Builder abstract class Builder { abstract fun parent(parent: AccessibilityNodeInfo?): Builder abstract fun child(child: AccessibilityNodeInfo): Builder abstract fun build(): ParentChildNodePair } companion object { @JvmStatic fun builder(): Builder = AutoValue_ParentChildNodePair.Builder() } } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.accessibilityforwarder import java.util.concurrent.ConcurrentHashMap import java.util.concurrent.atomic.AtomicInteger import java.util.function.Function /** Thread-safe helper class for assigning a unique ID to an object. */ internal class UniqueIdsGenerator { private val nextId = AtomicInteger(0) private val uniqueIdsByNode = ConcurrentHashMap() fun getUniqueId(a: A): Int { return uniqueIdsByNode.computeIfAbsent(a, Function { _: A -> nextId.getAndIncrement() })!! } } ================================================ FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml ================================================ ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/AndroidManifest.xml ================================================ ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/BUILD.bazel ================================================ # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Classic RL task implemented as an Android app. load("@rules_android//rules:rules.bzl", "android_binary") load("@rules_kotlin//kotlin:android.bzl", "kt_android_library") package( default_visibility = [":catch_packages"], ) package_group( name = "catch_packages", packages = [ "//java/com/google/androidenv/catch/...", "//javatests/com/google/androidenv/catch/...", ], ) licenses(["notice"]) android_binary( name = "app", manifest = "AndroidManifest.xml", multidex = "native", deps = [":MainActivity"], ) kt_android_library( name = "GameLogic", srcs = ["GameLogic.kt"], deps = [ "//java/com/google/androidenv/catch/sprite:Background", "//java/com/google/androidenv/catch/sprite:Ball", "//java/com/google/androidenv/catch/sprite:LineSegment", "//java/com/google/androidenv/catch/sprite:Paddle", ], ) kt_android_library( name = "GameLogicThread", srcs = ["GameLogicThread.kt"], deps = [ ":GameLogic", ], ) kt_android_library( name = "MainActivity", srcs = ["MainActivity.kt"], manifest = "AndroidManifest.xml", resource_files = glob(["res/**"]), deps = [ ":GameLogic", ":GameLogicThread", ":RenderThread", "//java/com/google/androidenv/catch/sprite:Background", "//java/com/google/androidenv/catch/sprite:Ball", "//java/com/google/androidenv/catch/sprite:Paddle", ], ) kt_android_library( name = "RenderThread", srcs = ["RenderThread.kt"], deps = [ ":GameLogic", ], ) ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/GameLogic.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.graphics.Canvas import android.view.MotionEvent import com.google.androidenv.catch.sprite.Background import com.google.androidenv.catch.sprite.Ball import com.google.androidenv.catch.sprite.LineSegment import com.google.androidenv.catch.sprite.Paddle import java.time.Duration import java.time.Instant import kotlin.random.Random /** The class that contains the game logic. */ open class GameLogic( // Expected number of frames per second. fps: Int = 60, // Pseudo random number generator. private val rand: Random = Random.Default, // Width and height of the game in pixels. private val width: Int, private val height: Int, // UI objects in the game. private var background: Background = Background(), private var ball: Ball = Ball(maxX = width, maxY = height, rand = rand), private var paddle: Paddle = Paddle(maxX = width, y = height), ) { private val sleepTime: Duration = Duration.ofMillis((1000.0 / fps).toLong()) /** Reinitializes the state of the game. */ // Need to make this open to allow for testing. open fun reset() { this.ball.reset() } /** Runs one "throw" of a [ball] that needs to be caught by the [paddle]. */ // Need to make this open to allow for testing. open fun run(): Boolean { var lastTimestamp = Instant.now() do { Thread.sleep(sleepTime.toMillis()) val now = Instant.now() val interval = Duration.between(lastTimestamp, now) lastTimestamp = now ball.update(interval) } while (!ball.isOutOfBounds()) return ball.intersects(LineSegment(paddle.topLeft(), paddle.topRight())) } /** Processes a user event (e.g. a touchscreen event) and updates the [paddle] accordingly. */ fun handleTouch(event: MotionEvent) { paddle.x = event.x.toInt() } /** Renders the game on [c]. */ open fun render(c: Canvas) { background.draw(c) ball.draw(c) paddle.draw(c) } } ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/GameLogicThread.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.util.Log /** A thread that continuously runs the game logic, resetting after each internal [run()]. */ class GameLogicThread(private val game: GameLogic, private val loggingTag: String) : Thread() { /** Whether this thread should continuously run. */ private var shouldRun: Boolean = true /** A counter of game runs. */ private var counter: Int = 0 /** * Lets the current [run()] iteration complete then break exit this [Thread]. * * Notice that [shouldRun] cannot have a private getter with a public setter (please see * https://youtrack.jetbrains.com/issue/KT-3110 for details), hence this public function. Also * notice that we cannot call this function [stop()] since it would shadow [Thread.stop()]. */ public fun finish() { shouldRun = false } /** Continuously runs the [game] until [finish()] is called. */ public override fun run() { while (shouldRun) { game.reset() Log.i(loggingTag, "${counter++} - ${game.run()}") } } } ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/MainActivity.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.app.Activity import android.content.Intent import android.graphics.Color import android.os.Bundle import android.util.Log import android.view.SurfaceHolder import android.view.SurfaceView import android.view.View import android.view.Window import com.google.androidenv.catch.sprite.Background import com.google.androidenv.catch.sprite.Ball import com.google.androidenv.catch.sprite.Paddle /** The activity that allows users to play the RL game of Catch. */ class MainActivity() : Activity(), SurfaceHolder.Callback { private var surfaceView: SurfaceView? = null private var renderThread: RenderThread? = null private var gameLogicThread: GameLogicThread? = null private val fps: Int = 60 private var gameCounter: Int = 0 private var width: Int = -1 private var height: Int = -1 private var extras: Bundle? = null // [Activity] overrides. /** Initializes the Android [View] and sets up callbacks. */ override fun onCreate(savedInstanceState: Bundle?) { super.onCreate(savedInstanceState) Log.i(TAG, "MainActivity::onCreate()") requestWindowFeature(Window.FEATURE_NO_TITLE) setContentView(R.layout.main) val surface: SurfaceView? = findViewById(R.id.surfaceView) if (surface == null) throw Exception("Could not create SurfaceView. Aborting...") surface.visibility = View.VISIBLE surface.holder.addCallback(this) surfaceView = surface extras = intent?.extras } override fun onNewIntent(intent: Intent?) { super.onNewIntent(intent) Log.i(TAG, "MainActivity::onNewIntent()") extras = intent?.extras startGame() } // [SurfaceHolder.Callback] overrides. override fun surfaceCreated(holder: SurfaceHolder) { Log.i(TAG, "MainActivity::surfaceCreated()") renderThread = RenderThread(surfaceHolder = holder, fps = fps).also { it.start() } } override fun surfaceChanged(holder: SurfaceHolder, format: Int, width: Int, height: Int) { Log.i(TAG, "MainActivity::surfaceChanged()") this.width = width this.height = height startGame() } override fun surfaceDestroyed(holder: SurfaceHolder) { Log.i(TAG, "MainActivity::surfaceDestroyed()") renderThread?.finish() renderThread?.join() gameLogicThread?.finish() gameLogicThread?.join() } private fun startGame() { Log.i(TAG, "MainActivity::startGame()") if (width <= 0 || height <= 0) { Log.e(TAG, "MainActivity::startGame() - Width or height not initialized yet.") return } val backgroundColor = Color.parseColor(extras?.getString("backgroundColor") ?: "BLACK") val ballColor = Color.parseColor(extras?.getString("ballColor") ?: "WHITE") val ballRadius = extras?.getFloat("ballRadius", 10.0f) ?: 10.0f val ballSpeed = extras?.getFloat("ballSpeed", 0.2f) ?: 0.2f val paddleColor = Color.parseColor(extras?.getString("paddleColor") ?: "WHITE") val paddleWidth = extras?.getInt("paddleWidth", 80) ?: 80 val paddleHeight = extras?.getInt("paddleHeight", 10) ?: 10 Log.i(TAG, "MainActivity::startGame() - extras bundle: $extras") val game = GameLogic( width = width, height = height, fps = fps, background = Background(color = backgroundColor), ball = Ball( maxX = width, maxY = height, color = ballColor, radius = ballRadius, speed = ballSpeed, ), paddle = Paddle( color = paddleColor, width = paddleWidth, height = paddleHeight, maxX = width, y = (height - paddleHeight / 2), ), ) // Stop the previous game logic thread if it's running. gameLogicThread?.finish() gameLogicThread?.join() // Create and start the new GameLogicThread, passing the game instance. gameLogicThread = GameLogicThread(game, TAG).also { it.start() } // Pass the same game instance to the render thread. renderThread?.game = game surfaceView?.setOnTouchListener( // Suppress warning for ClickableViewAccessibility since click handling // is not within an OnTouchListener. @SuppressWarnings("ClickableViewAccessibility") View.OnTouchListener { _, motionEvent -> game.handleTouch(motionEvent) true } ) } companion object { private const val TAG = "AndroidRLTask" } } ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/RenderThread.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.graphics.Canvas import android.view.SurfaceHolder import java.time.Duration /** A thread that continuously renders the game logic onto a surface. */ class RenderThread(private val surfaceHolder: SurfaceHolder, private val fps: Int = 60) : Thread() { /** Whether this thread should continuously run. */ private var shouldRun: Boolean = true /** How long to sleep at each [run()] iteration. */ private val sleepTime: Duration = Duration.ofMillis((1000.0 / fps).toLong()) /** The class responsible for issuing rendering commands to the canvas. */ var game: GameLogic? = null /** * Runs the current game logic [run()] to completion. * * Notice that [shouldRun] cannot have a private getter with a public setter (please see * https://youtrack.jetbrains.com/issue/KT-3110 for details), hence this public function. Also * notice that we cannot call this function [stop()] since it would shadow [Thread.stop()]. */ public fun finish() { shouldRun = false } /** Continuously renders the [game] onto [surfaceHolder]. */ public override fun run() { while (shouldRun) { if (surfaceHolder.surface?.isValid() ?: false) { val c: Canvas = surfaceHolder.lockCanvas() game?.render(c) surfaceHolder.unlockCanvasAndPost(c) } Thread.sleep(sleepTime.toMillis()) } } } ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/res/layout/main.xml ================================================ ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/res/values/strings.xml ================================================ Catch ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/sprite/BUILD.bazel ================================================ # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Sprites for the app. load("@rules_kotlin//kotlin:android.bzl", "kt_android_library") package( default_visibility = ["//java/com/google/androidenv/catch:catch_packages"], ) licenses(["notice"]) kt_android_library( name = "Background", srcs = ["Background.kt"], deps = [":Sprite"], ) kt_android_library( name = "Ball", srcs = ["Ball.kt"], deps = [ ":LineSegment", ":Point", ":Sprite", ], ) kt_android_library( name = "LineSegment", srcs = ["LineSegment.kt"], deps = [":Point"], ) kt_android_library( name = "Paddle", srcs = ["Paddle.kt"], deps = [ ":Point", ":Sprite", ], ) kt_android_library( name = "Point", srcs = ["Point.kt"], ) kt_android_library( name = "Sprite", srcs = ["Sprite.kt"], ) ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Background.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas import android.graphics.Color /** Represents the static background behind all objects. */ open class Background(private val color: Int = Color.BLACK) : Sprite() { /** Paints the canvas with the color given in the constructor. */ override fun draw(c: Canvas) { c.drawColor(color) } } ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Ball.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas import android.graphics.Color import android.graphics.Paint import java.time.Duration import kotlin.math.ceil import kotlin.math.sqrt import kotlin.random.Random /** Represents a ball that travels down in space with constant speed. */ open class Ball( private val maxX: Int, private val maxY: Int, private val color: Int = Color.WHITE, private val radius: Float = 10.0f, // `speed`'s unit is in pixels/ms. private val speed: Float = 1.0f, private val rand: Random = Random.Default, ) : Sprite() { // `x` and `y` represent the position of the center of this ball. // // Valid range [0, maxX]. 0==left, maxX==right. private var x: Int = rand.nextInt(maxX) // Valid range [0, maxY]. 0==top, maxY==bottom. private var y: Int = ceil(radius).toInt() private val paint: Paint = Paint(Paint.ANTI_ALIAS_FLAG).apply { style = Paint.Style.FILL color = (this@Ball).color } /** Returns `true` if this ball intersects the given line [segment]. */ fun intersects(segment: LineSegment): Boolean { /** A vector with two components. */ data class Vector2D(val u: Int, val v: Int) { /** Returns the dot product between two 2D vectors. */ fun dot(other: Vector2D): Int = u * other.u + v * other.v } /** Returns the vector representing [p] minus [q]. */ fun pointDiff(p: Point, q: Point): Vector2D = Vector2D(p.x - q.x, p.y - q.y) val direction = pointDiff(segment.p1, segment.p0) // p0 -> p1. val centerToP = pointDiff(segment.p0, Point(x, y)) // Ball center -> p0. // The `(centerToP + m * direction)` function models all the points in the line segment where // the independent variable `m` is a real number in [0,1]. Putting this function into the // formula for the circle (x ^ 2 + y ^ 2 = radius ^ 2) gives a quadratic equation // (am^2 + bm + c = 0) where: // [a] = direction · direction // [b] = 2 centerToP · direction // [c] = centerToP · centerToP - radius ^ 2 val a = direction.dot(direction) val b = 2 * centerToP.dot(direction) val c = centerToP.dot(centerToP) - radius * radius val delta = b * b - 4 * a * c if (delta < 0) return false // No real roots means the (infinite) line does not intersect the ball. val d = sqrt(delta) val m1 = (-b - d) / (2 * a) val m2 = (-b + d) / (2 * a) // If a root is in [0,1], the line segment intersects the circumference. // If [m1] < 0 and [m2] > 1, the line segment is "within" the circle meaning the circle // intersects the infinite line, but not the line segment. In this case, we consider that it // touched the ball. return (m1 >= 0 && m1 <= 1) || (m2 >= 0 && m2 <= 1) || (m1 < 0 && m2 > 1) } /** Places the ball at the top of the screen at a random x-coordinate. */ fun reset() { x = rand.nextInt(maxX) y = ceil(radius).toInt() } /** Moves the ball down by [timeDeltaMs]. */ open fun update(timeDelta: Duration) { y += (speed * timeDelta.toMillis()).toInt() } /** Returns whether the ball is over [maxY]. */ fun isOutOfBounds(): Boolean = y + radius > maxY || y - radius < 0 /** Draws this ball in `c`. */ override fun draw(c: Canvas) { c.drawCircle(x.toFloat(), y.toFloat(), radius, paint) } } ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/sprite/LineSegment.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite /** Represents a finite line segment in 2D connected by two points [p0] and [p1]. */ data class LineSegment(val p0: Point, val p1: Point) ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Paddle.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas import android.graphics.Color import android.graphics.Paint import android.graphics.Rect import kotlin.ranges.coerceIn /** Represents a paddle to hit/catch a falling ball. */ open class Paddle( private val color: Int = Color.WHITE, // Width and height in pixels. private val width: Int = 80, private val height: Int = 10, // maxX is the maximum X value for the center of the paddle. private val maxX: Int = 100, // The vertical position of the center of this paddle in pixels. val y: Int = 100, ) : Sprite() { // Memoize a few things to make [draw()] a bit faster. private val halfH = height / 2 private val halfW = width / 2 private val paint = Paint(Paint.ANTI_ALIAS_FLAG).apply { style = Paint.Style.FILL color = (this@Paddle).color } // The horizontal center of the paddle. var x: Int = maxX / 2 // Start in the middle. set(value) { field = value.coerceIn(0, maxX) } /** Returns the (x,y) coordinates of the top-left corner. */ fun topLeft(): Point = Point(x - halfW, y - halfH) /** Returns the (x,y) coordinates of the top-right corner. */ fun topRight(): Point = Point(x + halfW, y - halfH) fun move(deltaX: Int) { x += deltaX } override fun draw(c: Canvas) { val rect = Rect().apply { bottom = y + halfH top = y - halfH left = x - halfW right = x + halfW } c.drawRect(rect, paint) } } ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Point.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite /** Represents a cartesian point in 2D. */ data class Point(val x: Int, val y: Int) ================================================ FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Sprite.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas /** Represents something that can be drawn on the screen. */ open class Sprite { /** Draws the Sprite in the given canvas. */ open fun draw(c: Canvas) {} } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/AndroidManifest.xml ================================================ ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/BUILD.bazel ================================================ # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Tests for the Android version of the RL Catch game. load("@rules_kotlin//kotlin:android.bzl", "kt_android_local_test") load("@rules_kotlin//kotlin:core.bzl", "kt_kotlinc_options") kt_kotlinc_options( name = "kt_kotlinc_options", jvm_target = "11", # Need to override default 1.8. x_no_param_assertions = True, ) kt_android_local_test( name = "GameLogicTest", srcs = ["GameLogicTest.kt"], kotlinc_opts = ":kt_kotlinc_options", deps = [ "//java/com/google/androidenv/catch:GameLogic", "//java/com/google/androidenv/catch/sprite:Background", "//java/com/google/androidenv/catch/sprite:Ball", "//java/com/google/androidenv/catch/sprite:Paddle", "@maven//:androidx_test_ext_junit", "@maven//:androidx_test_runner", "@maven//:com_google_truth_truth", "@maven//:org_mockito_kotlin_mockito_kotlin", "@maven//:org_robolectric_robolectric", "@robolectric//bazel:android-all", ], ) kt_android_local_test( name = "GameLogicThreadTest", srcs = ["GameLogicThreadTest.kt"], kotlinc_opts = ":kt_kotlinc_options", deps = [ "//java/com/google/androidenv/catch:GameLogic", "//java/com/google/androidenv/catch:GameLogicThread", "@maven//:androidx_test_ext_junit", "@maven//:com_google_truth_truth", "@maven//:org_mockito_kotlin_mockito_kotlin", "@maven//:org_robolectric_robolectric", "@robolectric//bazel:android-all", ], ) kt_android_local_test( name = "MainActivityTest", srcs = [ "MainActivityTest.kt", ], kotlinc_opts = ":kt_kotlinc_options", manifest = "AndroidManifest.xml", deps = [ "//java/com/google/androidenv/catch:MainActivity", "@maven//:androidx_test_ext_junit", "@maven//:junit_junit", "@maven//:org_robolectric_robolectric", "@robolectric//bazel:android-all", ], ) kt_android_local_test( name = "RenderThreadTest", srcs = ["RenderThreadTest.kt"], kotlinc_opts = ":kt_kotlinc_options", deps = [ "//java/com/google/androidenv/catch:GameLogic", "//java/com/google/androidenv/catch:RenderThread", "@maven//:androidx_test_ext_junit", "@maven//:org_mockito_kotlin_mockito_kotlin", "@maven//:org_mockito_mockito_core", "@maven//:org_robolectric_robolectric", "@robolectric//bazel:android-all", ], ) ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/GameLogicTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.graphics.Canvas import androidx.test.core.view.MotionEventBuilder import androidx.test.ext.junit.runners.AndroidJUnit4 import com.google.androidenv.catch.sprite.Background import com.google.androidenv.catch.sprite.Ball import com.google.androidenv.catch.sprite.Paddle import com.google.common.truth.Truth.assertThat import java.time.Duration import java.time.Instant import kotlin.random.Random import org.junit.Test import org.junit.runner.RunWith import org.mockito.kotlin.any import org.mockito.kotlin.atLeast import org.mockito.kotlin.atMost import org.mockito.kotlin.doReturn import org.mockito.kotlin.mock import org.mockito.kotlin.spy import org.mockito.kotlin.times import org.mockito.kotlin.verify @RunWith(AndroidJUnit4::class) class GameLogicTest { @Test fun run_ballIsMissed() { // Arrange. val width = 123 val height = 33 val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 37 } val game = GameLogic( rand = mockRandom, width = width, height = height, ball = Ball(maxX = width, maxY = height, radius = 5.0f, rand = mockRandom), paddle = Paddle(maxX = width, y = height, width = 3, height = 2), ) game.reset() game.handleTouch( MotionEventBuilder.newBuilder().setPointer(/* x= */ 12.0f, /* y= */ 31.0f).build() ) // Act. val outcome = game.run() // Ball falls at x==37, ev.x==12 so ball is missed. // Assert. assertThat(outcome).isEqualTo(false) } @Test fun run_ballIsCaught() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 53 } val game = GameLogic(rand = mockRandom, width = 321, height = 47) game.reset() game.handleTouch( MotionEventBuilder.newBuilder().setPointer(/* x= */ 53.0f, /* y= */ 43.0f).build() ) // Act. val outcome = game.run() // Ball falls at x==53, ev.x==53 so ball is caught. // Assert. assertThat(outcome).isEqualTo(true) } @Test fun run_resetAllowsMultipleGamesToBePlayedWithASingleObjectAndDoesNotHang() { // Arrange. val mockRandom: Random = mock() val game = GameLogic(width = 101, height = 59, rand = mockRandom) // Act. repeat(17) { game.reset() val unused = game.run() // Ignore the outcome since we only care about run() terminating. } // Assert. // [rand.nextInt()] should be called once at construction and then 17 times for [reset()]. verify(mockRandom, times(18)).nextInt(any()) } @Test fun run_inASeparateThread() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 23 } val game = GameLogic(rand = mockRandom, width = 321, height = 89) game.reset() game.handleTouch( MotionEventBuilder.newBuilder().setPointer(/* x= */ 23.0f, /* y= */ 29.0f).build() ) var outcome: Boolean = false class MyThread(val g: GameLogic, var outcome: Boolean) : Thread() { public override fun run() { outcome = g.run() } } val someThread = MyThread(game, outcome) // Act. someThread.start() // Ball falls at x==23, ev.x==23 so ball is caught. someThread.join() // Assert. assertThat(outcome).isEqualTo(true) } @Test fun run_fpsLeadstoApproximatelyNumberOfElapsedTimeAndUpdateCalls() { // Arrange. val width = 123 val height = 300 val ball = spy(Ball(maxX = width, maxY = height, speed = 2.0f, radius = 1.0f)) val game = GameLogic(fps = 100, width = width, height = height, ball = ball) game.reset() // Act. val start = Instant.now() val unused = game.run() val end = Instant.now() // Assert. val elapsed = Duration.between(start, end) // The ball should take around `height / speed = 150` milliseconds to reach the bottom. Due to // timing non-determinism, we accept values between 100 and 200. assertThat(elapsed.toMillis()).isAtLeast(100L) assertThat(elapsed.toMillis()).isAtMost(200L) // At fps==100, we expect [update()] to be called every `1000 / 100 = 10` milliseconds. We // expect [elapsed] to be around 150ms (checked above) which should be around `150 / 10 = 15` // calls, so to account for timing non-determinism we accept between 5 and 25 calls. verify(ball, atLeast(5)).update(any()) verify(ball, atMost(25)).update(any()) } @Test fun render_drawCanBeCalledMultipleTimesWithinASingleRun() { // Arrange. val width = 321 val height = 89 val mockCanvas: Canvas = mock() val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 23 } val background = spy(Background()) val paddle = spy(Paddle()) val ball = spy(Ball(maxX = width, maxY = height)) val game = GameLogic( rand = mockRandom, width = width, height = height, background = background, ball = ball, paddle = paddle, ) game.reset() game.handleTouch( MotionEventBuilder.newBuilder().setPointer(/* x= */ 23.0f, /* y= */ 29.0f).build() ) class MyThread(val g: GameLogic) : Thread() { public override fun run() { val unused = g.run() } } val someThread = MyThread(game) // Act. someThread.start() repeat(11) { game.render(mockCanvas) } someThread.join() // Assert. verify(background, times(11)).draw(mockCanvas) verify(ball, times(11)).draw(mockCanvas) verify(paddle, times(11)).draw(mockCanvas) } } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/GameLogicThreadTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.util.Log import androidx.test.ext.junit.runners.AndroidJUnit4 import com.google.common.truth.Truth.assertThat import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.mockito.kotlin.atLeastOnce import org.mockito.kotlin.mock import org.mockito.kotlin.verify import org.robolectric.junit.rules.ExpectedLogMessagesRule @RunWith(AndroidJUnit4::class) class GameLogicThreadTest { // Rule to assert log messages, taken as a reference from MainActivityTest.kt @get:Rule val expectedLogMessagesRule = ExpectedLogMessagesRule() private val mockGame: GameLogic = mock() private val testTag = "TestAndroidRLTask" @Test fun run_iteratesGameAndLogs() { // Arrange val gameLogicThread = GameLogicThread(mockGame, testTag) // Act gameLogicThread.start() Thread.sleep(100) // Allow time for the thread to execute at least once. gameLogicThread.finish() gameLogicThread.join() // Wait for the thread to terminate. // Assert // Verify that the game's core methods were called at least once. verify(mockGame, atLeastOnce()).reset() verify(mockGame, atLeastOnce()).run() // Expect the log message from the run() loop. // The mock 'game.run()' returns false by default. expectedLogMessagesRule.expectLogMessage(Log.INFO, testTag, "0 - false") } @Test fun finish_stopsTheThread() { // Arrange val gameLogicThread = GameLogicThread(mockGame, testTag) // Act gameLogicThread.start() // Let it run for a moment before stopping it. Thread.sleep(50) gameLogicThread.finish() gameLogicThread.join() // Assert assertThat(gameLogicThread.isAlive).isFalse() } } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/MainActivityTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.content.Intent import android.util.Log import androidx.test.ext.junit.rules.ActivityScenarioRule import androidx.test.ext.junit.runners.AndroidJUnit4 import java.lang.reflect.Method import org.junit.Before import org.junit.Rule import org.junit.Test import org.junit.runner.RunWith import org.robolectric.junit.rules.ExpectedLogMessagesRule @RunWith(AndroidJUnit4::class) class MainActivityTest { @get:Rule(order = 0) val activityScenarioRule = ActivityScenarioRule(MainActivity::class.java) @get:Rule(order = 1) val expectedLogMessagesRule = ExpectedLogMessagesRule() @Before fun setUp() { expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::onCreate()") } @Test fun surfaceChanged_logsStartsGame() { activityScenarioRule.scenario.onActivity { activity -> // Arrange. val surfaceView = activity.findViewById(R.id.surfaceView) val surfaceHolder = surfaceView.holder // Act - Trigger the surfaceChanged callback with positive width and height. activity.surfaceChanged(surfaceHolder, 0, 100, 200) // Assert. expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::surfaceChanged()") expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::startGame()") } } @Test fun onNewIntent_logsStartsGame_errorsOnUninitializedWidthOrHeight() { // Arrange. val newIntent = Intent() // Find the onNewIntent method using reflection val onNewIntentMethod: Method = MainActivity::class.java.getDeclaredMethod("onNewIntent", Intent::class.java) // Enable access to protected method onNewIntentMethod.isAccessible = true activityScenarioRule.scenario.onActivity { activity -> // Act - Invoke the onNewIntent method using reflection. onNewIntentMethod.invoke(activity, newIntent) // Assert. expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::onNewIntent()") expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::startGame()") // In this test case where we don't call surfaceChanged(), default width and height // are -1 and should trigger this error to prevent Ball from initializing // with invalid negative values, since nextInt() expects a positive number. expectedLogMessagesRule.expectLogMessage( Log.ERROR, TAG, "MainActivity::startGame() - Width or height not initialized yet.", ) } } companion object { private const val TAG = "AndroidRLTask" } } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/RenderThreadTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch import android.graphics.Canvas import android.view.Surface import android.view.SurfaceHolder import androidx.test.ext.junit.runners.AndroidJUnit4 import org.junit.Test import org.junit.runner.RunWith import org.mockito.Mockito.verifyNoInteractions import org.mockito.kotlin.any import org.mockito.kotlin.atLeast import org.mockito.kotlin.atMost import org.mockito.kotlin.doReturn import org.mockito.kotlin.mock import org.mockito.kotlin.verify @RunWith(AndroidJUnit4::class) class RenderThreadTest { @Test fun run_finishBeforeStartResultsInNoRendering() { // Arrange. val surfaceHolder: SurfaceHolder = mock() val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 1000) val game: GameLogic = mock() renderThread.game = game // Act. renderThread.finish() renderThread.start() // Assert. verifyNoInteractions(game) verifyNoInteractions(surfaceHolder) } @Test fun run_startResultsInSomeRendering() { // Arrange. val canvas: Canvas = mock() val surface: Surface = mock() { on { isValid() } doReturn true } val surfaceHolder: SurfaceHolder = mock() { on { getSurface() } doReturn surface on { lockCanvas() } doReturn canvas } val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 1000) val game: GameLogic = mock() renderThread.game = game // Act. renderThread.start() Thread.sleep(/* millis= */ 500) // Sleep for at least one loop iteration. renderThread.finish() // Assert. verify(surfaceHolder, atLeast(1)).surface verify(surfaceHolder, atLeast(1)).lockCanvas() verify(surfaceHolder, atLeast(1)).unlockCanvasAndPost(any()) verify(game, atLeast(1)).render(canvas) } @Test fun run_finishStopsRendering() { // Arrange. val canvas: Canvas = mock() val surface: Surface = mock() { on { isValid() } doReturn true } val surfaceHolder: SurfaceHolder = mock() { on { getSurface() } doReturn surface on { lockCanvas() } doReturn canvas } val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 20) val game: GameLogic = mock() renderThread.game = game // Act. renderThread.start() Thread.sleep(/* millis= */ 500) // Sleep for around 10 iterations renderThread.finish() Thread.sleep(/* millis= */ 500) // Sleep some more to ensure nothing runs after. // Assert. verify(surfaceHolder, atLeast(1)).surface verify(surfaceHolder, atLeast(1)).lockCanvas() verify(surfaceHolder, atLeast(1)).unlockCanvasAndPost(any()) // We expect [game.render()] to be executed for around 500 / (1000 / 20 = 50) = 10 times. To // allow for some timing non-determinism we allow it to execute up to 15 times, but not more // than that since [renderThread.finish()] should stop the thread from calling it. verify(game, atLeast(1)).render(canvas) verify(game, atMost(15)).render(canvas) } @Test fun run_expectedFramesPerSecond() { // Arrange. val canvas: Canvas = mock() val surface: Surface = mock() { on { isValid() } doReturn true } val surfaceHolder: SurfaceHolder = mock() { on { getSurface() } doReturn surface on { lockCanvas() } doReturn canvas } val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 5) val game: GameLogic = mock() renderThread.game = game // Act. renderThread.start() Thread.sleep(/* millis= */ 2000) // Sleep for around 10 loop iterations. renderThread.finish() // Assert. verify(surfaceHolder, atLeast(1)).surface verify(surfaceHolder, atLeast(1)).lockCanvas() verify(surfaceHolder, atLeast(1)).unlockCanvasAndPost(any()) // We expect [game.render()] to be called around 2000ms / 5fps = 10 times but to account for // timing non-determinism we allow ±4 iterations. verify(game, atLeast(6)).render(canvas) verify(game, atMost(14)).render(canvas) } } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/BUILD.bazel ================================================ # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Unit tests for Sprites in Catch. load("@rules_kotlin//kotlin:android.bzl", "kt_android_local_test") load("@rules_kotlin//kotlin:core.bzl", "kt_kotlinc_options") kt_kotlinc_options( name = "kt_kotlinc_options", jvm_target = "11", # Need to override default 1.8. x_no_param_assertions = True, ) kt_android_local_test( name = "BackgroundTest", srcs = ["BackgroundTest.kt"], kotlinc_opts = ":kt_kotlinc_options", deps = [ "//java/com/google/androidenv/catch/sprite:Background", "@maven//:com_google_guava_guava", "@maven//:com_google_testparameterinjector_test_parameter_injector", "@maven//:org_mockito_kotlin_mockito_kotlin", "@maven//:org_yaml_snakeyaml", ], ) kt_android_local_test( name = "BallTest", srcs = ["BallTest.kt"], kotlinc_opts = ":kt_kotlinc_options", tags = ["robolectric"], deps = [ "//java/com/google/androidenv/catch/sprite:Ball", "//java/com/google/androidenv/catch/sprite:LineSegment", "//java/com/google/androidenv/catch/sprite:Point", "@maven//:androidx_test_ext_junit", "@maven//:com_google_guava_guava", "@maven//:com_google_truth_truth", "@maven//:org_mockito_kotlin_mockito_kotlin", "@maven//:org_robolectric_robolectric", "@robolectric//bazel:android-all", ], ) kt_android_local_test( name = "PaddleTest", srcs = ["PaddleTest.kt"], kotlinc_opts = ":kt_kotlinc_options", tags = ["robolectric"], deps = [ "//java/com/google/androidenv/catch/sprite:Paddle", "//java/com/google/androidenv/catch/sprite:Point", "@maven//:androidx_test_ext_junit", "@maven//:com_google_guava_guava", "@maven//:com_google_truth_truth", "@maven//:org_mockito_kotlin_mockito_kotlin", "@maven//:org_robolectric_robolectric", "@robolectric//bazel:android-all", ], ) kt_android_local_test( name = "SpriteTest", srcs = ["SpriteTest.kt"], kotlinc_opts = ":kt_kotlinc_options", deps = [ "//java/com/google/androidenv/catch/sprite:Sprite", "@maven//:org_mockito_kotlin_mockito_kotlin", "@maven//:org_mockito_mockito_core", ], ) ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/BackgroundTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas import android.graphics.Color import com.google.testing.junit.testparameterinjector.KotlinTestParameters.testValues import com.google.testing.junit.testparameterinjector.TestParameter import com.google.testing.junit.testparameterinjector.TestParameterInjector import org.junit.Test import org.junit.runner.RunWith import org.mockito.kotlin.mock import org.mockito.kotlin.times import org.mockito.kotlin.verify @RunWith(TestParameterInjector::class) class BackgroundTest { @Test fun draw_defaultConstructorIsBlack() { // Arrange. val mockCanvas: Canvas = mock() val background: Background = Background() // Act. background.draw(mockCanvas) // Assert. verify(mockCanvas, times(1)).drawColor(Color.BLACK) } @Test fun draw_customColors( @TestParameter colorInt: Int = testValues(0, 255, 13_579, 2_468, 12_384_173) ) { // Arrange. val mockCanvas: Canvas = mock() val background: Background = Background(color = colorInt) // Act. background.draw(mockCanvas) // Assert. verify(mockCanvas, times(1)).drawColor(colorInt) } } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/BallTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas import android.graphics.Color import android.graphics.Paint import androidx.test.ext.junit.runners.AndroidJUnit4 import com.google.common.truth.Truth.assertThat import java.time.Duration import kotlin.random.Random import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Suite import org.mockito.kotlin.any import org.mockito.kotlin.argumentCaptor import org.mockito.kotlin.doReturn import org.mockito.kotlin.eq import org.mockito.kotlin.mock import org.mockito.kotlin.verify import org.robolectric.ParameterizedRobolectricTestRunner @RunWith(Suite::class) @Suite.SuiteClasses( BallTest.UpdateAndResetTests::class, BallTest.ColorIntTest::class, BallTest.CheckBoundsTest::class, BallTest.IntersectsTest::class, ) class BallTest { @RunWith(AndroidJUnit4::class) class UpdateAndResetTests() { @Test fun isOutOfBounds_initialState_isFalse() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) { assertThat(isOutOfBounds()).isEqualTo(false) } } @Test fun isOutOfBounds_initialState_isTrueIfRadiusExceedsMaxY() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 100, maxY = 10, radius = 11.0f, speed = 1.0f, rand = mockRandom)) { assertThat(isOutOfBounds()).isEqualTo(true) } } @Test fun isOutOfBounds_initialState_isFalseIfRadiusExceedsOnlyMaxX() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 10, maxY = 100, radius = 11.0f, speed = 1.0f, rand = mockRandom)) { assertThat(isOutOfBounds()).isEqualTo(false) } } @Test fun update_zeroDurationDoesNotMove_withinBounds() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) { // Act. update(Duration.ofMillis(0)) // The ball should not move. // Assert. assertThat(isOutOfBounds()).isEqualTo(false) // It should still be within the bounds. } } @Test fun update_zeroDurationDoesNotMove_outOfBounds() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) { update(Duration.ofMillis(110)) // Place the ball out of bounds. assertThat(isOutOfBounds()).isEqualTo(true) // Act. update(Duration.ofMillis(0)) // The ball should not move. // Assert. assertThat(isOutOfBounds()).isEqualTo(true) // It should still be out of bounds. } } @Test fun update_negativeDurationsMovesUp() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) { update(Duration.ofMillis(30)) // Move the ball down 30 pixels. assertThat(isOutOfBounds()).isEqualTo(false) // Act. update(Duration.ofMillis(-50)) // Move the ball _up_ 50 pixels. // Assert. assertThat(isOutOfBounds()).isEqualTo(true) // Now it should be out-of-bounds. } } @Test fun update_singleThrow() { // Ensures that a complete throw of a ball with radius==3.0f and maxY=100 behaves as expected. // [isOutOfBounds()] should return [false] for the first (100-3.0f-3.0f)=94 [update()] calls, // but [true] afterwards. // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) { // Act. repeat(94) { update(Duration.ofMillis(1)) assertThat(isOutOfBounds()).isEqualTo(false) } update(Duration.ofMillis(1)) // Assert. assertThat(isOutOfBounds()).isEqualTo(true) } } @Test fun intersects_afterUpdate() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. // Act & Assert. with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) { assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))).isEqualTo(true) update(Duration.ofMillis(1)) assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))).isEqualTo(false) } } @Test fun reset_intersectsInitialPositionShouldBeTrue() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) { // Act. assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))).isEqualTo(true) update(Duration.ofMillis(1)) // Move the ball 1 pixels down. assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))) .isEqualTo(false) // Segment is now outside of the ball. reset() // Resetting should move the ball up again. // Assert. assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))) .isEqualTo(true) // Segment is now inside of the ball. } } @Test fun reset_differentInitialXCoordinates() { // Arrange. val ball: Ball = Ball(maxX = 100, maxY = 100, radius = 3.0f) // Act. var pointInside: Boolean = false var pointOutside: Boolean = false while (!pointInside || !pointOutside) { if (ball.intersects(LineSegment(Point(45, 0), Point(55, 0)))) { pointInside = true } else { pointOutside = true } ball.reset() // Sample a new initial position for the ball. } // Assert. // Eventually after many initial positions the ball should satisfy both conditions. assertThat(pointInside).isEqualTo(true) assertThat(pointOutside).isEqualTo(true) } } @RunWith(ParameterizedRobolectricTestRunner::class) class ColorIntTest(private val c: Int) { @Test fun draw_customBallColors() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 37 } val mockCanvas: Canvas = mock() val paintCaptor = argumentCaptor() val ball: Ball = Ball(maxX = 50, maxY = 80, radius = 1.23f, color = c, rand = mockRandom) // Act. ball.draw(mockCanvas) // Assert. verify(mockCanvas).drawCircle(eq(37.0f), eq(2.0f), eq(1.23f), paintCaptor.capture()) with(paintCaptor.lastValue) { assertThat(color).isEqualTo(c) assertThat(style).isEqualTo(Paint.Style.FILL) } } companion object { @JvmStatic @ParameterizedRobolectricTestRunner.Parameters(name = "color = {0}") fun parameters() = listOf(0, 255, -1, 13579, 2468, 12384173, Color.WHITE, Color.BLUE) } } @RunWith(ParameterizedRobolectricTestRunner::class) class CheckBoundsTest(private val p: ParamPack) { @Test fun intersects_checkBounds() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn p.maxX / 2 } // Horizontal middle. // Act. val ball: Ball = Ball(maxX = p.maxX, maxY = p.maxY, radius = p.radius, rand = mockRandom) // Assert. assertThat(ball.intersects(LineSegment(Point(p.x - 1, p.y), Point(p.x + 1, p.y)))) .isEqualTo(p.expected) } data class ParamPack( val maxX: Int, val maxY: Int, val radius: Float, val x: Int, val y: Int, val expected: Boolean, ) companion object { @JvmStatic @ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}") fun parameters() = listOf( ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 0, y = 0, expected = false, ), // Ball to the right of `x`. ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 39, y = 0, expected = false, ), // Ball to the right of `x`. ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 40, y = 10, expected = true, ), // Ball contains `x`. ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 50, y = 0, expected = true, ), // Ball contains `x`. ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 60, y = 10, expected = true, ), // Ball contains `x`. ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 61, y = 0, expected = false, ), // Ball to the left of `x`. ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 100, y = 0, expected = false, ), // Ball to the left of `x`. ParamPack( maxX = 100, maxY = 100, radius = 10.0f, x = 50, y = 21, expected = false, ), // Ball above `y`. ) } } @RunWith(ParameterizedRobolectricTestRunner::class) class IntersectsTest(private val p: ParamPack) { @Test fun intersects_ballAtx50y10radius10() { // Arrange. val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle. // Act. val ball: Ball = Ball(maxX = 100, maxY = 100, radius = 10.0f, rand = mockRandom) // Assert. assertThat(ball.intersects(p.segment)).isEqualTo(p.expected) } data class ParamPack(val segment: LineSegment, val expected: Boolean) companion object { @JvmStatic @ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}") fun parameters() = listOf( ParamPack( segment = LineSegment(Point(50, 10), Point(80, 40)), expected = true, ), // Segment that starts at the center of the ball so it should always intersect. ParamPack( segment = LineSegment(Point(49, 0), Point(51, 0)), expected = true, ), // Tangential segment that touches the bottom of the ball. ParamPack( segment = LineSegment(Point(40, 5), Point(65, 7)), expected = true, ), // Segment longer than diameter, touching the circumference twice. ParamPack( segment = LineSegment(Point(42, 2), Point(58, 1)), expected = true, ), // Segment shorter than diameter, touching the circumference twice. ParamPack( segment = LineSegment(Point(44, 4), Point(54, 3)), expected = true, ), // Segment shorter than diameter, fully inside the circle, not touching the // circumference. ParamPack( segment = LineSegment(Point(35, 4), Point(54, 3)), expected = true, ), // Segment that touches the circumference once "from the left". ParamPack( segment = LineSegment(Point(54, 7), Point(67, 13)), expected = true, ), // Segment that touches the circumference once "from the right". ParamPack( segment = LineSegment(Point(36, 7), Point(45, 0)), expected = false, ), // Segment "to the left of the ball". No intersection. ParamPack( segment = LineSegment(Point(58, -3), Point(60, 3)), expected = false, ), // Segment "to the right of the ball". No intersection. ) } } } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/PaddleTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas import android.graphics.Color import android.graphics.Paint import android.graphics.Rect import androidx.test.ext.junit.runners.AndroidJUnit4 import com.google.common.truth.Truth.assertThat import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.Suite import org.mockito.kotlin.argumentCaptor import org.mockito.kotlin.mock import org.mockito.kotlin.verify import org.robolectric.ParameterizedRobolectricTestRunner @RunWith(Suite::class) @Suite.SuiteClasses( PaddleTest.ConstructorTests::class, PaddleTest.MoveTests::class, PaddleTest.XSetterTests::class, PaddleTest.DrawTests::class, ) class PaddleTest { @RunWith(AndroidJUnit4::class) class ConstructorTests() { @Test fun x_initialValueShouldBeAtCenter() { with(Paddle(maxX = 30)) { assertThat(x).isEqualTo(15) } with(Paddle(maxX = 31)) { assertThat(x).isEqualTo(15) } } @Test fun topLeft_correspondsToGivenValues() { with(Paddle(width = 10, height = 6, maxX = 40, y = 33)) { assertThat(topLeft()).isEqualTo(Point(x = 15, y = 30)) } } @Test fun topRight_correspondsToGivenValues() { with(Paddle(width = 10, height = 6, maxX = 40, y = 33)) { assertThat(topRight()).isEqualTo(Point(x = 25, y = 30)) } } } @RunWith(ParameterizedRobolectricTestRunner::class) class MoveTests(private val p: ParamPack) { @Test fun move_expectedDestination() { // Arrange. with(Paddle(maxX = 50)) { // Act. move(deltaX = p.displacement) // Assert. assertThat(x).isEqualTo(p.expectedX) } } data class ParamPack(val displacement: Int, val expectedX: Int) companion object { @JvmStatic @ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}") fun parameters() = listOf( // Initial position is x==25. ParamPack(displacement = 10, expectedX = 35), ParamPack(displacement = -10, expectedX = 15), ParamPack(displacement = 0, expectedX = 25), // Going beyond the left and right walls should clamp the values to 0 and 50. ParamPack(displacement = -26, expectedX = 0), ParamPack(displacement = 26, expectedX = 50), ) } } @RunWith(ParameterizedRobolectricTestRunner::class) class XSetterTests(private val p: ParamPack) { @Test fun xSetter_expectedDestination() { // Arrange. with(Paddle(maxX = 50)) { // Act. x = p.target // Assert. assertThat(x).isEqualTo(p.expectedX) } } data class ParamPack(val target: Int, val expectedX: Int) companion object { @JvmStatic @ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}") fun parameters() = listOf( // Initial position is x==25. ParamPack(target = 0, expectedX = 0), ParamPack(target = 15, expectedX = 15), ParamPack(target = 25, expectedX = 25), ParamPack(target = 35, expectedX = 35), ParamPack(target = 50, expectedX = 50), // Going beyond the left and right walls should clamp the values to 0 and 50. ParamPack(target = -1, expectedX = 0), ParamPack(target = 51, expectedX = 50), ) } } @RunWith(AndroidJUnit4::class) class DrawTests() { @Test fun draw_initialPosition() { // Arrange. val mockCanvas: Canvas = mock() val rectCaptor = argumentCaptor() val paintCaptor = argumentCaptor() with(Paddle(color = Color.RED, width = 100, height = 20, maxX = 300, y = 400)) { // Act. draw(mockCanvas) // Assert. assertThat(x).isEqualTo(150) verify(mockCanvas).drawRect(rectCaptor.capture(), paintCaptor.capture()) with(rectCaptor.lastValue) { assertThat(bottom).isEqualTo(400 + 10) assertThat(top).isEqualTo(400 - 10) assertThat(left).isEqualTo(150 - 50) assertThat(right).isEqualTo(150 + 50) } } } @Test fun draw_afterMove() { // Arrange. val mockCanvas: Canvas = mock() val rectCaptor = argumentCaptor() val paintCaptor = argumentCaptor() with(Paddle(color = Color.RED, width = 100, height = 20, maxX = 300, y = 400)) { // Act. move(50) draw(mockCanvas) // Assert. assertThat(x).isEqualTo(200) verify(mockCanvas).drawRect(rectCaptor.capture(), paintCaptor.capture()) with(rectCaptor.lastValue) { assertThat(bottom).isEqualTo(400 + 10) assertThat(top).isEqualTo(400 - 10) assertThat(left).isEqualTo(200 - 50) assertThat(right).isEqualTo(200 + 50) } with(paintCaptor.lastValue) { assertThat(color).isEqualTo(Color.RED) assertThat(style).isEqualTo(Paint.Style.FILL) } } } } } ================================================ FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/SpriteTest.kt ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package com.google.androidenv.catch.sprite import android.graphics.Canvas import org.junit.Test import org.junit.runner.RunWith import org.junit.runners.JUnit4 import org.mockito.Mockito.verifyNoInteractions import org.mockito.kotlin.mock import org.mockito.kotlin.times import org.mockito.kotlin.verify /** Trivial tests to ensure the types in the API are correct. */ @RunWith(JUnit4::class) class SpriteTest { @Test fun defaultImplementationDoesNothing() { // Arrange. val mockCanvas: Canvas = mock() val sprite = Sprite() // Act. sprite.draw(mockCanvas) // Assert. verifyNoInteractions(mockCanvas) // No methods should be called on the canvas. } @Test fun draw_argumentsAreForwarded() { // Arrange. val mockSprite: Sprite = mock() val mockCanvas: Canvas = mock() // Act. mockSprite.draw(mockCanvas) // Assert. verify(mockSprite, times(1)).draw(mockCanvas) } } ================================================ FILE: android_env/components/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/components/action_fns.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Functions to convert actions between different components' formats.""" from absl import logging from android_env.components import action_type as action_type_lib from android_env.components import errors from android_env.components import pixel_fns from android_env.components.simulators import base_simulator import numpy as np def send_action_to_simulator( action: dict[str, np.ndarray], simulator: base_simulator.BaseSimulator, screen_width: int, screen_height: int, num_fingers: int, ) -> bool: """Sends the selected action to the given simulator. The simulator will interpret the action according to `action["action_type"]`. The effect this action triggers in the Android OS will be determined by the currently running application. Args: action: action which will get interpreted as a touchscreen event. simulator: The simulator that will receive the action. screen_width: The width of the touchscreen in pixels. screen_height: The height of the touchscreen in pixels. num_fingers: The number of fingers used in this simulator. """ try: match action['action_type']: # If the action is a TOUCH or LIFT, send a touch event to the simulator. case action_type_lib.ActionType.TOUCH | action_type_lib.ActionType.LIFT: prepared_action = _prepare_touch_action( action, screen_width, screen_height, num_fingers ) simulator.send_touch(prepared_action) # If the action is a key event, send a key event to the simulator. case action_type_lib.ActionType.KEYDOWN: simulator.send_key(action['keycode'].item(0), event_type='keydown') case action_type_lib.ActionType.KEYUP: simulator.send_key(action['keycode'].item(0), event_type='keyup') case action_type_lib.ActionType.KEYPRESS: simulator.send_key(action['keycode'].item(0), event_type='keypress') except errors.SendActionError: logging.exception('Unable to execute action: %r', action) return False return True def _prepare_touch_action( action: dict[str, np.ndarray], screen_width: int, screen_height: int, num_fingers: int, ) -> list[tuple[int, int, bool, int]]: """Turns an AndroidEnv action into values that the simulator can interpret. Converts float-valued 'touch_position' to integer coordinates corresponding to specific pixels, and 'action_type' to booleans indicating whether the screen is touched at said location or not. The result of this function can be sent directly to the underlying simulator (e.g. the Android Emulator, virtual machine, or a phone). Args: action: An action containing 'action_type' and 'touch_position'. Returns: A tuple with the format (x: int, y: int, down/up: bool, finger_index: int). """ touch_events = [] for i, finger_action in enumerate(_split_touch_action(action, num_fingers)): is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH touch_position = finger_action['touch_position'] touch_pixels = pixel_fns.touch_position_to_pixel_position( touch_position, width_height=(screen_width, screen_height) ) touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i)) return touch_events def _split_touch_action( action: dict[str, np.ndarray], num_fingers: int ) -> list[dict[str, np.ndarray]]: """Splits a multitouch action into a list of single-touch actions.""" single_touch_actions = [{ 'action_type': action['action_type'], 'touch_position': action['touch_position'], }] for i in range(2, num_fingers + 1): single_touch_actions.append({ 'action_type': action[f'action_type_{i}'], 'touch_position': action[f'touch_position_{i}'], }) return single_touch_actions def lift_all_fingers_action(num_fingers: int) -> dict[str, np.ndarray]: """A lift action with each finger.""" # There's always at least one finger. lift_action = { 'action_type': np.array(action_type_lib.ActionType.LIFT), 'touch_position': np.array([0, 0]), } # Subsequent fingers have separate dict entries. for i in range(2, num_fingers + 1): lift_action |= { f'action_type_{i}': np.array(action_type_lib.ActionType.LIFT), f'touch_position_{i}': np.array([0, 0]), } return lift_action ================================================ FILE: android_env/components/action_fns_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock from absl.testing import absltest from absl.testing import parameterized from android_env.components import action_fns from android_env.components import action_type as action_type_lib from android_env.components import errors from android_env.components.simulators import base_simulator import numpy as np class ActionFnsTest(parameterized.TestCase): def test_send_action_to_simulator_missing_action_type(self): """A `KeyError` should be raised if the action is missing "action_type".""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) action = {'some_key': np.array(123, np.int32)} # Act & Assert. self.assertRaises( KeyError, action_fns.send_action_to_simulator, action, simulator, 800, 600, 1, ) def test_send_action_to_simulator_sendactionerror(self): """Returns `False` if the simulator raises a SendActionError.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) simulator.send_touch.side_effect = errors.SendActionError('oops!') action = { 'action_type': action_type_lib.ActionType.TOUCH, 'touch_position': np.array([0.3, 0.5], np.float32), } # Act. output = action_fns.send_action_to_simulator( action, simulator, 800, 600, 1, ) # Assert. self.assertFalse(output) simulator.send_touch.assert_called_once() def test_send_action_to_simulator_touch_success_one_finger(self): """Returns `True` with a proper 1-finger touch action.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) action = { 'action_type': action_type_lib.ActionType.TOUCH, 'touch_position': np.array([0.2, 0.5], np.float32), } # Act. output = action_fns.send_action_to_simulator( action, simulator, 800, 600, 1, ) # Assert. self.assertTrue(output) simulator.send_touch.assert_called_once_with( [(np.int32(160), np.int32(300), True, 0)] ) def test_send_action_to_simulator_touch_success_multiple_finger(self): """Returns `True` with a proper 3-finger touch action.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) action = { 'action_type': action_type_lib.ActionType.TOUCH, 'touch_position': np.array([0.2, 0.5], np.float32), 'action_type_2': action_type_lib.ActionType.LIFT, 'touch_position_2': np.array([0.1, 0.2], np.float32), 'action_type_3': action_type_lib.ActionType.TOUCH, 'touch_position_3': np.array([0.5, 0.2], np.float32), } # Act. output = action_fns.send_action_to_simulator( action, simulator, 800, 600, 3, ) # Assert. self.assertTrue(output) simulator.send_touch.assert_called_once_with([ (np.int32(160), np.int32(300), True, 0), (np.int32(80), np.int32(120), False, 1), (np.int32(400), np.int32(120), True, 2), ]) def test_send_action_to_simulator_keydown_success(self): """Returns `True` with a proper keydown action.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) action = { 'action_type': action_type_lib.ActionType.KEYDOWN, 'keycode': np.array([21], np.int32), } # Act. output = action_fns.send_action_to_simulator( action, simulator, 800, 600, 1, ) # Assert. self.assertTrue(output) simulator.send_key.assert_called_once_with(21, event_type='keydown') def test_send_action_to_simulator_keyup_success(self): """Returns `True` with a proper keyup action.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) action = { 'action_type': action_type_lib.ActionType.KEYUP, 'keycode': np.array([42], np.int32), } # Act. output = action_fns.send_action_to_simulator( action, simulator, 800, 600, 1, ) # Assert. self.assertTrue(output) simulator.send_key.assert_called_once_with(42, event_type='keyup') def test_send_action_to_simulator_keypress_success(self): """Returns `True` with a proper keypress action.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) action = { 'action_type': action_type_lib.ActionType.KEYPRESS, 'keycode': np.array([96], np.int32), } # Act. output = action_fns.send_action_to_simulator( action, simulator, 800, 600, 1, ) # Assert. self.assertTrue(output) simulator.send_key.assert_called_once_with(96, event_type='keypress') @parameterized.named_parameters( ( 'one_finger', 1, { 'action_type': np.array(action_type_lib.ActionType.LIFT), 'touch_position': np.array([0, 0]), }, ), ( 'two_fingers', 2, { 'action_type': np.array(action_type_lib.ActionType.LIFT), 'touch_position': np.array([0, 0]), 'action_type_2': np.array(action_type_lib.ActionType.LIFT), 'touch_position_2': np.array([0, 0]), }, ), ) def test_lift_all_fingers_action( self, num_fingers: int, expected_action: dict[str, np.ndarray] ): """Returns the expected action.""" output = action_fns.lift_all_fingers_action(num_fingers) for k, v in expected_action.items(): np.testing.assert_array_equal(v, output[k]) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/action_type.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """The different kinds of actions that AndroidEnv supports. The native action space of AndroidEnv consists of a tuple consisting of - A position (x, y) ∈ [0, 1] x [0, 1], determining the location of the action on the screen, and - A discrete value, indicating the action type, which is in this file. See https://arxiv.org/abs/2105.13231, section 2.2 for details. """ import enum @enum.unique class ActionType(enum.IntEnum): """Integer values to describe each supported action in AndroidEnv. Note for KEY* types: - Only meaningful if connected to a _physical_ keyboard, _not_ virtual keyboard. - Added afterwards so they did not appear in the paper. Attributes: TOUCH: Touching the screen at a location. LIFE: Lifting the (imaginary) pointer from the screen at a location. REPEAT: Repeating the last chosen action. KEYDOWN: Sending a key down event. KEYUP: Sending a key up event. KEYPRESS: Sending a key down event, immediately followed by a key up event. """ TOUCH = 0 LIFT = 1 REPEAT = 2 KEYDOWN = 3 KEYUP = 4 KEYPRESS = 5 ================================================ FILE: android_env/components/adb_call_parser.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Processes adb_pb2.AdbRequest commands.""" import os import re import subprocess import sys import tempfile from absl import logging from android_env.components import adb_controller as adb_control from android_env.proto import adb_pb2 # A mapping from a Button enum to keycode strings. # # Please see https://developer.android.com/reference/android/view/KeyEvent # # We currently only accept the following entries: _BUTTON_TO_KEYCODE = { adb_pb2.AdbRequest.PressButton.Button.HOME: 'KEYCODE_HOME', adb_pb2.AdbRequest.PressButton.Button.BACK: 'KEYCODE_BACK', adb_pb2.AdbRequest.PressButton.Button.ENTER: 'KEYCODE_ENTER', } class AdbCallParser: """Parses AdbRequest messages and executes corresponding adb commands.""" def __init__(self, adb_controller: adb_control.AdbController): self._adb_controller = adb_controller self._handlers = { 'install_apk': self._install_apk, 'start_activity': self._start_activity, 'force_stop': self._force_stop, 'tap': self._tap, 'press_button': self._press_button, 'start_screen_pinning': self._start_screen_pinning, 'send_broadcast': self._send_broadcast, 'uninstall_package': self._handle_uninstall_package, 'get_current_activity': self._get_current_activity, 'get_orientation': self._get_orientation, 'push': self._push, 'pull': self._pull, 'input_text': self._input_text, 'settings': self._handle_settings, 'generic': self._handle_generic, 'package_manager': self._handle_package_manager, 'dumpsys': self._handle_dumpsys, } def _execute_command( self, command_args: list[str], timeout: float | None ) -> tuple[adb_pb2.AdbResponse, bytes]: """Executes the command, catches errors and populates the response status. Args: command_args: a list of arguments for the ADB request. timeout: Timeout in seconds. Returns: A tuple of the AdbResponse with the status populated, and the output bytes from the command. """ response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) command_output = b'' try: command_output = self._adb_controller.execute_command( command_args, timeout=timeout) except subprocess.CalledProcessError as adb_error: if adb_error.stdout is not None: response.status = adb_pb2.AdbResponse.Status.ADB_ERROR response.error_message = adb_error.stdout except subprocess.TimeoutExpired: response.status = adb_pb2.AdbResponse.Status.TIMEOUT response.error_message = 'Timeout' return response, command_output def parse(self, request: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: """Executes `request` and returns an appropriate response.""" response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) command_type = request.WhichOneof('command') logging.debug('AdbRequest command type: %s', command_type) if command_type is None: response.status = adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND response.error_message = 'AdbRequest.command is None.' return response if request.timeout_sec < 0: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ('AdbRequest.timeout_sec cannot be negative. ' f'Got: {request.timeout_sec}') return response timeout: float | None = request.timeout_sec or None return self._handlers[command_type](request, timeout) def _force_stop( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Stops an application. Args: request: The external request containing the package to force stop. timeout: Optional time limit in seconds. Returns: An AdbResponse. """ force_stop = request.force_stop response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) if not force_stop.package_name: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = '`force_stop.package_name` cannot be empty.' return response response, _ = self._execute_command( ['shell', 'am', 'force-stop', force_stop.package_name], timeout) return response def _fetch_current_task_id( self, full_activity_name: str, timeout: float | None = None ) -> int: """Returns the task ID of the given `full_activity_name`. Args: full_activity_name: The full name of the activity whose corresponding task id we are looking for. timeout: Optional time limit in seconds. Returns: task_id: An integer corresponding to the specified activity. """ stack = self._adb_controller.execute_command( ['shell', 'am', 'stack', 'list'], timeout=timeout) lines = stack.decode('utf-8').splitlines() regex = re.compile( r'^\ *taskId=(?P[0-9]*): (?P[^\s]*) .*visible=true' r'.*topActivity=ComponentInfo{(?P[^\s]*)}$') for line in lines: match = regex.search(line) if match is None: continue current_task_id_str = match.group('id') base_activity = match.group('base_activity') top_activity = match.group('top_activity') # If neither of the matched activities equals the activity we are # looking for, we discard their task id and continue the search. if full_activity_name not in {base_activity, top_activity}: logging.info('Full activity %s was not found in current line %s', full_activity_name, line) continue # Otherwise return the integer task id. try: return int(current_task_id_str) except ValueError: logging.info('Failed to parse task ID [%r].', current_task_id_str) # At this point if we could not find a task ID, there's nothing we can do. logging.error('Could not find current activity in stack list: %r', lines) return -1 def _start_screen_pinning( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Pins an application. Args: request: The request containing the activity to pin. timeout: Optional time limit in seconds. Returns: An AdbResponse. """ full_activity = request.start_screen_pinning.full_activity response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) if not full_activity: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( '`start_screen_pinning.full_activity` cannot be empty.') return response current_task_id = self._fetch_current_task_id(full_activity, timeout) if current_task_id == -1: response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR response.error_message = ('Could not find task ID for activity ' f'[{full_activity}]') return response response, _ = self._execute_command( ['shell', 'am', 'task', 'lock', str(current_task_id)], timeout=timeout) return response def _send_broadcast( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Sends a broadcast. Args: request: The request with the information for the broadcast event. timeout: Optional time limit in seconds. Returns: An AdbResponse. """ send_broadcast = request.send_broadcast response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) if not send_broadcast.action: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ('`send_broadcast.{action}` cannot be empty.') return response if send_broadcast.component: component_args = ['-n', send_broadcast.component] else: component_args = [] response, _ = self._execute_command( ['shell', 'am', 'broadcast', '-a', send_broadcast.action] + component_args, timeout=timeout, ) return response def _install_apk( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Installs an app given its local path in the filesystem. Args: request: The external request with an install_apk field. Contains information for the .apk installation. timeout: Optional time limit in seconds. Returns: An AdbResponse. """ install_apk = request.install_apk response = adb_pb2.AdbResponse() location_type = install_apk.WhichOneof('location') logging.info('location_type: %s', location_type) match location_type: case 'filesystem': fpath = install_apk.filesystem.path if not os.path.exists(fpath): response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR response.error_message = f'Could not find local_apk_path: {fpath}' return response response, _ = self._execute_command( ['install', '-r', '-t', '-g', fpath], timeout=timeout ) case 'blob': # `delete_on_close` was only added in Python 3.12 so we add a switch # here to still support previous Python versions. if sys.version_info >= (3, 12): kwargs = {'suffix': '.apk', 'delete_on_close': False} else: kwargs = {'suffix': '.apk'} with tempfile.NamedTemporaryFile(**kwargs) as f: fpath = f.name f.write(install_apk.blob.contents) response, _ = self._execute_command( ['install', '-r', '-t', '-g', fpath], timeout=timeout ) case _: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( f'Unsupported `install_apk.location` type: {location_type}' ) return response return response def _start_activity( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Starts a given activity. Options for `start_activity`: `am start` command options: -D: enable debugging -W: wait for launch to complete --start-profiler : start profiler and send results to -P : like above, but profiling stops when app goes idle -R: repeat the activity launch times. Prior to each repeat, the top activity will be finished. -S: force stop the target app before starting the activity --opengl-trace: enable tracing of OpenGL functions Args: request: The request with information on what activity to start. timeout: Optional time limit in seconds. Returns: An AdbResponse. If successful, StartActivityResponse will contain the activity name and adb command output. """ activity = request.start_activity.full_activity if not activity: return adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, error_message='`start_activity.full_activity` cannot be empty.') force_stop = '-S' if request.start_activity.force_stop else '' response, command_output = self._execute_command( ['shell', 'am', 'start', force_stop, '-W', '-n', activity] + list(request.start_activity.extra_args or []), timeout=timeout) # Check command output for potential errors. expected_error = re.compile(r""".*Error.*""", re.VERBOSE) if expected_error.match(str(command_output)): return adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.INTERNAL_ERROR, error_message=f'start_activity failed with error: {command_output}') response.start_activity.full_activity = activity response.start_activity.output = command_output return response def _press_button( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Presses a keyboard key. Args: request: The request with information on what button to press. timeout: Optional time limit in seconds. Returns: An AdbResponse. """ button = request.press_button.button if button not in _BUTTON_TO_KEYCODE: return adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, error_message=('PressButton.button must be one of ' f'[{_BUTTON_TO_KEYCODE.keys()}]. ' f'Got: {button}. Please see `adb.proto`.')) keycode = _BUTTON_TO_KEYCODE[button] response, command_output = self._execute_command( ['shell', 'input', 'keyevent', keycode], timeout=timeout) response.press_button.output = command_output return response def _handle_uninstall_package( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Handles UninstallPackage messages. Args: request: The specification of what to uninstall. timeout: Optional time limit in seconds. Returns: An AdbResponse """ package_name = request.uninstall_package.package_name response = adb_pb2.AdbResponse() # Every UninstallPackage should have a package_name. if not package_name: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( '`uninstall_package.package_name` cannot be empty.') return response # Get list of installed packages and issue an uninstall only if it's # already installed. package_response = self._handle_package_manager( adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( packages=adb_pb2.AdbRequest.PackageManagerRequest.List .Packages())))) if package_name in package_response.package_manager.list.items: response, _ = self._execute_command(['uninstall', package_name], timeout) else: msg = (f'Cannot uninstall {package_name} since it is not installed.') logging.warning(msg) response.error_message = msg return response def _get_current_activity( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Fetches current activity. Args: request: The request with the `.get_current_activity` field set. This is unused, but it's in the signature so that all calls are uniform. timeout: Optional time limit in seconds. Returns: AdbResponse containing the current activity. """ del request # Unused. response, visible_task = self._execute_command( ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'], timeout=timeout) if response.status != adb_pb2.AdbResponse.Status.OK: return response if not visible_task: _, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'], timeout=timeout) response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR response.error_message = ('Empty visible_task. `am stack list`: ' f'{am_stack_list}') return response visible_task = visible_task.decode('utf-8') if sys.platform == 'win32': visible_task_list = re.findall( r'visible=true topActivity=ComponentInfo{(.+?)}', visible_task) if not visible_task_list: visible_task = '' else: visible_task = 'ComponentInfo{' + visible_task_list[0] + '}' p = re.compile(r'.*\{(.*)\}') matches = p.search(visible_task) if matches is None: _, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'], timeout=timeout) response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR response.error_message = ( 'Could not extract current activity. Will return nothing. ' f'`am stack list`: {am_stack_list}') return response response.get_current_activity.full_activity = matches.group(1) return response def _get_orientation( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Fetches current device orientation. Args: request: The request with the `.get_orientation` field set. timeout: Optional time limit in seconds. Returns: AdbResponse containing the current device orientation. This is unused, but it's in the signature so that all calls are uniform. """ del request # Unused. logging.info('Getting orientation...') response = self._handle_dumpsys( adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='input')), timeout=timeout) output = response.dumpsys.output if not output: logging.error('Empty dumpsys output.') response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR response.error_message = 'Failed to execute `dumpsys input`' return response output = output.decode('utf-8') lines = output.split('\n') # Split by lines. skip_next = False for line in lines: # There may be multiple devices in output. An invalid device can be # identified by negative PhysicalWidth. physical_width = re.match(r'\s+PhysicalWidth:\s+(-?\d+)px', line) if physical_width: skip_next = int(physical_width.group(1)) < 0 # Depending on the device type, the orientation could take these forms: # SurfaceOrientation: 0 # InputDeviceOrientation: Rotation0 surface_orientation = re.match( r'\s+(SurfaceOrientation|InputDeviceOrientation):\s+.*(\d)', line ) if surface_orientation is not None: if skip_next: continue if surface_orientation.re.groups < 2: continue orientation = surface_orientation.group(2) logging.info('Done getting orientation: %r', orientation) response.get_orientation.orientation = int(orientation) return response response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR response.error_message = ( 'Could not find SurfaceOrientation/InputDeviceOrientation in dumpsys ' 'output' ) return response def _push( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Uploads contents to the device. Args: request: The request with the contents to push to the device. timeout: Optional time limit in seconds. Returns: An empty AdbResponse. """ path = request.push.path if not path: return adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, error_message='Push.path is empty.') # Create temporary file with `push` contents. with tempfile.NamedTemporaryFile(delete=False) as f: fname = f.name f.write(request.push.content) # Issue `adb push` command to upload file. logging.info('Uploading %r to %r.', fname, path) response, _ = self._execute_command(['push', fname, path], timeout=timeout) # Delete it. os.remove(fname) return response def _pull( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Downloads file content from the device. Args: request: The request with the information on what to get from the device. timeout: Optional time limit in seconds. Returns: An AdbResponse with the contents of the specified file. """ path = request.pull.path if not path: return adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, error_message='Pull.path is empty.') # Issue `adb pull` command to copy it to a temporary file. with tempfile.NamedTemporaryFile(delete=False) as f: fname = f.name logging.debug('Downloading %r to %r.', path, fname) response, _ = self._execute_command(['pull', path, fname], timeout=timeout) # Read the content of the file. with open(fname, 'rb') as f: response.pull.content = f.read() # Delete it. os.remove(fname) return response def _input_text( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Inserts text as keyboard events. Args: request: The external request. timeout: Optional time limit in seconds. Returns: An AdbResponse """ text = request.input_text.text if not text: return adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, error_message='InputText.text is empty.') response, _ = self._execute_command(['shell', 'input', 'text', text], timeout=timeout) return response def _tap( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Taps the device screen. Args: request: The request with information on where to tap the screen. timeout: Optional time limit in seconds. Returns: An AdbResponse """ x = request.tap.x y = request.tap.y # Check for negative coordinates. # Notice that zero coordinates are valid coordinates (i.e. the first # column/row of the screen). if x < 0 or y < 0: return adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION, error_message=( f'Tap coordinates must be non-negative. Got: {request.tap}.')) response, _ = self._execute_command( ['shell', 'input', 'tap', str(x), str(y)], timeout=timeout) return response def _handle_settings( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Handles SettingsRequest messages. Args: request: The specification of what to do with settings. timeout: Optional time limit in seconds. Returns: An AdbResponse """ request = request.settings response = adb_pb2.AdbResponse() # Every SettingsRequest should have a namespace. if ( request.name_space == adb_pb2.AdbRequest.SettingsRequest.Namespace.UNKNOWN ): response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( f'Unknown SettingsRequest.name_space. Got: {request}.') return response namespace = adb_pb2.AdbRequest.SettingsRequest.Namespace.Name( request.name_space).lower() match request.WhichOneof('verb'): case 'get': get = request.get if not get.key: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( f'Empty SettingsRequest.get.key. Got: {request}.' ) return response response, command_output = self._execute_command( ['shell', 'settings', 'get', namespace, get.key], timeout=timeout ) response.settings.output = command_output case 'put': put = request.put if not put.key or not put.value: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( f'Empty SettingsRequest.put key or value. Got: {request}.' ) return response response, command_output = self._execute_command( ['shell', 'settings', 'put', namespace, put.key, put.value], timeout=timeout, ) response.settings.output = command_output case 'delete_key': delete = request.delete_key if not delete.key: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( f'Empty SettingsRequest.delete_key.key. Got: {request}.' ) return response response, command_output = self._execute_command( ['shell', 'settings', 'delete', namespace, delete.key], timeout=timeout, ) response.settings.output = command_output case 'reset': reset = request.reset # At least one of `package_name` or `mode` should be given. if ( not reset.package_name and reset.mode == adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN ): response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( 'At least one of SettingsRequest.reset package_name or mode' f' should be given. Got: {request}.' ) return response mode = adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.Name( reset.mode ).lower() arg = reset.package_name or mode response, command_output = self._execute_command( ['shell', 'settings', 'reset', namespace, arg], timeout=timeout ) response.settings.output = command_output case 'list': response, command_output = self._execute_command( ['shell', 'settings', 'list', namespace], timeout=timeout ) response.settings.output = command_output case _: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( f'Unknown SettingsRequest.verb. Got: {request}.' ) return response def _handle_generic( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Handles GenericRequest messages. Args: request: The request with the `.generic` field set indicating what `adb` shell command to issue timeout: Optional time limit in seconds. Returns: An AdbResponse """ response, command_output = self._execute_command( list(request.generic.args), timeout) response.generic.output = command_output return response def _handle_package_manager( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Handles PackageManagerRequest messages. Args: request: The request with the `.package_manager` field set containing the sub-commands to issue to `adb pm`. timeout: Optional time limit in seconds. Returns: An AdbResponse. """ request = request.package_manager response = adb_pb2.AdbResponse() match request.WhichOneof('verb'): case 'list': what = request.list.WhichOneof('what') response, output = self._execute_command( ['shell', 'pm', 'list', what], timeout=timeout ) if output: items = output.decode('utf-8').split() # Remove prefix for each item. prefix = { 'features': 'feature:', 'libraries': 'library:', 'packages': 'package:', }[what] items = [x[len(prefix) :] for x in items if x.startswith(prefix)] response.package_manager.list.items.extend(items) response.package_manager.output = output case 'clear': package_name = request.clear.package_name if not package_name: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( f'Empty PackageManagerRequest.clear.package_name. Got: {request}.' ) return response args = ['shell', 'pm', 'clear', package_name] if request.clear.user_id: args.insert(3, '-f') args.insert(4, request.clear.user_id) response, response.package_manager.output = self._execute_command( args, timeout=timeout ) case 'grant': grant = request.grant if not grant.package_name: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = '`grant.package_name` cannot be empty.' return response if not grant.permissions: response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = '`grant.permissions` cannot be empty.' return response for permission in grant.permissions: logging.info('Granting permission: %r', permission) response, response.package_manager.output = self._execute_command( ['shell', 'pm', 'grant', grant.package_name, permission], timeout=timeout, ) return response def _handle_dumpsys( self, request: adb_pb2.AdbRequest, timeout: float | None = None ) -> adb_pb2.AdbResponse: """Handles DumpsysRequest messages. Args: request: The request with the `.dumpsys` field set containing sub-commands to `adb dumpsys` shell command.. timeout: Optional time limit in seconds. Returns: An AdbResponse. """ request = request.dumpsys cmd = ['shell', 'dumpsys'] if request.timeout_sec < 0 or request.timeout_ms < 0: response = adb_pb2.AdbResponse() response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( 'DumpsysRequest.timeout_{sec, ms} should be non-negative. ' f'Got: {request}.') return response if request.list_only: # `-l` cannot be combined with the following options. if request.service or request.args or request.skip_services: response = adb_pb2.AdbResponse() response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( 'DumpsysRequest.list_only cannot be combined with other options. ' f'Got: {request}.') return response cmd.append('-l') if request.timeout_sec > 0: cmd.append('-t') cmd.append(str(request.timeout_sec)) elif request.timeout_ms > 0: cmd.append('-T') cmd.append(str(request.timeout_ms)) if ( request.priority != adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET ): cmd.append('--priority') cmd.append(adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.Name( request.priority)) if request.skip_services: if request.service: response = adb_pb2.AdbResponse() response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION response.error_message = ( 'DumpsysRequest.skip_services cannot be combined with `service`. ' f'Got: {request}.') return response cmd.append('--skip') cmd.append(','.join(request.skip_services)) if request.service: cmd.append(request.service) if request.args: cmd += list(request.args) if request.proto: cmd.append('--proto') response, response.dumpsys.output = self._execute_command( cmd, timeout=timeout) return response ================================================ FILE: android_env/components/adb_call_parser_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import builtins import os import subprocess import sys import tempfile from unittest import mock from absl.testing import absltest from absl.testing import parameterized from android_env.components import adb_call_parser from android_env.components import adb_controller from android_env.proto import adb_pb2 class AdbCallParserTest(parameterized.TestCase): def test_unknown_command(self): """Gets UNKNOWN_COMMAND for an empty request.""" adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() response = parser.parse(request) self.assertEqual( response.status, adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND ) def test_invalid_timeout(self): """AdbRequest.timeout_sec must be positive.""" adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.tap.x = 123 request.timeout_sec = -5 response = parser.parse(request) self.assertEqual( response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION ) @mock.patch.object(os.path, 'exists', autospec=True) def test_install_apk_file_not_found(self, mock_exists): """Should fail installing APK when it is not found.""" adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.install_apk.filesystem.path = '/my/home/game.apk' mock_exists.return_value = False response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() @mock.patch.object(os.path, 'exists', autospec=True) def test_install_apk_successful(self, mock_exists): """Should succeed installing an arbitrary APK.""" adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.install_apk.filesystem.path = '/my/home/game.apk' mock_exists.return_value = True response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['install', '-r', '-t', '-g', '/my/home/game.apk'], None) @mock.patch.object(tempfile, 'NamedTemporaryFile', autospec=True) def test_install_apk_from_blob(self, mock_tempfile): """Should succeed installing APK from blob.""" adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() blob_content = b'A fake blob content' request.install_apk.blob.contents = blob_content mock_tempfile.return_value.__enter__.return_value.name = '/my/home/test.apk' mock_tempfile.return_value.__enter__.return_value.write.return_value = None response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['install', '-r', '-t', '-g', '/my/home/test.apk'], None ) # pytype: disable=attribute-error expected_tempfile_kwargs = ( {'suffix': '.apk', 'delete_on_close': False} if sys.version_info > (3, 12) else {'suffix': '.apk'} ) mock_tempfile.assert_has_calls([ mock.call(**expected_tempfile_kwargs), # Constructor mock.call().__enter__(), # Enter context mock.call().__enter__().write(blob_content), # Call write function mock.call().__exit__(None, None, None), # Exit context ]) # pytype: enable=attribute-error def test_start_activity_empty_full_activity(self): """A start_activity command should always have a nonempty activity.""" adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_activity.extra_args.extend(['blah']) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) def test_start_activity_successful(self): adb = mock.create_autospec(adb_controller.AdbController) command_output = (b'Stopping: my.project.SplashActivity\n' b'Starting: Intent { cmp=my.project.SplashActivity }\n') adb.execute_command.return_value = command_output parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_activity.full_activity = 'my.project.SplashActivity' request.start_activity.extra_args.extend(['blah']) request.start_activity.force_stop = True response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call([ 'shell', 'am', 'start', '-S', '-W', '-n', 'my.project.SplashActivity', 'blah' ], timeout=None), ]) def test_start_activity_successful_no_force_stop(self): adb = mock.create_autospec(adb_controller.AdbController) command_output = (b'Stopping: my.project.SplashActivity\n' b'Starting: Intent { cmp=my.project.SplashActivity }\n') adb.execute_command.return_value = command_output parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_activity.full_activity = 'my.project.SplashActivity' request.start_activity.extra_args.extend(['blah']) request.start_activity.force_stop = False response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call([ 'shell', 'am', 'start', '', '-W', '-n', 'my.project.SplashActivity', 'blah' ], timeout=None), ]) def test_start_activity_error(self): adb = mock.create_autospec(adb_controller.AdbController) command_output = (b'Stopping: my.project.SplashActivity\n' b'Starting: Intent { cmp=my.project.SplashActivity }\n' b'Error: Activity not started, unknown error code 101\n') adb.execute_command.return_value = command_output parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_activity.full_activity = 'my.project.SplashActivity' request.start_activity.extra_args.extend(['blah']) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) self.assertEqual( response.error_message, f'start_activity failed with error: {str(command_output)}') def test_force_stop(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.force_stop.package_name = 'my.project' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'am', 'force-stop', 'my.project'], None) def test_grant_permissions_empty_package_name(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.package_manager.grant.permissions.extend(['perm1', 'perm2']) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) def test_grant_permissions_empty_permissions(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.package_manager.grant.package_name = 'my.project' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) def test_grant_permissions_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.package_manager.grant.package_name = 'my.project' request.package_manager.grant.permissions.extend(['perm1', 'perm2']) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call(['shell', 'pm', 'grant', 'my.project', 'perm1'], None), mock.call(['shell', 'pm', 'grant', 'my.project', 'perm2'], None), ]) def test_press_button_invalid_button(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.press_button.button = 99999 response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) def test_press_button_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'' parser = adb_call_parser.AdbCallParser(adb) # HOME. request = adb_pb2.AdbRequest() request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.HOME response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_with( ['shell', 'input', 'keyevent', 'KEYCODE_HOME'], None) # BACK. request = adb_pb2.AdbRequest() request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.BACK response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_with( ['shell', 'input', 'keyevent', 'KEYCODE_BACK'], None) # ENTER. request = adb_pb2.AdbRequest() request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.ENTER response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_with( ['shell', 'input', 'keyevent', 'KEYCODE_ENTER'], None) def test_start_screen_pinning_package_not_found(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = ( b' taskId=12345: my.project.AnotherActivity visible=true' b' topActivity=ComponentInfo{my.project.AnotherActivity}') parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_screen_pinning.full_activity = 'my.project.AmazingActivity' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) self.assertNotEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'am', 'stack', 'list'], None) def test_start_screen_pinning_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = ( b' taskId=12345: my.project.AmazingActivity visible=true' b' topActivity=ComponentInfo{my.project.AmazingActivity}') parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_screen_pinning.full_activity = 'my.project.AmazingActivity' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call(['shell', 'am', 'stack', 'list'], None), mock.call(['shell', 'am', 'task', 'lock', '12345'], None), ]) def test_start_screen_pinning_base_activity(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = ( b' taskId=12345: my.project.MainActivity visible=true' b' topActivity=ComponentInfo{my.project.TopActivity}') parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_screen_pinning.full_activity = 'my.project.MainActivity' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call(['shell', 'am', 'stack', 'list'], None), mock.call(['shell', 'am', 'task', 'lock', '12345'], None), ]) def test_start_screen_pinning_top_activity(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = ( b' taskId=12345: my.project.MainActivity visible=true' b' topActivity=ComponentInfo{my.project.TopActivity}') parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.start_screen_pinning.full_activity = 'my.project.TopActivity' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call(['shell', 'am', 'stack', 'list'], None), mock.call(['shell', 'am', 'task', 'lock', '12345'], None), ]) def test_send_broadcast_empty_action(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( send_broadcast=adb_pb2.AdbRequest.SendBroadcast()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) def test_send_broadcast_successful(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.send_broadcast.action = 'SOME-ACTION' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) def test_send_broadcast_with_component_successful(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.send_broadcast.action = 'SOME-ACTION' request.send_broadcast.component = 'SOME-COMPONENT' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) def test_uninstall_package_empty_package_name(self): adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.uninstall_package.package_name = '' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) def test_uninstall_package_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'package:my.package' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest() request.uninstall_package.package_name = 'my.package' response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) def test_get_current_activity_no_visible_task(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = None parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) self.assertNotEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call( ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'], None), mock.call(['shell', 'am', 'stack', 'list'], None), ]) def test_get_orientation_empty_dumpsys(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) self.assertNotEmpty(response.error_message) adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'], None) def test_get_orientation_invalid_device_no_surface_orientation(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b' PhysicalWidth: -123px' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) self.assertNotEmpty(response.error_message) adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'], None) @parameterized.named_parameters( ('rotation_0', b""" SurfaceOrientation: 0""", 0), ('rotation_90', b""" SurfaceOrientation: 1""", 1), ('rotation_180', b""" SurfaceOrientation: 2""", 2), ('rotation_270', b""" SurfaceOrientation: 3""", 3), ('rotation_0_new', b""" InputDeviceOrientation: 0""", 0), ('rotation_90_new', b""" InputDeviceOrientation: 1""", 1), ('rotation_180_new', b""" InputDeviceOrientation: 2""", 2), ('rotation_270_new', b""" InputDeviceOrientation: 3""", 3), ) def test_get_orientation_success( self, orientation: bytes, expected_orientation: int ): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = ( b"""SomeRandomKey: 12345\n""" + orientation + b""" MoreRandomStuff: awesome_value """ ) parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) self.assertEqual(response.get_orientation.orientation, expected_orientation) adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'], None) def test_get_current_activity_no_matches(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()) for platform in ['win32', 'linux']: with mock.patch.object( sys, 'platform', autospec=True, return_value=platform): response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR) self.assertNotEmpty(response.error_message) adb.execute_command.assert_has_calls([ mock.call([ 'shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true' ], None), mock.call(['shell', 'am', 'stack', 'list'], None), ]) def test_get_current_activity_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'{MyAwesomeActivity}' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()) for platform in ['win32', 'linux']: with mock.patch.object( sys, 'platform', autospec=True, return_value=platform): response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) # `execute_command` will be called once for each platform. adb.execute_command.assert_called_with( ['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'], None) self.assertEqual(response.get_current_activity.full_activity, 'MyAwesomeActivity') def test_push_no_path(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( push=adb_pb2.AdbRequest.Push(content=b'Has content but no path')) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() def test_push_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( push=adb_pb2.AdbRequest.Push( content=b'My text.', path='/sdcard/my_file.txt')) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once() args, kwargs = adb.execute_command.call_args self.assertLen(args, 1) cmd_args = args[0] self.assertLen(cmd_args, 3) self.assertEqual(cmd_args[0], 'push') self.assertEqual(cmd_args[2], '/sdcard/my_file.txt') self.assertIn('timeout', kwargs) self.assertIsNone(kwargs['timeout']) def test_pull_no_path(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest(pull=adb_pb2.AdbRequest.Pull()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() @mock.patch.object(builtins, 'open', autospec=True) def test_pull_successful(self, mock_open): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = b'S3cR3t. dO nOt TeLl ANYONE' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( pull=adb_pb2.AdbRequest.Pull(path='/sdcard/my_file.txt')) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) self.assertEqual(response.pull.content, b'S3cR3t. dO nOt TeLl ANYONE') adb.execute_command.assert_called_once() args, kwargs = adb.execute_command.call_args self.assertLen(args, 1) cmd_args = args[0] self.assertLen(cmd_args, 3) self.assertEqual(cmd_args[0], 'pull') self.assertEqual(cmd_args[1], '/sdcard/my_file.txt') self.assertIn('timeout', kwargs) self.assertIsNone(kwargs['timeout']) def test_input_text_no_text(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest(input_text=adb_pb2.AdbRequest.InputText()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() def test_input_text_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( input_text=adb_pb2.AdbRequest.InputText( text='The Greatest Text of All Time')) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'input', 'text', 'The Greatest Text of All Time'], None) @parameterized.named_parameters( ('negative_x_and_negative_y', adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=-1))), ('negative_x', adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=123))), ('negative_y', adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=456, y=-1))), ) def test_tap_failed(self, request: adb_pb2.AdbRequest): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() def test_tap_successful(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=135, y=246)) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'input', 'tap', '135', '246'], None) @parameterized.named_parameters( ('empty_request', adb_pb2.AdbRequest.SettingsRequest()), ('no_namespace', adb_pb2.AdbRequest.SettingsRequest( get=adb_pb2.AdbRequest.SettingsRequest.Get(key='my_key'))), ('get_no_key', adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, get=adb_pb2.AdbRequest.SettingsRequest.Get())), ('put_no_key', adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, put=adb_pb2.AdbRequest.SettingsRequest.Put())), ('put_no_value', adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, put=adb_pb2.AdbRequest.SettingsRequest.Put(key='another_key'))), ('delete_no_key', adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete())), ('reset_no_package_name_and_no_mode', adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, reset=adb_pb2.AdbRequest.SettingsRequest.Reset())), ) def test_settings_failures(self, request): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest(settings=request) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() def test_settings_success_get(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'here it is!' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, get=adb_pb2.AdbRequest.SettingsRequest.Get(key='some_key')) request = adb_pb2.AdbRequest(settings=request) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) self.assertEqual(response.settings.output, b'here it is!') adb.execute_command.assert_called_once_with( ['shell', 'settings', 'get', 'system', 'some_key'], None) def test_settings_success_put(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'Done for ya!' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE, put=adb_pb2.AdbRequest.SettingsRequest.Put(key='key1', value='val2')) request = adb_pb2.AdbRequest(settings=request) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) self.assertEqual(response.settings.output, b'Done for ya!') adb.execute_command.assert_called_once_with( ['shell', 'settings', 'put', 'secure', 'key1', 'val2'], None) def test_settings_success_delete(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'Key deleted.' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete(key='useless_key')) request = adb_pb2.AdbRequest(settings=request) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) self.assertEqual(response.settings.output, b'Key deleted.') adb.execute_command.assert_called_once_with( ['shell', 'settings', 'delete', 'global', 'useless_key'], None) @parameterized.named_parameters( ('mode_untrusted_defaults', adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS, '', 'untrusted_defaults'), ('mode_untrusted_clear', adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR, '', 'untrusted_clear'), ('mode_trusted_defaults', adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS, '', 'trusted_defaults'), # If `package_name` is given, it takes precedence over `mode`. ('mode_unknown_package_given', adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN, 'great.package', 'great.package'), ('mode_untrusted_defaults_package_given', adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS, 'great.package', 'great.package'), ('mode_untrusted_clear_package_given', adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR, 'great.package', 'great.package'), ('mode_trusted_defaults_package_given', adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS, 'great.package', 'great.package'), ) def test_settings_success_reset(self, mode, package_name, expected_arg): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'Pkg reset.' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, reset=adb_pb2.AdbRequest.SettingsRequest.Reset( package_name=package_name, mode=mode)) request = adb_pb2.AdbRequest(settings=request) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) self.assertEqual(response.settings.output, b'Pkg reset.') adb.execute_command.assert_called_once_with( ['shell', 'settings', 'reset', 'global', expected_arg], None) def test_settings_success_list(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'volume_ring=5\nvolume_system=7' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, list=adb_pb2.AdbRequest.SettingsRequest.List()) request = adb_pb2.AdbRequest(settings=request) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) self.assertEqual(response.settings.output, b'volume_ring=5\nvolume_system=7') adb.execute_command.assert_called_once_with( ['shell', 'settings', 'list', 'system'], None) def test_generic_command(self): adb = mock.create_autospec(adb_controller.AdbController) expected_output = b'generic_output' args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action'] adb.execute_command.return_value = expected_output parser = adb_call_parser.AdbCallParser(adb) generic_request = adb_pb2.AdbRequest.GenericRequest(args=args) request = adb_pb2.AdbRequest(generic=generic_request) response = parser.parse(request) self.assertEqual(adb_pb2.AdbResponse.Status.OK, response.status) self.assertEmpty(response.error_message) self.assertEqual(response.generic.output, expected_output) adb.execute_command.assert_called_once_with(args, None) def test_generic_command_adb_error(self): adb = mock.create_autospec(adb_controller.AdbController) args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action'] adb.execute_command.side_effect = subprocess.CalledProcessError( cmd='cmd', output='adb_error', returncode=-1) parser = adb_call_parser.AdbCallParser(adb) generic_request = adb_pb2.AdbRequest.GenericRequest(args=args) request = adb_pb2.AdbRequest(generic=generic_request) response = parser.parse(request) self.assertEqual(adb_pb2.AdbResponse.Status.ADB_ERROR, response.status) self.assertEqual('adb_error', response.error_message) self.assertEmpty(response.generic.output) adb.execute_command.assert_called_once_with(args, None) def test_generic_command_timeout(self): adb = mock.create_autospec(adb_controller.AdbController) args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action'] adb.execute_command.side_effect = subprocess.TimeoutExpired( cmd='cmd', timeout=10) parser = adb_call_parser.AdbCallParser(adb) generic_request = adb_pb2.AdbRequest.GenericRequest(args=args) request = adb_pb2.AdbRequest(generic=generic_request) response = parser.parse(request) self.assertEqual(adb_pb2.AdbResponse.Status.TIMEOUT, response.status) self.assertEqual('Timeout', response.error_message) self.assertEmpty(response.generic.output) adb.execute_command.assert_called_once_with(args, None) @parameterized.named_parameters( ('features', adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( features=adb_pb2.AdbRequest.PackageManagerRequest.List .Features())))), ('libraries', adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( libraries=adb_pb2.AdbRequest.PackageManagerRequest.List .Libraries())))), ('packages', adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( packages=adb_pb2.AdbRequest.PackageManagerRequest.List .Packages())))), ) def test_package_manager_list_bad_output(self, request): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b"""Something irrelevant.""" parser = adb_call_parser.AdbCallParser(adb) response = parser.parse(request) response.package_manager.output = b"""Something irrelevant.""" self.assertEmpty(response.package_manager.list.items) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once() def test_package_manager_list_features(self): adb = mock.create_autospec(adb_controller.AdbController) output = b""" feature:android.hardware.audio.output feature:android.hardware.bluetooth feature:android.hardware.camera feature:android.hardware.fingerprint feature:android.software.autofill feature:android.software.backup feature:android.software.webview """ adb.execute_command.return_value = output parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( features=adb_pb2.AdbRequest.PackageManagerRequest.List.Features( )))) response = parser.parse(request) self.assertEqual(response.package_manager.output, output) self.assertEqual(response.package_manager.list.items, [ 'android.hardware.audio.output', 'android.hardware.bluetooth', 'android.hardware.camera', 'android.hardware.fingerprint', 'android.software.autofill', 'android.software.backup', 'android.software.webview', ]) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'pm', 'list', 'features'], None) def test_package_manager_list_libraries(self): adb = mock.create_autospec(adb_controller.AdbController) output = b""" library:android.ext.shared library:android.hidl.base-V1.0-java library:android.hidl.manager-V1.0-java library:android.net.ipsec.ike library:android.test.base library:android.test.mock library:android.test.runner library:androidx.window.sidecar library:com.android.future.usb.accessory library:com.android.location.provider library:com.android.media.remotedisplay library:com.android.mediadrm.signer library:com.android.nfc_extras library:com.google.android.gms library:com.google.android.trichromelibrary library:javax.obex library:org.apache.http.legacy """ adb.execute_command.return_value = output parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( libraries=adb_pb2.AdbRequest.PackageManagerRequest.List .Libraries()))) response = parser.parse(request) self.assertEqual(response.package_manager.output, output) self.assertEqual(response.package_manager.list.items, [ 'android.ext.shared', 'android.hidl.base-V1.0-java', 'android.hidl.manager-V1.0-java', 'android.net.ipsec.ike', 'android.test.base', 'android.test.mock', 'android.test.runner', 'androidx.window.sidecar', 'com.android.future.usb.accessory', 'com.android.location.provider', 'com.android.media.remotedisplay', 'com.android.mediadrm.signer', 'com.android.nfc_extras', 'com.google.android.gms', 'com.google.android.trichromelibrary', 'javax.obex', 'org.apache.http.legacy', ]) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'pm', 'list', 'libraries'], None) def test_package_manager_list_packages(self): adb = mock.create_autospec(adb_controller.AdbController) output = b""" package:com.android.phone package:com.awesome.company package:com.another.great.thingie """ adb.execute_command.return_value = output parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages( )))) response = parser.parse(request) self.assertEqual(response.package_manager.output, output) self.assertEqual(response.package_manager.list.items, [ 'com.android.phone', 'com.awesome.company', 'com.another.great.thingie', ]) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'pm', 'list', 'packages'], None) def test_package_manager_clear_no_package_name(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b"""Something irrelevant.""" parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear( package_name=''))) response = parser.parse(request) self.assertEmpty(response.package_manager.output) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() def test_package_manager_clear_successful_no_user_id(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b"""Some successful message.""" parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear( package_name='my.package'))) response = parser.parse(request) self.assertEqual(response.package_manager.output, b"""Some successful message.""") self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'pm', 'clear', 'my.package'], None) def test_package_manager_clear_successful_with_user_id(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b"""Some successful message.""" parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear( package_name='my.package', user_id='mrawesome'))) response = parser.parse(request) self.assertEqual(response.package_manager.output, b"""Some successful message.""") self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'pm', 'clear', '-f', 'mrawesome', 'my.package'], None) def test_dumpsys_empty_request(self): """An empty `DumpsysRequest` is a valid request.""" adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest(dumpsys=adb_pb2.AdbRequest.DumpsysRequest()) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with(['shell', 'dumpsys'], timeout=None) @parameterized.named_parameters( ('negative_timeout_sec', adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_sec=-1))), ('negative_timeout_ms', adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_ms=-2))), ) def test_dumpsys_negative_timeouts(self, request): """`DumpsysRequest.timeout_{sec, ms}` if passed, should be positive.""" adb = mock.create_autospec(adb_controller.AdbController) parser = adb_call_parser.AdbCallParser(adb) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() @parameterized.named_parameters( ('both_timeouts_zero', 0, 0, ['shell', 'dumpsys']), ('sec_takes_precedence_zero', 123, 0, ['shell', 'dumpsys', '-t', '123']), ('sec_takes_precedence', 123, 456, ['shell', 'dumpsys', '-t', '123']), ('ms_if_no_sec', 0, 456, ['shell', 'dumpsys', '-T', '456']), ) def test_dumpsys_timeout_successful(self, timeout_sec, timeout_ms, expected): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest( timeout_sec=timeout_sec, timeout_ms=timeout_ms)) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with(expected, timeout=None) @parameterized.named_parameters( ('priority_undefined', adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET, ['shell', 'dumpsys']), ('priority_normal', adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.NORMAL, ['shell', 'dumpsys', '--priority', 'NORMAL']), ('priority_high', adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.HIGH, ['shell', 'dumpsys', '--priority', 'HIGH']), ('priority_critical', adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.CRITICAL, ['shell', 'dumpsys', '--priority', 'CRITICAL']), ) def test_dumpsys_priority_timeout_successful(self, priority, expected): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest(priority=priority)) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with(expected, timeout=None) @parameterized.named_parameters( ( 'window_service', adb_pb2.AdbRequest.DumpsysRequest(list_only=True, service='window'), ), ( 'arbitrary_args', adb_pb2.AdbRequest.DumpsysRequest( list_only=True, args=['myoption', 'anotheroption'] ), ), ( 'skip_usb', adb_pb2.AdbRequest.DumpsysRequest( list_only=True, skip_services=['usb'] ), ), ) def test_dumpsys_list_only_cannot_be_combined( self, dumpsys_request: adb_pb2.AdbRequest.DumpsysRequest ): """When `list_only==True`, the request cannot contain a few fields.""" # Arrange. adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest(dumpsys=dumpsys_request) # Act. response = parser.parse(request) # Assert. self.assertEqual( response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION ) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() def test_dumpsys_list_only_success(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest(list_only=True)) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with(['shell', 'dumpsys', '-l'], timeout=None) def test_dumpsys_skip_services_cannot_combine_with_service(self): """When using `DumpsysRequest.skip_service`, it cannot contain `.service`.""" adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest( service='wifi', skip_services=['window', 'usb'])) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION) self.assertNotEmpty(response.error_message) adb.execute_command.assert_not_called() def test_dumpsys_skip_services(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest( skip_services=['window', 'usb'])) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'dumpsys', '--skip', 'window,usb'], timeout=None) def test_dumpsys_single_service(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window')) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'window'], timeout=None) def test_dumpsys_single_service_with_args(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'whatever' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest( service='window', args=['arg1', 'arg2'])) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'dumpsys', 'window', 'arg1', 'arg2'], timeout=None) def test_dumpsys_single_service_with_proto(self): adb = mock.create_autospec(adb_controller.AdbController) adb.execute_command.return_value = b'some binary output' parser = adb_call_parser.AdbCallParser(adb) request = adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window', proto=True)) response = parser.parse(request) self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK) self.assertEmpty(response.error_message) adb.execute_command.assert_called_once_with( ['shell', 'dumpsys', 'window', '--proto'], timeout=None) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/adb_controller.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A class to manage and control an external ADB process.""" import os import subprocess import time from absl import logging from android_env.components import config_classes from android_env.components import errors class AdbController: """Manages communication with adb.""" def __init__(self, config: config_classes.AdbControllerConfig): """Instantiates an AdbController object.""" self._config = config logging.info('config: %r', self._config) if not self._config.use_adb_server_port_from_os_env: # Unset problematic environment variables. ADB commands will fail if these # are set. They are normally exported by AndroidStudio. if 'ANDROID_HOME' in os.environ: logging.info('Removing ANDROID_HOME from os.environ') del os.environ['ANDROID_HOME'] if 'ANDROID_ADB_SERVER_PORT' in os.environ: logging.info('Removing ANDROID_ADB_SERVER_PORT from os.environ') del os.environ['ANDROID_ADB_SERVER_PORT'] # Explicitly expand the $HOME environment variable. self._os_env_vars = dict(os.environ).copy() self._os_env_vars.update( {'HOME': os.path.expandvars(self._os_env_vars.get('HOME', ''))} ) logging.info('self._os_env_vars: %r', self._os_env_vars) def command_prefix(self, include_device_name: bool = True) -> list[str]: """The command for instantiating an adb client to this server.""" if self._config.use_adb_server_port_from_os_env: # When using the adb server port set from the OS environment, we don't # need to pass the port explicitly. adb_port_args = [] else: # When using the adb server port set from the config, we need to pass the # port explicitly. adb_port_args = ['-P', str(self._config.adb_server_port)] command_prefix = [ self._config.adb_path, *adb_port_args, ] if include_device_name: command_prefix.extend(['-s', self._config.device_name]) return command_prefix def init_server(self, timeout: float | None = None): """Initialize the ADB server deamon on the given port. This function should be called immediately after initializing the first adb_controller, and before launching the simulator. Args: timeout: A timeout to use for this operation. If not set the default timeout set on the constructor will be used. """ # Make an initial device-independent call to ADB to start the deamon. self.execute_command(['devices'], timeout, device_specific=False) time.sleep(0.2) def _restart_server(self, timeout: float | None = None): """Kills and restarts the adb server. Args: timeout: A timeout to use for this operation. If not set the default timeout set on the constructor will be used. """ logging.info('Restarting adb server.') self.execute_command( ['kill-server'], timeout=timeout, device_specific=False ) time.sleep(0.2) cmd_output = self.execute_command( ['start-server'], timeout=timeout, device_specific=False ) logging.info('start-server output: %r', cmd_output.decode('utf-8')) time.sleep(2.0) self.execute_command(['devices'], timeout=timeout, device_specific=False) time.sleep(0.2) def execute_command( self, args: list[str], timeout: float | None = None, device_specific: bool = True, ) -> bytes: """Executes an adb command. Args: args: A list of strings representing each adb argument. For example: ['install', '/my/app.apk'] timeout: A timeout to use for this operation. If not set the default timeout set on the constructor will be used. device_specific: Whether the call is device-specific or independent. Returns: The output of running such command as a binary string. """ timeout = self._config.default_timeout if timeout is None else timeout command = self.command_prefix(include_device_name=device_specific) + args command_str = 'adb ' + ' '.join(command[1:]) n_tries = 2 latest_error = None for i in range(n_tries): try: logging.info('Executing ADB command: [%s]', command_str) cmd_output = subprocess.check_output( command, stderr=subprocess.STDOUT, timeout=timeout, env=self._os_env_vars, ) logging.debug('ADB command output: %s', cmd_output) return cmd_output except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e: logging.exception( 'Failed to execute ADB command (try %d of %d): [%s]', i + 1, n_tries, command_str, ) if e.stdout is not None: logging.error('**stdout**:') for line in e.stdout.splitlines(): logging.error(' %s', line) if e.stderr is not None: logging.error('**stderr**:') for line in e.stderr.splitlines(): logging.error(' %s', line) latest_error = e if device_specific and i < n_tries - 1: self._restart_server(timeout=timeout) raise errors.AdbControllerError( f'Error executing adb command: [{command_str}]\n' f'Caused by: {latest_error}' ) from latest_error ================================================ FILE: android_env/components/adb_controller_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import subprocess import time from unittest import mock from absl.testing import absltest from android_env.components import adb_controller as adb_controller_lib from android_env.components import config_classes from android_env.components import errors # Timeout to be used by default in tests below. Set to a small value to avoid # hanging on a failed test. _TIMEOUT = 2 class AdbControllerTest(absltest.TestCase): def setUp(self): super().setUp() # Set env vars. os.environ['MY_ENV_VAR'] = '/some/path/' os.environ['HOME'] = '$MY_ENV_VAR' self._env_before = os.environ.copy() def tearDown(self): super().tearDown() if 'ANDROID_HOME' in os.environ: del os.environ['ANDROID_HOME'] if 'ANDROID_ADB_SERVER_PORT' in os.environ: del os.environ['ANDROID_ADB_SERVER_PORT'] @mock.patch.object(subprocess, 'check_output', autospec=True) @mock.patch.object(time, 'sleep', autospec=True) def test_init_server(self, mock_sleep, mock_check_output): """We expect an `adb devices` call when initializing the server.""" # Arrange. adb_controller = adb_controller_lib.AdbController( config_classes.AdbControllerConfig( adb_path='my_adb', device_name='awesome_device', use_adb_server_port_from_os_env=True, ) ) # Act. adb_controller.init_server(timeout=_TIMEOUT) # Assert. expected_env = self._env_before expected_env['HOME'] = '/some/path/' mock_check_output.assert_called_once_with( ['my_adb', 'devices'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ) mock_sleep.assert_called_once() @mock.patch.object(subprocess, 'check_output', autospec=True) @mock.patch.object(time, 'sleep', autospec=True) def test_init_server_with_adb_server_port_from_os_env( self, mock_sleep, mock_check_output ): """Us OS env vars if `use_adb_server_port_from_os_env` is True.""" # Arrange. # Set the ADB server port to 1234 in the OS environment. os.environ['ANDROID_ADB_SERVER_PORT'] = '1234' os.environ['ANDROID_HOME'] = '/some/path/to/android' adb_controller = adb_controller_lib.AdbController( config_classes.AdbControllerConfig( adb_path='my_adb', device_name='awesome_device', adb_server_port=9999, use_adb_server_port_from_os_env=True, ) ) # Act. adb_controller.init_server(timeout=_TIMEOUT) # Assert. expected_env = self._env_before expected_env['HOME'] = '/some/path/' expected_env['ANDROID_HOME'] = '/some/path/to/android' expected_env['ANDROID_ADB_SERVER_PORT'] = '1234' mock_check_output.assert_called_once_with( ['my_adb', 'devices'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ) mock_sleep.assert_called_once() @mock.patch.object(subprocess, 'check_output', autospec=True) @mock.patch.object(time, 'sleep', autospec=True) def test_restart_server(self, mock_sleep, mock_check_output): """When an adb command fails, we expect the server to be restarted.""" # Arrange. mock_check_output.side_effect = [ subprocess.CalledProcessError(returncode=1, cmd='blah'), ] + ['fake_output'.encode('utf-8')] * 4 adb_controller = adb_controller_lib.AdbController( config_classes.AdbControllerConfig( adb_path='my_adb', device_name='awesome_device', use_adb_server_port_from_os_env=True, ) ) # Act. adb_controller.execute_command(['my_command'], timeout=_TIMEOUT) # Assert. expected_env = self._env_before expected_env['HOME'] = '/some/path/' mock_check_output.assert_has_calls([ mock.call( ['my_adb', '-s', 'awesome_device', 'my_command'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', 'kill-server'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', 'start-server'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', 'devices'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', '-s', 'awesome_device', 'my_command'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), ]) mock_sleep.assert_has_calls( [mock.call(0.2), mock.call(2.0), mock.call(0.2)] ) @mock.patch.object(subprocess, 'check_output', autospec=True) @mock.patch.object(time, 'sleep', autospec=True) def test_invalid_command(self, mock_sleep, mock_check_output): """Restart the server when given an invalid command.""" # Arrange. restart_sequence = ['fake_output'.encode('utf-8')] * 3 mock_check_output.side_effect = ( [ subprocess.CalledProcessError(returncode=1, cmd='blah'), ] + restart_sequence + [subprocess.CalledProcessError(returncode=1, cmd='blah')] # Don't restart if last call fails. ) adb_controller = adb_controller_lib.AdbController( config_classes.AdbControllerConfig( adb_path='my_adb', device_name='awesome_device', use_adb_server_port_from_os_env=True, ) ) # Act. with self.assertRaises(errors.AdbControllerError): adb_controller.execute_command(['my_command'], timeout=_TIMEOUT) # Assert. expected_env = self._env_before expected_env['HOME'] = '/some/path/' mock_check_output.assert_has_calls( [ mock.call( ['my_adb', '-s', 'awesome_device', 'my_command'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', 'kill-server'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', 'start-server'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', 'devices'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), mock.call( ['my_adb', '-s', 'awesome_device', 'my_command'], stderr=subprocess.STDOUT, timeout=_TIMEOUT, env=expected_env, ), ], any_order=False, ) mock_sleep.assert_has_calls( [mock.call(0.2), mock.call(2.0), mock.call(0.2)] ) @mock.patch.object(subprocess, 'check_output', autospec=True) @mock.patch.object(time, 'sleep', autospec=True) def test_avoid_infinite_recursion(self, mock_sleep, mock_check_output): """Raise an error if the command fails even after restarts.""" del mock_sleep mock_check_output.side_effect = subprocess.CalledProcessError( returncode=1, cmd='blah' ) adb_controller = adb_controller_lib.AdbController( config_classes.AdbControllerConfig( adb_path='my_adb', device_name='awesome_device', use_adb_server_port_from_os_env=True, ) ) self.assertRaises( errors.AdbControllerError, adb_controller.execute_command, ['my_command'], timeout=_TIMEOUT, ) class AdbControllerInitTest(absltest.TestCase): def test_deletes_problem_env_vars(self): os.environ['ANDROID_HOME'] = '/usr/local/Android/Sdk' os.environ['ANDROID_ADB_SERVER_PORT'] = '1337' adb_controller_lib.AdbController( config_classes.AdbControllerConfig( adb_path='my_adb', device_name='awesome_device', adb_server_port=9999, default_timeout=_TIMEOUT, ) ) self.assertNotIn('ANDROID_HOME', os.environ) self.assertNotIn('ANDROID_ADB_SERVER_PORT', os.environ) def test_use_adb_server_port_from_os_env_retains_os_env_vars(self): os.environ['ANDROID_HOME'] = '/usr/local/Android/Sdk' os.environ['ANDROID_ADB_SERVER_PORT'] = '1337' adb_controller_lib.AdbController( config_classes.AdbControllerConfig( adb_path='my_adb', device_name='awesome_device', adb_server_port=9999, default_timeout=_TIMEOUT, use_adb_server_port_from_os_env=True, ) ) self.assertIn('ANDROID_ADB_SERVER_PORT', os.environ) self.assertEqual(os.environ['ANDROID_ADB_SERVER_PORT'], '1337') self.assertIn('ANDROID_HOME', os.environ) self.assertEqual(os.environ['ANDROID_HOME'], '/usr/local/Android/Sdk') if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/adb_log_stream.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Class for a stream of logs output by a locally running emulator.""" import subprocess from absl import logging from android_env.components import log_stream _LOGCAT_COMMAND = ['logcat', '-v', 'epoch'] class AdbLogStream(log_stream.LogStream): """Manages adb logcat process for a locally running emulator.""" def __init__(self, adb_command_prefix: list[str], verbose: bool = False): super().__init__(verbose=verbose) self._adb_command_prefix = adb_command_prefix def _get_stream_output(self): # Before spawning a long-lived process, we issue `logcat -b all -c` to clear # all buffers to avoid interference from previous runs. clear_buffer_output = subprocess.check_output( self._adb_command_prefix + ['logcat', '-b', 'all', '-c'], stderr=subprocess.STDOUT, timeout=100) logging.info('clear_buffer_output: %r', clear_buffer_output) cmd = self._adb_command_prefix + _LOGCAT_COMMAND + self._filters self._adb_subprocess = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, bufsize=1, universal_newlines=True) return self._adb_subprocess.stdout def stop_stream(self): if not hasattr(self, '_adb_subprocess') or self._adb_subprocess is None: logging.error('`stop_stream()` called before `get_stream_output()`. ' 'This violates the `LogStream` API.') else: self._adb_subprocess.kill() ================================================ FILE: android_env/components/adb_log_stream_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for adb_log_stream.""" import subprocess from unittest import mock from absl.testing import absltest from android_env.components import adb_log_stream class FakeAdbSubprocess: @property def stdout(self): return [f'line_{i}' for i in range(100)] def kill(self): pass class AdbLogStreamTest(absltest.TestCase): @mock.patch.object(subprocess, 'check_output', return_value=b'') @mock.patch.object(subprocess, 'Popen', return_value=FakeAdbSubprocess()) def test_get_stream_output(self, mock_popen, unused_mock_check_output): stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo']) stream.set_log_filters(['bar']) stream_output = stream.get_stream_output() for i, line in enumerate(stream_output): self.assertEqual(line, f'line_{i}') mock_popen.assert_called_with( ['foo', 'logcat', '-v', 'epoch', 'bar', '*:S'], stderr=subprocess.STDOUT, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True) def test_stop_stream_before_get_stream_output(self): """Calling `stop_stream()` before `get_stream_output()` should not crash.""" # Arrange. stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo']) # Act. stream.stop_stream() # Assert. # Nothing to assert. The test should just finish without raising an # exception. if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/app_screen_checker.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Determines if the current app screen matches an expected app screen.""" from collections.abc import Callable, Sequence import enum import re import time from typing import Self from absl import logging from android_env.components import adb_call_parser as adb_call_parser_lib from android_env.components import errors from android_env.proto import adb_pb2 from android_env.proto import task_pb2 class _DumpsysNode: """A node in a dumpsys tree.""" def __init__(self, data: str): self._children = [] self._data = data @property def data(self) -> str: return self._data @property def children(self) -> list[Self]: return self._children def find_child( self, predicate: Callable[[Self], bool], max_levels: int = 0 ) -> Self | None: """Returns the first direct child that matches `predicate`, None otherwise. Args: predicate: Function-like that accepts a _DumpsysNode and returns boolean. max_levels: Maximum number of levels down the tree to search for a child. If non-positive, only direct children will be searched for. Returns: A _DumpsysNode or None. """ if not self.children: return None try: return next(x for x in self.children if predicate(x)) except StopIteration: logging.info('Failed to find child. max_levels: %i.', max_levels) # Search children. if max_levels: for child in self.children: child_result = child.find_child(predicate, max_levels - 1) if child_result is not None: return child_result return None def __repr__(self): return self._data def print_tree(self, indent: int = 2): """Prints this tree in logging.info().""" logging.info(' ' * indent + self.data) for c in self.children: c.print_tree(indent + 2) def build_tree_from_dumpsys_output(dumpsys_output: str) -> _DumpsysNode: """Constructs a tree from a dumpsys string output. Args: dumpsys_output: string Verbatim output from adb dumpsys. The expected format is a list where each line is a node and the indentation marks the relationship with its parent or sibling. Returns: _DumpsysNode The root of the tree. """ lines = dumpsys_output.split('\n') # Split by lines. lines = [x.rstrip(' \r') for x in lines] lines = [x for x in lines if len(x)] # Remove empty lines. root = _DumpsysNode('___root___') # The root of all nodes. parents_stack = [root] for line in lines: stripped_line = line.lstrip(' ') indent = len(line) - len(stripped_line) # Number of indent spaces. new_node = _DumpsysNode(stripped_line) # Create a node without indentation. parent = parents_stack.pop() if parent.data == '___root___': # The root is an exception for indentation. parent_indent = -2 else: parent_indent = (len(parents_stack) - 1) * 2 if indent == parent_indent: # `new_node` is a sibiling. parent = parents_stack.pop() elif indent < parent_indent: # Indentation reduced (i.e. a block finished) num_levels = (indent // 2) + 1 parents_stack = parents_stack[:num_levels] parent = parents_stack.pop() elif indent > parent_indent: # `new_node` is a child. pass # No need to change the current parent. parent.children.append(new_node) parents_stack.append(parent) parents_stack.append(new_node) return root def matches_path( dumpsys_activity_output: str, expected_view_hierarchy_path: Sequence[re.Pattern[str]], max_levels: int = 0, ) -> bool: """Returns True if the current dumpsys output matches the expected path. Args: dumpsys_activity_output: The output of running `dumpsys activity ...`. expected_view_hierarchy_path: [regex] A list of regular expressions to be tested at each level of the tree. max_levels: How many levels to search from root for View Hierarchy. Returns: True if the dumpsys tree contains one path that matches all regexes. """ root = build_tree_from_dumpsys_output(dumpsys_activity_output) # Find the View Hierarchy. view_hierarchy = root.find_child( lambda x: x.data.startswith('View Hierarchy'), max_levels) if view_hierarchy is None: logging.error( 'view_hierarchy is None. Dumpsys activity output: %s. tree: %r', str(dumpsys_activity_output), root.print_tree()) logging.error('Tree root: %s', str(root)) return False current_node = view_hierarchy for i, regex in enumerate(expected_view_hierarchy_path): def regex_predicate(node, expr=regex): matches = expr.match(node.data) return matches is not None child = current_node.find_child(regex_predicate) if child is None: logging.error('Mismatched regex (%i, %s). current_node: %s', i, regex.pattern, current_node) logging.error('Dumpsys activity output: %s', str(dumpsys_activity_output)) logging.error('Tree root: %s', str(root)) return False else: current_node = child return True class AppScreenChecker: """Checks that the current app screen matches an expected screen.""" class Outcome(enum.IntEnum): """Possible return vales from checking the current app screen.""" # The current app screen matches the expected app screen. SUCCESS = 0 # There's no activity to check. EMPTY_EXPECTED_ACTIVITY = 1 # We were unable to determine the current activity. FAILED_ACTIVITY_EXTRACTION = 2 # The current activity does not match the expected activity. UNEXPECTED_ACTIVITY = 3 # The current view hierarchy does not match the expected view hierarchy. UNEXPECTED_VIEW_HIERARCHY = 4 def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser, expected_app_screen: task_pb2.AppScreen): self._adb_call_parser = adb_call_parser self._expected_app_screen = expected_app_screen self._expected_activity = expected_app_screen.activity self._expected_view_hierarchy_path = [ re.compile(regex) for regex in expected_app_screen.view_hierarchy_path ] # Return type is AppScreenChecker.Outcome, but pytype doesn't understand that. def matches_current_app_screen(self) -> enum.IntEnum: """Determines whether the current app screen matches `expected_app_screen`.""" if not self._expected_activity: return AppScreenChecker.Outcome.EMPTY_EXPECTED_ACTIVITY # Check if we are still on the expected Activity. response = self._adb_call_parser.parse( adb_pb2.AdbRequest( get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())) if response.status != adb_pb2.AdbResponse.OK: return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION current_activity = response.get_current_activity.full_activity if current_activity != self._expected_activity: logging.error('current_activity: %s, expected_activity: %s', current_activity, self._expected_activity) return AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY # Extract just the package name from the full activity name. package_name = self._expected_activity.split('/')[0] # Check if we are in the expected view hierarchy path. if self._expected_view_hierarchy_path: dumpsys_response = self._adb_call_parser.parse( adb_pb2.AdbRequest( dumpsys=adb_pb2.AdbRequest.DumpsysRequest( service='activity', args=[package_name, package_name]))) if dumpsys_response.status != adb_pb2.AdbResponse.OK: return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION if dumpsys_response.dumpsys.output: if not matches_path( dumpsys_response.dumpsys.output.decode('utf-8'), self._expected_view_hierarchy_path, max_levels=3): return AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY return AppScreenChecker.Outcome.SUCCESS def wait_for_app_screen(self, timeout_sec: float) -> float: """Waits for `self._expected_app_screen` to be the current screen. Args: timeout_sec: Maximum total time to wait for the screen to pop up. Returns: The total amount of time in seconds spent waiting for the screen to pop up. Raises: errors.WaitForAppScreenError if the screen does not pop up within `timeout_sec`. """ logging.info('Waiting for app screen...') start_time = time.time() while time.time() - start_time < timeout_sec: if self.matches_current_app_screen() == AppScreenChecker.Outcome.SUCCESS: wait_time = time.time() - start_time logging.info('Successfully waited for app screen in %r seconds: [%r]', wait_time, self._expected_app_screen) return wait_time time.sleep(0.1) wait_time = time.time() - start_time logging.error('Failed to wait for app screen in %r seconds: [%r].', wait_time, self._expected_app_screen) raise errors.WaitForAppScreenError() ================================================ FILE: android_env/components/app_screen_checker_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.app_screen_checker.""" import re from unittest import mock from absl.testing import absltest from android_env.components import adb_call_parser from android_env.components import app_screen_checker from android_env.components import errors from android_env.proto import adb_pb2 from android_env.proto import task_pb2 def _flatten_tree( tree: app_screen_checker._DumpsysNode, flat_tree: list[str], indent: int = 2 ): """Appends a list of strings to `flat_tree` from `tree`.""" flat_tree.append(' ' * indent + tree.data) for c in tree.children: _flatten_tree(c, flat_tree, indent + 2) class AppScreenCheckerTest(absltest.TestCase): # Ensures that build_tree_from_dumpsys_output produces a node whose flat # representation matches our expectation from an arbitrary hierarchy. def test_build_tree_from_dumpsys_output(self): dumpsys_output = """ Queen Elizabeth II Charles William George Charlotte Louis Harry Archie Anne Peter Savannah Isla Zara Mia Lena Andrew Beatrice Eugenie Edward Louise James """ tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output) flat_tree = [] _flatten_tree(tree, flat_tree, indent=2) self.assertEqual(flat_tree, [ ' ___root___', ' Queen Elizabeth II', ' Charles', ' William', ' George', ' Charlotte', ' Louis', ' Harry', ' Archie', ' Anne', ' Peter', ' Savannah', ' Isla', ' Zara', ' Mia', ' Lena', ' Andrew', ' Beatrice', ' Eugenie', ' Edward', ' Louise', ' James', ]) # Ensures that build_tree_from_dumpsys_output produces a node whose flat # representation matches our expectation from an arbitrary hierarchy. def test_build_forest_from_dumpsys_output(self): dumpsys_output = """ Tree1 Branch1 Leaf1 Leaf2 Branch2 Leaf3 Leaf4 Leaf5 Tree2 Branch3 Leaf6 Leaf7 Branch4 Leaf8 Leaf9 Leaf10 Leaf11 """ tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output) flat_tree = [] _flatten_tree(tree, flat_tree, indent=2) self.assertEqual(flat_tree, [ ' ___root___', ' Tree1', ' Branch1', ' Leaf1', ' Leaf2', ' Branch2', ' Leaf3', ' Leaf4', ' Leaf5', ' Tree2', ' Branch3', ' Leaf6', ' Leaf7', ' Branch4', ' Leaf8', ' Leaf9', ' Leaf10', ' Leaf11', ]) def test_no_view_hierarchy_matches_path(self): dumpsys_output = """ TASK ACTIVITY Missing View Hierarchy A B C D E F """ expected_path = ['^A$', 'B$'] expected_view_hierarchy_path = [ re.compile(regex) for regex in expected_path ] self.assertFalse( app_screen_checker.matches_path(dumpsys_output, expected_view_hierarchy_path)) def test_matches_path(self): dumpsys_output = """ TASK ACTIVITY Some node we don't care Blah View Hierarchy Hirohito Akihito Naruhito Aiko Fumihito Mako Kako Hisahito Masahito """ expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$'] expected_view_hierarchy_path = [ re.compile(regex) for regex in expected_path ] self.assertTrue( app_screen_checker.matches_path( dumpsys_output, expected_view_hierarchy_path, max_levels=2)) # Also check that the following path does not match anything in the tree. expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kenji$'] expected_view_hierarchy_path = [ re.compile(regex) for regex in expected_path ] self.assertFalse( app_screen_checker.matches_path(dumpsys_output, expected_view_hierarchy_path)) def test_matches_path_one_level_deep(self): dumpsys_output = """ TASK ACTIVITY Some node we don't care Blah Some intermediate node View Hierarchy Hirohito Akihito Naruhito Aiko Fumihito Mako Kako Hisahito Masahito """ expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$'] expected_view_hierarchy_path = [ re.compile(regex) for regex in expected_path ] self.assertTrue( app_screen_checker.matches_path( dumpsys_output, expected_view_hierarchy_path, max_levels=3)) # Also check that the view hierarchy is not found when searching only grand # children of TASK. expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$'] expected_view_hierarchy_path = [ re.compile(regex) for regex in expected_path ] self.assertFalse( app_screen_checker.matches_path( dumpsys_output, expected_view_hierarchy_path, max_levels=2)) def test_wait_for_app_screen_zero_timeout(self): """Ensures that an exception is raised if the timeout is passed.""" app_screen = task_pb2.AppScreen(activity='whatever.MyActivity') call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) screen_checker = app_screen_checker.AppScreenChecker( adb_call_parser=call_parser, expected_app_screen=app_screen) # With a zero timeout, the method should never be able to wait for the # screen to pop up and an exception should be raised. self.assertRaises( errors.WaitForAppScreenError, screen_checker.wait_for_app_screen, timeout_sec=0.0) def test_wait_for_app_screen_successful(self): """Ensures that with the right conditions, the app screen should pop up.""" app_screen = task_pb2.AppScreen(activity='my.favorite.AwesomeActivity') call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) call_parser.parse.return_value = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK, get_current_activity=adb_pb2.AdbResponse.GetCurrentActivityResponse( full_activity='my.favorite.AwesomeActivity')) screen_checker = app_screen_checker.AppScreenChecker( call_parser, app_screen) timeout = 1.0 wait_time = screen_checker.wait_for_app_screen(timeout_sec=timeout) # The call should not generate an exception and the return value should be # less than the timeout given. self.assertLess(wait_time, timeout) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/config_classes.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Dataclass definitions used for instantiating AndroidEnv components.""" import dataclasses import enum @dataclasses.dataclass class AdbControllerConfig: """Settings for instatiating an `AdbController` instance.""" # Filesystem path to the `adb` binary. # NOTE: This must be a full path and must not contain environment variables # or user folder shorthands (e.g. `~/some/path/to/adb`) since they will not be # expanded internally by AndroidEnv. adb_path: str = '~/Android/Sdk/platform-tools/adb' # Port for adb server. adb_server_port: int = 5037 # Default timeout in seconds for internal commands. default_timeout: float = 120.0 # Name of the device to communicate with. device_name: str = '' # Whether to use adb server port set in OS Environment variables. # When True, adb will use the ANDROID_ADB_SERVER_PORT OS environment variable # for selecting its server port if available (or the default 5037 if not set). # When False, the ANDROID_ADB_SERVER_PORT OS environment var will be unset # and adb_server_port will be used and supplied as an argument to all adb # commands. use_adb_server_port_from_os_env: bool = False @dataclasses.dataclass class DeviceSettingsConfig: """Config class for DeviceSettings.""" # Whether to show circles on the screen indicating touch position. show_touches: bool = True # Whether to show blue lines on the screen indicating touch position. show_pointer_location: bool = True # Whether or not to show the status (top) bar. show_status_bar: bool = False # Whether or not to show the navigation (bottom) bar. show_navigation_bar: bool = False @dataclasses.dataclass class CoordinatorConfig: """Config class for Coordinator.""" # Number of virtual "fingers" of the agent. num_fingers: int = 1 # Whether to enable keyboard key events. enable_key_events: bool = False # Time between periodic restarts in minutes. If > 0, will trigger # a simulator restart at the beginning of the next episode once the time has # been reached. periodic_restart_time_min: float = 0.0 # General Android settings. device_settings: DeviceSettingsConfig = dataclasses.field( default_factory=DeviceSettingsConfig ) @dataclasses.dataclass class SimulatorConfig: """Base class for all simulator configs.""" # If true, the log stream of the simulator will be verbose. verbose_logs: bool = False # How often to (asynchronously) grab the screenshot from the simulator. # If <= 0, stepping the environment blocks on fetching the screenshot (the # environment is synchronous). interaction_rate_sec: float = 0.0 @enum.unique class GPUMode(enum.Enum): """Emulator GPU Mode.""" HOST = 'host' SWANGLE_INDIRECT = 'swangle_indirect' SWIFTSHADER_INDIRECT = 'swiftshader_indirect' @dataclasses.dataclass class EmulatorLauncherConfig: """Config class for EmulatorLauncher.""" # NOTE: If `adb_port`, `emulator_console_port` and `grpc_port` are defined # (i.e. not all equal to 0), it is assumed that the emulator they point to # exists already and EmulatorLauncher will be skipped. # Filesystem path to the `emulator` binary. emulator_path: str = '~/Android/Sdk/emulator/emulator' # Filesystem path to the Android SDK root. android_sdk_root: str = '~/Android/Sdk' # Name of the AVD. avd_name: str = '' # Local directory for AVDs. android_avd_home: str = '~/.android/avd' # Name of the snapshot to load. snapshot_name: str = '' # Path to the KVM device. kvm_device: str = '/dev/kvm' # Path to directory which will hold temporary files. tmp_dir: str = '/tmp/android_env/simulator/' # GPU mode override. # Please see # https://developer.android.com/studio/run/emulator-acceleration#accel-graphics. gpu_mode: str = GPUMode.SWANGLE_INDIRECT.value # Whether to run in headless mode (i.e. without a graphical window). run_headless: bool = True # Whether to restrict network access. # If True, will disable networking on the device. This option is only # available for emulator version > 31.3.9 (June 2022). restrict_network: bool = False # Whether to set `SHOW_PERF_STATS=1` when launching the emulator to display # performance and memory statistics. show_perf_stats: bool = False # ADB port for the Android device. adb_port: int = 0 # Port for telnet communication with the emulator. emulator_console_port: int = 0 # Port for gRPC communication with the emulator. grpc_port: int = 0 @dataclasses.dataclass class EmulatorConfig(SimulatorConfig): """Config class for EmulatorSimulator.""" # Configuration for launching the Android Emulator. emulator_launcher: EmulatorLauncherConfig = dataclasses.field( default_factory=EmulatorLauncherConfig ) # Configuration for talking to adb. adb_controller: AdbControllerConfig = dataclasses.field( default_factory=AdbControllerConfig ) # Path to file which holds emulator logs. If not provided, it will be # determined by the EmulatorLauncher. logfile_path: str = '' # The number of times to try launching the emulator before rebooting (reboot # on the n+1-st try). launch_n_times_without_reboot: int = 1 # The number of times to try launching the emulator before reinstalling # (reinstall on the n+1-st try). launch_n_times_without_reinstall: int = 2 @dataclasses.dataclass class FakeSimulatorConfig(SimulatorConfig): """Config class for FakeSimulator.""" # The dimensions in pixels of the device screen (HxW). screen_dimensions: tuple[int, int] = (0, 0) @dataclasses.dataclass class TaskManagerConfig: """Config class for TaskManager.""" # If max_bad_states episodes finish in a bad state in a row, restart # the simulation. max_bad_states: int = 3 # The frequency to check for the current activity and view hierarchy. # The unit is raw observation (i.e. each call to AndroidEnv.step()). dumpsys_check_frequency: int = 150 # The maximum number of tries for extracting the current activity before # forcing the episode to restart. max_failed_current_activity: int = 10 # The maximum number of extras elements to store. If this number is exceeded, # elements are dropped in the order they were received. extras_max_buffer_size: int = 100 @dataclasses.dataclass class TaskConfig: """Base config class for loading tasks.""" # The directory for temporary task-related resources. tmp_dir: str = '' @dataclasses.dataclass class FilesystemTaskConfig(TaskConfig): """Config for protobuf files stored in the local filesystem.""" # Filesystem path to `.binarypb` or `.textproto` protobuf Task. path: str = '' @dataclasses.dataclass class AndroidEnvConfig: """Config class for AndroidEnv.""" # Configs for main components. task: TaskConfig = dataclasses.field(default_factory=TaskConfig) task_manager: TaskManagerConfig = dataclasses.field( default_factory=TaskManagerConfig ) coordinator: CoordinatorConfig = dataclasses.field( default_factory=CoordinatorConfig ) simulator: SimulatorConfig = dataclasses.field(default_factory=EmulatorConfig) ================================================ FILE: android_env/components/coordinator.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Coordinator handles interaction between internal components of AndroidEnv.""" import copy import time from typing import Any from absl import logging from android_env.components import action_fns from android_env.components import adb_call_parser from android_env.components import config_classes from android_env.components import device_settings as device_settings_lib from android_env.components import errors from android_env.components import specs from android_env.components import task_manager as task_manager_lib from android_env.components.simulators import base_simulator from android_env.proto import adb_pb2 import dm_env import numpy as np class Coordinator: """Handles interaction between internal components of AndroidEnv.""" def __init__( self, simulator: base_simulator.BaseSimulator, task_manager: task_manager_lib.TaskManager, device_settings: device_settings_lib.DeviceSettings, config: config_classes.CoordinatorConfig | None = None, ): """Handles communication between AndroidEnv and its components. Args: simulator: A BaseSimulator instance. task_manager: The TaskManager, responsible for coordinating RL tasks. config: Settings to customize this Coordinator. """ self._simulator = simulator self._task_manager = task_manager self._config = config or config_classes.CoordinatorConfig() self._device_settings = device_settings self._adb_call_parser: adb_call_parser.AdbCallParser = None # Initialize stats. self._stats = { 'relaunch_count': 0, 'relaunch_count_periodic': 0, 'relaunch_count_setup_steps': 0, 'relaunch_count_reset_steps': 0, 'relaunch_count_simulator_launch': 0, 'relaunch_count_simulator_reset': 0, 'relaunch_count_execute_action': 0, 'relaunch_count_fetch_observation': 0, 'relaunch_count_update_settings': 0, 'failed_task_updates': 0, } # Initialize counters. self._simulator_healthy = False self._latest_observation_time = 0 self._simulator_start_time = None logging.info('Starting the simulator...') self._launch_simulator() def action_spec(self) -> dict[str, dm_env.specs.Array]: return specs.base_action_spec( num_fingers=self._config.num_fingers, enable_key_events=self._config.enable_key_events, ) def observation_spec(self) -> dict[str, dm_env.specs.Array]: return specs.base_observation_spec( height=self._device_settings.screen_height(), width=self._device_settings.screen_width(), ) def _should_periodic_relaunch(self) -> bool: """Checks if it is time to restart the simulator. If a periodic restart time was specified, the Coordinator will re-launch the simulator at regular time intervals. This helps to make sure that the simulator is not in a stale state even if the environment has been running for a significant amount of time. Returns: Boolean indicating if it is time to restart the simulator. """ if self._config.periodic_restart_time_min and self._simulator_start_time: sim_alive_time = (time.time() - self._simulator_start_time) / 60.0 logging.info('Simulator has been running for %f mins', sim_alive_time) if sim_alive_time > self._config.periodic_restart_time_min: logging.info('Maximum alive time reached. Restarting simulator.') self._stats['relaunch_count_periodic'] += 1 return True return False def _launch_simulator(self, max_retries: int = 3): """Launches the simulator. Sets up the simulator and other task-related settings. Args: max_retries: Number of times to attempt a restart before raising an error. """ self._simulator_healthy = False # Attempt to restart the system a given number of times. num_tries = 1 latest_error = None while True: if num_tries > max_retries: raise errors.TooManyRestartsError( 'Maximum number of restart attempts reached.' ) from latest_error logging.info('Simulator launch attempt %d of %d', num_tries, max_retries) self._task_manager.stop() # Launch the simulator. self._simulator.launch() self._simulator_start_time = time.time() # From here on, the simulator is assumed to be up and running. self._adb_call_parser = self._create_adb_call_parser() try: self._device_settings.update(self._config.device_settings) except errors.AdbControllerError as e: logging.exception('device_settings.update() failed.') self._stats['relaunch_count_update_settings'] += 1 self._latest_error = e num_tries += 1 continue # Start the task. self._task_manager.start( adb_call_parser_factory=self._create_adb_call_parser, log_stream=self._simulator.create_log_stream(), ) try: self._task_manager.setup_task() except errors.StepCommandError as error: logging.exception('Failed to set up the task. Restarting simulator.') self._stats['relaunch_count_setup_steps'] += 1 latest_error = error num_tries += 1 continue # Restart was successful. self._simulator_healthy = True self._stats['relaunch_count'] += 1 break def _create_adb_call_parser(self): """Creates a new AdbCallParser instance.""" return adb_call_parser.AdbCallParser( adb_controller=self._simulator.create_adb_controller() ) def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: return self._adb_call_parser.parse(call) def rl_reset(self) -> dm_env.TimeStep: """Resets the RL episode.""" # Relaunch the simulator if necessary. if not self._simulator_healthy or self._should_periodic_relaunch(): self._launch_simulator() # Reset counters. self._latest_observation_time = 0 for key in self._stats: if key.startswith('episode'): self._stats[key] = 0.0 # Execute a lift action before resetting the task. if not action_fns.send_action_to_simulator( action_fns.lift_all_fingers_action(self._config.num_fingers), self._simulator, self._device_settings.screen_width(), self._device_settings.screen_height(), self._config.num_fingers, ): self._stats['relaunch_count_execute_action'] += 1 self._simulator_healthy = False # Reset the task. self._task_manager.reset_task() self._device_settings.get_orientation() # Get data from the simulator. simulator_signals = self._gather_simulator_signals() return self._task_manager.rl_reset(simulator_signals) def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep: """Executes the selected action and returns a timestep. Args: agent_action: Selected action to perform on the simulated Android device. If `agent_action` is `None` it means that this is an RL reset (to start a new episode). Returns: An RL timestep. """ if not action_fns.send_action_to_simulator( agent_action, self._simulator, self._device_settings.screen_width(), self._device_settings.screen_height(), self._config.num_fingers, ): self._stats['relaunch_count_execute_action'] += 1 self._simulator_healthy = False # Get data from the simulator. try: simulator_signals = self._gather_simulator_signals() except errors.ReadObservationError: logging.exception('Unable to fetch observation. Restarting simulator.') self._stats['relaunch_count_fetch_observation'] += 1 self._simulator_healthy = False if not self._simulator_healthy: return dm_env.truncation(reward=0.0, observation=None) return self._task_manager.rl_step(simulator_signals) def _gather_simulator_signals(self) -> dict[str, np.ndarray]: """Gathers data from various sources to assemble the RL observation.""" # Get current timestamp and update the delta. now = time.time() timestamp_delta = ( 0 if self._latest_observation_time == 0 else (now - self._latest_observation_time) * 1e6 ) self._latest_observation_time = now return { 'pixels': self._simulator.get_screenshot(), 'orientation': self._device_settings.get_orientation(), 'timedelta': np.array(timestamp_delta, dtype=np.int64), } def __del__(self): self.close() def stats(self) -> dict[str, Any]: """Returns various statistics.""" return copy.deepcopy(self._stats) def close(self): """Cleans up the state of this Coordinator.""" if hasattr(self, '_task_manager'): try: self._task_manager.stop() except: # pylint: disable=bare-except logging.exception('Failed to stop task manager. Continuing.') if hasattr(self, '_simulator'): try: self._simulator.close() except: # pylint: disable=bare-except logging.exception('Failed to close simulator. Continuing.') ================================================ FILE: android_env/components/coordinator_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.coordinator.""" import tempfile import time from unittest import mock from absl.testing import absltest from absl.testing import parameterized from android_env.components import action_type from android_env.components import adb_call_parser from android_env.components import config_classes from android_env.components import coordinator as coordinator_lib from android_env.components import device_settings as device_settings_lib from android_env.components import errors from android_env.components import task_manager from android_env.components.simulators import base_simulator from android_env.proto import adb_pb2 from android_env.proto import state_pb2 from android_env.proto import task_pb2 import dm_env import numpy as np class CoordinatorTest(parameterized.TestCase): def setUp(self): super().setUp() self.addCleanup(mock.patch.stopall) # Disable previous patches. self._simulator = mock.create_autospec(base_simulator.BaseSimulator) self._random_screenshot = np.random.randint( low=0, high=255, size=(800, 600, 3), dtype=np.uint8) self._simulator.get_screenshot.return_value = self._random_screenshot self._task_manager = mock.create_autospec(task_manager.TaskManager) self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) self.enter_context( mock.patch.object( adb_call_parser, 'AdbCallParser', autospec=True, return_value=self._adb_call_parser)) self._coordinator = coordinator_lib.Coordinator( simulator=self._simulator, task_manager=self._task_manager, device_settings=device_settings_lib.DeviceSettings(self._simulator), ) def tearDown(self): super().tearDown() self._coordinator.close() @mock.patch.object(time, 'sleep', autospec=True) def test_relaunch_simulator(self, unused_mock_sleep): relaunch_count = self._coordinator.stats()['relaunch_count'] self._coordinator._launch_simulator() self.assertEqual(self._coordinator.stats()['relaunch_count'], relaunch_count + 1) @mock.patch.object(time, 'sleep', autospec=True) def test_reset(self, unused_mock_sleep): """'relaunch_count_execute_action' should be zero if there's no error.""" self._coordinator.rl_reset() stats = self._coordinator.stats() self.assertIn('relaunch_count_execute_action', stats) self.assertEqual(stats['relaunch_count_execute_action'], 0) @mock.patch.object(time, 'sleep', autospec=True) def test_reset_error_sending_action(self, unused_mock_sleep): """'relaunch_count_execute_action' should be positive if there's an error.""" self._simulator.send_touch.side_effect = errors.SendActionError() self._coordinator.rl_reset() stats = self._coordinator.stats() self.assertIn('relaunch_count_execute_action', stats) self.assertEqual(stats['relaunch_count_execute_action'], 1) @mock.patch.object(time, 'sleep', autospec=True) def test_lift_all_fingers(self, unused_mock_sleep): self._coordinator = coordinator_lib.Coordinator( simulator=self._simulator, task_manager=self._task_manager, device_settings=device_settings_lib.DeviceSettings(self._simulator), config=config_classes.CoordinatorConfig(num_fingers=3), ) self._coordinator.rl_reset() expected_actions = [ # (x, y, is_down, identifier). (0, 0, False, 0), (0, 0, False, 1), (0, 0, False, 2), ] actual_actions = self._simulator.send_touch.call_args[0][0] for actual, expected in zip(actual_actions, expected_actions): np.testing.assert_array_equal(actual, expected) @mock.patch.object(time, 'sleep', autospec=True) def test_process_action(self, unused_mock_sleep): def fake_rl_step(simulator_signals): return dm_env.transition( reward=10.0, observation={ 'pixels': simulator_signals['pixels'], 'orientation': simulator_signals['orientation'], 'timedelta': simulator_signals['timedelta'], 'extras': { 'extra': [0.0] } }) self._task_manager.rl_step.side_effect = fake_rl_step timestep = self._coordinator.rl_step( agent_action={ 'action_type': np.array(action_type.ActionType.LIFT), 'touch_position': np.array([0.5, 0.5]), }) obs = timestep.observation self.assertEqual(obs['pixels'].shape, (800, 600, 3)) np.testing.assert_equal(obs['orientation'], np.array([0, 0, 0, 0], dtype=np.uint8)) self.assertEqual(timestep.reward, 10.0) self.assertEqual(obs['extras'], {'extra': [0.0]}) self.assertFalse(timestep.last()) @mock.patch.object(time, 'sleep', autospec=True) def test_process_action_error(self, unused_mock_sleep): def fake_rl_step(simulator_signals): self.assertFalse(simulator_signals['simulator_healthy']) return dm_env.truncation(reward=0.0, observation=None) self._task_manager.rl_step.side_effect = fake_rl_step self._simulator.get_screenshot.side_effect = errors.ReadObservationError() timestep = self._coordinator.rl_step( agent_action={ 'action_type': np.array(action_type.ActionType.LIFT), 'touch_position': np.array([0.5, 0.5]), }) self.assertIsNone(timestep.observation) self.assertEqual(timestep.reward, 0.0) self.assertTrue(timestep.last()) @mock.patch.object(time, 'sleep', autospec=True) def test_execute_action_touch(self, unused_mock_sleep): def fake_rl_step(simulator_signals): return dm_env.transition( reward=123.0, observation={ 'pixels': simulator_signals['pixels'], 'orientation': simulator_signals['orientation'], 'timedelta': simulator_signals['timedelta'], 'extras': { 'extra': [0.0] } }) self._task_manager.rl_step.side_effect = fake_rl_step timestep = self._coordinator.rl_step( agent_action={ 'action_type': np.array(action_type.ActionType.TOUCH), 'touch_position': np.array([0.5, 0.5]) }) self.assertEqual(timestep.reward, 123.0) np.testing.assert_equal(timestep.observation['pixels'], self._random_screenshot) self._simulator.send_touch.assert_called_once_with([(300, 400, True, 0)]) @mock.patch.object(time, 'sleep', autospec=True) def test_execute_multitouch_action(self, unused_mock_sleep): self._coordinator = coordinator_lib.Coordinator( simulator=self._simulator, task_manager=self._task_manager, device_settings=device_settings_lib.DeviceSettings(self._simulator), config=config_classes.CoordinatorConfig(num_fingers=3), ) def fake_rl_step(simulator_signals): return dm_env.transition( reward=456.0, observation={ 'pixels': simulator_signals['pixels'], 'orientation': simulator_signals['orientation'], 'timedelta': simulator_signals['timedelta'], 'extras': { 'extra': [0.0] } }) self._task_manager.rl_step.side_effect = fake_rl_step action = { 'action_type': np.array([action_type.ActionType.TOUCH]), 'touch_position': np.array([0.25, 0.75]), 'action_type_2': np.array([action_type.ActionType.TOUCH]), 'touch_position_2': np.array([0.75, 0.25]), 'action_type_3': np.array([action_type.ActionType.LIFT]), 'touch_position_3': np.array([0.5, 0.5]), } timestep = self._coordinator.rl_step(action) self._simulator.send_touch.assert_called_once_with([(150, 600, True, 0), (450, 200, True, 1), (300, 400, False, 2)]) self.assertEqual(timestep.reward, 456.0) np.testing.assert_equal(timestep.observation['pixels'], self._random_screenshot) @mock.patch.object(time, 'sleep', autospec=True) def test_execute_action_repeat(self, unused_mock_sleep): def fake_rl_step(simulator_signals): return dm_env.transition( reward=10.0, observation={ 'pixels': simulator_signals['pixels'], 'orientation': simulator_signals['orientation'], 'timedelta': simulator_signals['timedelta'], 'extras': { 'extra': [0.0] } }) self._task_manager.rl_step.side_effect = fake_rl_step timestep = self._coordinator.rl_step( {'action_type': np.array(action_type.ActionType.REPEAT)}) self._simulator.send_touch.assert_not_called() np.testing.assert_equal(timestep.observation['pixels'], self._random_screenshot) @mock.patch.object(time, 'sleep', autospec=True) def test_execute_action_error(self, unused_mock_sleep): def fake_rl_step(simulator_signals): self.assertFalse(simulator_signals['simulator_healthy']) return dm_env.truncation(reward=0.0, observation=None) self._task_manager.rl_step.side_effect = fake_rl_step self._simulator.send_touch.side_effect = errors.SendActionError timestep = self._coordinator.rl_step({ 'action_type': np.array(action_type.ActionType.TOUCH), 'touch_position': np.array([0.3, 0.8]) }) self.assertIsNone(timestep.observation) @mock.patch.object(time, 'sleep', autospec=True) def test_max_restarts_setup_steps(self, unused_mock_sleep): init_fn_call = self._task_manager.setup_task.call_count self._task_manager.setup_task.side_effect = errors.StepCommandError self.assertRaises(errors.TooManyRestartsError, self._coordinator._launch_simulator) # The method was called three more times when attempting to relaunch. self.assertEqual(init_fn_call + 3, self._task_manager.setup_task.call_count) @mock.patch.object(time, 'sleep', autospec=True) def test_execute_adb_call(self, unused_mock_sleep): call = adb_pb2.AdbRequest( force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah')) expected_response = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) self._adb_call_parser.parse.side_effect = [expected_response] response = self._coordinator.execute_adb_call(call) self.assertEqual(response, expected_response) self._adb_call_parser.parse.assert_called_with(call) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/device_settings.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Sets and gets some global settings on an Android device.""" from typing import Final from unittest import mock from absl import logging from android_env.components import adb_call_parser from android_env.components import config_classes from android_env.components.simulators import base_simulator from android_env.proto import adb_pb2 import numpy as np # The internal `AdbCallParser` instance is lazily instantiated within # `DeviceSettings`. If we make it optional (i.e. `| None`), pytype will think # that it could be `None`, requiring either explicit runtime checks or escape # hatches in every actual call, even if it's never actually `None` if reached # via the public API. # The trick here is to create this dummy instance of the right type that's used # as a sentinel to indicate that it hasn't been initialized yet. _PLACEHOLDER_ADB_CALL_PARSER: Final[adb_call_parser.AdbCallParser] = ( mock.create_autospec(adb_call_parser.AdbCallParser) ) class DeviceSettings: """An abstraction for general properties and settings of an Android device.""" def __init__(self, simulator: base_simulator.BaseSimulator): self._simulator = simulator self._adb_call_parser = _PLACEHOLDER_ADB_CALL_PARSER # The size of the device screen in pixels. self._screen_width: int = 0 self._screen_height: int = 0 # The device orientation. self._orientation = np.zeros(4, dtype=np.uint8) def update(self, config: config_classes.DeviceSettingsConfig) -> None: """Sets the configuration of the device according to `config`.""" if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER: self._adb_call_parser = adb_call_parser.AdbCallParser( adb_controller=self._simulator.create_adb_controller() ) self._update_screen_size() self._set_show_touches(config.show_touches) self._set_show_pointer_location(config.show_pointer_location) self._set_status_navigation_bars( config.show_navigation_bar, config.show_status_bar ) def screen_width(self) -> int: """The screen width in pixels. Only valid after `update()` is called.""" return self._screen_width def screen_height(self) -> int: """The screen height in pixels. Only valid after `update()` is called.""" return self._screen_height def get_orientation(self) -> np.ndarray: """Returns the device orientation. Please see specs.py for details.""" if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER: self._adb_call_parser = adb_call_parser.AdbCallParser( adb_controller=self._simulator.create_adb_controller() ) self._update_orientation() return self._orientation def _update_screen_size(self) -> None: """Sets the screen size from a screenshot ignoring the color channel.""" screenshot = self._simulator.get_screenshot() self._screen_height = screenshot.shape[0] self._screen_width = screenshot.shape[1] def _set_show_touches(self, show: bool) -> None: """Whether to display circles indicating the touch position.""" self._adb_call_parser.parse( adb_pb2.AdbRequest( settings=adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, put=adb_pb2.AdbRequest.SettingsRequest.Put( key='show_touches', value='1' if show else '0' ), ) ) ) def _set_show_pointer_location(self, show: bool) -> None: """Whether to display blue lines on the screen indicating touch position.""" self._adb_call_parser.parse( adb_pb2.AdbRequest( settings=adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM, put=adb_pb2.AdbRequest.SettingsRequest.Put( key='pointer_location', value='1' if show else '0' ), ) ) ) def _set_status_navigation_bars( self, show_navigation: bool, show_status: bool ) -> None: """Whether to display the status (top) and navigation (bottom) bars.""" if show_navigation and show_status: policy_control_value = 'null*' elif show_navigation and not show_status: policy_control_value = 'immersive.status=*' elif not show_navigation and show_status: policy_control_value = 'immersive.navigation=*' else: policy_control_value = 'immersive.full=*' self._adb_call_parser.parse( adb_pb2.AdbRequest( settings=adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, put=adb_pb2.AdbRequest.SettingsRequest.Put( key='policy_control', value=policy_control_value ), ) ) ) def _update_orientation(self) -> None: """Updates the current device orientation.""" # Skip fetching the orientation if we already have it. if not np.all(self._orientation == np.zeros(4)): return orientation_response = self._adb_call_parser.parse( adb_pb2.AdbRequest( get_orientation=adb_pb2.AdbRequest.GetOrientationRequest() ) ) if orientation_response.status != adb_pb2.AdbResponse.Status.OK: logging.error('Got bad orientation: %r', orientation_response) return orientation = orientation_response.get_orientation.orientation if orientation not in {0, 1, 2, 3}: logging.error('Got bad orientation: %r', orientation) return # Transform into one-hot format. orientation_onehot = np.zeros([4], dtype=np.uint8) orientation_onehot[orientation] = 1 self._orientation = orientation_onehot ================================================ FILE: android_env/components/device_settings_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from unittest import mock from absl.testing import absltest from absl.testing import parameterized from android_env.components import config_classes from android_env.components import device_settings as device_settings_lib from android_env.components.simulators import base_simulator import numpy as np class DeviceSettingsTest(parameterized.TestCase): def test_screen_size_before_update(self): """The screen size should be 0x0 without calling `update()`.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) device_settings = device_settings_lib.DeviceSettings(simulator) # Act. height = device_settings.screen_height() width = device_settings.screen_width() # Assert. self.assertEqual(height, 0) self.assertEqual(width, 0) def test_screen_size_after_update(self): """The screen size should be set after calling `update()`.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) simulator.get_screenshot.return_value = np.random.randint( low=0, high=255, size=(123, 456, 3), dtype=np.uint8 ) adb_controller = simulator.create_adb_controller.return_value adb_controller.execute_command.return_value = b'' device_settings = device_settings_lib.DeviceSettings(simulator) # Act. device_settings.update(config_classes.DeviceSettingsConfig()) height = device_settings.screen_height() width = device_settings.screen_width() # Assert. self.assertEqual(height, 123) self.assertEqual(width, 456) @parameterized.named_parameters( ( 'show_touches', config_classes.DeviceSettingsConfig(show_touches=True), mock.call( ['shell', 'settings', 'put', 'system', 'show_touches', '1'], timeout=None, ), ), ( 'show_touches_false', config_classes.DeviceSettingsConfig(show_touches=False), mock.call( ['shell', 'settings', 'put', 'system', 'show_touches', '0'], timeout=None, ), ), ( 'show_pointer_location', config_classes.DeviceSettingsConfig(show_pointer_location=True), mock.call( ['shell', 'settings', 'put', 'system', 'pointer_location', '1'], timeout=None, ), ), ( 'show_pointer_location_false', config_classes.DeviceSettingsConfig(show_pointer_location=False), mock.call( ['shell', 'settings', 'put', 'system', 'pointer_location', '0'], timeout=None, ), ), ( 'show_navigation_and_status', config_classes.DeviceSettingsConfig( show_navigation_bar=True, show_status_bar=True ), mock.call( ['shell', 'settings', 'put', 'global', 'policy_control', 'null*'], timeout=None, ), ), ( 'show_navigation_and_no_status', config_classes.DeviceSettingsConfig( show_navigation_bar=True, show_status_bar=False ), mock.call( [ 'shell', 'settings', 'put', 'global', 'policy_control', 'immersive.status=*', ], timeout=None, ), ), ( 'show_no_navigation_and_status', config_classes.DeviceSettingsConfig( show_navigation_bar=False, show_status_bar=True ), mock.call( [ 'shell', 'settings', 'put', 'global', 'policy_control', 'immersive.navigation=*', ], timeout=None, ), ), ( 'show_no_navigation_and_no_status', config_classes.DeviceSettingsConfig( show_navigation_bar=False, show_status_bar=False ), mock.call( [ 'shell', 'settings', 'put', 'global', 'policy_control', 'immersive.full=*', ], timeout=None, ), ), ) def test_update( self, settings: config_classes.DeviceSettingsConfig, expected_call ): """We expect the right call for each setting.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) adb_controller = simulator.create_adb_controller.return_value adb_controller.execute_command.return_value = b'' device_settings = device_settings_lib.DeviceSettings(simulator) # Act. device_settings.update(settings) # Assert. adb_controller.execute_command.assert_has_calls( [expected_call], any_order=True ) def test_get_orientation_bad_response(self): """The orientation should be unset if the underlying response is bad.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) adb_controller = simulator.create_adb_controller.return_value adb_controller.execute_command.return_value = b'' device_settings = device_settings_lib.DeviceSettings(simulator) # Act. orientation = device_settings.get_orientation() # Assert. np.testing.assert_array_equal(orientation, np.zeros(4)) def test_get_orientation_bad_orientation(self): """The orientation should be unset if the underlying orientation is bad.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) adb_controller = simulator.create_adb_controller.return_value adb_controller.execute_command.return_value = b' InputDeviceOrientation: 9' device_settings = device_settings_lib.DeviceSettings(simulator) # Act. orientation = device_settings.get_orientation() # Assert. np.testing.assert_array_equal(orientation, np.zeros(4)) def test_get_orientation_success(self): """Checks that the orientation comes back as expected.""" # Arrange. simulator = mock.create_autospec(base_simulator.BaseSimulator) adb_controller = simulator.create_adb_controller.return_value adb_controller.execute_command.return_value = b' InputDeviceOrientation: 3' device_settings = device_settings_lib.DeviceSettings(simulator) # Act. orientation = device_settings.get_orientation() # The output should be idempotent if the underlying system did not change. orientation_again = device_settings.get_orientation() # Assert. np.testing.assert_array_equal(orientation, np.array([0, 0, 0, 1])) np.testing.assert_array_equal(orientation, orientation_again) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/dumpsys_thread.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A ThreadFunction that runs and parses adb dumpsys.""" import concurrent.futures from absl import logging from android_env.components import app_screen_checker as app_screen_checker_lib _Outcome = app_screen_checker_lib.AppScreenChecker.Outcome class DumpsysThread: """A thread that checks if the user is in the expected app screen.""" def __init__( self, app_screen_checker: app_screen_checker_lib.AppScreenChecker, check_frequency: int = 10, max_failed_current_activity: int = 10, ): """Initializes the dumpsys reader thread. This loops forever checking if the user is in the expected screen dictated by `app_screen_checker`. These analyses are too expensive to be in the critical path of AndroidEnv::step() so we consume them async from this separate thread. Args: app_screen_checker: The class that actually determines if the current screen matches the expected screen. check_frequency: Integer. We only call dumpsys 1/check_frequency times in each iteration of the while loop below. max_failed_current_activity: Integer. We try to fetch the current activity but sometimes it fails. If it fails more than `max_failed_current_activity` consecutive times, we declare that the user has exited `expected_activity`. """ self._app_screen_checker = app_screen_checker self._main_loop_counter = 0 self._check_frequency = check_frequency self._max_failed_activity_extraction = max_failed_current_activity self._num_failed_activity_extraction = 0 self._latest_check: concurrent.futures.Future | None = None def check_user_exited(self, timeout: float | None = None) -> bool: """Returns True if the user is not in the expected screen. Args: timeout: An optional time in seconds to block waiting for the result of the (expensive) checking operation. If None, the function will return immediately with `False`. Returns: Whether the user of the Android device has exited the expected screen determined by `AppScreenChecker` given at __init__(). """ # Update and check loop_counter against check_frequency. self._main_loop_counter += 1 if (self._check_frequency <= 0 or self._main_loop_counter < self._check_frequency): return False self._main_loop_counter = 0 # If the latest check is None, perform a check and return. if self._latest_check is None: with concurrent.futures.ThreadPoolExecutor() as executor: self._latest_check = executor.submit(self._check_impl) return False # If there's a check in flight, continue only if it's finished. if not timeout and not self._latest_check.done(): return False v = self._latest_check.result(timeout=timeout) self._latest_check = None # Reset the check. return v def _check_impl(self) -> bool: """The synchronous implementation of Dumpsys.""" outcome = self._app_screen_checker.matches_current_app_screen() # We were unable to determine the current activity. if outcome == _Outcome.FAILED_ACTIVITY_EXTRACTION: self._num_failed_activity_extraction += 1 logging.info('self._num_failed_activity_extraction: %s', self._num_failed_activity_extraction) if (self._num_failed_activity_extraction >= self._max_failed_activity_extraction): logging.error('Maximum number of failed activity extraction reached.') self._num_failed_activity_extraction = 0 return True else: self._num_failed_activity_extraction = 0 # The current app screen matches all expectations. if (outcome == _Outcome.SUCCESS or outcome == _Outcome.EMPTY_EXPECTED_ACTIVITY): return False # Player has exited the app. Terminate the episode. elif outcome == _Outcome.UNEXPECTED_ACTIVITY: return True # Player has exited the main game. Terminate the episode. elif outcome == _Outcome.UNEXPECTED_VIEW_HIERARCHY: return True return False ================================================ FILE: android_env/components/dumpsys_thread_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.dumpsys_thread.""" from unittest import mock from absl.testing import absltest from android_env.components import app_screen_checker as screen_checker from android_env.components import dumpsys_thread class DumpsysThreadTest(absltest.TestCase): def setUp(self): super().setUp() self._app_screen_checker = mock.create_autospec( screen_checker.AppScreenChecker) def test_unexpected_activity(self): dumpsys = dumpsys_thread.DumpsysThread( app_screen_checker=self._app_screen_checker, check_frequency=1) outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY self._app_screen_checker.matches_current_app_screen.return_value = outcome # The first time that `check_user_exited()` is called, it'll only trigger # the processing, but it should return immediately. self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) # The second time it should then wait for the result. self.assertTrue(dumpsys.check_user_exited(timeout=1.0)) def test_unexpected_view_hierarchy(self): dumpsys = dumpsys_thread.DumpsysThread( app_screen_checker=self._app_screen_checker, check_frequency=1) outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY self._app_screen_checker.matches_current_app_screen.return_value = outcome self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) self.assertTrue(dumpsys.check_user_exited(timeout=1.0)) def test_success(self): dumpsys = dumpsys_thread.DumpsysThread( app_screen_checker=self._app_screen_checker, check_frequency=1) outcome = screen_checker.AppScreenChecker.Outcome.SUCCESS self._app_screen_checker.matches_current_app_screen.return_value = outcome self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) def test_skipped(self): dumpsys = dumpsys_thread.DumpsysThread( app_screen_checker=self._app_screen_checker, check_frequency=5) self._app_screen_checker.matches_current_app_screen.side_effect = [ screen_checker.AppScreenChecker.Outcome.SUCCESS, screen_checker.AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION ] for _ in range(17): self.assertFalse(dumpsys.check_user_exited(timeout=1.0)) # The first 4 calls will hit the early exit from `check_frequency`. # The 5th call will trigger the processing (increasing the call count to # matches_current_app_screen() by 1), but it should return early. # The 10th call will find a result of the previous processing, and it should # be SUCCESS. # The next 4 calls (11, 12, 13, 14) will hit the early exit from # `check_frequency`. # The 15th call should trigger the processing again (increasing the call # count to matches_current_app_screen() by 1), but it should return early. # The next 2 calls (16, 17) will hit the early exit from `check_frequency`. # In total there should be only two calls to `matches_current_app_screen()`. expected_call_count = 2 self.assertEqual( self._app_screen_checker.matches_current_app_screen.call_count, expected_call_count) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/errors.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Definitions of exceptions used by AndroidEnv.""" class AndroidEnvError(Exception): """Base class for all known errors generated by AndroidEnv.""" # An integer that identifies this class of error. # Subclasses should use a different value. ERROR_CODE: int = 0 class ReadObservationError(AndroidEnvError): """When the environment is unable to obtain an observation from a simulator.""" ERROR_CODE = 1 class CoordinatorError(AndroidEnvError): """Error raised by the Coordinator.""" ERROR_CODE = 2 class TooManyRestartsError(CoordinatorError): """The number of restarts has exceeded _MAX_RESTART_TRIES.""" ERROR_CODE = 3 class AdbControllerError(AndroidEnvError): """Errors that can be raised by ADBController.""" ERROR_CODE = 4 class SimulatorError(AndroidEnvError): """Errors that can be raised by a simulator.""" ERROR_CODE = 5 class SendActionError(AndroidEnvError): """Raised when action couldn't be sent successfully.""" ERROR_CODE = 6 class StepCommandError(AndroidEnvError): """Raised when setup step interpreter cannot process a command.""" ERROR_CODE = 7 class WaitForAppScreenError(StepCommandError): """Raised when the wait_for_app_screen success check is not met.""" ERROR_CODE = 8 class CheckInstallError(StepCommandError): """Raised when the check_install success check is not met.""" ERROR_CODE = 9 def from_code(code: int, msg: str = '') -> AndroidEnvError | None: """Returns an AndroidEnvError instance from the given arguments.""" code_to_error = { 0: AndroidEnvError, 1: ReadObservationError, 2: CoordinatorError, 3: TooManyRestartsError, 4: AdbControllerError, 5: SimulatorError, 6: SendActionError, 7: StepCommandError, 8: WaitForAppScreenError, 9: CheckInstallError, } if code in code_to_error: return code_to_error[code](msg) ================================================ FILE: android_env/components/errors_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for errors.py.""" from absl.testing import absltest from absl.testing import parameterized from android_env.components import errors class ErrorsTest(parameterized.TestCase): @parameterized.parameters( (errors.ReadObservationError, 1), (errors.CoordinatorError, 2), (errors.TooManyRestartsError, 3), (errors.AdbControllerError, 4), (errors.SimulatorError, 5), (errors.SendActionError, 6), (errors.StepCommandError, 7), (errors.WaitForAppScreenError, 8), (errors.CheckInstallError, 9), ) def test_error_codes(self, error, expected_error_code): with self.assertRaises(error) as context: raise error() self.assertEqual(context.exception.ERROR_CODE, expected_error_code) def test_error_codes_unique(self): error_codes = set() errors_list = ( errors.ReadObservationError, errors.CoordinatorError, errors.TooManyRestartsError, errors.AdbControllerError, errors.SimulatorError, errors.SendActionError, errors.StepCommandError, errors.WaitForAppScreenError, errors.CheckInstallError, ) for error in errors_list: self.assertNotIn(error.ERROR_CODE, error_codes) error_codes.add(error.ERROR_CODE) @parameterized.parameters([ errors.ReadObservationError(), errors.CoordinatorError(), errors.TooManyRestartsError(), errors.AdbControllerError(), errors.SimulatorError(), errors.SendActionError(), errors.StepCommandError(), errors.WaitForAppScreenError(), errors.CheckInstallError(), ]) def test_all_errors_are_androidenv_errors(self, error): self.assertIsInstance(error, errors.AndroidEnvError) @parameterized.named_parameters([ ('less_than_zero', -1), # The largest `ERROR_CODE` is currently `CheckInstallError == 10`. ('greater_than_all_errors', 10 + 1), ('less_than_zero_float', -3.14159265), ('greater_than_all_errors_float', 123.456), ]) def test_from_code_unsupported_code(self, code: int): """Unsupported errors should raise `RuntimeError`.""" self.assertIsNone(errors.from_code(code)) @parameterized.parameters([ (-1, None, 'No such error code.'), (0, errors.AndroidEnvError, 'hello'), (0, errors.AndroidEnvError, ''), (1, errors.ReadObservationError, 'Could not read obs.'), (2, errors.CoordinatorError, 'Some error'), (3, errors.TooManyRestartsError, 'Too many already...'), (4, errors.AdbControllerError, 'Some adb error...'), (5, errors.SimulatorError, 'Simulator is not coping.'), (6, errors.SendActionError, 'Could not send action.'), (7, errors.StepCommandError, 'Some issue setting up the task.'), (8, errors.WaitForAppScreenError, 'Waited for too long!'), (9, errors.CheckInstallError, 'App did not install correctly.'), ]) def test_from_code(self, code: int, expected_class: errors.AndroidEnvError, msg: str): """`from_code` should produce consistent outputs for known errors.""" error = errors.from_code(code, msg) if error is not None: self.assertIsInstance(error, expected_class) self.assertEqual(error.ERROR_CODE, code) self.assertEqual(str(error), msg) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/log_stream.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Abstract class for handling a stream of logs from a simulator.""" import abc from collections.abc import Generator, Sequence import threading from absl import logging class LogStream(metaclass=abc.ABCMeta): """Manages the stream of logs output by a simulator.""" def __init__(self, verbose: bool = False): self._verbose = verbose self._filters = [] self._should_stream = threading.Event() def get_stream_output(self) -> Generator[str, None, None]: """Starts log process and returns the stream of logs.""" for line in self._get_stream_output(): if self._verbose: logging.info('line: %r', line) if self._should_stream.is_set(): yield line @abc.abstractmethod def _get_stream_output(self): """Starts log process and returns the stream of logs.""" pass @abc.abstractmethod def stop_stream(self) -> None: """Terminates the log stream. NOTE: This should only be called _after_ `get_stream_output()`. """ def pause_stream(self) -> None: """No lines are yielded while the event is not set.""" logging.info('Pausing LogStream.') self._should_stream.clear() def resume_stream(self) -> None: """The stream will continue yielding lines if the event is set.""" logging.info('Resuming LogStream.') self._should_stream.set() def set_log_filters(self, log_filters: Sequence[str]): """Sets the filters for the log stream.""" self._filters = list(log_filters) + ['*:S'] ================================================ FILE: android_env/components/log_stream_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for log_stream.""" from absl.testing import absltest from android_env.components import log_stream class FakeLogStream(log_stream.LogStream): def __init__(self, filter_name: str): super().__init__() self._filter_name = filter_name def _get_stream_output(self): """Starts a log process and returns the stream of logs.""" lines = [ f'{self._filter_name} fake_line_1', 'fake_line_2', f'{self._filter_name} fake_line_3', f'{self._filter_name} fake_line_4', 'fake_line_5', 'fake_line_6', ] for line in lines: if f'{self._filter_name}:V' in self._filters: if self._filter_name in line: yield line else: yield line def stop_stream(self): """Stops the log stream from the simulator.""" pass class LogStreamTest(absltest.TestCase): def test_get_stream_output(self): filter_name = 'AndroidRLTask' stream = FakeLogStream(filter_name=filter_name) stream.resume_stream() stream_output = stream.get_stream_output() expected_lines = [ f'{filter_name} fake_line_1', 'fake_line_2', f'{filter_name} fake_line_3', f'{filter_name} fake_line_4', 'fake_line_5', 'fake_line_6', ] for line, expected_line in zip(stream_output, expected_lines): self.assertEqual(line, expected_line) def test_set_log_filters(self): filter_name = 'AndroidRLTask' stream = FakeLogStream(filter_name=filter_name) stream.set_log_filters([f'{filter_name}:V']) stream.resume_stream() stream_output = stream.get_stream_output() expected_lines = [ f'{filter_name} fake_line_1', f'{filter_name} fake_line_3', f'{filter_name} fake_line_4', ] for line, expected_line in zip(stream_output, expected_lines): self.assertEqual(line, expected_line) def test_pause_resume_stream(self): filter_name = 'AndroidRLTask' stream = FakeLogStream(filter_name=filter_name) stream.resume_stream() stream_output = stream.get_stream_output() expected_lines = [ f'{filter_name} fake_line_1', 'fake_line_2', f'{filter_name} fake_line_3', f'{filter_name} fake_line_4', 'fake_line_5', 'fake_line_6', ] for line, expected_line in zip(stream_output, expected_lines): self.assertEqual(line, expected_line) # If the stream is paused, we expect no lines to be yielded. stream.pause_stream() stream_output = list(stream.get_stream_output()) self.assertEmpty(stream_output) # If the stream is resumed, we expect to see all lines yielded. stream.resume_stream() stream_output = stream.get_stream_output() for line, expected_line in zip(stream_output, expected_lines): self.assertEqual(line, expected_line) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/logcat_thread.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A class that launches a thread to read Android log outputs.""" from collections.abc import Callable import dataclasses import re import threading from absl import logging from android_env.components import log_stream as log_stream_lib @dataclasses.dataclass class EventListener: """A function that's called when an event is triggered.""" regexp: re.Pattern[str] handler_fn: Callable[[re.Pattern[str], re.Match[str]], None] class LogcatThread: """Reads log entries in a separate thread.""" def __init__(self, log_stream: log_stream_lib.LogStream): """Initializes this LogcatThread with optional filters. Please see https://developer.android.com/studio/command-line/logcat for more info on `logcat`. Args: log_stream: Stream of logs from simulator. """ self._log_stream = log_stream self._listeners = {} self._line_ready = threading.Event() self._line_ready.set() self._should_stop = threading.Event() self._thread = threading.Thread(target=self._process_logs) self._thread.daemon = True self._thread.start() def add_event_listener(self, event_listener: EventListener) -> None: """Adds `fn` to the list of handlers to call when `event` occurs.""" event_regexp = event_listener.regexp if event_regexp not in self._listeners: self._listeners[event_regexp] = [] self._listeners[event_regexp].append(event_listener.handler_fn) def remove_event_listener(self, event_listener: EventListener) -> None: """Removes `fn` from the list of handlers to call when `event` occurs.""" event_regexp = event_listener.regexp if event_regexp not in self._listeners: logging.error('Event: %r is not registered.', event_regexp) return self._listeners[event_regexp].remove(event_listener.handler_fn) def line_ready(self) -> threading.Event: """Indicates whether all listeners have been notified for a given line.""" return self._line_ready def pause(self): self._log_stream.pause_stream() def resume(self): """Resume or restart the thread if it's dead after resetting environment.""" if not self._thread.is_alive(): self._should_stop.clear() self._thread = threading.Thread(target=self._process_logs) self._thread.daemon = True self._thread.start() self._log_stream.resume_stream() def kill(self): self._should_stop.set() self._log_stream.stop_stream() self._thread.join(timeout=3.0) def _process_logs(self) -> None: """A loop that runs until `self._should_stop` is set().""" # pylint: disable=g-line-too-long # Format is: "TIME_SEC PID TID PRIORITY TAG: MESSAGE" # # Example: # ' 1553110400.424 5583 5658 D NostalgicRacer: com.google.example.games.nostalgicracer.views.renderers.OpenGLRenderDriver@912fb8.onSurfaceChanged 480x320' # # pylint: enable=g-line-too-long logline_regexp = r""" ^ # Beginning of the line. [ ]+(?P[0-9]+\.[0-9]+) # Spaces and a float. [ ]+(?P[0-9]+) # Spaces and an int. [ ]+(?P[0-9]+) # Spaces and an int. [ ]+(?P.) # Spaces and any single character. [ ]+(?P[^:]*): # Spaces and any char that's not ':'. [ ](?P.*)$ # The actual log message. """ logline_re = re.compile(logline_regexp, re.VERBOSE) for line in self._log_stream.get_stream_output(): if self._should_stop.is_set(): break if not line: # Skip empty lines. continue matches = logline_re.match(line) if not matches or len(matches.groups()) != 6: continue # Make sure that values are not read until all listeners are notified. self._line_ready.clear() # We're currently only consuming `message`, but we may use the other # fields in the future. content = matches.group('message') for ev, listeners in self._listeners.items(): ev_matches = ev.match(content) if ev_matches: for listener in listeners: # Notify listeners. listener(ev, ev_matches) self._line_ready.set() # Allow consumers to read values. ================================================ FILE: android_env/components/logcat_thread_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.logcat_thread.""" import re import threading from absl.testing import absltest from android_env.components import log_stream from android_env.components import logcat_thread from android_env.proto import task_pb2 class FakeStream: """This class simulates the logs coming from ADB.""" def __init__(self): self._values = [] self._kill = False self._lock = threading.Lock() def send_value(self, value): with self._lock: self._values.append(value) def has_next_value(self): return bool(self._values) def kill(self): self._kill = True def reset(self): self._kill = False def __iter__(self): while True: if self._kill: return if not self._values: continue else: with self._lock: next_value = self._values.pop(0) yield next_value def make_stdout(data): """Returns a valid log output with given data as message.""" return ' 1553110400.424 5583 5658 D Tag: %s' % data class FakeLogStream(log_stream.LogStream): """FakeLogStream class that wraps a FakeStream.""" def __init__(self): super().__init__(verbose=False) self.logs = FakeStream() self.stream_is_alive = True def _get_stream_output(self): return self.logs def stop_stream(self): self.stream_is_alive = False self.logs.kill() def reset(self): self.stream_is_alive = True self.logs.reset() class LogcatThreadTest(absltest.TestCase): def setUp(self): super().setUp() self.fake_log_stream = FakeLogStream() def tearDown(self): self.fake_log_stream.stop_stream() super().tearDown() def test_set_filters(self): log_parsing_config = task_pb2.LogParsingConfig(filters=['AndroidRLTask:V']) self.fake_log_stream.set_log_filters(log_parsing_config.filters) _ = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) expected_filters = ['AndroidRLTask:V', '*:S'] self.assertEqual(expected_filters, self.fake_log_stream._filters) def test_kill(self): logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) self.assertTrue(self.fake_log_stream.stream_is_alive) logcat.kill() self.assertFalse(self.fake_log_stream.stream_is_alive) def test_listeners(self): """Ensures that we can wait for a specific message without polling.""" logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) # Start yielding lines from LogStream. logcat.resume() # Set up a listener that modifies an arbitrary state. some_state = threading.Event() def my_handler(event: re.Pattern[str], match: re.Match[str]): del event, match nonlocal some_state some_state.set() # Create a desired event and hook up the listener. my_event = re.compile('Hello world') listener = logcat_thread.EventListener(my_event, my_handler) logcat.add_event_listener(listener) self.fake_log_stream.logs.send_value('Hi there!') # This should not match. self.assertFalse(some_state.is_set()) self.fake_log_stream.logs.send_value(make_stdout('Hello world')) some_state.wait(timeout=1.0) self.assertTrue(some_state.is_set()) # Waiting for any events should also trigger the listener. some_state.clear() self.fake_log_stream.logs.send_value(make_stdout('Hello world')) some_state.wait(timeout=1.0) self.assertTrue(some_state.is_set()) # After removing the listener, it should not be called anymore. some_state.clear() logcat.remove_event_listener(listener) self.fake_log_stream.logs.send_value(make_stdout('Hello world')) some_state.wait(timeout=1.0) self.assertFalse(some_state.is_set()) def test_resume_does_not_recreate_alive_thread(self): logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) thread_before = logcat._thread self.assertTrue(thread_before.is_alive()) logcat.resume() thread_after = logcat._thread self.assertTrue(thread_after.is_alive()) self.assertIs(thread_before, thread_after) def test_resume_recreates_thread(self): logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream) self.assertTrue(logcat._thread.is_alive()) logcat.kill() self.assertFalse(logcat._thread.is_alive()) self.assertTrue(logcat._should_stop.is_set()) self.fake_log_stream.reset() logcat.resume() self.assertTrue(logcat._thread.is_alive()) self.assertFalse(logcat._should_stop.is_set()) self.assertTrue(logcat._thread.daemon) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/pixel_fns.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Utils for AndroidEnv.""" from collections.abc import Sequence from dm_env import specs import numpy as np def touch_position_to_pixel_position( touch_position: np.ndarray, width_height: Sequence[int], ) -> tuple[int, int]: """Maps touch position in [0,1] to the corresponding pixel on the screen.""" touch_pixels = (touch_position * width_height).astype(np.int32) cap_idx = lambda v, idx_len: min(v, idx_len - 1) return tuple(map(cap_idx, touch_pixels, width_height)) def transpose_pixels(frame: np.ndarray) -> np.ndarray: """Converts image from shape (H, W, C) to (W, H, C) and vice-versa.""" return np.transpose(frame, axes=(1, 0, 2)) def orient_pixels(frame: np.ndarray, orientation: int) -> np.ndarray: """Rotates screen pixels according to the given orientation.""" match orientation: case 0: # PORTRAIT_90 return frame case 1: # LANDSCAPE_90 return np.rot90(frame, k=3, axes=(0, 1)) case 2: # PORTRAIT_180 return np.rot90(frame, k=2, axes=(0, 1)) case 3: # LANDSCAPE_270 return np.rot90(frame, k=1, axes=(0, 1)) case _: raise ValueError( 'Orientation must be an integer in [0, 3] but is %r' % orientation ) def convert_int_to_float(data: np.ndarray, data_spec: specs.Array): """Converts an array of int values to floats between 0 and 1.""" if not np.issubdtype(data.dtype, np.integer): raise TypeError(f'{data.dtype} is not an integer type') if isinstance(data_spec, specs.BoundedArray): value_min = data_spec.minimum value_max = data_spec.maximum else: # We use the int type to figure out the boundaries. iinfo = np.iinfo(data_spec.dtype) value_min = iinfo.min value_max = iinfo.max return np.float32(1.0 * (data - value_min) / (value_max - value_min)) ================================================ FILE: android_env/components/pixel_fns_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for pixel_fns.""" from absl.testing import absltest from absl.testing import parameterized from android_env.components import pixel_fns from dm_env import specs import numpy as np class UtilsTest(parameterized.TestCase): @parameterized.parameters( ([0.5, 0.5], [320, 480], (160, 240)), ([0.25, 0.75], [320, 480], (80, 360)), ([0.0, 0.0], [320, 480], (0, 0)), ([1.0, 1.0], [320, 480], (319, 479)), ) def test_touch_position_to_pixel_position( self, touch_pos, width_height, pixel_pos): self.assertEqual( pixel_fns.touch_position_to_pixel_position( np.array(touch_pos), width_height ), pixel_pos, ) def test_transpose_pixels(self): image = np.reshape(np.array(range(12)), (3, 2, 2)) expected = [[[0, 1], [4, 5], [8, 9]], [[2, 3], [6, 7], [10, 11]]] self.assertEqual(pixel_fns.transpose_pixels(image).shape, (2, 3, 2)) self.assertTrue((pixel_fns.transpose_pixels(image) == expected).all()) def test_orient_pixels(self): image = np.reshape(np.array(range(12)), (3, 2, 2)) expected_90 = [[[8, 9], [4, 5], [0, 1]], [[10, 11], [6, 7], [2, 3]]] rot_90 = 1 # LANDSCAPE_90 rotated = pixel_fns.orient_pixels(image, rot_90) self.assertEqual(rotated.shape, (2, 3, 2)) self.assertTrue((rotated == expected_90).all()) expected_180 = [[[10, 11], [8, 9]], [[6, 7], [4, 5]], [[2, 3], [0, 1]]] rot_180 = 2 # PORTRAIT_180 rotated = pixel_fns.orient_pixels(image, rot_180) self.assertEqual(rotated.shape, (3, 2, 2)) self.assertTrue((rotated == expected_180).all()) expected_270 = [[[2, 3], [6, 7], [10, 11]], [[0, 1], [4, 5], [8, 9]]] rot_270 = 3 # LANDSCAPE_270 rotated = pixel_fns.orient_pixels(image, rot_270) self.assertEqual(rotated.shape, (2, 3, 2)) self.assertTrue((rotated == expected_270).all()) rot_0 = 0 # PORTRAIT_0 rotated = pixel_fns.orient_pixels(image, rot_0) self.assertEqual(rotated.shape, (3, 2, 2)) self.assertTrue((rotated == image).all()) def test_convert_int_to_float_bounded_array(self): spec = specs.BoundedArray( shape=(4,), dtype=np.int32, minimum=[0, 1, 10, -2], maximum=[5, 5, 20, 2], name='bounded_array') data = np.array([2, 2, 10, 0], dtype=np.int32) float_data = pixel_fns.convert_int_to_float(data, spec) np.testing.assert_equal( np.array([2.0 / 5.0, 1.0 / 4.0, 0.0, 0.5], dtype=np.float32), float_data ) def test_convert_int_to_float_bounded_array_broadcast(self): spec = specs.BoundedArray( shape=(3,), dtype=np.int16, minimum=2, maximum=4, name='bounded_array') data = np.array([2, 3, 4], dtype=np.int16) float_data = pixel_fns.convert_int_to_float(data, spec) np.testing.assert_equal( np.array([0.0, 0.5, 1.0], dtype=np.float32), float_data) def test_convert_int_to_float_no_bounds(self): spec = specs.Array( shape=(3,), dtype=np.int8, # int8 implies min=-128, max=127 name='bounded_array') data = np.array([-128, 0, 127], dtype=np.int16) float_data = pixel_fns.convert_int_to_float(data, spec) np.testing.assert_equal( np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/setup_step_interpreter.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A component that parses and processes SetupSteps.""" from collections.abc import Sequence import copy import time from typing import Any from absl import logging from android_env.components import adb_call_parser as adb_call_parser_lib from android_env.components import app_screen_checker from android_env.components import errors from android_env.proto import adb_pb2 from android_env.proto import task_pb2 class SetupStepInterpreter: """An interpreter for SetupSteps.""" def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser): """Initializes this interpreter. Args: adb_call_parser: An object to communicate with Android via ADB. """ self._adb_call_parser = adb_call_parser self._stats = { 'error_count_adb_request': 0, 'error_count_wait_for_app_screen': 0, 'error_count_check_install': 0, 'error_count_wait_for_message': 0, 'total_time_waiting_for_app_screen': 0 } def stats(self) -> dict[str, Any]: return copy.deepcopy(self._stats) def interpret(self, setup_steps: Sequence[task_pb2.SetupStep]) -> None: """Returns True if parsing and processing `setup_steps` is successful.""" if setup_steps: logging.info('Executing setup steps: %s', setup_steps) for step in setup_steps: self._process_step_command(step) logging.info('Done executing setup steps.') def _process_step_command(self, step_cmd: task_pb2.SetupStep) -> None: """Processes a single step command from a reset or extra setup.""" if not step_cmd: logging.info('Empty step_cmd') return logging.info('Executing step_cmd: %r', step_cmd) step_type = step_cmd.WhichOneof('step') success_condition = step_cmd.success_condition success_check = success_condition.WhichOneof('check') assert step_type or success_check, ( 'At least one of step and success_condition must be defined.') num_tries = 0 max_retries = max(success_condition.num_retries, 3) latest_error = None while num_tries < max_retries: num_tries += 1 try: unused_adb_response = self._execute_step_cmd(step_cmd, step_type) time.sleep(0.5) self._check_success(success_check, success_condition) return except NotImplementedError: logging.exception('Not implemented error! Skipping this step command.') return except errors.AdbControllerError as error: latest_error = error self._stats['error_count_adb_request'] += 1 logging.exception('ADB call [%r] has failed. Try %d of %d.', step_cmd.adb_request, num_tries, max_retries) except errors.WaitForAppScreenError as error: latest_error = error self._stats['error_count_wait_for_app_screen'] += 1 logging.exception('Failed to wait for app screen. Try %d of %d.', num_tries, max_retries) except errors.CheckInstallError as error: latest_error = error self._stats['error_count_check_install'] += 1 logging.exception('Package [%r] not installed. Try %d of %d.', success_condition.check_install.package_name, num_tries, max_retries) raise errors.StepCommandError( f'Step failed: [{step_cmd}]') from latest_error def _execute_step_cmd( self, step_cmd: task_pb2.SetupStep, step_type: str | None ) -> adb_pb2.AdbResponse | None: """Executes a step command of given type.""" match step_type: case None: return None case 'sleep': time.sleep(step_cmd.sleep.time_sec) return None case 'adb_request': response = self._adb_call_parser.parse(step_cmd.adb_request) if response.status != adb_pb2.AdbResponse.Status.OK: raise errors.AdbControllerError( f'Failed to execute AdbRequest [{step_cmd.adb_request}].\n' f'Status: {response.status}\n' f'Error: {response.error_message}' ) return response case _: raise NotImplementedError(f'No step command of type [{step_type}].') def _check_success( self, success_check: str | None, success_condition: task_pb2.SuccessCondition, ) -> None: """Checks whether the given success condition was met.""" match success_check: case None: return None case 'wait_for_app_screen': wait_for_app_screen = success_condition.wait_for_app_screen screen_checker = app_screen_checker.AppScreenChecker( adb_call_parser=self._adb_call_parser, expected_app_screen=wait_for_app_screen.app_screen, ) wait_time = screen_checker.wait_for_app_screen( timeout_sec=wait_for_app_screen.timeout_sec ) self._stats['total_time_waiting_for_app_screen'] += wait_time case 'check_install': self._check_install(success_condition.check_install) case _: raise NotImplementedError(f'No success check called [{success_check}].') def _check_install(self, check_install: task_pb2.CheckInstall) -> None: """Checks that the given package is installed.""" package = check_install.package_name logging.info('Checking if package is installed: [%r]', package) request = adb_pb2.AdbRequest( package_manager=adb_pb2.AdbRequest.PackageManagerRequest( list=adb_pb2.AdbRequest.PackageManagerRequest.List( packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages( )))) start_time = time.time() while time.time() - start_time < check_install.timeout_sec: response = self._adb_call_parser.parse(request) if package in response.package_manager.list.items: logging.info('Done confirming that package is installed.') return time.sleep(0.1) logging.error('Package not found.') raise errors.CheckInstallError() ================================================ FILE: android_env/components/setup_step_interpreter_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.setup_step_interpreter.""" from unittest import mock from absl.testing import absltest from android_env.components import adb_call_parser from android_env.components import errors from android_env.components import setup_step_interpreter from android_env.proto import adb_pb2 from android_env.proto import task_pb2 from google.protobuf import text_format def _to_proto(proto_class, text): proto = proto_class() text_format.Parse(text, proto) return proto class SetupStepInterpreterTest(absltest.TestCase): def setUp(self): super().setUp() self._parser = mock.create_autospec( adb_call_parser.AdbCallParser, instance=True) def test_empty_setup_steps(self): """Simple test where nothing should break, and nothing should be done. The test simply expects this test to not crash. """ interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([]) def test_none_setup_steps(self): """Simple test where nothing should break, and nothing should be done. The test simply expects this test to not crash. """ interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) # Empty setup steps should be ignored. interpreter.interpret([]) def test_invalid_setup_step(self): interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) # Empty setup steps should be ignored. self.assertRaises(AssertionError, interpreter.interpret, [task_pb2.SetupStep()]) def test_adb_install_apk_filesystem(self): self._parser.parse.return_value = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ adb_request: { install_apk: { filesystem: { path: "/my/favorite/dir/my_apk.apk" } } }""") ]) self._parser.parse.assert_called_once_with( adb_pb2.AdbRequest( install_apk=adb_pb2.AdbRequest.InstallApk( filesystem=adb_pb2.AdbRequest.InstallApk.Filesystem( path='/my/favorite/dir/my_apk.apk')))) def test_adb_force_stop(self): self._parser.parse.return_value = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ adb_request: { force_stop: { package_name: "my.app.Activity" } }""") ]) self._parser.parse.assert_called_once_with( adb_pb2.AdbRequest( force_stop=adb_pb2.AdbRequest.ForceStop( package_name='my.app.Activity'))) def test_adb_start_activity(self): self._parser.parse.return_value = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ adb_request: { start_activity: { full_activity: "my.app.Activity" extra_args: "arg1" extra_args: "arg2" } }""") ]) self._parser.parse.assert_called_once_with( adb_pb2.AdbRequest( start_activity=adb_pb2.AdbRequest.StartActivity( full_activity='my.app.Activity', extra_args=['arg1', 'arg2']))) def test_adb_single_tap(self): self._parser.parse.return_value = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto(task_pb2.SetupStep, """ adb_request: { tap: { x: 321 y: 654 } }""") ]) self._parser.parse.assert_called_once_with( adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=321, y=654))) def test_adb_press_button(self): self._parser.parse.return_value = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto(task_pb2.SetupStep, """ adb_request: { press_button: { button: HOME } }""") ]) self._parser.parse.assert_called_once_with( adb_pb2.AdbRequest( press_button=adb_pb2.AdbRequest.PressButton( button=adb_pb2.AdbRequest.PressButton.Button.HOME))) self._parser.reset_mock() interpreter.interpret([ _to_proto(task_pb2.SetupStep, """ adb_request: { press_button: { button: BACK } }""") ]) self._parser.parse.assert_called_once_with( adb_pb2.AdbRequest( press_button=adb_pb2.AdbRequest.PressButton( button=adb_pb2.AdbRequest.PressButton.Button.BACK))) def test_adb_start_screen_pinning(self): self._parser.parse.return_value = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ adb_request: { start_screen_pinning: { full_activity: "my.app.HighlanderApp" # "There can be only one". } }""") ]) self._parser.parse.assert_called_once_with( adb_pb2.AdbRequest( start_screen_pinning=adb_pb2.AdbRequest.StartScreenPinning( full_activity='my.app.HighlanderApp'))) @mock.patch('time.sleep') def test_time_sleep(self, mock_sleep): interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret( [_to_proto(task_pb2.SetupStep, """sleep: { time_sec: 0.875 }""")]) assert mock_sleep.call_count == 2 mock_sleep.assert_has_calls([mock.call(0.875), mock.call(0.5)]) @mock.patch('time.sleep') def test_wait_for_app_screen_empty_activity(self, unused_mock_sleep): interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) with self.assertRaises(errors.StepCommandError): interpreter.interpret([ _to_proto(task_pb2.SetupStep, """success_condition: {wait_for_app_screen: { }}""") ]) @mock.patch('time.sleep') def test_check_install_not_installed(self, unused_mock_sleep): self._parser.parse.return_value = adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ 'com.some.package', 'not.what.you.are.looking.for', ]))) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) with self.assertRaises(errors.StepCommandError): interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ success_condition: { check_install: { package_name: "faz" timeout_sec: 0.0001 } } """) ]) def test_check_install_installed(self): self._parser.parse.return_value = adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ 'com.some.package', 'baz', ]))) interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) # The test checks that this command raises no AssertionError. interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ success_condition: { check_install: { package_name: "baz" timeout_sec: 0.0001 } }""") ]) def test_num_retries_failure(self): self._parser.parse.side_effect = [ adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List( items=[]))), ] * 3 interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) with self.assertRaises(errors.StepCommandError): interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ success_condition: { check_install: { package_name: "faz" timeout_sec: 0.0001 } num_retries: 3 }""") ]) # We retried 3 times after the first call, so we expect 3+1 calls. self.assertEqual(self._parser.parse.call_count, 3) @mock.patch('time.sleep') def test_num_retries_success(self, unused_mock_sleep): self._parser.parse.side_effect = [ adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List( items=[]))), adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List( items=[]))), adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ 'com.some.package', 'bar', ]))), adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[]))) ] interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ success_condition: { check_install: { package_name: "bar" timeout_sec: 0.0001 } num_retries: 5 }""") ]) # The check should succeed on the third try. self.assertEqual(self._parser.parse.call_count, 3) def test_retry_step(self): self._parser.parse.side_effect = [ adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List( items=[]))), adb_pb2.AdbResponse( package_manager=adb_pb2.AdbResponse.PackageManagerResponse( list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[ 'com.some.package', 'bar', ]))), ] interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=self._parser) interpreter.interpret([ _to_proto( task_pb2.SetupStep, """ success_condition: { check_install: { package_name: "bar" timeout_sec: 0.0001 } num_retries: 2 }""") ]) # We expect the check to fail once and succeed on the second pass. self.assertEqual(self._parser.parse.call_count, 2) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/simulators/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/components/simulators/base_simulator.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A base class for talking to different types of Android simulators.""" import abc from collections.abc import Callable import threading import time from absl import logging from android_env.components import adb_controller from android_env.components import config_classes from android_env.components import errors from android_env.components import log_stream from android_env.proto import state_pb2 import numpy as np class BaseSimulator(metaclass=abc.ABCMeta): """An interface for communicating with an Android simulator.""" def __init__(self, config: config_classes.SimulatorConfig): """Instantiates a BaseSimulator object. The simulator may be an emulator, virtual machine or even a physical device. Each simulator has its own AdbController that is used for internal bookkeeping. Args: config: Settings for this simulator. """ self._config = config self._interaction_thread: InteractionThread | None = None # An increasing number that tracks the attempt at launching the simulator. self._num_launch_attempts: int = 0 def get_logs(self) -> str: """Returns logs recorded by the simulator (if provided).""" return 'No simulator logs provided.' @abc.abstractmethod def adb_device_name(self) -> str: """Returns the device name that the adb client will connect to.""" @abc.abstractmethod def create_adb_controller(self) -> adb_controller.AdbController: """Returns an ADB controller which can communicate with this simulator.""" @abc.abstractmethod def create_log_stream(self) -> log_stream.LogStream: """Creates a stream of logs from the simulator.""" def launch(self) -> None: """Starts the simulator.""" # Stop screenshot thread if it's enabled. if self._interaction_thread is not None: self._interaction_thread.stop() self._interaction_thread.join() self._num_launch_attempts += 1 try: self._launch_impl() except Exception as error: for line in self.get_logs().splitlines(): logging.error(line) raise errors.SimulatorError( 'Exception caught in simulator. Please see the simulator logs ' 'above for more details.' ) from error # Start interaction thread. if self._config.interaction_rate_sec > 0: self._interaction_thread = InteractionThread( self._get_screenshot_impl, self._config.interaction_rate_sec ) self._interaction_thread.start() @abc.abstractmethod def _launch_impl(self) -> None: """Platform specific launch implementation.""" @abc.abstractmethod def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None: """Sends a touch event to be executed on the simulator. Args: touches: A list of touch events. Each element in the list corresponds to a single touch event. Each touch event tuple should have: 0 x: The horizontal coordinate of this event. 1 y: The vertical coordinate of this event. 2 is_down: Whether the finger is touching or not the screen. 3 identifier: Identifies a particular finger in a multitouch event. """ @abc.abstractmethod def send_key(self, keycode: np.int32, event_type: str) -> None: """Sends a keyboard event. Args: keycode: Represents a specific keyboard key. This is platform and simulator-specific. event_type: Type of key event to be sent. """ def load_state( self, request: state_pb2.LoadStateRequest ) -> state_pb2.LoadStateResponse: """Loads a state. Args: request: A `LoadStateRequest` containing any parameters necessary to specify how/what state to load. Returns: A `LoadStateResponse` containing the status, error message (if applicable), and any other relevant information. """ raise NotImplementedError('This simulator does not support load_state()') def save_state( self, request: state_pb2.SaveStateRequest ) -> state_pb2.SaveStateResponse: """Saves a state. Args: request: A `SaveStateRequest` containing any parameters necessary to specify how/what state to save. Returns: A `SaveStateResponse` containing the status, error message (if applicable), and any other relevant information. """ raise NotImplementedError('This simulator does not support save_state()') def get_screenshot(self) -> np.ndarray: """Returns pixels representing the current screenshot of the simulator.""" if self._config.interaction_rate_sec > 0: assert self._interaction_thread is not None return self._interaction_thread.screenshot() # Async mode. else: return self._get_screenshot_impl() # Sync mode. @abc.abstractmethod def _get_screenshot_impl(self) -> np.ndarray: """Actual implementation of `get_screenshot()`. The output numpy array should have shape [height, width, num_channels] and can be loaded into PIL using Image.fromarray(img, mode='RGB') and be saved as a PNG file using my_pil.save('/tmp/my_screenshot.png', 'PNG'). """ def close(self): """Frees up resources allocated by this object.""" if self._interaction_thread is not None: self._interaction_thread.stop() self._interaction_thread.join() class InteractionThread(threading.Thread): """A thread that gets screenshot in the background.""" def __init__( self, get_screenshot_fn: Callable[[], np.ndarray], interaction_rate_sec: float, ): super().__init__() self._get_screenshot_fn = get_screenshot_fn self._interaction_rate_sec = interaction_rate_sec self._should_stop = threading.Event() self._screenshot = self._get_screenshot_fn() def run(self): last_read = time.time() while not self._should_stop.is_set(): self._screenshot = self._get_screenshot_fn() now = time.time() elapsed = now - last_read last_read = now sleep_time = self._interaction_rate_sec - elapsed if sleep_time > 0.0: time.sleep(sleep_time) logging.info('InteractionThread.run() finished.') def stop(self): logging.info('Stopping InteractionThread.') self._should_stop.set() def screenshot(self) -> np.ndarray: return self._screenshot ================================================ FILE: android_env/components/simulators/base_simulator_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import time from unittest import mock from absl.testing import absltest from android_env.components import config_classes from android_env.components import errors # fake_simulator.FakeSimulator inherits from BaseSimulator, so there's no need # to import it here explicitly. from android_env.components.simulators import base_simulator from android_env.components.simulators.fake import fake_simulator import numpy as np class BaseSimulatorTest(absltest.TestCase): def test_launch(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480)) ) # The simulator should launch and not crash. simulator.launch() def test_launch_close(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig() ) # The simulator should launch and not crash. simulator.launch() # Closing the simulator should also not crash. simulator.close() def test_get_screenshot(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480)) ) # The simulator should launch and not crash. simulator.launch() screenshot = simulator.get_screenshot() np.testing.assert_equal(screenshot.shape, [640, 480, 3]) def test_print_logs_on_exception(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig() ) with mock.patch.object( simulator, 'get_logs' ) as mock_get_logs, mock.patch.object( simulator, '_launch_impl', autospec=True ) as mock_launch: mock_launch.side_effect = ValueError('Oh no!') self.assertRaises(errors.SimulatorError, simulator.launch) mock_get_logs.assert_called_once() def test_get_screenshot_error_async(self): """An exception in the underlying interaction thread should bubble up.""" # Arrange. mock_interaction_thread = mock.create_autospec( base_simulator.InteractionThread ) mock_interaction_thread.screenshot.side_effect = ( errors.ReadObservationError() ) simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(interaction_rate_sec=0.5) ) with mock.patch.object( base_simulator, 'InteractionThread', autospec=True, return_value=mock_interaction_thread, ): simulator.launch() # Act & Assert. self.assertRaises(errors.ReadObservationError, simulator.get_screenshot) # Cleanup. simulator.close() def test_get_screenshot_faster_than_screenshot_impl(self): """Return same screenshot when step is faster than the interaction rate.""" # Arrange. slow_rate = 0.5 simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(interaction_rate_sec=slow_rate) ) # Act. with mock.patch.object( simulator, '_get_screenshot_impl', autospec=True ) as mock_get_screenshot_impl: mock_get_screenshot_impl.side_effect = ( np.array(i, ndmin=3) for i in itertools.count(0, 1) ) simulator.launch() # Get two screenshots one after the other without pausing. screenshot1 = simulator.get_screenshot() screenshot2 = simulator.get_screenshot() # Assert. self.assertAlmostEqual(screenshot1[0][0][0], screenshot2[0][0][0]) # Cleanup. simulator.close() def test_get_screenshot_slower_than_screenshot_impl(self): """Return different screenshots when step slower than the interaction rate.""" # Arrange. fast_rate = 0.01 simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(interaction_rate_sec=fast_rate) ) # Act. with mock.patch.object( simulator, '_get_screenshot_impl', autospec=True ) as mock_get_screenshot_impl: mock_get_screenshot_impl.side_effect = ( np.array(i, ndmin=3) for i in itertools.count(0, 1) ) simulator.launch() # Sleep for 500ms between two screenshots. screenshot1 = simulator.get_screenshot() time.sleep(0.5) screenshot2 = simulator.get_screenshot() # Assert. self.assertNotEqual(screenshot1[0][0][0], screenshot2[0][0][0]) # Cleanup. simulator.close() def test_interaction_thread_closes_upon_relaunch(self): """Async interaction should kill the InteractionThread when relaunching.""" # Arrange. simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(interaction_rate_sec=0.01) ) mock_interaction_thread = mock.create_autospec( base_simulator.InteractionThread ) # Act & Assert. with mock.patch.object( base_simulator, 'InteractionThread', autospec=True, return_value=mock_interaction_thread, ): simulator.launch() mock_interaction_thread.stop.assert_not_called() mock_interaction_thread.join.assert_not_called() simulator.launch() mock_interaction_thread.stop.assert_called_once() mock_interaction_thread.join.assert_called_once() simulator.close() if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/simulators/emulator/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/components/simulators/emulator/emulator_launcher.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Prepares and launches an emulator process.""" import glob import os import subprocess import tempfile from absl import logging from android_env.components import config_classes class EmulatorLauncher: """Handles launching an emulator.""" def __init__( self, config: config_classes.EmulatorLauncherConfig, adb_controller_config: config_classes.AdbControllerConfig, ): """Launches an emulator.""" self._config = config self._adb_controller_config = adb_controller_config self._emulator = None self._emulator_output = None self._is_closed = False # Create directory for tmp files. # Note: this will be deleted once EmulatorLauncher instance is cleaned up. os.makedirs(config.tmp_dir, exist_ok=True) self._local_tmp_dir_handle = tempfile.TemporaryDirectory( dir=config.tmp_dir, prefix='simulator_instance_' ) self._local_tmp_dir = self._local_tmp_dir_handle.name self._logfile_path = os.path.join(self._local_tmp_dir, 'emulator_output') logging.info('Simulator local_tmp_dir: %s', self._local_tmp_dir) def logfile_path(self) -> str: return self._logfile_path def launch_emulator_process(self) -> None: """Launches the emulator.""" logging.info('Booting new emulator: %s', self._config.emulator_path) # Set necessary environment variables. base_lib_dir = self._config.emulator_path[:-8] + 'lib64/' ld_library_path = ':'.join([ base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/', base_lib_dir + 'gles_swiftshader/', base_lib_dir ]) extra_env_vars = { 'ANDROID_HOME': '', 'ANDROID_SDK_ROOT': self._config.android_sdk_root, 'ANDROID_AVD_HOME': self._config.android_avd_home, 'ANDROID_EMULATOR_KVM_DEVICE': self._config.kvm_device, 'ANDROID_ADB_SERVER_PORT': str( self._adb_controller_config.adb_server_port ), 'LD_LIBRARY_PATH': ld_library_path, 'QT_XKB_CONFIG_ROOT': str( self._config.emulator_path[:-8] + 'qt_config/' ), 'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1', 'SHOW_PERF_STATS': str(1 if self._config.show_perf_stats else 0), } logging.info('extra_env_vars: %s', ' '.join(f'{k}={v}' for k, v in extra_env_vars.items())) env_vars = dict(os.environ).copy() env_vars.update(extra_env_vars) # Compile command. grpc_port = ( ['-grpc', str(self._config.grpc_port)] if self._config.grpc_port >= 0 else [] ) run_headless = ( ['-no-skin', '-no-window'] if self._config.run_headless else [] ) ports = [ '-ports', '%s,%s' % (self._config.emulator_console_port, self._config.adb_port), ] snapshot = [ '-snapshot', self._config.snapshot_name, '-feature', 'AllowSnapshotMigration,MigratableSnapshotSave', ] snapshot = snapshot if self._config.snapshot_name else ['-no-snapshot'] restrict_network_args = [ '-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options', 'restrict=y' ] network_args = ( restrict_network_args if self._config.restrict_network else [] ) command = ( [ self._config.emulator_path, '-adb-path', self._adb_controller_config.adb_path, '-gpu', self._config.gpu_mode, '-no-audio', '-show-kernel', '-verbose', '-avd', self._config.avd_name, ] + grpc_port + run_headless + ports + snapshot + network_args ) logging.info('Emulator launch command: %s', ' '.join(command)) # Prepare logfile. self._emulator_output = open(self._logfile_path, 'wb') # Spawn the emulator process. self._emulator = subprocess.Popen( command, env=env_vars, stdout=self._emulator_output, stderr=self._emulator_output) def confirm_shutdown(self) -> None: """Shuts down the emulator process.""" if self._emulator is not None: logging.info('Checking if emulator process has finished...') try: self._emulator.wait(timeout=30.0) except subprocess.TimeoutExpired: logging.exception( 'The emulator process did not finish after 30s. ' 'returncode: %s. Will now try to kill() it.', self._emulator.returncode) self._emulator.kill() self._emulator = None self._emulator_output.close() logging.info('The emulator process has finished.') def close(self): """Clean up launcher files and processes.""" if not self._is_closed: self._local_tmp_dir_handle.cleanup() self.confirm_shutdown() self._is_closed = True def __del__(self): self.close() ================================================ FILE: android_env/components/simulators/emulator/emulator_launcher_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.emulator_launcher.""" import builtins import os import subprocess import tempfile from unittest import mock from absl.testing import absltest from absl.testing import parameterized from android_env.components import config_classes from android_env.components.simulators.emulator import emulator_launcher class EmulatorLauncherTest(parameterized.TestCase): def setUp(self): super().setUp() self._emulator_path = 'fake/path/emulator' self._adb_path = 'fake/path/adb' self._adb_port = 5554 self._adb_server_port = 1234 self._emulator_console_port = 5555 self._avd_name = 'my_avd_name' self._expected_command = [ self._emulator_path, '-adb-path', 'fake/path/adb', '-gpu', 'swangle_indirect', '-no-audio', '-show-kernel', '-verbose', '-avd', self._avd_name, ] self._headless = ['-no-skin', '-no-window'] self._ports = ['-ports', f'{self._emulator_console_port},{self._adb_port}'] self._snapshot = ['-no-snapshot'] base_lib_dir = self._emulator_path[:-8] + 'lib64/' ld_library_path = ':'.join([ base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/', base_lib_dir + 'gles_swiftshader/', base_lib_dir ]) # Instantiate the config to extract default values. config = config_classes.EmulatorLauncherConfig() self._expected_env_vars = { 'ANDROID_HOME': '', 'ANDROID_SDK_ROOT': config.android_sdk_root, 'ANDROID_AVD_HOME': config.android_avd_home, 'ANDROID_EMULATOR_KVM_DEVICE': '/dev/kvm', 'ANDROID_ADB_SERVER_PORT': '1234', 'LD_LIBRARY_PATH': ld_library_path, 'QT_XKB_CONFIG_ROOT': str(self._emulator_path[:-8] + 'qt_config/'), 'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1', } @parameterized.named_parameters([ ('hide_perf_stats', False), ('show_perf_stats', True), ]) @mock.patch.object(os, 'makedirs') @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) def test_launch( self, show_perf_stats: bool, mock_tmp_dir, unused_os_environ, unused_os_makedirs, ): mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' config = config_classes.EmulatorLauncherConfig( adb_port=self._adb_port, emulator_console_port=self._emulator_console_port, emulator_path=self._emulator_path, avd_name=self._avd_name, grpc_port=-1, show_perf_stats=show_perf_stats, ) adb_controller_config = config_classes.AdbControllerConfig( adb_path=self._adb_path, adb_server_port=self._adb_server_port, ) launcher = emulator_launcher.EmulatorLauncher( config=config, adb_controller_config=adb_controller_config ) expected_env_vars = self._expected_env_vars expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' with mock.patch.object( subprocess, 'Popen', autospec=True ) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f: f.return_value.__enter__ = f() launcher.launch_emulator_process() emulator_init.assert_called_once_with( args=self._expected_command + self._headless + self._ports + self._snapshot, env=expected_env_vars, stdout=f(), stderr=f(), ) @parameterized.named_parameters([ ('hide_perf_stats', False), ('show_perf_stats', True), ]) @mock.patch.object(os, 'makedirs') @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) def test_grpc_port( self, show_perf_stats: bool, mock_tmp_dir, unused_os_environ, unused_os_makedirs, ): mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' config = config_classes.EmulatorLauncherConfig( adb_port=self._adb_port, emulator_console_port=self._emulator_console_port, emulator_path=self._emulator_path, avd_name=self._avd_name, grpc_port=8554, show_perf_stats=show_perf_stats, ) adb_controller_config = config_classes.AdbControllerConfig( adb_path=self._adb_path, adb_server_port=self._adb_server_port, ) launcher = emulator_launcher.EmulatorLauncher( config=config, adb_controller_config=adb_controller_config ) expected_env_vars = self._expected_env_vars expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' with mock.patch.object( subprocess, 'Popen', autospec=True ) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f: f.return_value.__enter__ = f() launcher.launch_emulator_process() emulator_init.assert_called_once_with( args=self._expected_command + ['-grpc', '8554'] + self._headless + self._ports + self._snapshot, env=expected_env_vars, stdout=f(), stderr=f(), ) @parameterized.named_parameters([ ('hide_perf_stats', False), ('show_perf_stats', True), ]) @mock.patch.object(os, 'makedirs') @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) def test_snapshot( self, show_perf_stats: bool, mock_tmp_dir, unused_os_environ, unused_os_makedirs, ): mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' config = config_classes.EmulatorLauncherConfig( adb_port=self._adb_port, emulator_console_port=self._emulator_console_port, emulator_path=self._emulator_path, avd_name=self._avd_name, grpc_port=-1, snapshot_name='my_snapshot', show_perf_stats=show_perf_stats, ) adb_controller_config = config_classes.AdbControllerConfig( adb_path=self._adb_path, adb_server_port=self._adb_server_port, ) launcher = emulator_launcher.EmulatorLauncher( config=config, adb_controller_config=adb_controller_config ) expected_snapshot = [ '-snapshot', 'my_snapshot', '-feature', 'AllowSnapshotMigration,MigratableSnapshotSave' ] expected_env_vars = self._expected_env_vars expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' with mock.patch.object( subprocess, 'Popen', autospec=True) as emulator_init, \ mock.patch.object(builtins, 'open', autospec=True) as f: f.return_value.__enter__ = f() launcher.launch_emulator_process() emulator_init.assert_called_once_with( args=self._expected_command + self._headless + self._ports + expected_snapshot, env=expected_env_vars, stdout=f(), stderr=f(), ) @parameterized.named_parameters([ ('hide_perf_stats', False), ('show_perf_stats', True), ]) @mock.patch.object(os, 'makedirs') @mock.patch.object(os, 'environ', autospec=True, return_value=dict()) @mock.patch.object(tempfile, 'TemporaryDirectory', instance=True) def test_network_restrict( self, show_perf_stats: bool, mock_tmp_dir, unused_os_environ, unused_os_makedirs, ): mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir' config = config_classes.EmulatorLauncherConfig( adb_port=self._adb_port, emulator_console_port=self._emulator_console_port, emulator_path=self._emulator_path, avd_name=self._avd_name, grpc_port=-1, restrict_network=True, show_perf_stats=show_perf_stats, ) adb_controller_config = config_classes.AdbControllerConfig( adb_path=self._adb_path, adb_server_port=self._adb_server_port, ) launcher = emulator_launcher.EmulatorLauncher( config=config, adb_controller_config=adb_controller_config ) expected_snapshot = ['-no-snapshot'] expected_network_restrict = [ '-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options', 'restrict=y' ] expected_env_vars = self._expected_env_vars expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0' with mock.patch.object( subprocess, 'Popen', autospec=True) as emulator_init, \ mock.patch.object(builtins, 'open', autospec=True) as f: f.return_value.__enter__ = f() launcher.launch_emulator_process() emulator_init.assert_called_once_with( self._expected_command + self._headless + self._ports + expected_snapshot + expected_network_restrict, env=expected_env_vars, stdout=f(), stderr=f(), ) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/simulators/emulator/emulator_simulator.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """A class that manages an Android Emulator.""" import os import time from typing import Any from absl import logging from android_env.components import adb_controller from android_env.components import adb_log_stream from android_env.components import config_classes from android_env.components import errors from android_env.components import log_stream from android_env.components.simulators import base_simulator from android_env.components.simulators.emulator import emulator_launcher from android_env.proto import state_pb2 import grpc import numpy as np import portpicker from android_env.proto import emulator_controller_pb2 from android_env.proto import emulator_controller_pb2_grpc from android_env.proto import snapshot_service_pb2 from android_env.proto import snapshot_service_pb2_grpc from google.protobuf import empty_pb2 _DEFAULT_SNAPSHOT_NAME: str = 'default_snapshot' def _is_existing_emulator_provided( launcher_config: config_classes.EmulatorLauncherConfig, ) -> bool: """Returns true if all necessary args were provided.""" return bool( launcher_config.adb_port and launcher_config.emulator_console_port and launcher_config.grpc_port ) def _pick_adb_port() -> int: """Tries to pick a port in the recommended range 5555-5585. If no such port can be found, will return a random unused port. More info: https://developer.android.com/studio/command-line/adb#howadbworks. Returns: port: an available port for adb. """ for p in range(5555, 5587, 2): if portpicker.is_port_free(p): return p return portpicker.pick_unused_port() def _pick_emulator_grpc_port() -> int: """Tries to pick the recommended port for grpc. If no such port can be found, will return a random unused port. More info: https://android.googlesource.com/platform/external/qemu/+/emu-master-dev/android/android-grpc/docs/. Returns: port: an available port for emulator grpc. """ if portpicker.is_port_free(8554): return 8554 else: return portpicker.pick_unused_port() class EmulatorBootError(errors.SimulatorError): """Raised when an emulator failed to boot.""" class EmulatorCrashError(errors.SimulatorError): """Raised when a simulator crashed.""" class EmulatorSimulator(base_simulator.BaseSimulator): """Controls an Android Emulator.""" def __init__(self, config: config_classes.EmulatorConfig): """Instantiates an EmulatorSimulator.""" super().__init__(config) self._config = config # If adb_port, console_port and grpc_port are all already provided, # we assume the emulator already exists and there's no need to launch. if _is_existing_emulator_provided(self._config.emulator_launcher): self._existing_emulator_provided = True logging.info('Connecting to existing emulator "%r"', self.adb_device_name()) else: self._existing_emulator_provided = False self._config.emulator_launcher.adb_port = _pick_adb_port() self._config.emulator_launcher.emulator_console_port = ( portpicker.pick_unused_port() ) self._config.emulator_launcher.grpc_port = _pick_emulator_grpc_port() self._channel = None self._emulator_stub: emulator_controller_pb2_grpc.EmulatorControllerStub | None = ( None ) self._snapshot_stub = None # Set the image format to RGBA. The width and height of the returned # screenshots will use the device's width and height. self._image_format = emulator_controller_pb2.ImageFormat( format=emulator_controller_pb2.ImageFormat.ImgFormat.RGBA8888) if ( self._config.launch_n_times_without_reboot > self._config.launch_n_times_without_reinstall ): raise ValueError( 'Number of launch attempts before reboot' f' ({self._config.launch_n_times_without_reboot}) should not be' ' greater than number of launch attempts before reinstall' f' ({self._config.launch_n_times_without_reinstall})' ) # Initialize own ADB controller. self._config.adb_controller.device_name = self.adb_device_name() self._adb_controller = self.create_adb_controller() self._adb_controller.init_server() logging.info( 'Initialized simulator with ADB server port %r.', self._config.adb_controller.adb_server_port, ) # If necessary, create EmulatorLauncher. if self._existing_emulator_provided: self._logfile_path = self._config.logfile_path or None self._launcher = None else: logging.info( 'emulator_launcher config: %r', self._config.emulator_launcher ) self._launcher = emulator_launcher.EmulatorLauncher( config=self._config.emulator_launcher, adb_controller_config=self._config.adb_controller, ) self._logfile_path = ( self._config.logfile_path or self._launcher.logfile_path() ) def _reconnect_on_grpc_error(func): """Decorator function for reconnecting to emulator upon grpc errors.""" def wrapper(self, *args, **kwargs): try: return func(self, *args, **kwargs) except grpc.RpcError: logging.exception('RpcError caught. Reconnecting to emulator...') self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( self._config.emulator_launcher.grpc_port ) return func(self, *args, **kwargs) return wrapper def get_logs(self) -> str: """Returns logs recorded by the emulator.""" if self._logfile_path and os.path.exists(self._logfile_path): with open(self._logfile_path, 'rb') as f: return f.read().decode('utf-8') else: return f'Logfile does not exist: {self._logfile_path}.' def adb_device_name(self) -> str: return 'emulator-%s' % (self._config.emulator_launcher.adb_port - 1) def create_adb_controller(self): """Returns an ADB controller which can communicate with this simulator.""" return adb_controller.AdbController(self._config.adb_controller) def create_log_stream(self) -> log_stream.LogStream: return adb_log_stream.AdbLogStream( adb_command_prefix=self._adb_controller.command_prefix(), verbose=self._config.verbose_logs, ) def _launch_impl(self) -> None: """Prepares an Android Emulator for RL interaction. The behavior depends on `self._num_launch_attempts`'s value: * <= self._config.launch_n_times_without_reboot -> Normal boot behavior. * > self._config.launch_n_times_without_reboot but <= self._config.launch_n_times_without_reinstall -> reboot (i.e. process is killed and started again). * > self._config.launch_n_times_without_reinstall -> reinstall (i.e. process is killed, emulator files are deleted and the process started again). """ logging.info('Attempt %r at launching the Android Emulator (%r)', self._num_launch_attempts, self.adb_device_name()) if self._launcher is not None: # If not the first time, then shutdown the emulator first. if ( self._emulator_stub is not None and self._num_launch_attempts > self._config.launch_n_times_without_reboot ): self._shutdown_emulator() # Subsequent attempts cause the emulator files to be reinstalled. if ( self._num_launch_attempts > self._config.launch_n_times_without_reinstall ): logging.info('Closing emulator (%r)', self.adb_device_name()) self._launcher.close() self._launcher = emulator_launcher.EmulatorLauncher( config=self._config.emulator_launcher, adb_controller_config=self._config.adb_controller, ) self._launcher.launch_emulator_process() # Establish grpc connection to emulator process. self._emulator_stub, self._snapshot_stub = self._connect_to_emulator( self._config.emulator_launcher.grpc_port ) # Confirm booted status. try: self._confirm_booted() except EmulatorCrashError: logging.exception('Failed to confirm booted status of emulator.') logging.info('Done booting the Android Emulator.') def load_state( self, request: state_pb2.LoadStateRequest ) -> state_pb2.LoadStateResponse: """Loads a state using the emulator's snapshotting mechanism. Args: request: The `LoadStateRequest`. In this case, `args` should be a dict containing the key 'snapshot_name', representing the name of the snapshot to load. If `request.args.snapshot_name` is `None`, a default snapshot name is used. Returns: A response indicating whether the snapshot was successfully loaded. * If the snapshot was loaded successfully, the status will be `OK`. * If no snapshot of the given name was found, the status will be `NOT_FOUND`. * If an error occurred during the snapshot loading process, the status will be `ERROR` and the `error_message` field will be filled. """ assert self._snapshot_stub is not None snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME) snapshot_list = self._snapshot_stub.ListSnapshots( snapshot_service_pb2.SnapshotFilter( statusFilter=snapshot_service_pb2.SnapshotFilter.LoadStatus.All ) ) if any( snapshot.snapshot_id == snapshot_name for snapshot in snapshot_list.snapshots ): snapshot_result = self._snapshot_stub.LoadSnapshot( snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name) ) if snapshot_result.success: return state_pb2.LoadStateResponse( status=state_pb2.LoadStateResponse.Status.OK ) else: return state_pb2.LoadStateResponse( status=state_pb2.LoadStateResponse.Status.ERROR, error_message=snapshot_result.err.decode('utf-8'), ) else: return state_pb2.LoadStateResponse( status=state_pb2.LoadStateResponse.Status.NOT_FOUND ) def save_state( self, request: state_pb2.SaveStateRequest ) -> state_pb2.SaveStateResponse: """Saves a state using the emulator's snapshotting mechanism. Args: request: The `SaveStateRequest`. In this case, `args` should be a dict containing the key 'snapshot_name', representing the name of the snapshot to save. If `request.args.snapshot_name` is `None`, a default snapshot name is used. Returns: A response indicating whether the snapshot was successfully saved. * If the snapshot was saved successfully, the status will be `OK`. * If an error occurred during the snapshot saving process, the status will be `ERROR` and the `error_message` field will be filled. """ assert self._snapshot_stub is not None snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME) snapshot_result = self._snapshot_stub.SaveSnapshot( snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name) ) if snapshot_result.success: return state_pb2.SaveStateResponse( status=state_pb2.SaveStateResponse.Status.OK ) else: return state_pb2.SaveStateResponse( status=state_pb2.SaveStateResponse.Status.ERROR, error_message=snapshot_result.err.decode('utf-8'), ) def _connect_to_emulator( self, grpc_port: int, timeout_sec: int = 100, ) -> tuple[ emulator_controller_pb2_grpc.EmulatorControllerStub, snapshot_service_pb2_grpc.SnapshotServiceStub, ]: """Connects to an emulator and returns a corresponsing stub.""" logging.info('Creating gRPC channel to the emulator on port %r', grpc_port) port = f'localhost:{grpc_port}' options = [('grpc.max_send_message_length', -1), ('grpc.max_receive_message_length', -1)] creds = grpc.local_channel_credentials() try: self._channel = grpc.secure_channel(port, creds, options=options) grpc.channel_ready_future(self._channel).result(timeout=timeout_sec) except (grpc.RpcError, grpc.FutureTimeoutError) as grpc_error: logging.exception('Failed to connect to the emulator.') raise EmulatorBootError( 'Failed to connect to the emulator.') from grpc_error logging.info('Added gRPC channel for the Emulator on port %s', port) emulator_controller_stub = ( emulator_controller_pb2_grpc.EmulatorControllerStub(self._channel) ) snapshot_stub = snapshot_service_pb2_grpc.SnapshotServiceStub(self._channel) return emulator_controller_stub, snapshot_stub @_reconnect_on_grpc_error def _confirm_booted(self, startup_wait_time_sec: int = 300): """Waits until the emulator is fully booted.""" assert ( self._emulator_stub is not None ), 'Emulator stub has not been initialized yet.' start_time = time.time() deadline = start_time + startup_wait_time_sec success = False while time.time() < deadline: emu_status = self._emulator_stub.getStatus(empty_pb2.Empty()) logging.info('Waiting for emulator (%r) to start... (%rms)', self.adb_device_name(), emu_status.uptime) if emu_status.booted: success = True break time.sleep(5.0) elapsed_time = time.time() - start_time if not success: raise EmulatorCrashError( f'The emulator failed to boot after {startup_wait_time_sec} seconds') logging.info('Done booting the emulator (in %f seconds).', elapsed_time) logging.info('********** Emulator logs **********') for line in self.get_logs().splitlines(): logging.info(line) logging.info('******* End of emulator logs *******') logging.info('See the full emulator logs at %r', self._logfile_path) @_reconnect_on_grpc_error def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None: """Sends a touch event to be executed on the simulator. Args: touches: A list of touch events. Each element in the list corresponds to a single touch event. Each touch event tuple should have: 0 x: The horizontal coordinate of this event. 1 y: The vertical coordinate of this event. 2 is_down: Whether the finger is touching or not the screen. 3 identifier: Identifies a particular finger in a multitouch event. """ assert ( self._emulator_stub is not None ), 'Emulator stub has not been initialized yet.' touch_events = [ emulator_controller_pb2.Touch( x=t[0], y=t[1], pressure=int(t[2]), identifier=t[3]) for t in touches ] self._emulator_stub.sendTouch( emulator_controller_pb2.TouchEvent(touches=touch_events)) @_reconnect_on_grpc_error def send_key(self, keycode: np.int32, event_type: str) -> None: """Sends a key event to the emulator. Args: keycode: Code representing the desired key press in XKB format. See the emulator_controller_pb2 for details. event_type: Type of key event to be sent. """ event_types = emulator_controller_pb2.KeyboardEvent.KeyEventType.keys() if event_type not in event_types: raise ValueError( f'Event type must be one of {event_types} but is {event_type}.') assert ( self._emulator_stub is not None ), 'Emulator stub has not been initialized yet.' self._emulator_stub.sendKey( emulator_controller_pb2.KeyboardEvent( codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType.Value( event_type ), keyCode=int(keycode), ) ) @_reconnect_on_grpc_error def _get_screenshot_impl(self) -> np.ndarray: """Fetches the latest screenshot from the emulator.""" assert ( self._emulator_stub is not None ), 'Emulator stub has not been initialized yet.' assert self._image_format, 'ImageFormat has not been initialized yet.' image_proto = self._emulator_stub.getScreenshot(self._image_format) h, w = image_proto.format.height, image_proto.format.width image = np.frombuffer(image_proto.image, dtype='uint8', count=h * w * 4) image.shape = (h, w, 4) return image[:, :, :3] @_reconnect_on_grpc_error def _shutdown_emulator(self): """Sends a signal to trigger emulator shutdown.""" if self._emulator_stub is None: logging.info('Emulator (%r) is not up.', self.adb_device_name()) return assert self._launcher is not None, 'Launcher is already down.' logging.info('Shutting down the emulator (%r)...', self.adb_device_name()) self._emulator_stub.setVmState( emulator_controller_pb2.VmRunState( state=emulator_controller_pb2.VmRunState.RunState.SHUTDOWN)) self._launcher.confirm_shutdown() def close(self): super().close() if self._launcher is not None: self._shutdown_emulator() logging.info('Closing emulator (%r)', self.adb_device_name()) self._launcher.close() self._emulator_stub = None self._snapshot_stub = None if self._channel is not None: self._channel.close() super().close() ================================================ FILE: android_env/components/simulators/emulator/emulator_simulator_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.emulator_simulator.""" import builtins import os import time from unittest import mock from absl.testing import absltest from android_env.components import adb_call_parser from android_env.components import adb_controller from android_env.components import config_classes from android_env.components.simulators.emulator import emulator_launcher from android_env.components.simulators.emulator import emulator_simulator from android_env.proto import state_pb2 import grpc from PIL import Image import portpicker from android_env.proto import emulator_controller_pb2 from android_env.proto import emulator_controller_pb2_grpc from android_env.proto import snapshot_service_pb2 class EmulatorSimulatorTest(absltest.TestCase): def setUp(self): super().setUp() self.addCleanup(mock.patch.stopall) # Disable previous patches. self._adb_controller = mock.create_autospec(adb_controller.AdbController) self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser) self._launcher = mock.create_autospec(emulator_launcher.EmulatorLauncher) self._launcher.logfile_path.return_value = 'logfile_path' self._emulator_stub = mock.create_autospec( emulator_controller_pb2_grpc.EmulatorControllerStub) self._grpc_channel = mock.create_autospec(grpc.Channel) mock.patch.object( grpc.aio, 'secure_channel', return_value=self._grpc_channel).start() mock.patch.object( grpc, 'secure_channel', return_value=self._grpc_channel).start() mock.patch.object( grpc, 'local_channel_credentials', return_value=self._grpc_channel).start() self._mock_future = mock.create_autospec(grpc.Future) mock.patch.object( grpc, 'channel_ready_future', return_value=self._mock_future).start() mock.patch.object(time, 'time', return_value=12345).start() mock.patch.object( adb_controller, 'AdbController', return_value=self._adb_controller).start() mock.patch.object( adb_call_parser, 'AdbCallParser', autospec=True, return_value=self._adb_call_parser).start() mock.patch.object( emulator_launcher, 'EmulatorLauncher', return_value=self._launcher).start() def test_adb_device_name_not_empty(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) self.assertNotEmpty(simulator.adb_device_name()) def test_logfile_path(self): """The log file's path should correspond to the one from the config.""" config = config_classes.EmulatorConfig( logfile_path='fake/logfile/path', emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) with mock.patch.object( os.path, 'exists', autospec=True, return_value=True ), mock.patch.object(builtins, 'open', autospec=True) as mock_open: mock_file = mock_open.return_value.__enter__.return_value mock_file.read.return_value = b'fake_logs' logs = simulator.get_logs() mock_open.assert_called_once_with('fake/logfile/path', 'rb') self.assertEqual(logs, 'fake_logs') @mock.patch.object(portpicker, 'is_port_free', return_value=True) def test_grpc_port(self, unused_mock_portpicker): launcher_config = config_classes.EmulatorLauncherConfig( tmp_dir=self.create_tempdir().full_path ) config = config_classes.EmulatorConfig( emulator_launcher=launcher_config, adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) self.assertEqual(launcher_config.grpc_port, 8554) @mock.patch.object(portpicker, 'is_port_free', return_value=False) def test_grpc_port_unavailable(self, unused_mock_portpicker): launcher_config = config_classes.EmulatorLauncherConfig( tmp_dir=self.create_tempdir().full_path ) config = config_classes.EmulatorConfig( emulator_launcher=launcher_config, adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) self.assertNotEqual(launcher_config.grpc_port, 8554) def test_launch_operation_order(self): """Makes sure that adb_controller is started before Emulator is launched.""" # Arrange. call_order = [] self._adb_controller.init_server.side_effect = lambda: call_order.append( 'init_server' ) self._launcher.launch_emulator_process.side_effect = ( lambda: call_order.append('launch_emulator_process') ) config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) # Act. simulator.launch() # The simulator should launch and not crash. # Assert. # The adb server should be initialized before launching the emulator. self.assertEqual(call_order, ['init_server', 'launch_emulator_process']) def test_close(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() # For whatever reason clients may want to close the EmulatorSimulator. # We just want to check that the simulator does not crash and/or leak # resources. simulator.close() def test_value_error_if_launch_attempt_params_incorrect(self): self.assertRaises( ValueError, emulator_simulator.EmulatorSimulator, config=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), launch_n_times_without_reboot=2, launch_n_times_without_reinstall=1, ), ) def test_launch_attempt_reboot(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), launch_n_times_without_reboot=1, launch_n_times_without_reinstall=2, ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() self._launcher.launch_emulator_process.assert_called_once() self._launcher.reset_mock() # Launch attempt 2. simulator.launch() self._launcher.confirm_shutdown.assert_called_once() self._launcher.close.assert_not_called() self._launcher.launch_emulator_process.assert_called_once() def test_launch_attempt_reinstall_after_zero_attempts(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), launch_n_times_without_reboot=0, launch_n_times_without_reinstall=0, ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should not reboot or reinstall on its very first launch. simulator.launch() self._launcher.launch_emulator_process.assert_called_once() self._launcher.confirm_shutdown.assert_not_called() self._launcher.close.assert_not_called() # Every subsequent attempt should reboot and reinstall. self._launcher.reset_mock() simulator.launch() self._launcher.confirm_shutdown.assert_called_once() self._launcher.close.assert_called_once() # Now this should `close()`. self._launcher.launch_emulator_process.assert_called_once() def test_launch_attempt_reinstall(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), launch_n_times_without_reboot=1, launch_n_times_without_reinstall=2, ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() self._launcher.launch_emulator_process.assert_called_once() # Launch attempt 2. self._launcher.reset_mock() simulator.launch() self._launcher.confirm_shutdown.assert_called_once() self._launcher.close.assert_not_called() # Reboots don't `close()`. self._launcher.launch_emulator_process.assert_called_once() # Launch attempt 3. self._launcher.reset_mock() simulator.launch() self._launcher.confirm_shutdown.assert_called_once() self._launcher.close.assert_called_once() # Now this should `close()`. self._launcher.launch_emulator_process.assert_called_once() def test_get_screenshot(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() simulator._emulator_stub.getScreenshot = mock.MagicMock( return_value=emulator_controller_pb2.Image( format=emulator_controller_pb2.ImageFormat(width=5678, height=1234), image=Image.new('RGBA', (1234, 5678)).tobytes(), timestampUs=123)) screenshot = simulator.get_screenshot() # The screenshot should have the same screen dimensions as reported by ADB # and it should have 3 channels (RGB). self.assertEqual(screenshot.shape, (1234, 5678, 3)) def test_load_state(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() with mock.patch.object( simulator, '_snapshot_stub', create_autospec=True ) as mock_snapshot_stub: snapshot_list = snapshot_service_pb2.SnapshotList() snapshot_list.snapshots.add(snapshot_id='snapshot_name_foo') snapshot_list.snapshots.add(snapshot_id='snapshot_name_bar') mock_snapshot_stub.ListSnapshots.return_value = snapshot_list mock_snapshot_stub.LoadSnapshot.return_value = ( snapshot_service_pb2.SnapshotPackage(success=True) ) load_response = simulator.load_state( request=state_pb2.LoadStateRequest( args={'snapshot_name': 'snapshot_name_foo'} ) ) self.assertEqual( load_response.status, state_pb2.LoadStateResponse.Status.OK ) load_response = simulator.load_state( request=state_pb2.LoadStateRequest( args={'snapshot_name': 'snapshot_name_baz'} ) ) self.assertEqual( load_response.status, state_pb2.LoadStateResponse.Status.NOT_FOUND ) mock_snapshot_stub.LoadSnapshot.return_value = ( snapshot_service_pb2.SnapshotPackage(success=False, err=b'error') ) load_response = simulator.load_state( request=state_pb2.LoadStateRequest( args={'snapshot_name': 'snapshot_name_bar'} ) ) self.assertEqual( load_response.status, state_pb2.LoadStateResponse.Status.ERROR ) self.assertEqual(load_response.error_message, 'error') def test_save_state(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() with mock.patch.object( simulator, '_snapshot_stub', create_autospec=True ) as mock_snapshot_stub: mock_snapshot_stub.SaveSnapshot.return_value = ( snapshot_service_pb2.SnapshotPackage(success=True) ) save_response = simulator.save_state( request=state_pb2.SaveStateRequest( args={'snapshot_name': 'snapshot_name_foo'} ) ) self.assertEqual( save_response.status, state_pb2.SaveStateResponse.Status.OK ) mock_snapshot_stub.SaveSnapshot.return_value = ( snapshot_service_pb2.SnapshotPackage(success=False, err=b'error') ) save_response = simulator.save_state( request=state_pb2.SaveStateRequest( args={'snapshot_name': 'snapshot_name_bar'} ) ) self.assertEqual( save_response.status, state_pb2.SaveStateResponse.Status.ERROR ) self.assertEqual(save_response.error_message, 'error') def test_send_touch(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)]) simulator.send_touch([(1, 2, True, 0), (3, 4, True, 1)]) simulator.send_touch([(321, 654, False, 0), (531, 642, False, 1)]) simulator._emulator_stub.sendTouch.assert_has_calls([ mock.call( emulator_controller_pb2.TouchEvent(touches=[{ 'x': 123, 'y': 456, 'pressure': 1 }, { 'x': 135, 'y': 246, 'pressure': 1, 'identifier': 1 }])), mock.call( emulator_controller_pb2.TouchEvent(touches=[{ 'x': 1, 'y': 2, 'pressure': 1 }, { 'x': 3, 'y': 4, 'pressure': 1, 'identifier': 1 }])), mock.call( emulator_controller_pb2.TouchEvent(touches=[{ 'x': 321, 'y': 654, 'pressure': 0 }, { 'x': 531, 'y': 642, 'pressure': 0, 'identifier': 1 }])), ]) def test_send_key(self): config = config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( grpc_port=1234, tmp_dir=self.create_tempdir().full_path ), adb_controller=config_classes.AdbControllerConfig( adb_path='/my/adb', adb_server_port=5037, ), ) simulator = emulator_simulator.EmulatorSimulator(config) # The simulator should launch and not crash. simulator.launch() simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None) simulator.send_key(123, 'keydown') simulator.send_key(321, 'keydown') simulator.send_key(321, 'keyup') simulator.send_key(123, 'keyup') simulator.send_key(321, 'keypress') simulator.send_key(123, 'keypress') simulator._emulator_stub.sendKey.assert_has_calls([ mock.call( emulator_controller_pb2.KeyboardEvent( codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType .keydown, keyCode=123, )), mock.call( emulator_controller_pb2.KeyboardEvent( codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType .keydown, keyCode=321, )), mock.call( emulator_controller_pb2.KeyboardEvent( codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType .keyup, keyCode=321, )), mock.call( emulator_controller_pb2.KeyboardEvent( codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType .keyup, keyCode=123, )), mock.call( emulator_controller_pb2.KeyboardEvent( codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType .keypress, keyCode=321, )), mock.call( emulator_controller_pb2.KeyboardEvent( codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB, eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType .keypress, keyCode=123, )) ]) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/simulators/fake/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/components/simulators/fake/fake_simulator.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Fake Simulator for testing AndroidEnv infrastructure.""" import random import threading import time from absl import logging from android_env.components import adb_controller from android_env.components import config_classes from android_env.components import log_stream from android_env.components.simulators import base_simulator import numpy as np class FakeStream: """This class simulates the logs coming from ADB.""" def __init__(self): self._values = [ '', self._make_stdout('reward: 0.5'), self._make_stdout('reward: 1.0'), self._make_stdout('extra: my_extra [1.0]'), self._make_stdout('episode end'), ] self._kill = False self._lock = threading.Lock() def _make_stdout(self, data): """Returns a valid log output with given data as message.""" return f' 1553110400.424 5583 5658 D Tag: {data}' def kill(self): self._kill = True def __iter__(self): while True: if self._kill: return else: with self._lock: next_value = random.choices( self._values, weights=[0.49, 0.15, 0.15, 0.15, 0.01], k=1)[0] time.sleep(0.1) yield next_value class FakeLogStream(log_stream.LogStream): """FakeLogStream class that wraps a FakeStream.""" def __init__(self): super().__init__(verbose=False) self.stream = FakeStream() def _get_stream_output(self): return self.stream def stop_stream(self): self.stream.kill() class FakeAdbController(adb_controller.AdbController): """Fake adb controller for FakeSimulator.""" def execute_command( self, args: list[str], timeout: float | None = None, device_specific: bool = True, ) -> bytes: """Returns fake output for adb commands.""" del timeout, device_specific # Fake "service is ready" output. if args[:3] == ['shell', 'service', 'check']: return f'Service {args[-1]}: found'.encode('utf-8') # Fake dumpsys output for getting orientation. if args == ['shell', 'dumpsys', 'input']: return b' SurfaceOrientation: 0' # app_screen_checker: fake_task expects 'fake_activity'. if args[:4] == ['shell', 'am', 'stack', 'list']: return (b'taskId=0 fake_activity visible=true ' b'topActivity=ComponentInfo{fake_activity}') return b'fake output' class FakeSimulator(base_simulator.BaseSimulator): """FakeSimulator class.""" def __init__(self, config: config_classes.FakeSimulatorConfig): """FakeSimulator class that can replace EmulatorSimulator in AndroidEnv.""" super().__init__(config) self._screen_dimensions = np.array(config.screen_dimensions) logging.info('Created FakeSimulator.') def get_logs(self) -> str: return 'FakeSimulator: fake logs' def adb_device_name(self) -> str: return 'fake_simulator' def create_adb_controller(self): return FakeAdbController(config_classes.AdbControllerConfig()) def create_log_stream(self) -> log_stream.LogStream: return FakeLogStream() def _launch_impl(self) -> None: pass def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None: del touches def send_key(self, keycode: np.int32, event_type: str) -> None: del keycode, event_type def _get_screenshot_impl(self) -> np.ndarray: return np.random.randint( low=0, high=255, size=(*self._screen_dimensions, 3), dtype=np.uint8) ================================================ FILE: android_env/components/simulators/fake/fake_simulator_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for fake_simulator.""" import re from absl.testing import absltest from android_env.components import config_classes from android_env.components.simulators.fake import fake_simulator import numpy as np class FakeSimulatorTest(absltest.TestCase): def test_device_name(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) ) self.assertEqual(simulator.adb_device_name(), 'fake_simulator') def test_launch_close(self): # The simulator should launch and not crash. simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) ) simulator.launch() # Closing the simulator should also not crash. simulator.close() def test_get_screenshot(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) ) simulator.launch() screenshot = simulator.get_screenshot() np.testing.assert_equal(screenshot.shape, [320, 480, 3]) np.testing.assert_equal(screenshot.dtype, np.uint8) def test_log_stream(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) ) simulator.launch() log_stream = simulator.create_log_stream() # Start yielding lines from LogStream. log_stream.resume_stream() lines = [ '', ' 1553110400.424 5583 5658 D Tag: reward: 0.5', ' 1553110400.424 5583 5658 D Tag: reward: 1.0', ' 1553110400.424 5583 5658 D Tag: extra: my_extra [1.0]', ' 1553110400.424 5583 5658 D Tag: episode end', ] for i, line in enumerate(log_stream.get_stream_output()): self.assertIn(line, lines) if i > 10: break def test_adb_output(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) ) simulator.launch() adb_controller = simulator.create_adb_controller() line = adb_controller.execute_command(['shell', 'dumpsys', 'input']) line = line.decode('utf-8') matches = re.match(r'\s+SurfaceOrientation:\s+(\d)', line) self.assertIsNotNone(matches) orientation = matches.group(1) self.assertEqual(orientation, '0') line = adb_controller.execute_command(['shell', 'service', 'check', 'foo']) line = line.decode('utf-8') self.assertEqual(line, 'Service foo: found') line = adb_controller.execute_command(['shell', 'am', 'stack', 'list']) line = line.decode('utf-8') self.assertEqual(line, 'taskId=0 fake_activity visible=true ' 'topActivity=ComponentInfo{fake_activity}') def test_send_touch(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) ) simulator.launch() simulator.send_touch([(0, 1, True, 0)]) simulator.send_touch([(0, 1, False, 0)]) # No assertions, we just want to ensure that `send_touch()` can be called # without crashing anything. def test_send_key(self): simulator = fake_simulator.FakeSimulator( config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480)) ) simulator.launch() simulator.send_key(np.int32(123), 'keydown') simulator.send_key(np.int32(123), 'keyup') simulator.send_key(np.int32(123), 'keypress') # No assertions, we just want to ensure that `send_key()` can be called # without crashing anything. if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/specs.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base specs for AndroidEnv.""" from android_env.components import action_type from android_env.proto import task_pb2 from dm_env import specs import numpy as np _PROTO_DTYPE_TO_NUMPY_DTYPE = { task_pb2.ArraySpec.DataType.FLOAT: np.float32, task_pb2.ArraySpec.DataType.DOUBLE: np.float64, task_pb2.ArraySpec.DataType.INT8: np.int8, task_pb2.ArraySpec.DataType.INT16: np.int16, task_pb2.ArraySpec.DataType.INT32: np.int32, task_pb2.ArraySpec.DataType.INT64: np.int64, task_pb2.ArraySpec.DataType.UINT8: np.uint8, task_pb2.ArraySpec.DataType.UINT16: np.uint16, task_pb2.ArraySpec.DataType.UINT32: np.uint32, task_pb2.ArraySpec.DataType.UINT64: np.uint64, task_pb2.ArraySpec.DataType.BOOL: np.bool_, task_pb2.ArraySpec.DataType.STRING_U1: np.dtype(('U1')), task_pb2.ArraySpec.DataType.STRING_U16: np.dtype((' dict[str, specs.Array]: """Default action spec for AndroidEnv. Args: num_fingers: Number of virtual fingers of the agent. enable_key_events: Whether keyboard key events are enabled. Returns: A dict of action specs, each item corresponding to a virtual finger. action_type: An integer of type ActionType: TOUCH=0, LIFT=1, REPEAT=2 touch_position: Position [x, y] of the touch action, where x, y are float values between 0.0 and 1.0 corresponding to the relative position on the screen. IGNORED when (action_type != ActionType.TOUCH). keycode: code representing the desired key press in XKB format. See the emulator_controller_pb2 for details. action_type_i: Action type for additional fingers (i>1). touch_position_i: Touch position for additional fingers (i>1). """ num_actions = len(action_type.ActionType) if enable_key_events else 3 action_spec = { 'action_type': specs.DiscreteArray(num_values=num_actions, name='action_type'), 'touch_position': specs.BoundedArray( shape=(2,), dtype=np.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0], name='touch_position'), } for i in range(2, num_fingers + 1): action_spec.update({ f'action_type_{i}': specs.DiscreteArray( num_values=len(action_type.ActionType), name=f'action_type_{i}'), f'touch_position_{i}': specs.BoundedArray( shape=(2,), dtype=np.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0], name=f'touch_position_{i}'), }) if enable_key_events: action_spec['keycode'] = specs.DiscreteArray( num_values=(1 << 16) - 1, name='keycode') return action_spec def base_observation_spec(height: int, width: int) -> dict[str, specs.Array]: """Default observation spec for AndroidEnv. Args: height: Height of the device screen in pixels. width: Width of the device screen in pixels. Returns: pixels: Spec for the RGB screenshot of the device. Has shape (H, W, 3) timedelta: Spec for time delta since the last observation (in microseconds). The first timestep immediately after reset() will have this value set to 0. orientation: Spec for the latest orientation in a one-hot representation: [1, 0, 0, 0]: PORTRAIT (0 degrees) [0, 1, 0, 0]: LANDSCAPE (90 degrees clockwise) [0, 0, 1, 0]: PORTRAIT (180 degrees) ("upside down") [0, 0, 0, 1]: LANDSCAPE (270 degrees clockwise) """ return { 'pixels': specs.BoundedArray( shape=(height, width, 3), dtype=np.uint8, name='pixels', minimum=0, maximum=255), 'timedelta': specs.Array(shape=(), dtype=np.int64, name='timedelta'), 'orientation': specs.BoundedArray( shape=np.array([4]), dtype=np.uint8, name='orientation', minimum=0, maximum=1), } ================================================ FILE: android_env/components/specs_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for specs.py.""" from absl.testing import absltest from absl.testing import parameterized from android_env.components import specs from android_env.proto import task_pb2 from dm_env import specs as dm_env_specs import numpy as np class SpecsTest(parameterized.TestCase): def test_base_action_spec(self): action_spec = specs.base_action_spec(num_fingers=1) for spec in action_spec.values(): self.assertIsInstance(spec, dm_env_specs.Array) self.assertEqual(action_spec['action_type'].shape, ()) self.assertEqual(action_spec['action_type'].dtype, np.int32) self.assertEqual(action_spec['touch_position'].shape, (2,)) self.assertEqual(action_spec['touch_position'].dtype, np.float32) def test_base_action_spec_with_key_events(self): action_spec = specs.base_action_spec(num_fingers=1, enable_key_events=True) for spec in action_spec.values(): self.assertIsInstance(spec, dm_env_specs.Array) self.assertEqual(action_spec['action_type'].shape, ()) self.assertEqual(action_spec['action_type'].dtype, np.int32) self.assertEqual(action_spec['touch_position'].shape, (2,)) self.assertEqual(action_spec['touch_position'].dtype, np.float32) self.assertEqual(action_spec['keycode'].shape, ()) self.assertEqual(action_spec['keycode'].dtype, np.int32) def test_base_action_spec_multitouch(self): action_spec = specs.base_action_spec(num_fingers=3) self.assertLen(action_spec.keys(), 6) for spec in action_spec.values(): self.assertIsInstance(spec, dm_env_specs.Array) self.assertEqual(action_spec['action_type'].shape, ()) self.assertEqual(action_spec['action_type'].dtype, np.int32) self.assertEqual(action_spec['touch_position'].shape, (2,)) self.assertEqual(action_spec['touch_position'].dtype, np.float32) self.assertEqual(action_spec['action_type_2'].shape, ()) self.assertEqual(action_spec['action_type_2'].dtype, np.int32) self.assertEqual(action_spec['touch_position_2'].shape, (2,)) self.assertEqual(action_spec['touch_position_2'].dtype, np.float32) self.assertEqual(action_spec['action_type_3'].shape, ()) self.assertEqual(action_spec['action_type_3'].dtype, np.int32) self.assertEqual(action_spec['touch_position_3'].shape, (2,)) self.assertEqual(action_spec['touch_position_3'].dtype, np.float32) @parameterized.parameters( (480, 320), (100, 100), (1440, 1960), ) def test_base_observation_spec(self, height, width): observation_spec = specs.base_observation_spec(height, width) for spec in observation_spec.values(): self.assertIsInstance(spec, dm_env_specs.Array) self.assertEqual(observation_spec['pixels'].shape, (height, width, 3)) self.assertEqual(observation_spec['pixels'].dtype, np.uint8) self.assertEqual(observation_spec['timedelta'].shape, ()) self.assertEqual(observation_spec['timedelta'].dtype, np.int64) self.assertEqual(observation_spec['orientation'].shape, (4,)) self.assertEqual(observation_spec['orientation'].dtype, np.uint8) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/components/task_manager.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """TaskManager handles all events and information related to the task.""" import ast from collections.abc import Callable, Iterable, Sequence import copy import datetime import itertools import json import re import threading import time from typing import Any from absl import logging from android_env.components import adb_call_parser as adb_call_parser_lib from android_env.components import app_screen_checker from android_env.components import config_classes from android_env.components import dumpsys_thread from android_env.components import log_stream as log_stream_lib from android_env.components import logcat_thread from android_env.components import setup_step_interpreter from android_env.proto import task_pb2 import dm_env import numpy as np class TaskManager: """Handles all events and information related to the task.""" def __init__( self, task: task_pb2.Task, config: config_classes.TaskManagerConfig | None = None, ): """Controls task-relevant events and information. Args: task: A task proto defining the RL task. config: Configuration for this instance. """ self._task = task self._config = config or config_classes.TaskManagerConfig() self._lock = threading.Lock() self._logcat_thread = None self._dumpsys_thread = None self._setup_step_interpreter = None # Initialize stats. self._stats = { 'episode_steps': 0, 'reset_count_step_timeout': 0, 'reset_count_user_exited': 0, 'reset_count_episode_end': 0, 'reset_count_max_duration_reached': 0, 'restart_count_max_bad_states': 0, 'task_updates': 0, } # Initialize internal state self._task_start_time = None self._bad_state_counter = 0 self._is_bad_episode = False self._latest_values = { 'reward': 0.0, 'score': 0.0, 'extra': {}, 'episode_end': False, } logging.info('Task config: %s', self._task) def stats(self) -> dict[str, Any]: """Returns a dictionary of stats. This method is expected to be called after setup_task() has been called. """ output = copy.deepcopy(self._stats) if self._setup_step_interpreter is not None: output.update(self._setup_step_interpreter.stats()) return output def setup_task(self) -> None: """Performs one-off task setup..""" self._setup_step_interpreter.interpret(self._task.setup_steps) def stop(self) -> None: """Suspends task processing.""" n_tries = 3 for i in range(n_tries): try: self._stop_logcat_thread() break except: # pylint: disable=bare-except logging.exception( 'Failed to stop logcat thread [%d/%d]. Continuing.', i + 1, n_tries ) time.sleep(1) def start( self, adb_call_parser_factory: Callable[[], adb_call_parser_lib.AdbCallParser], log_stream: log_stream_lib.LogStream, ) -> None: """Starts task processing.""" self._start_logcat_thread(log_stream=log_stream) self._logcat_thread.resume() self._start_dumpsys_thread(adb_call_parser_factory()) self._start_setup_step_interpreter(adb_call_parser_factory()) def reset_task(self) -> None: """Resets a task for a new run.""" self._logcat_thread.pause() self._setup_step_interpreter.interpret(self._task.reset_steps) self._logcat_thread.resume() # Reset some other variables. if not self._is_bad_episode: self._bad_state_counter = 0 self._is_bad_episode = False self._task_start_time = datetime.datetime.now() with self._lock: self._latest_values = { 'reward': 0.0, 'score': 0.0, 'extra': {}, 'episode_end': False, } def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep: """Performs one RL step.""" self._stats['episode_steps'] = 0 self._logcat_thread.line_ready().wait() with self._lock: extras = self._get_current_extras() observation['extras'] = extras return dm_env.TimeStep( step_type=dm_env.StepType.FIRST, reward=0.0, discount=0.0, observation=observation, ) def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep: """Performs one RL step.""" self._stats['episode_steps'] += 1 self._logcat_thread.line_ready().wait() with self._lock: reward = self._get_current_reward() extras = self._get_current_extras() transition_fn = self._determine_transition_fn() observation['extras'] = extras return transition_fn(reward=reward, observation=observation) def _get_current_reward(self) -> float: """Returns total reward accumulated since the last step.""" reward = self._latest_values['reward'] self._latest_values['reward'] = 0.0 return reward def _get_current_extras(self) -> dict[str, Any]: """Returns task extras accumulated since the last step.""" extras = {} for name, values in self._latest_values['extra'].items(): extras[name] = np.stack(values) self._latest_values['extra'] = {} return extras def _determine_transition_fn(self) -> Callable[..., dm_env.TimeStep]: """Determines the type of RL transition will be used.""" # Check if user existed the task if self._dumpsys_thread.check_user_exited(): self._increment_bad_state() self._stats['reset_count_user_exited'] += 1 logging.warning('User exited the task. Truncating the episode.') logging.info('************* END OF EPISODE *************') return dm_env.truncation # Check if episode has ended if self._latest_values['episode_end']: self._stats['reset_count_episode_end'] += 1 logging.info('End of episode from logcat! Ending episode.') return dm_env.termination # Check if step limit or time limit has been reached if self._task.max_episode_steps > 0: if self._stats['episode_steps'] > self._task.max_episode_steps: self._stats['reset_count_max_duration_reached'] += 1 logging.info( 'Maximum task duration (%r steps) reached. Truncating the episode.', self._task.max_episode_steps, ) return dm_env.truncation if self._task.max_episode_sec > 0.0: task_duration = datetime.datetime.now() - self._task_start_time max_episode_sec = self._task.max_episode_sec if task_duration > datetime.timedelta(seconds=int(max_episode_sec)): self._stats['reset_count_max_duration_reached'] += 1 logging.info( 'Maximum task duration (%r sec) reached. Truncating the episode.', max_episode_sec, ) return dm_env.truncation return dm_env.transition def _start_setup_step_interpreter( self, adb_call_parser: adb_call_parser_lib.AdbCallParser ): self._setup_step_interpreter = setup_step_interpreter.SetupStepInterpreter( adb_call_parser=adb_call_parser ) def _start_logcat_thread(self, log_stream: log_stream_lib.LogStream): log_stream.set_log_filters(list(self._task.log_parsing_config.filters)) self._logcat_thread = logcat_thread.LogcatThread(log_stream=log_stream) for event_listener in self._logcat_listeners(): self._logcat_thread.add_event_listener(event_listener) def _start_dumpsys_thread( self, adb_call_parser: adb_call_parser_lib.AdbCallParser ): self._dumpsys_thread = dumpsys_thread.DumpsysThread( app_screen_checker=app_screen_checker.AppScreenChecker( adb_call_parser=adb_call_parser, expected_app_screen=self._task.expected_app_screen, ), check_frequency=self._config.dumpsys_check_frequency, max_failed_current_activity=self._config.max_failed_current_activity, ) def _stop_logcat_thread(self): if self._logcat_thread is not None: self._logcat_thread.kill() self._logcat_thread = None def _increment_bad_state(self) -> None: """Increments the bad state counter. Bad states are errors that shouldn't happen and that trigger an episode reset. If enough bad states have been seen consecutively, we restart the simulation in the hope of returning the simulation to a good state. """ logging.warning('Bad state detected.') if self._config.max_bad_states: self._is_bad_episode = True self._bad_state_counter += 1 logging.warning('Bad state counter: %d.', self._bad_state_counter) if self._bad_state_counter >= self._config.max_bad_states: logging.error('Too many consecutive bad states. Restarting simulator.') self._stats['restart_count_max_bad_states'] += 1 self._should_restart = True else: logging.warning('Max bad states not set, bad states will be ignored.') def _logcat_listeners(self) -> Iterable[logcat_thread.EventListener]: """Creates list of EventListeners for logcat thread.""" # Defaults to 'a^' since that regex matches no string by definition. regexps = self._task.log_parsing_config.log_regexps return itertools.chain( self._reward_listeners(regexps), self._reward_event_listeners(regexps), self._score_listeners(regexps), self._episode_end_listeners(regexps), self._extras_listeners(regexps), self._json_extras_listeners(regexps), ) def _reward_listeners( self, regexps: task_pb2.LogParsingConfig.LogRegexps ) -> Iterable[logcat_thread.EventListener]: """Creates an iterable of reward listeners.""" def _reward_handler(event: re.Pattern[str], match: re.Match[str]): del event reward = float(match.group(1)) with self._lock: self._latest_values['reward'] += reward for regexp in regexps.reward: yield logcat_thread.EventListener( regexp=re.compile(regexp or 'a^'), handler_fn=_reward_handler ) def _reward_event_listeners( self, regexps: task_pb2.LogParsingConfig.LogRegexps ) -> Iterable[logcat_thread.EventListener]: """Creates an iterable of reward event listeners.""" for reward_event in regexps.reward_event: def get_reward_event_handler(reward): def _reward_event_handler(event: re.Pattern[str], match: re.Match[str]): del event, match with self._lock: self._latest_values['reward'] += reward return _reward_event_handler yield logcat_thread.EventListener( regexp=re.compile(reward_event.event or 'a^'), handler_fn=get_reward_event_handler(reward_event.reward), ) def _score_listeners( self, regexps: task_pb2.LogParsingConfig.LogRegexps ) -> Iterable[logcat_thread.EventListener]: """Creates an iterable of score listeners.""" def _score_handler(event: re.Pattern[str], match: re.Match[str]): del event current_score = float(match.group(1)) with self._lock: current_reward = current_score - self._latest_values['score'] self._latest_values['score'] = current_score self._latest_values['reward'] += current_reward yield logcat_thread.EventListener( regexp=re.compile(regexps.score or 'a^'), handler_fn=_score_handler ) def _episode_end_listeners( self, regexps: task_pb2.LogParsingConfig.LogRegexps ) -> Iterable[logcat_thread.EventListener]: """Creates an iterable of episode end listeners.""" def _episode_end_handler(event: re.Pattern[str], match: re.Match[str]): del event, match with self._lock: self._latest_values['episode_end'] = True for regexp in regexps.episode_end: yield logcat_thread.EventListener( regexp=re.compile(regexp or 'a^'), handler_fn=_episode_end_handler ) def _process_extra(self, extra_name: str, extra: Sequence[int | float]): extra = np.array(extra) with self._lock: latest_extras = self._latest_values['extra'] if extra_name in latest_extras: # If latest extra is not flushed, append. if ( len(latest_extras[extra_name]) >= self._config.extras_max_buffer_size ): latest_extras[extra_name].pop(0) latest_extras[extra_name].append(extra) else: latest_extras[extra_name] = [extra] self._latest_values['extra'] = latest_extras def _extras_listeners( self, regexps: task_pb2.LogParsingConfig.LogRegexps ) -> Iterable[logcat_thread.EventListener]: """Creates an iterable of extras listeners.""" def _extras_handler(event: re.Pattern[str], match: re.Match[str]): del event extra_name = match.group('name') extra = match.group('extra') if extra: try: extra = ast.literal_eval(extra) except ( ValueError, TypeError, SyntaxError, MemoryError, RecursionError, ): logging.exception('Could not parse extra: %s', extra) # Don't try to process the extra as text; that would probably crash. return else: # No extra value provided for boolean extra. Setting value to True. extra = 1 self._process_extra(extra_name, extra) for regexp in regexps.extra: yield logcat_thread.EventListener( regexp=re.compile(regexp or 'a^'), handler_fn=_extras_handler ) def _json_extras_listeners( self, regexps: task_pb2.LogParsingConfig.LogRegexps ) -> Iterable[logcat_thread.EventListener]: """Creates an iterable of JSON extras listeners.""" def _json_extras_handler(event: re.Pattern[str], match: re.Match[str]): del event extra_data = match.group('json_extra') try: extra = dict(json.loads(extra_data)) except ValueError: logging.error('JSON string could not be parsed: %s', extra_data) return for extra_name, extra_value in extra.items(): self._process_extra(extra_name, extra_value) for regexp in regexps.json_extra: yield logcat_thread.EventListener( regexp=re.compile(regexp or 'a^'), handler_fn=_json_extras_handler ) ================================================ FILE: android_env/components/task_manager_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.components.task_manager.py.""" import json from unittest import mock from absl.testing import absltest from android_env.components import adb_call_parser as adb_call_parser_lib from android_env.components import dumpsys_thread from android_env.components import log_stream from android_env.components import logcat_thread from android_env.components import setup_step_interpreter from android_env.components import task_manager from android_env.proto import task_pb2 import numpy as np class TaskManagerTest(absltest.TestCase): def setUp(self): super().setUp() self.addCleanup(mock.patch.stopall) # Disable previous patches. self._setup_step_interpreter = mock.create_autospec( setup_step_interpreter.SetupStepInterpreter) self._dumpsys_thread = mock.create_autospec(dumpsys_thread.DumpsysThread) self._logcat_thread = mock.create_autospec(logcat_thread.LogcatThread) self._log_stream = mock.create_autospec(log_stream.LogStream) mock.patch.object( setup_step_interpreter, 'SetupStepInterpreter', return_value=self._setup_step_interpreter).start() mock.patch.object( dumpsys_thread, 'DumpsysThread', return_value=self._dumpsys_thread).start() mock.patch.object( logcat_thread, 'LogcatThread', return_value=self._logcat_thread).start() mock.patch.object( log_stream, 'LogStream', return_value=self._log_stream).start() def test_start(self): task_mgr = task_manager.TaskManager(task=task_pb2.Task()) adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) self.assertIsNotNone(task_mgr._logcat_thread) self.assertIsNotNone(task_mgr._dumpsys_thread) self.assertIsNotNone(task_mgr._setup_step_interpreter) def test_setup_task(self): task_mgr = task_manager.TaskManager(task=task_pb2.Task()) adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() self._setup_step_interpreter.interpret.assert_called_once() def test_step_count(self): task_mgr = task_manager.TaskManager(task=task_pb2.Task()) adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() task_mgr.rl_reset(observation={}) self.assertEqual(task_mgr.stats()['episode_steps'], 0) task_mgr.rl_step(observation={}) self.assertEqual(task_mgr.stats()['episode_steps'], 1) task_mgr.rl_step(observation={}) self.assertEqual(task_mgr.stats()['episode_steps'], 2) task_mgr.rl_reset(observation={}) self.assertEqual(task_mgr.stats()['episode_steps'], 0) def test_get_current_reward(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. match = event_listener.regexp.match('Reward: 123.0') if match is None: # Ignore events that are not rewards. return event_listener.handler_fn(event_listener.regexp, match) task = task_pb2.Task() task.log_parsing_config.log_regexps.reward.extend([ '^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$' ]) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) self.assertEqual(timestep.reward, 123.0) np.testing.assert_equal(timestep.observation['pixels'], np.array([1, 2, 3])) def test_reward_event(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. match_1 = event_listener.regexp.match('foo_1') match_2 = event_listener.regexp.match('foo_2') match_3 = event_listener.regexp.match('Reward: 2.0') if match_1: event_listener.handler_fn(event_listener.regexp, match_1) if match_2: event_listener.handler_fn(event_listener.regexp, match_2) if match_3: event_listener.handler_fn(event_listener.regexp, match_3) task = task_pb2.Task() reward_event_1 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent( event='foo_1', reward=5.0) reward_event_2 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent( event='foo_2', reward=-1.0) task.log_parsing_config.log_regexps.reward_event.extend( [reward_event_1, reward_event_2]) task.log_parsing_config.log_regexps.reward.extend( ['^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$']) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) self.assertEqual(timestep.reward, 6.0) def test_get_current_reward_via_score(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. event = event_listener.regexp match = event.match('score: 200.0') if match is None: # Ignore events that are not scores. return event_listener.handler_fn(event, match) # Scores are accumulated by their differences, so a subsequent lower score # means that the final reward decreases. match = event.match('score: 185') event_listener.handler_fn(event, match) task = task_pb2.Task() task.log_parsing_config.log_regexps.score = ( '^score: ([-+]?[0-9]*\\.?[0-9]*)$') task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) self.assertEqual(timestep.reward, 185.0) def test_get_current_extras(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. event = event_listener.regexp match = event.match('extra: some_extra [1, 2]') if match is None: # Ignore events that are not extras. return # Emit events. fn = event_listener.handler_fn fn(event, event.match('extra: an_extra [1, 2, 3]')) fn(event, event.match('extra: an_extra [4, 5, 6]')) fn(event, event.match('extra: another_extra 0.5')) fn(event, event.match('extra: multi_dimension_extra [[9,8,7],[6,5,4]]')) fn(event, event.match('extra: boolean_extra')) # Setup the task and trigger the listener. task = task_pb2.Task() task.log_parsing_config.log_regexps.extra.extend([ '^extra: (?P[^ ]*)[ ]?(?P.*)$' ]) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) # Check expectations. self.assertIn('extras', timestep.observation) extras = timestep.observation['extras'] np.testing.assert_almost_equal([[1, 2, 3], [4, 5, 6]], extras.get('an_extra')) np.testing.assert_almost_equal([0.5], extras.get('another_extra')) np.testing.assert_almost_equal([[[9, 8, 7], [6, 5, 4]]], extras.get('multi_dimension_extra')) np.testing.assert_equal([1], extras.get('boolean_extra')) def test_get_current_extras_json_format(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. event = event_listener.regexp match = event.match('json_extra: {}') if match is None: # Ignore events that are not extras. return # Emit events. extra = { 'extra_scalar': 0, 'extra_list': [1, 2, 3, 4], 'extra_dict': { 'foo': 'bar' }, 'extra_string': 'a_string' } extra_update = {'extra_string': 'a_new_string', 'extra_float': 0.6} fn = event_listener.handler_fn fn(event, event.match(f'json_extra: {json.dumps(extra)}')) fn(event, event.match(f'json_extra: {json.dumps(extra_update)}')) # Setup the task and trigger the listener. task = task_pb2.Task() task.log_parsing_config.log_regexps.json_extra.extend([ '^json_extra: (?P.*)$' ]) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) # Check expectations. self.assertIn('extras', timestep.observation) extras = timestep.observation['extras'] expected_extra = { 'extra_scalar': [0], 'extra_list': [[1, 2, 3, 4]], 'extra_dict': [{ 'foo': 'bar' }], 'extra_string': ['a_string', 'a_new_string'], 'extra_float': [0.6] } np.testing.assert_almost_equal( expected_extra.get('extra_scalar'), extras.get('extra_scalar')) np.testing.assert_almost_equal( expected_extra.get('extra_list'), extras.get('extra_list')) np.testing.assert_equal( expected_extra.get('extra_string'), extras.get('extra_string')) np.testing.assert_almost_equal( expected_extra.get('extra_float'), extras.get('extra_float')) np.testing.assert_equal( expected_extra.get('extra_dict'), extras.get('extra_dict')) def test_get_current_extras_failed_to_parse(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. event = event_listener.regexp match = event.match('extra: some_extra [1, 2]') if match is None: # Ignore events that are not extras. return # Emit events. fn = event_listener.handler_fn fn(event, event.match('extra: extra_with_malformed_1 [1]')) fn(event, event.match('extra: extra_with_malformed_1 [\'this is \\ bad]')) fn(event, event.match('extra: extra_with_malformed_1 [2]')) fn(event, event.match('extra: extra_with_malformed_2 [\'this is bad]')) fn(event, event.match('extra: extra_with_malformed_2 [1]')) fn(event, event.match('extra: extra_malformed_only [_very_bad_news]')) # Setup the task and trigger the listener. task = task_pb2.Task() task.log_parsing_config.log_regexps.extra.extend([ '^extra: (?P[^ ]*)[ ]?(?P.*)$' ]) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) # Check expectations. self.assertIn('extras', timestep.observation) extras = timestep.observation['extras'] np.testing.assert_almost_equal(extras.get('extra_with_malformed_1'), [[1], [2]]) np.testing.assert_almost_equal(extras.get('extra_with_malformed_2'), [[1]]) self.assertNotIn('extra_malformed_only', extras) def test_multi_log_regexp(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. match = event_listener.regexp.match('Reward_2: 123.0') if match is None: # Ignore events that are not rewards. return event_listener.handler_fn(event_listener.regexp, match) task = task_pb2.Task() task.log_parsing_config.log_regexps.reward.extend([ '^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$', '^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$' ]) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) self.assertEqual(timestep.reward, 123.0) def test_multi_reward_regexp(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away.' def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. match_1 = event_listener.regexp.match('Reward_1: 5.0') match_2 = event_listener.regexp.match('Reward_2: 10.0') if match_1: event_listener.handler_fn(event_listener.regexp, match_1) if match_2: event_listener.handler_fn(event_listener.regexp, match_2) task = task_pb2.Task() task.log_parsing_config.log_regexps.reward.extend([ '^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$', '^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$', ]) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) self.assertEqual(timestep.reward, 15.0) def test_determine_transition_fn(self): # Replace `LogcatThread.add_event_listener` with one that simply calls `fn` # right away. def my_add_ev_listener(event_listener: logcat_thread.EventListener): # Check that the event matches what's expected. event = event_listener.regexp match = event.match('I am done!') if match is None: # Ignore events that are not episode end. return event_listener.handler_fn(event, match) task = task_pb2.Task() task.log_parsing_config.log_regexps.episode_end.extend(['I am done!']) task_mgr = task_manager.TaskManager(task=task) self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser) task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream) task_mgr.setup_task() timestep = task_mgr.rl_step( observation={ 'pixels': np.array([1, 2, 3]), }) self.assertTrue(timestep.last()) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/env_interface.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Abstract AndroidEnv interface. AndroidEnv is a standard dm_env.Environment instance, but it also offers a few extra methods that clients may use for extended functionality. """ import abc from typing import Any from android_env.proto import adb_pb2 from android_env.proto import state_pb2 import dm_env from dm_env import specs import numpy as np class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta): """Pure virtual interface for AndroidEnv implementations.""" # Methods required by dm_env.Environment. @abc.abstractmethod def action_spec(self) -> dict[str, specs.Array]: """Returns the action specification.""" @abc.abstractmethod def observation_spec(self) -> dict[str, specs.Array]: """Returns the observation specification.""" @abc.abstractmethod def reset(self) -> dm_env.TimeStep: """Resets the current episode.""" @abc.abstractmethod def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: """Executes `action` and returns a `TimeStep`.""" @abc.abstractmethod def close(self) -> None: """Frees up resources.""" # Extensions provided by AndroidEnv. def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]: """Returns extra info provided by tasks.""" return {} @property def raw_action(self) -> Any: """Returns the latest action.""" @property def raw_observation(self) -> Any: """Returns the latest observation.""" def stats(self) -> dict[str, Any]: """Returns information generated inside the implementation.""" return {} def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: """Executes `call` and returns its response.""" return adb_pb2.AdbResponse() def load_state( self, request: state_pb2.LoadStateRequest ) -> state_pb2.LoadStateResponse: """Loads a state. Args: request: A `LoadStateRequest` containing any parameters necessary to specify how/what state to load. Returns: A `LoadStateResponse` containing the status, error message (if applicable), and any other relevant information. """ raise NotImplementedError('This environment does not support loading state') def save_state( self, request: state_pb2.SaveStateRequest ) -> state_pb2.SaveStateResponse: """Saves a state. Args: request: A `SaveStateRequest` containing any parameters necessary to specify how/what state to save. Returns: A `SaveStateResponse` containing the status, error message (if applicable), and any other relevant information. """ raise NotImplementedError('This environment does not support saving state') ================================================ FILE: android_env/environment.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Android environment implementation.""" from typing import Any from absl import logging from android_env import env_interface from android_env.components import adb_call_parser from android_env.components import coordinator as coordinator_lib from android_env.components import task_manager as task_manager_lib from android_env.components.simulators import base_simulator from android_env.proto import adb_pb2 from android_env.proto import state_pb2 import dm_env import numpy as np class AndroidEnv(env_interface.AndroidEnvInterface): """An RL environment that interacts with Android apps.""" def __init__( self, simulator: base_simulator.BaseSimulator, coordinator: coordinator_lib.Coordinator, task_manager: task_manager_lib.TaskManager, ): """Initializes the state of this AndroidEnv object.""" self._simulator = simulator self._coordinator = coordinator self._task_manager = task_manager self._latest_action = {} self._latest_observation = {} self._latest_extras = {} self._reset_next_step = True self._is_closed = False logging.info('Action spec: %s', self.action_spec()) logging.info('Observation spec: %s', self.observation_spec()) def __del__(self) -> None: self.close() # Methods required by dm_env.Environment. def action_spec(self) -> dict[str, dm_env.specs.Array]: return self._coordinator.action_spec() def observation_spec(self) -> dict[str, dm_env.specs.Array]: return self._coordinator.observation_spec() def reset(self) -> dm_env.TimeStep: """Resets the environment for a new RL episode.""" logging.info('Resetting AndroidEnv...') # Execute a reset. Timestep will be of type FIRST. timestep = self._coordinator.rl_reset() # Process relevant information. if timestep.observation is not None: self._latest_extras = timestep.observation.pop('extras') self._latest_observation = timestep.observation.copy() else: # If the observation is None, we return the latest observation again. timestep = timestep._replace(observation=self._latest_observation.copy()) self._latest_action = {} self._reset_next_step = False logging.info('Done resetting AndroidEnv.') logging.info('************* NEW EPISODE *************') return timestep def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: """Takes a step in the environment.""" # Check if it's time to reset the episode. if self._reset_next_step: return self.reset() # Execute selected action. timestep = self._coordinator.rl_step(action) # Process relevant information. if timestep.observation is not None: self._latest_extras = timestep.observation.pop('extras') self._latest_observation = timestep.observation.copy() else: # If the observation is None, we return the latest observation again. timestep = timestep._replace(observation=self._latest_observation.copy()) self._latest_action = action.copy() if timestep.last(): self._reset_next_step = True logging.info('************* END OF EPISODE *************') return timestep def close(self) -> None: """Cleans up running processes, threads and local files.""" if not self._is_closed: logging.info('Cleaning up AndroidEnv...') if hasattr(self, '_coordinator'): self._coordinator.close() logging.info('Done cleaning up AndroidEnv.') self._is_closed = True # Extensions provided by AndroidEnv. def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]: """Returns latest task extras.""" task_extras = {} # Build a copy to avoid reusing objects. for k, spec in self._latest_extras.items(): extra_values = spec.astype(spec.dtype) task_extras[k] = extra_values[-1] if latest_only else extra_values return task_extras @property def raw_action(self): return self._latest_action.copy() @property def raw_observation(self): return self._latest_observation.copy() def stats(self) -> dict[str, Any]: coordinator_stats = self._coordinator.stats() task_manager_stats = self._task_manager.stats() return coordinator_stats | task_manager_stats def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse: return self._coordinator.execute_adb_call(call) def load_state( self, request: state_pb2.LoadStateRequest ) -> state_pb2.LoadStateResponse: """Loads a state. Args: request: A `LoadStateRequest` containing any parameters necessary to specify how/what state to load. Returns: A `LoadStateResponse` containing the status, error message (if applicable), and any other relevant information. """ self._task_manager.stop() response = self._simulator.load_state(request) self._task_manager.start( adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser( self._simulator.create_adb_controller() ), log_stream=self._simulator.create_log_stream(), ) return response def save_state( self, request: state_pb2.SaveStateRequest ) -> state_pb2.SaveStateResponse: """Saves a state. Args: request: A `SaveStateRequest` containing any parameters necessary to specify how/what state to save. Returns: A `SaveStateResponse` containing the status, error message (if applicable), and any other relevant information. """ return self._simulator.save_state(request) ================================================ FILE: android_env/environment_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Unit tests for AndroidEnv.""" from unittest import mock from absl.testing import absltest from android_env import environment from android_env.components import config_classes from android_env.components import coordinator as coordinator_lib from android_env.components import task_manager as task_manager_lib from android_env.components.simulators import base_simulator from android_env.components.simulators.fake import fake_simulator from android_env.proto import adb_pb2 from android_env.proto import state_pb2 import dm_env import numpy as np def _create_mock_coordinator() -> coordinator_lib.Coordinator: coordinator = mock.create_autospec(coordinator_lib.Coordinator) coordinator.action_spec.return_value = { 'action_type': dm_env.specs.DiscreteArray(num_values=3), 'touch_position': dm_env.specs.BoundedArray( shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0), } coordinator.observation_spec.return_value = { 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8), 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64), 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8), } return coordinator def _create_fake_simulator() -> fake_simulator.FakeSimulator: return fake_simulator.FakeSimulator( config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456)) ) class AndroidEnvTest(absltest.TestCase): def test_specs(self): simulator = _create_fake_simulator() coordinator = _create_mock_coordinator() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager ) # Check action spec. self.assertNotEmpty(env.action_spec()) self.assertIn('action_type', env.action_spec()) self.assertIsInstance(env.action_spec()['action_type'], dm_env.specs.DiscreteArray) self.assertIn('touch_position', env.action_spec()) self.assertIsInstance(env.action_spec()['touch_position'], dm_env.specs.BoundedArray) # Check observation spec. self.assertNotEmpty(env.observation_spec()) self.assertIn('pixels', env.observation_spec()) self.assertIsInstance(env.observation_spec()['pixels'], dm_env.specs.Array) # The `pixels` entry in the observation spec should match the screen size of # the simulator with three color channels (RGB). self.assertEqual(env.observation_spec()['pixels'].shape, (123, 456, 3)) self.assertIn('timedelta', env.observation_spec()) self.assertIsInstance(env.observation_spec()['timedelta'], dm_env.specs.Array) # The `timedelta` should be a scalar. self.assertEqual(env.observation_spec()['timedelta'].shape, ()) self.assertIn('orientation', env.observation_spec()) # The `orientation` should be a one-hot vector with four dimensions. self.assertIsInstance(env.observation_spec()['orientation'], dm_env.specs.Array) self.assertEqual(env.observation_spec()['orientation'].shape, (4,)) def test_reset_and_step(self): simulator = _create_fake_simulator() coordinator = _create_mock_coordinator() task_manager = mock.create_autospec(task_manager_lib.TaskManager) coordinator.action_spec.return_value = { 'action_type': dm_env.specs.DiscreteArray(num_values=3), 'touch_position': dm_env.specs.BoundedArray( shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0), } coordinator.observation_spec.return_value = { 'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8), 'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64), 'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8), } env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager ) coordinator.rl_reset.return_value = dm_env.TimeStep( step_type=dm_env.StepType.FIRST, reward=0.0, discount=0.0, observation={ 'pixels': np.random.rand(987, 654, 3), 'timedelta': 123456, 'orientation': np.array((1, 0, 0, 0)), 'extras': { 'click': np.array([[246]], dtype=np.int64) } }, ) ts = env.reset() self.assertIsInstance(ts, dm_env.TimeStep) # After a `reset()` the TimeStep should follow some expectations. self.assertTrue(ts.first()) self.assertEqual(ts.reward, 0.0) self.assertEqual(ts.discount, 0.0) obs = ts.observation self.assertIn('pixels', obs) self.assertEqual(obs['pixels'].shape, (987, 654, 3)) self.assertIn('timedelta', obs) self.assertEqual(obs['timedelta'], 123456) self.assertIn('orientation', obs) self.assertEqual(obs['orientation'].shape, (4,)) np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0)) # Extras should also be provided. extras = env.task_extras() self.assertIn('click', extras) self.assertEqual(extras['click'], np.array([246], dtype=np.int64)) coordinator.stats.return_value = {'my_measurement': 135} task_manager.stats.return_value = {'another_measurement': 79} # Step again in the environment and check expectations again. pixels = np.random.rand(987, 654, 3) latest_observation = { 'pixels': pixels, 'timedelta': 123456, 'orientation': np.array((1, 0, 0, 0)), 'extras': { 'click': np.array([[246]], dtype=np.int64) } } coordinator.rl_step.return_value = dm_env.transition( reward=0.0, discount=0.0, observation=latest_observation, ) ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) self.assertIsInstance(ts, dm_env.TimeStep) # The StepType now should NOT be FIRST. self.assertFalse(ts.first()) self.assertEqual(ts.reward, 0.0) self.assertEqual(ts.discount, 0.0) obs = ts.observation self.assertIn('pixels', obs) self.assertEqual(obs['pixels'].shape, (987, 654, 3)) self.assertIn('timedelta', obs) self.assertEqual(obs['timedelta'], 123456) self.assertIn('orientation', obs) self.assertEqual(obs['orientation'].shape, (4,)) np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0)) # Extras should still be provided. extras = env.task_extras() self.assertIn('click', extras) self.assertEqual(extras['click'], np.array([246], dtype=np.int64)) # At this point these methods and properties should return something. self.assertNotEmpty(env.stats()) self.assertNotEmpty(env.raw_observation) self.assertNotIn('extras', env.raw_observation) self.assertNotEmpty(env.raw_action) # If the observation is None, we want to return the latest observation. coordinator.rl_step.return_value = dm_env.truncation( reward=0.0, observation=None, ) ts = env.step({'action_type': 1, 'touch_position': (10, 20)}) self.assertIsInstance(ts, dm_env.TimeStep) # Assert the observation matches the latest observation. obs = ts.observation self.assertIn('pixels', obs) self.assertEqual(obs['pixels'].shape, (987, 654, 3)) np.testing.assert_equal(obs['pixels'], pixels) self.assertIn('timedelta', obs) self.assertEqual(obs['timedelta'], 123456) self.assertIn('orientation', obs) self.assertEqual(obs['orientation'].shape, (4,)) np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0)) def test_adb_call(self): simulator = _create_fake_simulator() coordinator = _create_mock_coordinator() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager ) call = adb_pb2.AdbRequest( force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah')) expected_response = adb_pb2.AdbResponse( status=adb_pb2.AdbResponse.Status.OK) coordinator.execute_adb_call.return_value = expected_response response = env.execute_adb_call(call) self.assertEqual(response, expected_response) coordinator.execute_adb_call.assert_called_once_with(call) def test_load_state(self): simulator = mock.create_autospec(base_simulator.BaseSimulator) coordinator = _create_mock_coordinator() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager ) expected_response = state_pb2.LoadStateResponse( status=state_pb2.LoadStateResponse.Status.OK ) request = state_pb2.LoadStateRequest(args={'foo': 'bar'}) simulator.load_state.return_value = expected_response response = env.load_state(request) self.assertEqual(response, expected_response) simulator.load_state.assert_called_once_with(request) task_manager.stop.assert_called_once() task_manager.start.assert_called_once() def test_save_state(self): simulator = mock.create_autospec(base_simulator.BaseSimulator) coordinator = _create_mock_coordinator() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager ) expected_response = state_pb2.SaveStateResponse( status=state_pb2.SaveStateResponse.Status.OK ) request = state_pb2.SaveStateRequest(args={'foo': 'bar'}) simulator.save_state.return_value = expected_response response = env.save_state(request) self.assertEqual(response, expected_response) simulator.save_state.assert_called_once_with(request) def test_double_close(self): simulator = _create_fake_simulator() coordinator = _create_mock_coordinator() task_manager = mock.create_autospec(task_manager_lib.TaskManager) env = environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager ) env.close() env.close() coordinator.close.assert_called_once() if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/loader.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Function for loading AndroidEnv.""" import os from absl import logging from android_env import environment from android_env.components import config_classes from android_env.components import coordinator as coordinator_lib from android_env.components import device_settings as device_settings_lib from android_env.components import task_manager as task_manager_lib from android_env.components.simulators.emulator import emulator_simulator from android_env.components.simulators.fake import fake_simulator from android_env.proto import task_pb2 from google.protobuf import text_format def _load_task(task_config: config_classes.TaskConfig) -> task_pb2.Task: """Returns the task according to `task_config`.""" task = task_pb2.Task() match task_config: case config_classes.FilesystemTaskConfig(): with open(task_config.path, 'r') as proto_file: text_format.Parse(proto_file.read(), task) case _: logging.error('Unsupported TaskConfig: %r', task_config) return task def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv: """Loads an AndroidEnv instance.""" task = _load_task(config.task) task_manager = task_manager_lib.TaskManager(task) match config.simulator: case config_classes.EmulatorConfig(): _process_emulator_launcher_config(config.simulator) simulator = emulator_simulator.EmulatorSimulator(config=config.simulator) case config_classes.FakeSimulatorConfig(): simulator = fake_simulator.FakeSimulator(config=config.simulator) case _: raise ValueError('Unsupported simulator config: {config.simulator}') device_settings = device_settings_lib.DeviceSettings(simulator) coordinator = coordinator_lib.Coordinator( simulator, task_manager, device_settings ) return environment.AndroidEnv( simulator=simulator, coordinator=coordinator, task_manager=task_manager ) def _process_emulator_launcher_config( emulator_config: config_classes.EmulatorConfig, ) -> None: """Adjusts the configuration of the emulator depending on some conditions.""" # Expand the user directory if specified. launcher_config = emulator_config.emulator_launcher launcher_config.android_avd_home = os.path.expanduser( launcher_config.android_avd_home ) launcher_config.android_sdk_root = os.path.expanduser( launcher_config.android_sdk_root ) launcher_config.emulator_path = os.path.expanduser( launcher_config.emulator_path ) emulator_config.adb_controller.adb_path = os.path.expanduser( emulator_config.adb_controller.adb_path ) ================================================ FILE: android_env/loader_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for loader.""" import builtins import os from unittest import mock from absl.testing import absltest from android_env import env_interface from android_env import loader from android_env.components import config_classes from android_env.components import coordinator as coordinator_lib from android_env.components import device_settings as device_settings_lib from android_env.components import task_manager as task_manager_lib from android_env.components.simulators.emulator import emulator_simulator from android_env.components.simulators.fake import fake_simulator from android_env.proto import task_pb2 class LoaderTest(absltest.TestCase): @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) @mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True) @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) @mock.patch.object(builtins, 'open', autospec=True) def test_load_emulator( self, mock_open, mock_coordinator, mock_device_settings, mock_simulator_class, mock_task_manager, ): # Arrange. mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = '' config = config_classes.AndroidEnvConfig( task=config_classes.FilesystemTaskConfig(path='some/path/'), simulator=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( avd_name='my_avd', android_avd_home='~/.android/avd', android_sdk_root='~/Android/Sdk', emulator_path='~/Android/Sdk/emulator/emulator', run_headless=False, ), adb_controller=config_classes.AdbControllerConfig( adb_path='~/Android/Sdk/platform-tools/adb', ), ), ) # Act. env = loader.load(config) # Assert. self.assertIsInstance(env, env_interface.AndroidEnvInterface) mock_simulator_class.assert_called_with( config=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( avd_name='my_avd', android_avd_home=os.path.expanduser('~/.android/avd'), android_sdk_root=os.path.expanduser('~/Android/Sdk'), emulator_path=os.path.expanduser( '~/Android/Sdk/emulator/emulator' ), run_headless=False, gpu_mode='swangle_indirect', ), adb_controller=config_classes.AdbControllerConfig( adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'), adb_server_port=5037, ), ) ) mock_coordinator.assert_called_with( mock_simulator_class.return_value, mock_task_manager.return_value, mock_device_settings.return_value, ) @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) @mock.patch.object(fake_simulator, 'FakeSimulator', autospec=True) @mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True) @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) @mock.patch.object(builtins, 'open', autospec=True) def test_load_fake_simulator( self, mock_open, mock_coordinator, mock_device_settings, mock_simulator_class, mock_task_manager, ): # Arrange. mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = '' config = config_classes.AndroidEnvConfig( task=config_classes.FilesystemTaskConfig(path='some/path/'), simulator=config_classes.FakeSimulatorConfig( screen_dimensions=(1234, 5678) ), ) # Act. env = loader.load(config) # Assert. self.assertIsInstance(env, env_interface.AndroidEnvInterface) mock_simulator_class.assert_called_with( config=config_classes.FakeSimulatorConfig( screen_dimensions=(1234, 5678) ) ) mock_coordinator.assert_called_with( mock_simulator_class.return_value, mock_task_manager.return_value, mock_device_settings.return_value, ) @mock.patch.object(task_manager_lib, 'TaskManager', autospec=True) @mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True) @mock.patch.object(coordinator_lib, 'Coordinator', autospec=True) @mock.patch.object(builtins, 'open', autospec=True) def test_task( self, mock_open, mock_coordinator, mock_simulator, mock_task_manager ): # Arrange. del mock_coordinator, mock_simulator mock_open.return_value.__enter__ = mock_open mock_open.return_value.read.return_value = r''' id: "fake_task" name: "Fake Task" description: "Task for testing loader." max_episode_sec: 0 ''' config = config_classes.AndroidEnvConfig( task=config_classes.FilesystemTaskConfig(path='some/path/'), simulator=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( avd_name='my_avd' ), adb_controller=config_classes.AdbControllerConfig( adb_path='~/Android/Sdk/platform-tools/adb', ), ), ) # Act. env = loader.load(config) # Assert. expected_task = task_pb2.Task() expected_task.id = 'fake_task' expected_task.name = 'Fake Task' expected_task.description = 'Task for testing loader.' expected_task.max_episode_sec = 0 mock_task_manager.assert_called_with(expected_task) self.assertIsInstance(env, env_interface.AndroidEnvInterface) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/proto/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/proto/a11y/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/proto/a11y/a11y.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; import "android_env/proto/a11y/android_accessibility_forest.proto"; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // A service to send Accessibility information to a remote server. // // The client is assumed to be running inside an Android device (e.g. emulator // or real device) while the server is assumed to be running outside (e.g. in a // Python process). service A11yService { // Sends a forest of Accessibility trees to a server. rpc SendForest(AndroidAccessibilityForest) returns (ForestResponse) {} // Sends an a11y event to a server. rpc SendEvent(EventRequest) returns (EventResponse) {} // Long-lived bidirection communication between the client and the server. rpc Bidi(stream ClientToServer) returns (stream ServerToClient) {} } // TODO(b/334952387): Remove `ForestResponse`, `EventRequest` and // `EventResponse` once bidi communication is in-place. message ForestResponse { // The error if anything. string error = 1; } // An Accessibility event. message EventRequest { // A single event as a dictionary. map event = 1; } message EventResponse { // The error if anything. string error = 1; } // The message sent from the Android device to the server running outside of the // device. message ClientToServer { oneof payload { EventRequest event = 1; AndroidAccessibilityForest forest = 2; } } // The message sent from the server running outside of the device to the Android // device. message ServerToClient { // A request to obtain the Accessibility forest. message GetA11yForest {} oneof payload { GetA11yForest get_forest = 1; } } ================================================ FILE: android_env/proto/a11y/android_accessibility_action.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // An Android Accessibility Action. // Next index: 3 message AndroidAccessibilityAction { // Required ID that uniquely identifies the action for this node. // Can be one of the standard action IDs listed in the documentation. // https://developer.android.com/reference/android/view/accessibility/AccessibilityNodeInfo.AccessibilityAction int32 id = 1; // Optional label describing what the action is. string label = 2; } ================================================ FILE: android_env/proto/a11y/android_accessibility_forest.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; import "android_env/proto/a11y/android_accessibility_window_info.proto"; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // A forest of Android accessibility trees. Each tree belongs to a single // window. Next index: 2 message AndroidAccessibilityForest { // All of the windows present on screen. repeated AndroidAccessibilityWindowInfo windows = 1; } ================================================ FILE: android_env/proto/a11y/android_accessibility_node_info.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; import "android_env/proto/a11y/android_accessibility_action.proto"; import "android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto"; import "android_env/proto/a11y/rect.proto"; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // An Android AccessibilityNodeInfo. // Next index: 32 message AndroidAccessibilityNodeInfo { // Unique monotonically-increasing ID. int32 unique_id = 1; // The bounds of this node within the device's screen. ProtoRect bounds_in_screen = 2; // The name of the View class that created this node. string class_name = 3; // The content description of the node. string content_description = 4; // The hint text of the node. string hint_text = 5; // The name of the package this node comes from. string package_name = 6; // The text of this node. string text = 7; // The start index of the text selection. int64 text_selection_start = 8; // The end index of the text selection. int64 text_selection_end = 9; // The view ID resource name of the node. string view_id_resource_name = 10; // The ID of the window this node belongs to. int32 window_id = 11; // If true, this node can be checked. bool is_checkable = 12; // If true, this node is currently checked. bool is_checked = 13; // If true, this node (probably) responds to being clicked. bool is_clickable = 14; // If true, this node's text can be edited by the user. bool is_editable = 15; // If true, this node is enabled (e.g., if it is a button). bool is_enabled = 16; // If true, this node can be focused (e.g., a text input). bool is_focusable = 17; // If true, this node is currently focused. bool is_focused = 18; // If true, this node (probably) responds to being long pressed. bool is_long_clickable = 19; // If true, this node is a password input. bool is_password = 20; // If true, this node can be scrolled. bool is_scrollable = 21; // If true, this node is currently selected. bool is_selected = 22; // If true, this node is (probably) visible to the user. bool is_visible_to_user = 23; // List of actions that can be performed on this node. repeated AndroidAccessibilityAction actions = 24; // Ordered list of child IDs (i.e., unique_id). repeated int32 child_ids = 25 [packed = true]; // List of clickable spans present in the node's text or content description. repeated AndroidAccessibilityNodeInfoClickableSpan clickable_spans = 26; // The depth of this node in the accessibility tree. int32 depth = 27; // Unique ID of the node that this node is declaring itself to be labeled by. int32 labeled_by_id = 28; // Unique ID of the node that this is node is declaring itself to be a label // for. int32 label_for_id = 29; // The drawing order for the node. int32 drawing_order = 30; // The tooltip text of the node. string tooltip_text = 31; } ================================================ FILE: android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // A single clickable span found in the accessibility node's text. // Next index: 6 message AndroidAccessibilityNodeInfoClickableSpan { // The source of the span (so the client can find the correct spannable string // in the node). // Next index: 3 enum SpanSource { UNKNOWN_TYPE = 0; // Catch all type for forward compatibility. TEXT = 1; // The span is from node#getText CONTENT_DESCRIPTION = 2; // The span is from node#getContentDescription. } // The text of the span (a substring of the spannable string). string text = 1; // The URL attached to the span if specified. string url = 2; // The source of the span. SpanSource source = 3; // The index of the first character of the span in the spannable string. // The end of the span would be a sum of span_start and text.length(). int32 start = 4; // The unique_id from the corresponding AndroidAccessibilityNodeInfo. int32 node_id = 5; } ================================================ FILE: android_env/proto/a11y/android_accessibility_tree.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; import "android_env/proto/a11y/android_accessibility_node_info.proto"; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // A tree (actually a graph) of Android accessibility nodes. // Next index: 3 message AndroidAccessibilityTree { // All of the nodes in the graph. The root node is the node whose ID is 0. repeated AndroidAccessibilityNodeInfo nodes = 1; } ================================================ FILE: android_env/proto/a11y/android_accessibility_window_info.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; import "android_env/proto/a11y/android_accessibility_tree.proto"; import "android_env/proto/a11y/rect.proto"; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // An Android AccessibilityWindowInfo. // Next index: 12 message AndroidAccessibilityWindowInfo { // Type of the window. // Next index: 8 enum WindowType { // The window type is an unknown value. UNKNOWN_TYPE = 0; // A standard application window. TYPE_APPLICATION = 1; // An IME window (e.g. GBoard). TYPE_INPUT_METHOD = 2; // A system window (e.g., a notification). TYPE_SYSTEM = 3; // An accessibility overlay. TYPE_ACCESSIBILITY_OVERLAY = 4; // A system window used to divide the screen in split-screen mode. This type // of window is present only in split-screen mode. TYPE_SPLIT_SCREEN_DIVIDER = 5; // Used to show the UI for window-based magnification. TYPE_MAGNIFICATION_OVERLAY = 6; // System window that has the function to control an associated window. TYPE_WINDOW_CONTROL = 7; } // Bounds of this window in the device's screen. ProtoRect bounds_in_screen = 1; // A unique ID identifying the display in which this window is shown. int32 display_id = 2; // Unique ID as defined by the Android platform. int32 id = 3; // Z-index of the window. Windows with a greater z-index appear in front of // those with a lesser z-index. int32 layer = 4; // The title of the window, if set. string title = 5; // The type of the window. WindowType window_type = 6; // If true, the window is currently accessibility-focused. bool is_accessibility_focused = 7; // If true, the window is currently active. bool is_active = 8; // If true, the window is currently focused. bool is_focused = 9; // If true, the window is in Picture in Picture mode. bool is_in_picture_in_picture_mode = 10; // The associated accessibility tree for this window. AndroidAccessibilityTree tree = 11; } ================================================ FILE: android_env/proto/a11y/rect.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; option java_multiple_files = true; option java_package = "com.google.androidenv.accessibilityforwarder"; // Proto representation of Android Rect. // https://developer.android.com/reference/android/graphics/Rect // Next index: 5 message ProtoRect { int32 left = 1; int32 top = 2; int32 right = 3; int32 bottom = 4; } ================================================ FILE: android_env/proto/adb.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; message AdbRequest { // Installs an APK into the simulator. message InstallApk { // A location in the filesystem. message Filesystem { string path = 1; } // A byte sequence of a single APK file. message Blob { // The serialized file as bytes. bytes contents = 1; } oneof location { Filesystem filesystem = 2; Blob blob = 6; } } message StartActivity { string full_activity = 1; repeated string extra_args = 2; // Whether to stop the current app before starting the activity. // Notice that if this option is `true`, the activity probably needs the // `android:launchMode="singleTop"` attribute in its `AndroidManifest.xml`, // otherwise intents may not be received by `onNewIntent()`. Please see more // info on `android:launchMode` at // https://developer.android.com/guide/topics/manifest/activity-element. bool force_stop = 3; } message SendBroadcast { // Action to send during the broadcast event. string action = 1; // Specify the component name with package name prefix to create an explicit // intent, such as com.example.app/.ExampleActivity (see -n specification at // https://developer.android.com/tools/adb#IntentSpec). string component = 2; } message UninstallPackage { string package_name = 1; } message ForceStop { string package_name = 1; } message Tap { // NOTE: These are absolute coordinates in the range of the screen // resolution. They are NOT floats in [0,1]. // Precondition: `x` and `y` must be non-negative. int32 x = 1; int32 y = 2; } message PressButton { enum Button { HOME = 0; BACK = 1; ENTER = 2; } Button button = 1; } // Pins the given activity to the screen. // This essentially locks the user into a single app mode (aka "Kiosk mode"). message StartScreenPinning { string full_activity = 1; } // Returns the full activity name that is currently opened to the user. // If successful, a GetCurrentActivityResponse is returned. message GetCurrentActivity {} // Returns the orientation of the device. message GetOrientationRequest {} // Performs `adb push`. // Please see https://developer.android.com/studio/command-line/adb#copyfiles. // // Notice that a source destination path for the file is not sent, but raw // bytes in `content` instead. Obviously, the `content` can be set from a real // file, but this is done to ensure Task definitions are as hermetic as // possible, without depending on the environment from where they're run. message Push { // The contents of the file. bytes content = 1; // Destination path _inside_ Android. E.g. /sdcard/my_file.txt. string path = 2; } // Performs `adb pull`. // Please see https://developer.android.com/studio/command-line/adb#copyfiles. // // Notice that a local destination for the copied file is not sent, as raw // bytes are returned instead (please see PullResponse). Obviously, these // bytes can be written to disk by the caller of this command. message Pull { // Path _inside_ Android. E.g. /sdcard/my_file.txt. string path = 1; } // Inserts text into the current text field (if any). // Essentially `adb shell input text `. message InputText { string text = 1; } // Issues an `adb shell settings` command. message SettingsRequest { // Each request has an associated namespace. enum Namespace { UNKNOWN = 0; SYSTEM = 1; SECURE = 2; GLOBAL = 3; } // Retrieves the current value for `key`. message Get { string key = 1; } // Changes the contents `key` to `value`. message Put { string key = 1; string value = 2; } // Deletes the entry for `key`. message Delete { string key = 1; } // Resets the global/secure table for a package with the given mode. message Reset { enum Mode { UNKNOWN = 0; UNTRUSTED_DEFAULTS = 1; UNTRUSTED_CLEAR = 2; TRUSTED_DEFAULTS = 3; } string package_name = 1; Mode mode = 2; } // Prints all defined keys in the given namespace. message List {} // The part of the system where this command will take place. // NOTE: We avoid the identifier `namespace` because it's a keyword in C++. Namespace name_space = 1; // The subcommand to issue to `adb settings`. // NOTE: We avoid the identifiers `delete` and `del` because they're // keywords in C++ and Python respectively. oneof verb { Get get = 2; Put put = 3; Delete delete_key = 4; Reset reset = 5; List list = 6; } } // Generic ADB command. Use this for commands that are not // explicitly implemented. // Calls `adb [args...]`. message GenericRequest { repeated string args = 1; } message PackageManagerRequest { message List { // Lists all features of the system. message Features {} // Lists all system libraries. message Libraries {} // Lists all packages; optionally only those whose name contains the text // in `filter`. message Packages { string filter = 1; // Extra options that control the output. Please see `pm help` for // details. repeated string options = 2; } oneof what { Features features = 1; Libraries libraries = 2; Packages packages = 3; } } // Deletes all data associated with a package. message Clear { // The package name to clear its cache. string package_name = 1; // Optional USER_ID. string user_id = 2; } message Grant { string package_name = 1; // Possible values listed at // https://developer.android.com/reference/android/Manifest.permission // To query an app's required permissions, use the following adb command: // > adb shell dumpsys package // The output will contain things like // android.permission.WRITE_SECURE_SETTINGS repeated string permissions = 2; } // The subcommand to issue to `pm`. oneof verb { List list = 1; Clear clear = 2; Grant grant = 3; } } // For executing `dumpsys` commands. message DumpsysRequest { enum PriorityLevel { UNSET = 0; NORMAL = 1; HIGH = 2; CRITICAL = 3; } // The service to dump. If empty, all services will be dumped. string service = 1; // Optional arguments to pass to the specific service dump. repeated string args = 2; // Lists services, does not dump them. // This effectively disables dumping information about any particular // service. bool list_only = 3; // Timeouts natively supported by `dumpsys`. int32 timeout_sec = 4; int32 timeout_ms = 5; // Whether to dump the process ID instead of the usual dump. bool pid = 6; // Whether dumps will be in proto format. Only works for services that // support dumping data in proto format. bool proto = 7; // Filters services based on specified priority. PriorityLevel priority = 8; // Excludes services from the dump. repeated string skip_services = 9; } oneof command { InstallApk install_apk = 1; StartActivity start_activity = 2; ForceStop force_stop = 3; Tap tap = 6; PressButton press_button = 7; StartScreenPinning start_screen_pinning = 10; UninstallPackage uninstall_package = 16; GetCurrentActivity get_current_activity = 17; GetOrientationRequest get_orientation = 24; Push push = 18; Pull pull = 19; InputText input_text = 20; SettingsRequest settings = 21; GenericRequest generic = 22; PackageManagerRequest package_manager = 23; DumpsysRequest dumpsys = 26; SendBroadcast send_broadcast = 25; } // Optional (soft) deadline in seconds for completing this command. // Expected to be >0. If ==0 (the default), it's ignored. // Notice that not all commands accept timeouts, but because it's such a // common parameter, we include it here instead of in each separate command. float timeout_sec = 100; } message AdbResponse { enum Status { // Reserved value for unset statuses. UNDEFINED = 0; // Returned when everything goes well. OK = 1; // Returned when handling unknown AdbRequest commands. UNKNOWN_COMMAND = 2; // Returned when an argument does not respect a precondition. FAILED_PRECONDITION = 3; // Returned when something internal did not work as expected. INTERNAL_ERROR = 4; // Returned when the adb command failed. ADB_ERROR = 5; // Returned when the adb command timed out. TIMEOUT = 6; } Status status = 1; // `error_message` is only populated in case of errors. string error_message = 2; // General stats that components may optionally report. map stats = 3; // Response for GetCurrentActivity requests. message GetCurrentActivityResponse { // The format of the output is `package/package.ActivityName', for example: // "com.example.vokram/com.example.vokram.MainActivity" string full_activity = 1; } // Response for GetOrientationRequests. message GetOrientationResponse { // Possible values are {0, 1, 2, 3} corresponding to {0, 90, 180, 270} // degrees respectively. // Please see https://developer.android.com/reference/android/view/Surface. int32 orientation = 1; } // Response for StartActivity requests. message StartActivityResponse { // The activity that was actually started. On a failed request, this will be // empty. string full_activity = 1; bytes output = 2; } // Response for PressButton requests. message PressButtonResponse { // The output, if any, by `adb` after sending a key press. // This is intentionally left as `bytes` instead of `string` so that content // other than `UTF-8` can be transmitted. bytes output = 1; } // Response for Push requests. message PushResponse {} // Response for Pull requests. message PullResponse { // The contents of the file. // This is intentionally left as `bytes` instead of `string` so that content // other than `UTF-8` can be transmitted. bytes content = 1; } // Response for InputText requests. message InputTextResponse {} // Response for SettingsRequests. message SettingsResponse { // The output, if any, of the `adb shell settings` command. bytes output = 1; } // Response for GenericRequests. message GenericResponse { // The output, if any, of the generic adb command. bytes output = 1; } // Response for PackageManagerRequests. message PackageManagerResponse { // The output, if any, of the `adb shell pm` command. bytes output = 1; message List { // A list of items. The actual content depends on the request, but it // could be things like features, libraries or package names. repeated string items = 1; } oneof verb { List list = 2; } } // Response for DumpsysRequests. message DumpsysResponse { // The output, if any, of the `dumpsys` command. bytes output = 1; } oneof payload { GetCurrentActivityResponse get_current_activity = 10; StartActivityResponse start_activity = 11; PressButtonResponse press_button = 12; PushResponse push = 13; PullResponse pull = 14; InputTextResponse input_text = 15; SettingsResponse settings = 16; GenericResponse generic = 17; PackageManagerResponse package_manager = 18; GetOrientationResponse get_orientation = 19; DumpsysResponse dumpsys = 21; } } ================================================ FILE: android_env/proto/emulator_controller.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Copyright (C) 2018 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Note that if you add/remove methods in this file you must update // the metrics sql as well ./android/scripts/gen-grpc-sql.py // // Please group deleted methods in a block including the date (MM/DD/YY) // it was removed. This enables us to easily keep metrics around after removal // // list of deleted methods // rpc iWasDeleted (03/12/12) // ... // LINT: LEGACY_NAMES syntax = "proto3"; package android.emulation.control; import "google/protobuf/empty.proto"; option java_multiple_files = true; option java_package = "com.android.emulator.control"; option objc_class_prefix = "AEC"; // An EmulatorController service lets you control the emulator. // Note that this is currently an experimental feature, and that the // service definition might change without notice. Use at your own risk! // // We use the following rough conventions: // // streamXXX --> streams values XXX (usually for emulator lifetime). Values // are updated as soon as they become available. // getXXX --> gets a single value XXX // setXXX --> sets a single value XXX, does not returning state, these // usually have an observable lasting side effect. // sendXXX --> send a single event XXX, possibly returning state information. // android usually responds to these events. service EmulatorController { // Set the sensor data rpc streamSensor(SensorValue) returns (stream SensorValue) {} // Get the sensor data rpc getSensor(SensorValue) returns (SensorValue) {} // Stream the sensor data rpc setSensor(SensorValue) returns (google.protobuf.Empty) {} // Set the physical model, this is likely the one you are // looking for when you wish to modify the device state. rpc setPhysicalModel(PhysicalModelValue) returns (google.protobuf.Empty) {} // Get the physical model rpc getPhysicalModel(PhysicalModelValue) returns (PhysicalModelValue) {} // Stream the physical model rpc streamPhysicalModel(PhysicalModelValue) returns (stream PhysicalModelValue) {} // Atomically set the current primary clipboard data. rpc setClipboard(ClipData) returns (google.protobuf.Empty) {} // Atomically get the current primary clipboard data. rpc getClipboard(google.protobuf.Empty) returns (ClipData) {} // Streams the current data on the clipboard. This will immediately produce // a result with the current state of the clipboard after which the stream // will block and wait until a new clip event is available from the guest. // Calling the setClipboard method above will not result in generating a clip // event. It is possible to lose clipboard events if the clipboard updates // very rapidly. rpc streamClipboard(google.protobuf.Empty) returns (stream ClipData) {} // Set the battery to the given state. rpc setBattery(BatteryState) returns (google.protobuf.Empty) {} // Get the battery to the given state. rpc getBattery(google.protobuf.Empty) returns (BatteryState) {} // Set the state of the gps, gps support will only work // properly if: // // - no location ui is active. That is the emulator // is launched in headless mode (-no-window) or the location // ui is disabled (-no-location-ui). // - the passiveUpdate is set to false. Setting this to false // will disable/break the LocationUI. // // Keep in mind that android usually only samples the gps at 1 hz. rpc setGps(GpsState) returns (google.protobuf.Empty) {} // Gets the latest gps state as delivered by the setGps call, or location ui // if active. // // Note: this is not necessarily the actual gps coordinate visible at the // time, due to gps sample frequency (usually 1hz). rpc getGps(google.protobuf.Empty) returns (GpsState) {} // Simulate a touch event on the finger print sensor. rpc sendFingerprint(Fingerprint) returns (google.protobuf.Empty) {} // Send a keyboard event. Translating the event. rpc sendKey(KeyboardEvent) returns (google.protobuf.Empty) {} // Send touch events. Note that mouse events can be simulated by touch events. rpc sendTouch(TouchEvent) returns (google.protobuf.Empty) {} // Send mouse events. rpc sendMouse(MouseEvent) returns (google.protobuf.Empty) {} // Make a phone call. rpc sendPhone(PhoneCall) returns (PhoneResponse) {} // Sends an sms message to the emulator. rpc sendSms(SmsMessage) returns (PhoneResponse) {} // Retrieve the status of the emulator. This will contain general // hardware information, and whether the device has booted or not. rpc getStatus(google.protobuf.Empty) returns (EmulatorStatus) {} // Gets an individual screenshot in the desired format. // // The image will be scaled to the desired ImageFormat, while maintaining // the aspect ratio. The returned image will never exceed the provided width // and height. Not setting the width or height (i.e. they are 0) will result // in using the device width and height. // // The resulting image will be properly oriented and can be displayed // directly without post processing. For example, if the device has a // 1080x1920 screen and is in landscape mode and called with no width or // height parameter, it will return an 1920x1080 image. // // This method will return an empty image if the display is not visible. rpc getScreenshot(ImageFormat) returns (Image) {} // Streams a series of screenshots in the desired format. // A new frame will be delivered whenever the device produces a new frame. // (Beware that this can produce a significant amount of data, and that // certain translations are (png transform) can be costly). // // If the requested display is not visible it will send a single empty image // and wait start producing images once the display becomes active, again // producing a single empty image when the display becomes inactive. rpc streamScreenshot(ImageFormat) returns (stream Image) {} // Streams a series of audio packets in the desired format. // A new frame will be delivered whenever the emulated device // produces a new audio frame. You can expect packets to be // delivered in intervals of 20-30ms. // // Be aware that this can block when the emulator does not // produce any audio whatsoever! rpc streamAudio(AudioFormat) returns (stream AudioPacket) {} // Injects a series of audio packets to the android microphone. // A new frame will be delivered whenever the emulated device // requests a new audio frame. Audio is usually delivered at a rate // that the emulator is requesting frames. Audio will be stored in a // temporary buffer that can hold 500ms of audio. // // Note: Currently the emulator will downsample to 16khz. // // - INVALID_ARGUMENT (code 3) The sampling rate was too high // - INVALID_ARGUMENT (code 3) The audio packet was too large to handle. // - FAILED_PRECONDITION (code 9) If there was a microphone registered // already. rpc injectAudio(stream AudioPacket) returns (google.protobuf.Empty) {} // Returns the last 128Kb of logcat output from the emulator // Note that parsed logcat messages are only available after L (Api >23). // it is possible that the logcat buffer gets overwritten, or falls behind. rpc getLogcat(LogMessage) returns (LogMessage) {} // Streams the logcat output from the emulator. The first call // can retrieve up to 128Kb. This call will not return. // Note that parsed logcat messages are only available after L (Api >23) // it is possible that the logcat buffer gets overwritten, or falls behind. rpc streamLogcat(LogMessage) returns (stream LogMessage) {} // Transition the virtual machine to the desired state. Note that // some states are only observable. For example you cannot transition // to the error state. rpc setVmState(VmRunState) returns (google.protobuf.Empty) {} // Gets the state of the virtual machine. rpc getVmState(google.protobuf.Empty) returns (VmRunState) {} // Atomically changes the current multi-display configuration. // After this call the given display configurations will be activated. You // can only update secondary displays. Displays with id 0 will be ignored. // // This call can result in the removal or addition of secondary displays, the // final display state can be observed by the returned configuration. // // The following gRPC error codes can be returned: // - FAILED_PRECONDITION (code 9) if the AVD does not support a configurable // secondary display. // - INVALID_ARGUMENT (code 3) if: // - The same display id is defined multiple times. // - The display configurations are outside valid ranges // (see DisplayConfiguration) // - INTERNAL (code 13) if there was an internal emulator failure. rpc setDisplayConfigurations(DisplayConfigurations) returns (DisplayConfigurations) {} // Returns all currently valid logical displays. // The gRPC error code FAILED_PRECONDITION (code 9) is returned if the AVD // does not support a configurable secondary display. rpc getDisplayConfigurations(google.protobuf.Empty) returns (DisplayConfigurations) {} // Notifies client of the following changes: // // - Virtual scene camera status change. // - Display configuration changes from extended ui. This will only be fired // if the user makes modifications the extended displays through the // extended control tab. // // Note that this method will send the initial virtual scene state // immediately. rpc streamNotification(google.protobuf.Empty) returns (stream Notification) {} // RotationRadian is relative to the camera's current orientation. rpc rotateVirtualSceneCamera(RotationRadian) returns (google.protobuf.Empty) { } // Velocity is absolute rpc setVirtualSceneCameraVelocity(Velocity) returns (google.protobuf.Empty) {} // Set foldable posture rpc setPosture(Posture) returns (google.protobuf.Empty) {} } // A Run State that describes the state of the Virtual Machine. message VmRunState { enum RunState { // The emulator is in an unknown state. You cannot transition to this state. UNKNOWN = 0; // Guest is actively running. You can transition to this state from the // paused state. RUNNING = 1; // Guest is paused to load a snapshot. You cannot transition to this state. RESTORE_VM = 2; // Guest has been paused. Transitioning to this state will pause the // emulator the guest will not be consuming any cpu cycles. PAUSED = 3; // Guest is paused to take or export a snapshot. You cannot // transition to this state. SAVE_VM = 4; // System shutdown, note that it is similar to power off. It tries to set // the system status and notify guest. The system is likely going to // disappear soon and do proper cleanup of resources, possibly taking // a snapshot. This is the same behavior as closing the emulator by clicking // the X (close) in the user interface. SHUTDOWN = 5; // Immediately terminate the emulator. No resource cleanup will take place. // There is a good change to corrupt the system. TERMINATE = 7; // Will cause the emulator to reset. This is not a state you can observe. RESET = 9; // Guest experienced some error state, you cannot transition to this state. INTERNAL_ERROR = 10; } RunState state = 1; } message ParameterValue { repeated float data = 1 [packed = true]; } message PhysicalModelValue { enum State { OK = 0; NO_SERVICE = -3; // qemud service is not available/initiated. DISABLED = -2; // Sensor is disabled. UNKNOWN = -1; // Unknown sensor (should not happen) } // Details on the sensors documentation can be found here: // https://developer.android.com/reference/android/hardware/Sensor.html#TYPE_ // The types must follow the order defined in // "external/qemu/android/hw-sensors.h" enum PhysicalType { POSITION = 0; // All values are angles in degrees. // values = [x,y,z] ROTATION = 1; MAGNETIC_FIELD = 2; // Temperature in °C TEMPERATURE = 3; // Proximity sensor distance measured in centimeters PROXIMITY = 4; // Ambient light level in SI lux units LIGHT = 5; // Atmospheric pressure in hPa (millibar) PRESSURE = 6; // Relative ambient air humidity in percent HUMIDITY = 7; VELOCITY = 8; AMBIENT_MOTION = 9; // Describing a hinge angle sensor in degrees. HINGE_ANGLE0 = 10; HINGE_ANGLE1 = 11; HINGE_ANGLE2 = 12; ROLLABLE0 = 13; ROLLABLE1 = 14; ROLLABLE2 = 15; } PhysicalType target = 1; // [Output Only] State status = 2; // Value interpretation depends on sensor, will contain at most 3 values. ParameterValue value = 3; } // A single sensor value. message SensorValue { enum State { OK = 0; NO_SERVICE = -3; // qemud service is not available/initiated. DISABLED = -2; // Sensor is disabled. UNKNOWN = -1; // Unknown sensor (should not happen) } // These are the various sensors that can be available in an emulated // devices. enum SensorType { // Measures the acceleration force in m/s2 that is applied to a device // on all three physical axes (x, y, and z), including the force of // gravity. ACCELERATION = 0; // Measures a device's rate of rotation in rad/s around each of the // three physical axes (x, y, and z). GYROSCOPE = 1; // Measures the ambient geomagnetic field for all three physical axes // (x, y, z) in μT. MAGNETIC_FIELD = 2; // Measures degrees of rotation that a device makes around all three // physical axes (x, y, z) ORIENTATION = 3; // Measures the temperature of the device in degrees Celsius (°C). TEMPERATURE = 4; // Measures the proximity of an object in cm relative to the view screen // of a device. This sensor is typically used to determine whether a // handset is being held up to a person's ear. PROXIMITY = 5; // Measures the ambient light level (illumination) in lx. LIGHT = 6; // Measures the ambient air pressure in hPa or mbar. PRESSURE = 7; // Measures the relative ambient humidity in percent (%). HUMIDITY = 8; MAGNETIC_FIELD_UNCALIBRATED = 9; GYROSCOPE_UNCALIBRATED = 10; } // Type of sensor SensorType target = 1; // [Output Only] State status = 2; // Value interpretation depends on sensor enum, will contain at most 3 // values. ParameterValue value = 3; } message LogMessage { // [Output Only] The contents of the log output. string contents = 1; // The starting byte position of the output that was returned. This // should match the start parameter sent with the request. If the serial // console output exceeds the size of the buffer, older output will be // overwritten by newer content and the start values will be mismatched. int64 start = 2; //[Output Only] The position of the next byte of content from the serial // console output. Use this value in the next request as the start // parameter. int64 next = 3; // Set the sort of response you are interested it in. // It the type is "Parsed" the entries field will contain the parsed // results. otherwise the contents field will be set. LogType sort = 4; // [Output Only] The parsed logcat entries so far. Only set if sort is // set to Parsed repeated LogcatEntry entries = 5; enum LogType { Text = 0; Parsed = 1; } } // A parsed logcat entry. message LogcatEntry { // The possible log levels. enum LogLevel { UNKNOWN = 0; DEFAULT = 1; VERBOSE = 2; DEBUG = 3; INFO = 4; WARN = 5; ERR = 6; FATAL = 7; SILENT = 8; } // A Unix timestamps in milliseconds (The number of milliseconds that // have elapsed since January 1, 1970 (midnight UTC/GMT), not counting // leap seconds) uint64 timestamp = 1; // Process id. uint32 pid = 2; // Thread id. uint32 tid = 3; LogLevel level = 4; string tag = 5; string msg = 6; } // Information about the hypervisor that is currently in use. message VmConfiguration { enum VmHypervisorType { // An unknown hypervisor UNKNOWN = 0; // No hypervisor is in use. This usually means that the guest is // running on a different CPU than the host, or you are using a // platform where no hypervisor is available. NONE = 1; // The Kernel based Virtual Machine // (https://www.linux-kvm.org/page/Main_Page) KVM = 2; // Intel® Hardware Accelerated Execution Manager (Intel® HAXM) // https://github.com/intel/haxm HAXM = 3; // Hypervisor Framework. // https://developer.apple.com/documentation/hypervisor HVF = 4; // Window Hypervisor Platform // https://docs.microsoft.com/en-us/virtualization/api/ WHPX = 5; GVM = 6; } VmHypervisorType hypervisorType = 1; int32 numberOfCpuCores = 2; int64 ramSizeBytes = 3; } // Representation of a clipped data object on the clipboard. message ClipData { // UTF-8 Encoded text. string text = 1; } // The Touch interface represents a single contact point on a // touch-sensitive device. The contact point is commonly a finger or stylus // and the device may be a touchscreen or trackpad. message Touch { // The horizontal coordinate. This is the physical location on the // screen For example 0 indicates the leftmost coordinate. int32 x = 1; // The vertical coordinate. This is the physical location on the screen // For example 0 indicates the top left coordinate. int32 y = 2; // The identifier is an arbitrary non-negative integer that is used to // identify and track each tool independently when multiple tools are // active. For example, when multiple fingers are touching the device, // each finger should be assigned a distinct tracking id that is used as // long as the finger remains in contact. Tracking ids may be reused // when their associated tools move out of range. // // The emulator currently supports up to 10 concurrent touch events. The // identifier can be any uninque value and will be mapped to the next // available internal identifier. int32 identifier = 3; // Reports the physical pressure applied to the tip of the tool or the // signal strength of the touch contact. // // The values reported must be non-zero when the tool is touching the // device and zero otherwise to indicate that the touch event is // completed. // // Make sure to deliver a pressure of 0 for the given identifier when // the touch event is completed, otherwise the touch identifier will not // be unregistered! int32 pressure = 4; // Optionally reports the cross-sectional area of the touch contact, or // the length of the longer dimension of the touch contact. int32 touch_major = 5; // Optionally reports the length of the shorter dimension of the touch // contact. This axis will be ignored if touch_major is reporting an // area measurement greater than 0. int32 touch_minor = 6; enum EventExpiration { // The system will use the default time of 120s to track // the touch event with the given identifier. If no update happens // within this timeframe the identifier is considered expired // and can be made available for re-use. This means that a touch event // with pressure 0 for this identifier will be send to the emulator. EVENT_EXPIRATION_UNSPECIFIED = 0; // Never expire the given slot. You must *ALWAYS* close the identifier // by sending a touch event with 0 pressure. NEVER_EXPIRE = 1; } EventExpiration expiration = 7; } // A TouchEvent contains a list of Touch objects that are in contact with // the touch surface. // // Touch events are delivered in sequence as specified in the touchList. // // TouchEvents are delivered to the emulated devices using ["Protocol // B"](https://www.kernel.org/doc/Documentation/input/multi-touch-protocol.txt) message TouchEvent { // The list of Touch objects, note that these do not need to be unique repeated Touch touches = 1; // The display device where the touch event occurred. // Omitting or using the value 0 indicates the main display. // // Touch events cannot be send to displays other than 0, due to // https://issuetracker.google.com/issues/150699691 int32 display = 2; } // The MouseEvent interface represents events that occur due to the user // interacting with a pointing device (such as a mouse). message MouseEvent { // The horizontal coordinate. This is the physical location on the // screen For example 0 indicates the leftmost coordinate. int32 x = 1; // The vertical coordinate. This is the physical location on the screen // For example 0 indicates the top left coordinate. int32 y = 2; // Indicates which buttons are pressed. // 0: No button was pressed // 1: Primary button (left) // 2: Secondary button (right) int32 buttons = 3; // The display device where the mouse event occurred. // Omitting or using the value 0 indicates the main display. int32 display = 4; } // KeyboardEvent objects describe a user interaction with the keyboard; each // event describes a single interaction between the user and a key (or // combination of a key with modifier keys) on the keyboard. // This follows the pattern as set by // (javascript)[https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent] // // Note: that only keyCode, key, or text can be set and that the semantics // will slightly vary. message KeyboardEvent { // Code types that the emulator can receive. Note that the emulator // will do its best to translate the code to an evdev value that // will be send to the emulator. This translation is based on // the chromium translation tables. See // (this)[https://android.googlesource.com/platform/external/qemu/+/refs/heads/emu-master-dev/android/android-grpc/android/emulation/control/keyboard/keycode_converter_data.inc] // for details on the translation. enum KeyCodeType { Usb = 0; Evdev = 1; XKB = 2; Win = 3; Mac = 4; } enum KeyEventType { // Indicates that this keyevent should be send to the emulator // as a key down event. Meaning that the key event will be // translated to an EvDev event type and bit 11 (0x400) will be // set before it is sent to the emulator. keydown = 0; // Indicates that the keyevent should be send to the emulator // as a key up event. Meaning that the key event will be // translated to an EvDev event type and // sent to the emulator. keyup = 1; // Indicates that the keyevent will be send to the emulator // as e key down event and immediately followed by a keyup event. keypress = 2; } // Type of keycode contained in the keyCode field. KeyCodeType codeType = 1; // The type of keyboard event that should be sent to the emulator KeyEventType eventType = 2; // This property represents a physical key on the keyboard (as opposed // to the character generated by pressing the key). In other words, this // property is a value which isn't altered by keyboard layout or the // state of the modifier keys. This value will be interpreted by the // emulator depending on the KeyCodeType. The incoming key code will be // translated to an evdev code type and send to the emulator. // The values in key and text will be ignored. int32 keyCode = 3; // The value of the key pressed by the user, taking into consideration // the state of modifier keys such as Shift as well as the keyboard // locale and layout. This follows the w3c standard used in browsers. // You can find an accurate description of valid values // [here](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values) // // Note that some keys can result in multiple evdev events that are // delivered to the emulator. for example the Key "A" will result in a // sequence: // ["Shift", "a"] -> [0x2a, 0x1e] whereas "a" results in ["a"] -> [0x1e]. // // Not all documented keys are understood by android, and only printable // ASCII [32-127) characters are properly translated. // // Keep in mind that there are a set of key values that result in android // specific behavior // [see](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values#Phone_keys): // // - "AppSwitch": Behaves as the "Overview" button in android. // - "GoBack": The Back button. // - "GoHome": The Home button, which takes the user to the phone's main // screen (usually an application launcher). // - "Power": The Power button. string key = 4; // Series of utf8 encoded characters to send to the emulator. An attempt // will be made to translate every character will an EvDev event type and // send to the emulator as a keypress event. The values in keyCode, // eventType, codeType and key will be ignored. // // Note that most printable ASCII characters (range [32-127) can be send // individually with the "key" param. Do not expect arbitrary UTF symbols to // arrive in the emulator (most will be ignored). // // Note that it is possible to overrun the keyboard buffer by slamming this // endpoint with large quantities of text (>1kb). The clipboard api is better // suited for transferring large quantities of text. string text = 5; } message Fingerprint { // True when the fingprint is touched. bool isTouching = 1; // The identifier of the registered fingerprint. int32 touchId = 2; } message GpsState { // Setting this to false will disable auto updating from the LocationUI, // otherwise the location UI will override the location at a frequency of 1hz. // // - This is unused if the emulator is launched with -no-window, or when he // location ui is disabled. // - This will BREAK the location ui experience if it is set to false. For // example routing will no longer function. bool passiveUpdate = 1; // The latitude, in degrees. double latitude = 2; // The longitude, in degrees. double longitude = 3; // The speed if it is available, in meters/second over ground double speed = 4; // gets the horizontal direction of travel of this device, and is not // related to the device orientation. It is guaranteed to be in the // range [0.0, 360.0] if the device has a bearing. 0=North, 90=East, // 180=South, etc.. double bearing = 5; // The altitude if available, in meters above the WGS 84 reference // ellipsoid. double altitude = 6; // The number of satellites used to derive the fix int32 satellites = 7; } message BatteryState { enum BatteryStatus { UNKNOWN = 0; CHARGING = 1; DISCHARGING = 2; NOT_CHARGING = 3; FULL = 4; } enum BatteryCharger { NONE = 0; AC = 1; USB = 2; WIRELESS = 3; } enum BatteryHealth { GOOD = 0; FAILED = 1; DEAD = 2; OVERVOLTAGE = 3; OVERHEATED = 4; } bool hasBattery = 1; bool isPresent = 2; BatteryCharger charger = 3; int32 chargeLevel = 4; BatteryHealth health = 5; BatteryStatus status = 6; } // An ImageTransport allows for specifying a side channel for // delivering image frames versus using the standard bytes array that is // returned with the gRPC request. message ImageTransport { enum TransportChannel { // Return full frames over the gRPC transport TRANSPORT_CHANNEL_UNSPECIFIED = 0; // Write images to the a file/shared memory handle. MMAP = 1; } // The desired transport channel used for delivering image frames. Only // relevant when streaming screenshots. TransportChannel channel = 1; // Handle used for writing image frames if transport is mmap. The client sets // and owns this handle. It can be either a shm region, or a mmap. A mmap // should be a url that starts with `file:///` // Note: the mmap can result in tearing. string handle = 2; } // The aspect ratio (width/height) will be different from the one // where the device is unfolded. message FoldedDisplay { uint32 width = 1; uint32 height = 2; // It is possible for the screen to be folded in different ways depending // on which surface is shown to the user. So xOffset and yOffset indicate // the top left corner of the folded screen within the original unfolded // screen. uint32 xOffset = 3; uint32 yOffset = 4; } message ImageFormat { enum ImgFormat { // Portable Network Graphics format // (https://en.wikipedia.org/wiki/Portable_Network_Graphics) PNG = 0; // Three-channel RGB color model supplemented with a fourth alpha // channel. https://en.wikipedia.org/wiki/RGBA_color_model // Each pixel consists of 4 bytes. RGBA8888 = 1; // Three-channel RGB color model, each pixel consists of 3 bytes RGB888 = 2; } // The (desired) format of the resulting bytes. ImgFormat format = 1; // [Output Only] The rotation of the image. The image will be rotated // based upon the coarse grained orientation of the device. Rotation rotation = 2; // The (desired) width of the image. When passed as input // the image will be scaled to match the given // width, while maintaining the aspect ratio of the device. // The returned image will never exceed the given width, but can be less. // Omitting this value (or passing in 0) will result in no scaling, // and the width of the actual device will be used. uint32 width = 3; // The (desired) height of the image. When passed as input // the image will be scaled to match the given // height, while maintaining the aspect ratio of the device. // The returned image will never exceed the given height, but can be less. // Omitting this value (or passing in 0) will result in no scaling, // and the height of the actual device will be used. uint32 height = 4; // The (desired) display id of the device. Setting this to 0 (or omitting) // indicates the main display. uint32 display = 5; // Set this if you wish to use a different transport channel to deliver image // frames. ImageTransport transport = 6; // [Output Only] Display configuration when screen is folded. The value is the // original configuration before scaling. FoldedDisplay foldedDisplay = 7; } message Image { ImageFormat format = 1; uint32 width = 2 [deprecated = true]; // width is contained in format. uint32 height = 3 [deprecated = true]; // height is contained in format. // The organization of the pixels in the image buffer is from left to // right and bottom up. This will be empty if an alternative image transport // is requested in the image format. In that case the side channel should // be used to obtain the image data. bytes image = 4; // [Output Only] Monotonically increasing sequence number in a stream of // screenshots. The first screenshot will have a sequence of 0. A single // screenshot will always have a sequence number of 0. The sequence is not // necessarily contiguous, and can be used to detect how many frames were // dropped. An example sequence could be: [0, 3, 5, 7, 9, 11]. uint32 seq = 5; // [Output Only] Unix timestamp in microseconds when the emulator estimates // the frame was generated. The timestamp is before the actual frame is // copied and transformed. This can be used to calculate variance between // frame production time, and frame depiction time. uint64 timestampUs = 6; } message Rotation { enum SkinRotation { PORTRAIT = 0; // 0 degrees LANDSCAPE = 1; // 90 degrees REVERSE_PORTRAIT = 2; // -180 degrees REVERSE_LANDSCAPE = 3; // -90 degrees } // The rotation of the device, derived from the sensor state // of the emulator. The derivation reflects how android observes // the rotation state. SkinRotation rotation = 1; // Specifies the angle of rotation, in degrees [-180, 180] double xAxis = 2; double yAxis = 3; double zAxis = 4; } message PhoneCall { enum Operation { InitCall = 0; AcceptCall = 1; RejectCallExplicit = 2; RejectCallBusy = 3; DisconnectCall = 4; PlaceCallOnHold = 5; TakeCallOffHold = 6; } Operation operation = 1; string number = 2; } message PhoneResponse { enum Response { OK = 0; BadOperation = 1; // Enum out of range BadNumber = 2; // Mal-formed telephone number InvalidAction = 3; // E.g., disconnect when no call is in progress ActionFailed = 4; // Internal error RadioOff = 5; // Radio power off } Response response = 1; } message Entry { string key = 1; string value = 2; } message EntryList { repeated Entry entry = 1; } message EmulatorStatus { // The emulator version string. string version = 1; // The time the emulator has been active in .ms uint64 uptime = 2; // True if the device has completed booting. // For P and later this information will accurate, // for older images we rely on adb. bool booted = 3; // The current vm configuration VmConfiguration vmConfig = 4; // The hardware configuration of the running emulator as // key valure pairs. EntryList hardwareConfig = 5; } message AudioFormat { enum SampleFormat { AUD_FMT_U8 = 0; // Unsigned 8 bit AUD_FMT_S16 = 1; // Signed 16 bit (little endian) } enum Channels { Mono = 0; Stereo = 1; } // Sampling rate to use, defaulting to 44100 if this is not set. // Note, that android devices typically will not use a sampling // rate higher than 48kHz. See https://developer.android.com/ndk/guides/audio. uint64 samplingRate = 1; Channels channels = 2; SampleFormat format = 3; } message AudioPacket { AudioFormat format = 1; // Unix epoch in us when this frame was captured. uint64 timestamp = 2; // Contains a sample in the given audio format. bytes audio = 3; } message SmsMessage { // The source address where this message came from. // // The address should be a valid GSM-formatted address as specified by // 3GPP 23.040 Sec 9.1.2.5. // // For example: +3106225412 or (650) 555-1221 string srcAddress = 1; // A utf8 encoded text message that should be delivered. string text = 2; } // A DisplayConfiguration describes a primary or secondary // display available to the emulator. The screen aspect ratio // cannot be longer (or wider) than 21:9 (or 9:21). Screen sizes // larger than 4k will be rejected. // // Common configurations (w x h) are: // - 480p (480x720) 142 dpi // - 720p (720x1280) 213 dpi // - 1080p (1080x1920) 320 dpi // - 4K (2160x3840) 320 dpi // - 4K (2160x3840) 640 dpi (upscaled) // // The behavior of the virtual display depends on the flags that are provided to // this method. By default, virtual displays are created to be private, // non-presentation and unsecure. message DisplayConfiguration { // These are the set of known android flags and their respective values. // you can combine the int values to (de)construct the flags field below. enum DisplayFlags { DISPLAYFLAGS_UNSPECIFIED = 0; // When this flag is set, the virtual display is public. // A public virtual display behaves just like most any other display // that is connected to the system such as an external or wireless // display. Applications can open windows on the display and the system // may mirror the contents of other displays onto it. see: // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PUBLIC VIRTUAL_DISPLAY_FLAG_PUBLIC = 1; // When this flag is set, the virtual display is registered as a // presentation display in the presentation display category. // Applications may automatically project their content to presentation // displays to provide richer second screen experiences. // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PRESENTATION VIRTUAL_DISPLAY_FLAG_PRESENTATION = 2; // When this flag is set, the virtual display is considered secure as // defined by the Display#FLAG_SECURE display flag. The caller promises // to take reasonable measures, such as over-the-air encryption, to // prevent the contents of the display from being intercepted or // recorded on a persistent medium. // see: // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_SECURE VIRTUAL_DISPLAY_FLAG_SECURE = 4; // This flag is used in conjunction with VIRTUAL_DISPLAY_FLAG_PUBLIC. // Ordinarily public virtual displays will automatically mirror the // content of the default display if they have no windows of their own. // When this flag is specified, the virtual display will only ever show // its own content and will be blanked instead if it has no windows. See // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY = 8; // Allows content to be mirrored on private displays when no content is // being shown. // This flag is mutually exclusive with // VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY. If both flags are specified // then the own-content only behavior will be applied. // see: // https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR) VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR = 16; } // The width of the display, restricted to: // 320 * (dpi / 160) <= width uint32 width = 1; // The heigh of the display, restricted to: // * 320 * (dpi / 160) <= height uint32 height = 2; // The pixel density (dpi). // See https://developer.android.com/training/multiscreen/screendensities // for details. This value should be in the range [120, ..., 640] uint32 dpi = 3; // A combination of virtual display flags. These flags can be constructed // by combining the DisplayFlags enum described above. // // The behavior of the virtual display depends on the flags. By default // virtual displays are created to be private, non-presentation and // unsecure. uint32 flags = 4; // The id of the display. // The primary (default) display has the display ID of 0. // A secondary display has a display ID not 0. // // The id can be used to get or stream a screenshot. uint32 display = 5; } message DisplayConfigurations { repeated DisplayConfiguration displays = 1; } message Notification { enum EventType { VIRTUAL_SCENE_CAMERA_INACTIVE = 0; VIRTUAL_SCENE_CAMERA_ACTIVE = 1; // Fired when an update to a display event has been fired through // the extended ui. This does not fire events when the display // is changed through the console or gRPC endpoint. DISPLAY_CONFIGURATIONS_CHANGED_UI = 2; // Keep adding more for other event types } EventType event = 1; } message RotationRadian { float x = 1; // x axis is horizontal and orthogonal to the view direction. float y = 2; // y axis points up and is perpendicular to the floor. float z = 3; // z axis is the view direction and is set to 0.0 in // rotateVirtualSceneCamera call. } message Velocity { float x = 1; // x axis is horizontal and orthogonal to the view direction. float y = 2; // y axis points up and is perpendicular to the floor. float z = 3; // z axis is the view direction } // must follow the definition in "external/qemu/android/hw-sensors.h" message Posture { enum PostureValue { POSTURE_UNKNOWN = 0; POSTURE_CLOSED = 1; POSTURE_HALF_OPENED = 2; POSTURE_OPENED = 3; POSTURE_FLIPPED = 4; POSTURE_TENT = 5; POSTURE_MAX = 6; } PostureValue value = 3; } ================================================ FILE: android_env/proto/snapshot.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Copyright (C) 2018 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto2"; // This file must be synchronized between // Emulator (branch aosp/emu-master-dev): // external/qemu/android/android-emu/android/snapshot/proto/snapshot.proto // // Android Studio (branch goog/studio-master-dev): // tools/adt/idea/android/src/com/android/emulator/snapshot.proto // // If you modify one, please modify the other. package emulator_snapshot; option java_package = "com.android.emulator.snapshot"; message Image { enum Type { IMAGE_TYPE_UNKNOWN = 0; IMAGE_TYPE_KERNEL = 1; IMAGE_TYPE_KERNEL_RANCHU = 2; IMAGE_TYPE_SYSTEM = 3; IMAGE_TYPE_SYSTEM_COPY = 4; IMAGE_TYPE_DATA = 5; IMAGE_TYPE_DATA_COPY = 6; IMAGE_TYPE_RAMDISK = 7; IMAGE_TYPE_SDCARD = 8; IMAGE_TYPE_CACHE = 9; IMAGE_TYPE_VENDOR = 10; IMAGE_TYPE_ENCRYPTION_KEY = 11; } optional Type type = 1; optional string path = 2; optional bool present = 3; optional int64 size = 4; optional int64 modification_time = 5; } message Host { optional string gpu_driver = 4; optional int32 hypervisor = 5; } message Config { // Features are int32, not enums here to make sure we don't have to update // one more protobuf definition with every single new feature flag, even // when the code doesn't really care about the actual meaning for them, // only for the values. repeated int32 enabled_features = 1; // This holds the renderer; int32 for the same reason as |enabled_features|. optional int32 selected_renderer = 2; optional int32 cpu_core_count = 3; optional int64 ram_size_bytes = 4; } message SaveStats { // Type of save // 0: non-incremental // 1: incremental optional uint32 incremental = 1; // Time taken to save. optional uint64 duration = 2; // How many changed bytes in RAM. optional uint64 ram_changed_bytes = 3; } message Snapshot { // Update every time when introducing some breaking changes that make the // previous loading code break when trying to load the new snapshot. // NOTE: if the old code is fine with just skipping the new fields or not // getting the meaning of new values, |version| should remain // unchanged. optional int32 version = 1; // Purely informative: when this snapshot was created, Unix timestamp. optional int64 creation_time = 2; // list of mounted disk images used during the snapshot creation. repeated Image images = 3; // Description of the host machine properties needed to load this snapshot. optional Host host = 4; // Description of the emulator configuration needed for this snapshot. // NOTE: try not to duplicate the configuration that's already in // hardware-qemu.ini; only add what's either not there or what // could've been overridden during process initialization. optional Config config = 5; // Set if the snapshot failed to load during the last attempt. // Code is up to the application to define, with 0 meaning 'not failed' just // in case. optional int64 failed_to_load_reason_code = 7; // Set if data image is mounted. // User build and userdebug build mount data partition at different time. // But it should be done before boot finished, so this field is very likely // to be true. // We snapshot it here just in case someday we support snapshot during // booting. optional bool guest_data_partition_mounted = 8; // Emulator rotation angle, in right angles (e.g. 1 is 90 degrees, 2 is 180 // etc). optional int32 rotation = 9; // Number of invalid loads / crashes that happened under this snapshot. optional int32 invalid_loads = 10; // Number of successful loads. optional int32 successful_loads = 11; // The name given to the snapshot by the user. Independent of the // file name. optional string logical_name = 12; // The file name of this snapshot's parent. The parent is the // snapshot that was loaded into the AVD prior to this snapshot // being taken optional string parent = 13; // Arbitrary description added by the user optional string description = 14; // Record of save stats. repeated SaveStats save_stats = 15; // Folded state. optional bool folded = 16; // Emulator boot parameters repeated string launch_parameters = 17; // Emulator build ID optional string emulator_build_id = 18; // System image build ID optional string system_image_build_id = 19; } ================================================ FILE: android_env/proto/snapshot_service.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Copyright (C) 2018 The Android Open Source Project // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Note that if you add/remove methods in this file you must update // the metrics sql as well by running ./android/scripts/gen-grpc-sql.py // // Please group deleted methods in a block including the date (MM/DD/YY) // it was removed. This enables us to easily keep metrics around after removal // // list of deleted methods // rpc iWasDeleted (03/12/12) // ... syntax = "proto3"; package android.emulation.control; import "android_env/proto/snapshot.proto"; option java_multiple_files = true; option java_package = "com.android.emulator.control"; option objc_class_prefix = "AEC"; // The SnapshotService enables you to list, insert, store, and retrieve // snapshots. // // Currently there are two types of snapshots: // // - Local (default): These are snapshots that are created locally. They are // stored internally inside qcow2 files and are very efficient. These are // the snapshots usually created by interacting with the UI. // // - Remote: These are snapshots that have been exported at a certain point. // an exported snapshot is normalized (completely self contained) and // can be imported into an emulator with a similar hardware configuration. // // Currently the emulator has limited support for importing snapshots: // - Once an imported snapshot has been loaded into an emulator it is no longer // possible to create new snapshots. // - The hardware configuration of the emulator your are pushing a snapshot to // must match (or be very similar) to the one you pulled the snapshot from. // // For example do not expect to be able to restore a snapshot on created on an // Intel cpu on an AMD cpu. service SnapshotService { // Lists all the snapshots, filtered by the given query, that are stored // locally for the currently running avd. This includes all the snapshots that // were imported (pushed) into this emulator. // // Returns a list of snapshot_id's and associated details that describes // the hardware configuration, logical name, etc of the snapshot. rpc ListSnapshots(SnapshotFilter) returns (SnapshotList) {} // Pulls down the snapshot stored inside the AVD as a tar.gz/tar stream // This will normalize the snapshot, all relevant data to push a snapshot // into a similar emulator will be placed inside the tar file. // // Pulling down a snapshot will pause the emulator until the snapshots // are rebased and ready for exporting. Once the snapshot is rebased // the emulator will continue and downloading should commence. // // Note that pulling .gz stream is slow. // // You must provide the snapshot_id and (desired) format. // // If SnapshotPackage.path is set, the gRPC service will directly write the // exported snapshot to SnapshotPackage.path without streaming, which is // usually significantly faster. It would require emulator to have direct // access to SnapshotPackage.path, which usually means it can only be used // when pulling from a local emulator. rpc PullSnapshot(SnapshotPackage) returns (stream SnapshotPackage) {} // Push a tar.gz stream contain the snapshot. The tar file should // be a snapshot that was exported through the PullSnapshot in the past. // The emulator will try to import the snapshot. The hardware configuration // of the current emulator should match the one used for pulling. // // A detailed description of the snapshot (emulator_snapshot.Snapshot) // is stored in the snapshot.pb file inside the tar. // // You must provide the snapshot_id and format in the first message. // Will return success and a possible error message when a failure occurs. // // If SnapshotPackage.path is set, the gRPC service will directly unzip the // exported snapshot from SnapshotPackage.path without streaming, which is // usually significantly faster. It would require emulator to have direct // access to SnapshotPackage.path, which usually means it can only be used // when pushing to a local emulator. rpc PushSnapshot(stream SnapshotPackage) returns (SnapshotPackage) {} // Loads the given snapshot inside the emulator and activates it. // The device will be in the state as it was when the snapshot was created. // // You will no longer be able to call Save if this was an imported // snapshot that was pushed into this emulator. // // You must provide the snapshot_id to indicate which snapshot to load // Will return success and a possible error message when a failure occurs. rpc LoadSnapshot(SnapshotPackage) returns (SnapshotPackage) {} // Creates as a snapshot of the current state of the emulator. // You can only save a snapshot if you never activated (Load) an imported // snapshot (Push). // // For example: // - PushSnapshot("some_snap.tar.gz"); // - LoadSnapshot("some_snap"); // - SaveSnapshot("same_newer_snap"); // <--- Will currently fail. // // You can provide the snapshot_id to indicate the name used for storing. // Will return success and a possible error message when a failure occurs. rpc SaveSnapshot(SnapshotPackage) returns (SnapshotPackage) {} // Deletes the snapshot with the given snapshot_id from the avd. // // You must provide the snapshot_id to indicate which snapshot to delete. // Will return success and a possible error message when a failure occurs. rpc DeleteSnapshot(SnapshotPackage) returns (SnapshotPackage) {} // Tracks the given process for automated snapshot creation in case of // assert failures. // // Will return success and a possible error message when a failure occurs. // The snapshot_id field will contain the name of the snapshot that // will be created. The pid field will contain the process id that is // being tracked. rpc TrackProcess(IceboxTarget) returns (IceboxTarget) {} } // Sets options for SnapshotService. Used for both request and response // messages. message SnapshotPackage { enum Format { TARGZ = 0; TAR = 1; DIRECTORY = 2; } // The identifier to the snapshot, only required for request messages. For // streaming service, only used in the first stream message of a gRPC call // (would be ignored in consequent stream messages of the same call). string snapshot_id = 1; // A stream of bytes. Encoded as a tar (possibly gzipped) file pendinf on the // value of format. bytes payload = 2; // [response only] status fields, usually indicates end of transmission. bool success = 3; bytes err = 4; // [request only] Format of the payload. Only used in request messages. For // streaming service, only used in the first stream message of a gRPC call // (would be ignored in consequent stream messages of the same call). Format format = 5; // [request only] Path to the snapshot package file. Only used in request // messages. // // When set in a request, the PullSnapshot/PushSnapshot operation will // directly write/read the exported snapshot in path without streaming, which // is usually significantly faster. It would require emulator to have direct // access to path, which usually means it can only be used with a local // emulator. string path = 6; } // A snapshot filter can be used to filter the results produced by ListSnapshots message SnapshotFilter { enum LoadStatus { // Only return compatible snapshots CompatibleOnly = 0; // Return all snapshots. All = 1; } // Filter snapshots by load status. LoadStatus statusFilter = 1; } // Provides detailed information regarding the snapshot. message SnapshotDetails { enum LoadStatus { // The emulator believes that the snapshot is compatible with the emulator // that provided this information. The emulator will attempt to load this // snapshot when requested. // // A snapshot is usually compatible when the following statements are true: // - The snapshot was taken by the current emulator version. i.e. // emulator_build_id in the details field matches the build_id of the // emulator that provided this information. // // - The snapshot was taken on the current running machine, and no hardware // changes have taken place between taking and loading the snapshot. // // - The avd configuration has not changed between when this snapshot was // taken and when the snapshot was loaded. // // - The system images on which the avd is based have not changed. Compatible = 0; // The emulator will not allow loading of the snapshot, as it deems the // snapshot to be incompatible. Loading of snapshots can be forced by // launching the emulator with the feature "AllowSnapshotMigration" enabled. Incompatible = 1; // This snapshot was successfully loaded in the emulator, and was used at // the starting point of the current running emulator. The following holds: // // A loaded snapshot is a compatible snapshot // There is at most one snapshot_id that is in the "Loaded" state Loaded = 2; } // The id of this snapshot. Use this id to load/delete/pull the // snapshot. string snapshot_id = 1; // Detailed information about this snapshot. This contains a detailed // hardware description of the snapshot. These details are the same // as the "snapshot.pb" file found in an exported snapshot. // Look at the import file for a detailed description of the available // fields. emulator_snapshot.Snapshot details = 2; // Provides information about the ability to restore this snapshot. LoadStatus status = 3; // The size of the folder that stores required information to load a snapshot. uint64 size = 4; } // A list of on snapshot details. message SnapshotList { repeated SnapshotDetails snapshots = 1; } message IceboxTarget { // This is the process id to attach to, if this value is not set (0) // The process name will be used instead. int64 pid = 1; // The process name to attach to if any, if this is not set the pid will // be used. This is usually the application name of your application under // test, that is passed in to the am instrument command. It is likely // what you will find in your AndroidManifest.xml string package_name = 2; // The name of the snapshot that icebox will create if a snapshot is // generated. string snapshot_id = 3; // [Output Only] True if icebox failed to track the given target. bool failed = 4; // [Output Only] Detailed error message that might provide more information. string err = 5; // Maximum number of snapshots the emulator can take during one Icebox run. // Set to -1 for unlimited number of snapshots. int32 max_snapshot_number = 6; } // list of deleted methods: // ================================================ FILE: android_env/proto/state.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; option java_multiple_files = true; message SaveStateRequest { map args = 1; } message LoadStateRequest { map args = 1; } message SaveStateResponse { enum Status { // Reserved value for unset statuses. UNDEFINED = 0; // Returned when everything goes well. OK = 1; // Returned when something internal did not work as expected. ERROR = 2; } Status status = 1; // `error_message` is only populated in case of errors. string error_message = 2; // Any additional info returned during the request; e.g., file paths or sizes. map additional_info = 3; } message LoadStateResponse { enum Status { // Reserved value for unset statuses. UNDEFINED = 0; // Returned when everything goes well. OK = 1; // Returned when there is no state to load. NOT_FOUND = 2; // Returned when something internal did not work as expected. ERROR = 3; } Status status = 1; // `error_message` is only populated in case of errors. string error_message = 2; // Any additional info returned during the request; e.g., file paths or sizes. map additional_info = 3; } ================================================ FILE: android_env/proto/task.proto ================================================ // Copyright 2026 DeepMind Technologies Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. syntax = "proto3"; package android_env; import "android_env/proto/adb.proto"; // An AppScreen identifies a unique configuration that we can observe on the // screen of a device. message AppScreen { // Fully-qualified name of the activity. string activity = 1; // A list of regexes to match at each level of the current view hierarchy. // The environment uses this list to determine whether the agent has "exited" // this current task. // Example: [ // "^DecorView@.*\[MainActivity\]$", // "^android.widget.LinearLayout\{.*\}$", // "^android.widget.FrameLayout\{.*android\:id\/content\}", // "^android.widget.RelativeLayout\{.*\}", // "^android.widget.FrameLayout\{.*app\:id\/fragment_holder\}", // "^android.widget.RelativeLayout\{.*\}", // "^com.google.example.games.nostalgicracer.views.RaceView3D\{.*app\:id\/gameplay_screen_3d\}", // ], repeated string view_hierarchy_path = 2; } // Waits for `app_screen` to be the current app screen shown to the user. message WaitForAppScreen { AppScreen app_screen = 1; // Maximum time in seconds to wait for the activity to become the current one. float timeout_sec = 2; } message CheckInstall { string package_name = 1; // Maximum time in seconds to wait. float timeout_sec = 2; } message Sleep { float time_sec = 1; } message SuccessCondition { int32 num_retries = 1; oneof check { WaitForAppScreen wait_for_app_screen = 2; CheckInstall check_install = 3; } } message SetupStep { SuccessCondition success_condition = 1; oneof step { AdbRequest adb_request = 2; Sleep sleep = 3; } } // A specification of structured observations // Analogous to dm_env.specs.Array() message ArraySpec { // An identifier for this ArraySpec. string name = 1; // The shape of the multi-dimensional values associated with this ArraySpec, repeated int32 shape = 2; enum DataType { INVALID_DATA_TYPE = 0; FLOAT = 1; DOUBLE = 2; INT8 = 3; INT16 = 4; INT32 = 5; INT64 = 6; UINT8 = 7; UINT16 = 8; UINT32 = 9; UINT64 = 10; BOOL = 11; STRING_U1 = 12; STRING_U16 = 13; STRING_U25 = 14; STRING_U250 = 15; STRING = 16; // String without max length OBJECT = 17; } // Data type of elements we expect to see in an array of this spec. DataType dtype = 3; } message LogParsingConfig { // `filters` are tags used by the app's logging system so that we can // identify them in logcat's output. It's the first argument to logging calls // such as Log.e("ActivityManager", "My message"). // Example: "ActivityManager" repeated string filters = 1; // Regular expressions that define how we can extract RL information such as // score, extras and episode end from raw logcat messages. message LogRegexps { // Regexp expected to match: // ...a floating point value which gets accumulated over time. // A delta in 'score' corresponds to the reward. string score = 1; // Regexp expected to match: // ...a floating point value directly forwarded by the environment. repeated string reward = 2; // Regexp expected to match: // ...a signal marking the end of an episode. repeated string episode_end = 3; // Regexp expected to match: // ...a string representing pairs of extra names and values. repeated string extra = 4; // Regexp expected to match: // ...a dict of extra names and values in json format. repeated string json_extra = 5; // Attaches rewards to arbitrary log messages, for example: // {event: "coin_collected" reward: 2.3} // {event: "car_crashed" reward: -1.4} message RewardEvent { // If `event` is matched, the environment will give `reward`. string event = 1; // Numerical value to give as reward if `event` is matched. float reward = 2; } repeated RewardEvent reward_event = 6; } LogRegexps log_regexps = 2; } // Description of a reinforcement learning task to be solved by an agent. message Task { // A globally unique identifier for this task. string id = 1; // A human readable name for this task. string name = 2; // A description of the task. string description = 3; repeated SetupStep setup_steps = 4; repeated SetupStep reset_steps = 5; AppScreen expected_app_screen = 6; // AndroidEnv resets the episode after `max_episode_sec` is passed since the // last reset(). Recommended for time sensitive tasks (e.g. reactive games). // Note that this is real time as measured by AndroidEnv and is independent of // the speed of simulation of Android. // If <= 0.0, this logic is disabled. float max_episode_sec = 7; // The maximum number of interactions in a single episode between the // environment and an agent. // This setting is appropriate for tasks that are not time-dependent or when // the performance of the simulation varies dramatically between runs. // If <= 0, this logic is disabled. int32 max_episode_steps = 8; // Defines parameters for parsing messages from logcat. LogParsingConfig log_parsing_config = 9; // NOTE: This field is deprecated and will be removed from this Task // definition soon. // // (Optional): The task may also define extras to help the RL agent. // An Extra in AndroidEnv is any information that apps may send to aid the // understanding of the task. The type of information sent through this // channel is usually something difficult to obtain from raw pixels and may // include things such as: // // - The current board configuration (e.g. of a chess game or a tetris game) // - The position of the avatar in a map // - Events (e.g. whether a button was pressed or a checkpoint was achieved) // // Notice that these are entirely optional and may not be available at all. // This specification ensures that only extras specified in the Task // definition will be passed to the agent, everything else is excluded. // The name of an extra must be unique across all extras. repeated ArraySpec extras_spec = 10; } ================================================ FILE: android_env/wrappers/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/wrappers/a11y/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: android_env/wrappers/a11y/a11y_events.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for accessing accessibility events.""" from collections.abc import Mapping from typing import Any from absl import logging from android_env.proto.a11y import a11y_pb2 import numpy as np from google.protobuf import any_pb2 _A11Y_EVENT_KEY = 'full_event' def package_events_to_task_extras( events: list[a11y_pb2.EventRequest], ) -> Mapping[str, np.ndarray]: if not events: return {} events = np.stack(events, axis=0) return {_A11Y_EVENT_KEY: events} def extract_events_from_task_extras( task_extras: Mapping[str, Any] | None = None, ) -> list[Mapping[str, str]]: """Inspects task_extras and extracts all accessibility events detected. Args: task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a key in task_extras, then this function returns an empty string. Otherwise, full_event is expected to be list to be a numpy array with one dimension, and contains a list of dictionary describing accessibility events that are present in the given task extras. e.g. 'event_type: TYPE_WINDOW_CONTENT_CHANGED // event_package_name: com.google.android.deskclock // source_class_name: android.widget.ImageView'. Returns: List of all events detected """ if task_extras is None or _A11Y_EVENT_KEY not in task_extras: return [] if ( not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray) or task_extras[_A11Y_EVENT_KEY].ndim != 1 ): raise ValueError( f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one' ' dimension.' ) if task_extras[_A11Y_EVENT_KEY].size == 0: return [] events = [] for e in task_extras[_A11Y_EVENT_KEY]: if isinstance(e, a11y_pb2.EventRequest): events.append(dict(e.event)) elif isinstance(e, dict): events.append(e) logging.warning( 'The event should come only from the a11y_grpc_wrapper. ' 'Please verify that the upacking operation has not been ' 'called twice. See here for full task_extras: %s', task_extras, ) elif isinstance(e, any_pb2.Any): ev = a11y_pb2.EventRequest() new_any = any_pb2.Any() new_any.CopyFrom(e) new_any.Unpack(ev) events.append(dict(ev.event)) else: raise TypeError( f'Unexpected event type: {type(e)}. See here for full ' f'task_extras: {task_extras}.' ) return events def keep_latest_event_only(task_extras: dict[str, Any]): """Removes all a11y events except the last one observed.""" if task_extras is None or 'full_event' not in task_extras: return if ( not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray) or task_extras[_A11Y_EVENT_KEY].ndim != 1 ): raise ValueError( f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one' ' dimension.' ) if task_extras[_A11Y_EVENT_KEY].size == 0: return [] task_extras[_A11Y_EVENT_KEY] = task_extras[_A11Y_EVENT_KEY][-1:] ================================================ FILE: android_env/wrappers/a11y/a11y_events_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for a11y_events.""" from absl.testing import absltest from absl.testing import parameterized from android_env.proto.a11y import a11y_pb2 from android_env.wrappers.a11y import a11y_events import numpy as np from google.protobuf import any_pb2 def _event_request(d: dict[str, str]) -> a11y_pb2.EventRequest: event_request = a11y_pb2.EventRequest() for k, v in d.items(): event_request.event[k] = v return event_request def _event_request_as_any(d: dict[str, str]) -> any_pb2.Any: event_request = _event_request(d) response = any_pb2.Any() response.Pack(event_request) return response class A11yEventsTest(parameterized.TestCase): @parameterized.parameters( dict(task_extras={}), dict( task_extras={'no_full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}]}, ), dict( task_extras={'full_event': np.array([])}, ), dict( task_extras={}, ), ) def test_no_events_in_task_extras(self, task_extras): events = a11y_events.extract_events_from_task_extras(task_extras) self.assertEmpty(events) @parameterized.parameters( dict( task_extras={'full_event': [{'1': '1'}, {'2': '2'}]}, expected_events=[{'1': '1'}, {'2': '2'}], ), dict( task_extras={'full_event': [{}]}, expected_events=[{}], ), dict( task_extras={ 'full_event_wrong_key': [1, 2, 3], 'full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}], }, expected_events=[{'1': '1'}, {'2': '2'}, {'3': '3'}], ), ) def test_task_extras(self, task_extras, expected_events): event_requests = [_event_request(e) for e in task_extras['full_event']] task_extras['full_event'] = np.stack(event_requests, axis=0) events = a11y_events.extract_events_from_task_extras(task_extras) self.assertEqual(len(events), len(expected_events)) for i, event in enumerate(expected_events): self.assertEqual(len(event), len(expected_events[i])) for k, v in event.items(): self.assertIn(k, expected_events[i]) self.assertEqual(v, expected_events[i][k]) def test_events_key_has_dict_event_requrests(self): event_requests = [ _event_request({'1': '1'}), {'2': '2'}, _event_request({'3': '3'}), ] expected_events = [ {'1': '1'}, {'2': '2'}, {'3': '3'}, ] task_extras = {'full_event': np.stack(event_requests, axis=0)} events = a11y_events.extract_events_from_task_extras(task_extras) self.assertEqual(len(events), len(expected_events)) for i, event in enumerate(expected_events): self.assertEqual(len(event), len(expected_events[i])) for k, v in event.items(): self.assertIn(k, expected_events[i]) self.assertEqual(v, expected_events[i][k]) def test_events_key_has__event_requrests_packed_as_any(self): event_requests = [ _event_request_as_any({'1': '1'}), {'2': '2'}, _event_request_as_any({'3': '3'}), ] expected_events = [ {'1': '1'}, {'2': '2'}, {'3': '3'}, ] task_extras = {'full_event': np.stack(event_requests, axis=0)} events = a11y_events.extract_events_from_task_extras(task_extras) self.assertEqual(len(events), len(expected_events)) for i, event in enumerate(expected_events): self.assertEqual(len(event), len(expected_events[i])) for k, v in event.items(): self.assertIn(k, expected_events[i]) self.assertEqual(v, expected_events[i][k]) def test_events_key_has_non_event_requrests(self): event_requests = [ _event_request({'1': '1'}), 3, # Not an even and not a dict. _event_request({'3': '3'}), ] task_extras = {'full_event': np.stack(event_requests, axis=0)} with self.assertRaises(TypeError): _ = a11y_events.extract_events_from_task_extras(task_extras) @parameterized.parameters( dict(task_extras={}, expected_extras={}), dict( task_extras={ 'no_full_event': 42, }, expected_extras={ 'no_full_event': 42, }, ), dict( task_extras={'full_event': np.array([1, 2]), 'no_full_event': 43}, expected_extras={'full_event': np.array([2]), 'no_full_event': 43}, ), dict( task_extras={'full_event': np.array([1, 2, 3])}, expected_extras={'full_event': np.array([3])}, ), dict( task_extras={'full_event': np.array([]), 'no_full_event': 44}, expected_extras={'full_event': np.array([]), 'no_full_event': 44}, ), ) def test_keep_latest_only(self, task_extras, expected_extras): a11y_events.keep_latest_event_only(task_extras) self.assertEqual(len(task_extras), len(expected_extras)) for k, v in task_extras.items(): self.assertIn(k, expected_extras) if k == 'full_event': np.testing.assert_array_equal(v, expected_extras['full_event']) else: self.assertEqual(v, expected_extras[k]) pass if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/a11y/a11y_forests.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tools for accessing accessibility events.""" from collections.abc import Mapping from typing import Any from android_env.proto.a11y import android_accessibility_forest_pb2 import numpy as np from google.protobuf import any_pb2 _A11Y_FORESTS_KEY = 'accessibility_tree' def package_forests_to_task_extras( forests: list[android_accessibility_forest_pb2.AndroidAccessibilityForest], ) -> Mapping[str, np.ndarray]: if not forests: return {} forests = np.stack(forests, axis=0) return {_A11Y_FORESTS_KEY: forests} def task_extras_has_forests(task_extras: Mapping[str, Any]) -> bool: """Checks that the task_extras has any a11y forest information.""" if _A11Y_FORESTS_KEY not in task_extras: return False payload = task_extras[_A11Y_FORESTS_KEY] if not isinstance(payload, np.ndarray) or payload.ndim != 1: raise ValueError( f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one' f' dimension. payload: {payload}' ) if payload.size == 0: return False if any(isinstance(f, any_pb2.Any) for f in payload): # Forests were packed as Any. return True return any( isinstance(f, android_accessibility_forest_pb2.AndroidAccessibilityForest) for f in payload ) def convert_to_forest( forest: android_accessibility_forest_pb2.AndroidAccessibilityForest | any_pb2.Any | None, ) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None: """Takes an object and attempts to convert it to a forest.""" if forest is None: return None if isinstance(forest, any_pb2.Any): output = android_accessibility_forest_pb2.AndroidAccessibilityForest() new_any = any_pb2.Any() new_any.CopyFrom(forest) new_any.Unpack(output) return output elif isinstance( forest, android_accessibility_forest_pb2.AndroidAccessibilityForest ): return forest else: return None def extract_forests_from_task_extras( task_extras: Mapping[str, Any] | None = None, ) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]: """Inspects task_extras and extracts all accessibility forests detected. Args: task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a key in task_extras, then this function returns an empty string. Otherwise, full_event is expected to be list to be a numpy array with one dimension, and contains a list of dictionary describing accessibility forests that are present in the given task extras. Returns: List of all forests detected """ if task_extras is None or not task_extras_has_forests(task_extras): return [] forests = [] for f in task_extras[_A11Y_FORESTS_KEY]: f = convert_to_forest(f) if f is not None: forests.append(f) return forests def keep_latest_forest_only(task_extras: dict[str, Any]): """Removes all a11y forests except the last one observed.""" if _A11Y_FORESTS_KEY not in task_extras.keys(): return payload = task_extras[_A11Y_FORESTS_KEY] if not isinstance(payload, np.ndarray) or payload.ndim != 1: raise ValueError( f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one' f' dimension. payload: {payload}' ) if payload.size == 0: return task_extras[_A11Y_FORESTS_KEY] = payload[-1:] ================================================ FILE: android_env/wrappers/a11y/a11y_forests_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for a11y_forests.""" from absl.testing import absltest from absl.testing import parameterized from android_env.proto.a11y import android_accessibility_forest_pb2 from android_env.wrappers.a11y import a11y_forests import numpy as np from google.protobuf import any_pb2 def _pack_any(proto_message) -> any_pb2.Any: response = any_pb2.Any() response.Pack(proto_message) return response def _empty_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): return android_accessibility_forest_pb2.AndroidAccessibilityForest() def _one_empty_window_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() forest.windows.add() return forest def _two_window_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() window = forest.windows.add() window.tree.nodes.add( class_name='foo', is_clickable=True, hint_text='Foo hint' ) forest.windows.add() return forest class A11YForestsTest(parameterized.TestCase): @parameterized.parameters( dict(task_extras={}, expected_forests=[], convert_to_np=[]), dict( task_extras={'accessibility_tree': []}, convert_to_np=['accessibility_tree'], expected_forests=[], ), dict( task_extras={ 'not_accessibility_tree': [ _empty_forest(), _one_empty_window_forest(), _two_window_forest(), ], }, convert_to_np=['not_accessibility_tree'], expected_forests=[], ), dict( task_extras={ 'accessibility_tree': [ _empty_forest(), {'not_a_forest_key': 'nor_a_forest_value'}, _two_window_forest(), ] }, convert_to_np=['accessibility_tree'], expected_forests=[_empty_forest(), _two_window_forest()], ), dict( task_extras={ 'accessibility_tree': [ {'not_a_forest_key': 'nor_a_forest_value'}, 3, 4, {'not_a_forest_key': _empty_forest()}, ], }, convert_to_np=['accessibility_tree'], expected_forests=[], ), dict( task_extras={'accessibility_tree': []}, convert_to_np=['accessibility_tree'], expected_forests=[], ), dict( task_extras={ 'accessibility_tree_wrong_key': [1, 2, 3], 'accessibility_tree': [ _empty_forest(), None, None, _one_empty_window_forest(), _two_window_forest(), ], }, convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'], expected_forests=[ _empty_forest(), _one_empty_window_forest(), _two_window_forest(), ], ), dict( task_extras={ 'accessibility_tree_wrong_key': [1, 2, 3], 'accessibility_tree': [ None, _pack_any(_empty_forest()), _pack_any(_one_empty_window_forest()), _pack_any(_two_window_forest()), ], }, convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'], expected_forests=[ _empty_forest(), _one_empty_window_forest(), _two_window_forest(), ], ), dict( task_extras={ 'accessibility_tree': [ _pack_any(_empty_forest()), {'not_a_forest_key': 'nor_a_forest_value'}, None, _two_window_forest(), None, ] }, convert_to_np=['accessibility_tree'], expected_forests=[_empty_forest(), _two_window_forest()], ), ) def test_task_extras(self, task_extras, expected_forests, convert_to_np): for k in convert_to_np: if task_extras[k]: task_extras[k] = np.stack(task_extras[k], axis=0) else: task_extras[k] = np.array([]) forests = a11y_forests.extract_forests_from_task_extras(task_extras) self.assertEqual(len(forests), len(expected_forests)) for idx, f in enumerate(forests): self.assertEqual(f, expected_forests[idx]) @parameterized.parameters( dict(task_extras={}, expected_extras={}), dict( task_extras={ 'no_accessibility_tree': 42, }, expected_extras={ 'no_accessibility_tree': 42, }, ), dict( task_extras={'accessibility_tree': []}, expected_extras={'accessibility_tree': []}, ), dict( task_extras={ 'accessibility_tree': [ _empty_forest(), _one_empty_window_forest(), ], 'no_accessibility_tree': 43, }, expected_extras={ 'accessibility_tree': [_one_empty_window_forest()], 'no_accessibility_tree': 43, }, ), dict( task_extras={ 'accessibility_tree': [ _empty_forest(), _one_empty_window_forest(), _two_window_forest(), ] }, expected_extras={'accessibility_tree': [_two_window_forest()]}, ), dict( task_extras={ 'accessibility_tree': [], 'no_accessibility_tree': 44, }, expected_extras={ 'accessibility_tree': [], 'no_accessibility_tree': 44, }, ), ) def test_keep_latest_only(self, task_extras, expected_extras): if 'accessibility_tree' in task_extras: if task_extras['accessibility_tree']: task_extras['accessibility_tree'] = np.stack( task_extras['accessibility_tree'], axis=0 ) else: task_extras['accessibility_tree'] = np.array([]) a11y_forests.keep_latest_forest_only(task_extras) self.assertSameElements(task_extras.keys(), expected_extras.keys()) for k in task_extras.keys(): if k == 'accessibility_tree': self.assertEqual(len(task_extras[k]), len(expected_extras[k])) for idx, f in enumerate(task_extras[k]): self.assertEqual(f, expected_extras[k][idx]) else: self.assertEqual(task_extras[k], expected_extras[k]) pass if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/a11y/a11y_servicer.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Accessibility Servicer implementation.""" import asyncio from collections.abc import AsyncIterator, Generator, Iterable import threading from absl import logging from android_env.proto.a11y import a11y_pb2 from android_env.proto.a11y import a11y_pb2_grpc from android_env.proto.a11y import android_accessibility_forest_pb2 import grpc class A11yServicer(a11y_pb2_grpc.A11yServiceServicer): """Services the A11yService requests.""" def __init__(self, latest_forest_only: bool = False): self._received_forests: list[ android_accessibility_forest_pb2.AndroidAccessibilityForest ] = [] self._received_events: list[a11y_pb2.EventRequest] = [] self._lock_forests = threading.Lock() self._lock_events = threading.Lock() self._latest_forest_only = latest_forest_only self._paused = True # A11y Forest bookkeeping. self._get_forest = asyncio.Event() # Whether to request a forest. self._forest_ready = asyncio.Event() # Whether the forest is ready. self._latest_forest: ( android_accessibility_forest_pb2.AndroidAccessibilityForest | None ) = None def SendForest( self, request: android_accessibility_forest_pb2.AndroidAccessibilityForest, context: grpc.ServicerContext, ) -> a11y_pb2.ForestResponse: self._process_forest(request) return a11y_pb2.ForestResponse() def SendEvent( self, request: a11y_pb2.EventRequest, context: grpc.ServicerContext, ) -> a11y_pb2.EventResponse: self._process_event(request) return a11y_pb2.EventResponse() async def Bidi( self, request_iterator: AsyncIterator[a11y_pb2.ClientToServer], context: grpc.aio.ServicerContext, ) -> AsyncIterator[a11y_pb2.ServerToClient]: """Processes incoming ClientToServer requests.""" logging.info('Starting A11yServicer.Bidi()') # Send a dummy message to unblock clients in their loop. yield a11y_pb2.ServerToClient() # This block defines two coroutines: # # * `read_client_requests()` # * `check_forest()` # # They cooperate with each other and both populate a queue `q` which is # consumed in a loop below, which actually yields requests which are sent to # the client. The processing finishes when the clients "closes" the # connection, which causes `read_client_requests()` to put a special value, # `STOP_ITERATION`, in the queue. # Queue for communicating from coroutines to `Bidi()`. q = asyncio.Queue() should_run = True async def read_client_requests(): """Coroutine for reading client requests.""" nonlocal should_run async for request in request_iterator: field_name = request.WhichOneof('payload') match field_name: case 'event': self._process_event(request.event) case 'forest': self._latest_forest = request.forest self._forest_ready.set() self._get_forest.clear() # Reset the `Event`. case _: logging.error('Unknown field %r', field_name) await q.put(a11y_pb2.ServerToClient()) # Send a special value to stop processing this `Bidi` connection. await q.put('STOP_ITERATION') should_run = False async def check_forest(): """Coroutine for sending "get forest" requests.""" nonlocal should_run while should_run: await self._get_forest.wait() await q.put(a11y_pb2.ServerToClient(get_forest={})) tasks = asyncio.gather(read_client_requests(), check_forest()) while should_run: v = await q.get() if v == 'STOP_ITERATION': break else: yield v await tasks logging.info('Finishing A11yServicer.Bidi()') async def get_forest( self, ) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None: """Issues a request to get the a11y forest from the client.""" self._get_forest.set() # Unblocks coroutine to send a request. await self._forest_ready.wait() # Wait for forest to be ready. self._forest_ready.clear() # Reset the `Event`. return self._latest_forest def gather_forests( self, ) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]: forests = [] with self._lock_forests: forests = self._received_forests self._received_forests = [] return forests def gather_events(self) -> list[a11y_pb2.EventRequest]: events = [] with self._lock_events: events = self._received_events self._received_events = [] return events def pause_and_clear(self) -> None: """Temporarily stop receiving events/forests and clear the queue. Used when resetting the environment; in this case: - all events/forests that have been received since last timestep are things that happened in the last episode after its `LAST` timestep (so we should ignore them, done by clearing the lists). - we're about to receive a bunch of events/forests just as a result of resetting the environment. We don't want to count these either; thus we temporarily stop receiving new ones. """ self._paused = True with self._lock_forests: self._received_forests = [] with self._lock_events: self._received_events = [] def resume(self) -> None: """Start receiving events/forests (e.g., after a reset).""" self._paused = False def _process_event(self, event: a11y_pb2.EventRequest) -> None: """Adds the given event to the internal buffer of events.""" if not self._paused: with self._lock_events: self._received_events.append(event) def _process_forest( self, forest: android_accessibility_forest_pb2.AndroidAccessibilityForest ) -> None: """Adds the given forest to the internal buffer of forests.""" if not self._paused: with self._lock_forests: if self._latest_forest_only: self._received_forests = [forest] else: self._received_forests.append(forest) ================================================ FILE: android_env/wrappers/a11y/a11y_servicer_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for a11y_servicer.""" import asyncio from collections.abc import AsyncIterator, Iterable from typing import TypeVar from unittest import IsolatedAsyncioTestCase, mock from absl.testing import absltest from absl.testing import parameterized from android_env.proto.a11y import a11y_pb2 from android_env.proto.a11y import android_accessibility_forest_pb2 from android_env.wrappers.a11y import a11y_servicer import grpc _T = TypeVar('_T') async def _aiter(xs: Iterable[_T]) -> AsyncIterator[_T]: """Utility to make an AsyncIterator from Iterable.""" for x in xs: yield x def one_window_one_node_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() window = forest.windows.add() node = window.tree.nodes.add() node.class_name = 'foo' node.is_clickable = True node.hint_text = 'Foo hint' return forest def one_window_two_nodes_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() window = forest.windows.add() node = window.tree.nodes.add() node.class_name = 'bar' node.is_clickable = True node.hint_text = 'Bar hint' node = window.tree.nodes.add() node.class_name = 'bar' node.is_clickable = False node.hint_text = 'Bar hint 2' return forest def empty_dict() -> dict[str, str]: return {} def single_item_dict_with_special_chars() -> dict[str, str]: return {'foo': 'bar\r\t\nbaz'} class A11yServicerTest(parameterized.TestCase, IsolatedAsyncioTestCase): def test_servicer_sendforest(self): mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) servicer = a11y_servicer.A11yServicer() servicer.resume() response = servicer.SendForest(one_window_one_node_forest(), mock_context) self.assertEqual(response.error, '') response = servicer.SendForest(one_window_two_nodes_forest(), mock_context) self.assertEqual(response.error, '') forests = servicer.gather_forests() self.assertLen(forests, 2) self.assertEqual(forests[0], one_window_one_node_forest()) self.assertEqual(forests[1], one_window_two_nodes_forest()) async def test_servicer_bidi_forests(self): """Checks that the bidirectional interface accepts forests.""" # Arrange. mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) servicer = a11y_servicer.A11yServicer() # Act. servicer.resume() responses = [ x async for x in servicer.Bidi( _aiter([ a11y_pb2.ClientToServer( event=a11y_pb2.EventRequest( event=single_item_dict_with_special_chars() ) ), a11y_pb2.ClientToServer(forest=one_window_two_nodes_forest()), ]), mock_context, ) ] forest = await servicer.get_forest() # Assert. self.assertEqual(responses[0], a11y_pb2.ServerToClient()) self.assertEqual(responses[1], a11y_pb2.ServerToClient()) self.assertIsNotNone(forest) self.assertEqual(forest, one_window_two_nodes_forest()) def test_servicer_sendforest_latest_only(self): mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) servicer = a11y_servicer.A11yServicer(latest_forest_only=True) servicer.resume() response = servicer.SendForest(one_window_one_node_forest(), mock_context) self.assertEqual(response.error, '') response = servicer.SendForest(one_window_two_nodes_forest(), mock_context) self.assertEqual(response.error, '') forests = servicer.gather_forests() self.assertLen(forests, 1) self.assertEqual(forests[0], one_window_two_nodes_forest()) def test_servicer_sendevent(self): mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) servicer = a11y_servicer.A11yServicer() servicer.resume() response = servicer.SendEvent( a11y_pb2.EventRequest(event=empty_dict()), mock_context ) self.assertEqual(response.error, '') response = servicer.SendEvent( a11y_pb2.EventRequest(event=single_item_dict_with_special_chars()), mock_context, ) self.assertEqual(response.error, '') events = servicer.gather_events() self.assertLen(events, 2) self.assertEqual(events[0].event, empty_dict()) self.assertEqual(events[1].event, single_item_dict_with_special_chars()) async def test_servicer_bidi_events(self): """Checks that the bidirectional interface accepts events.""" # Arrange. mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) servicer = a11y_servicer.A11yServicer() # Act. servicer.resume() responses = [ x async for x in servicer.Bidi( _aiter([ a11y_pb2.ClientToServer( event=a11y_pb2.EventRequest(event=empty_dict()) ), a11y_pb2.ClientToServer( event=a11y_pb2.EventRequest( event=single_item_dict_with_special_chars() ) ), ]), mock_context, ) ] events = servicer.gather_events() # Assert. self.assertEqual(responses[0], a11y_pb2.ServerToClient()) self.assertEqual(responses[1], a11y_pb2.ServerToClient()) self.assertLen(events, 2) self.assertEqual(events[0].event, empty_dict()) self.assertEqual(events[1].event, single_item_dict_with_special_chars()) def test_servicer_pause_and_clear_pauses(self): mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) servicer = a11y_servicer.A11yServicer() servicer.resume() servicer.pause_and_clear() response = servicer.SendEvent( a11y_pb2.EventRequest(event=empty_dict()), mock_context ) self.assertEqual(response.error, '') response = servicer.SendForest(one_window_one_node_forest(), mock_context) self.assertEqual(response.error, '') events = servicer.gather_events() self.assertEmpty(events) forests = servicer.gather_forests() self.assertEmpty(forests) def test_servicer_pause_and_clear_clears(self): mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) servicer = a11y_servicer.A11yServicer() servicer.resume() response = servicer.SendEvent( a11y_pb2.EventRequest(event=empty_dict()), mock_context ) self.assertEqual(response.error, '') response = servicer.SendForest(one_window_one_node_forest(), mock_context) self.assertEqual( response.error, '', ) servicer.pause_and_clear() events = servicer.gather_events() self.assertEmpty(events) forests = servicer.gather_forests() self.assertEmpty(forests) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/a11y_grpc_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wraps AndroidEnv to retrieve accessibility messages from gRPC.""" from concurrent import futures import time from typing import Any import urllib.request from absl import logging from android_env import env_interface from android_env.components import action_type as android_action_type_lib from android_env.proto import adb_pb2 from android_env.proto.a11y import a11y_pb2_grpc from android_env.wrappers import base_wrapper from android_env.wrappers.a11y import a11y_events from android_env.wrappers.a11y import a11y_forests from android_env.wrappers.a11y import a11y_servicer import dm_env import grpc import numpy as np import portpicker def _get_accessibility_forwarder_apk() -> bytes: logging.info('Downloading accessibility forwarder apk....') with urllib.request.urlopen( 'https://storage.googleapis.com/android_env-tasks/2024.05.13-accessibility_forwarder.apk' ) as response: return response.read() class EnableNetworkingError(ValueError): pass class A11yGrpcWrapper(base_wrapper.BaseWrapper): """Wrapper which receives A11y events and forests over gRPC. A11y forest protobufs and event dicts are sent from the Android emulator via gRPC from the `AccessibilityForwarder` (for use in developing reward functions, etc). This wrapper constructs a server which receives these messages and channels them into `task_extras`. The downside of forwarding this information through gRPC is that no messages will be sent if networking is turned off (e.g., if the AVD is in airplane mode). To mitigate this problem, the `AccessibilityForwarder` logs an error message if it fails to contact the server. This wrapper monitors the logs for such error messages, and attempts (in another thread, to not block environment transitions) to reconnect the AVD to the network. If this fails to fix the problem, this wrapper ends the episode. This wrapper is implemented to be robust to multiple upstream callers of `task_extras`, and to ensure they each receive the same extras at every timestep. Thus, the logic is the following: * New a11y events/forests are fetched during `reset` and `step`, *not* during `task_extras()` calls. * If no one has called `task_extras()` since the last `step` or `reset`, the extras are accumulated (so that no extras are missed because someone called `step()` twice without calling `task_extras()`). * If someone *has* called `task_extras()` since last step, the newly fetched extras replace the old extras. """ def __init__( self, env: env_interface.AndroidEnvInterface, disable_other_network_traffic: bool = False, install_a11y_forwarding: bool = False, start_a11y_service: bool = True, enable_a11y_tree_info: bool = False, add_latest_a11y_info_to_obs: bool = False, a11y_info_timeout: float | None = None, max_enable_networking_attempts: int = 10, latest_a11y_info_only: bool = False, grpc_server_ip: str = '10.0.2.2', ): """Initializes wrapper. Args: env: Environment to wrap. disable_other_network_traffic: When True, all network traffic, other than the connection to the servicer, is disabled. NOTE: This requires root access on the device (i.e. it uses the `su` command). An `AdbControllerError` exception will be raised if the underlying command fails. install_a11y_forwarding: If True, the wrapper handles the installation of all packages required for the servicer to collect a11y information. start_a11y_service: If True, starts the a11y forwarding services. NOTE: The packages must be installed beforehand, e.g., using the install_a11y_forwarding flag. enable_a11y_tree_info: When False, this wrapper collects only a11y events and not a11y tree. add_latest_a11y_info_to_obs: When True, the latest observed a11y forest is added to the observation. a11y_info_timeout: When larger than zero and add_latest_a11y_info_to_obs is set to True, the wrapper will wait the corresponding amount of time, measured in seconds, to collect the latest a11y forest. max_enable_networking_attempts: When the a11y gRPC service fails to provide a11y information, we attempt this many times to re-enable the networking. If all these attempts fail, fetching task_extras will raise an EnableNetworkingError. latest_a11y_info_only: When True, the a11y servicer is setup to save only the latest tree it has received from the Android app. grpc_server_ip: The IP address of the gRPC server which will be broadcasted to the AccessibilityForwarder app where it should log the a11y info. By default, this is set to the IP address of the AVD's host machine which is 10.0.2.2: See https://developer.android.com/studio/run/emulator-networking#networkaddresses. """ self._env = env self._grpc_server_ip = grpc_server_ip if install_a11y_forwarding: self._install_a11y_forwarding_apk() time.sleep(10.0) if start_a11y_service: self._start_a11y_services() time.sleep(3.0) if enable_a11y_tree_info: self._enable_a11y_tree_logs() self._relaunch_count = 0 self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) self._servicer = a11y_servicer.A11yServicer( latest_forest_only=latest_a11y_info_only ) a11y_pb2_grpc.add_A11yServiceServicer_to_server( self._servicer, self._server ) server_credentials = grpc.local_server_credentials() self._port = portpicker.pick_unused_port() logging.info('Using port %s', self._port) uri_address = f'[::]:{self._port}' self._server.add_secure_port(uri_address, server_credentials) logging.info('Starting server') self._server.start() logging.info('Server now running.') self._max_enable_networking_attempts = max_enable_networking_attempts self._reset_enable_networking_attempts() self._disable_other_network_traffic = disable_other_network_traffic self._should_accumulate = False self._accumulated_extras = None self._add_latest_a11y_info_to_obs = add_latest_a11y_info_to_obs self._a11y_info_timeout = a11y_info_timeout self._parent_action_spec = self._env.action_spec() if self._a11y_info_timeout is not None and self._a11y_info_timeout > 0.0: if 'action_type' not in self._parent_action_spec.keys(): raise ValueError( 'action_type not in the parent action spec: ' f'{self._parent_action_spec}. This is a strong requirement when ' f'a11y_info_timeout = {a11y_info_timeout} > 0' ) def _start_a11y_services(self) -> None: """Starts the accessibility forwarder services. Raises: RuntimeError: If accessibility service is not started. """ start_service_request = adb_pb2.AdbRequest( settings=adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE, put=adb_pb2.AdbRequest.SettingsRequest.Put( key='enabled_accessibility_services', value=( 'com.google.androidenv.accessibilityforwarder/com.google.' 'androidenv.accessibilityforwarder.AccessibilityForwarder' ), ), ) ) start_service_response = self._env.execute_adb_call(start_service_request) if start_service_response.status != adb_pb2.AdbResponse.Status.OK: raise RuntimeError( 'Could not start accessibility forwarder ' 'service: ' f'{start_service_response}.' ) def _install_a11y_forwarding_apk(self) -> None: """Enables accessibility information forwarding.""" a11y_fwd_apk = _get_accessibility_forwarder_apk() # Install and setup the Accesssibility Forwarder. install_request = adb_pb2.AdbRequest( install_apk=adb_pb2.AdbRequest.InstallApk( blob=adb_pb2.AdbRequest.InstallApk.Blob(contents=a11y_fwd_apk), ) ) install_response = self._env.execute_adb_call(install_request) if install_response.status != adb_pb2.AdbResponse.Status.OK: raise ValueError( f'Could not install accessibility_forwarder.apk: {install_response}.' ) def _enable_a11y_tree_logs(self) -> None: enable_tree_logs_request = adb_pb2.AdbRequest( send_broadcast=adb_pb2.AdbRequest.SendBroadcast( action=( 'accessibility_forwarder.intent.action.' 'ENABLE_ACCESSIBILITY_TREE_LOGS' ), component=( 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver' ), ) ) enable_tree_logs_response = self._env.execute_adb_call( enable_tree_logs_request ) if enable_tree_logs_response.status != adb_pb2.AdbResponse.Status.OK: raise ValueError( 'Could not enable accessibility tree logging: ' f'{enable_tree_logs_response}.' ) def _reset_enable_networking_attempts(self) -> None: self._enable_networking_attempts_left = self._max_enable_networking_attempts self._enabling_networking_future = None self._a11y_exception = None def get_port(self): return self._port def close(self): self._server.stop(None) logging.info('gRPC server stopped') self._env.close() def attempt_enable_networking(self) -> None: """Attempts to turn on networking within the Android device. Attempt to turn on the networking in the Android device, by: - turning off airplane mode; - turning on the wifi connection. """ self.execute_adb_call( adb_pb2.AdbRequest( settings=adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, put=adb_pb2.AdbRequest.SettingsRequest.Put( key='airplane_mode_on', value='0' ), ) ) ) time.sleep(1.0) self.execute_adb_call( adb_pb2.AdbRequest( generic=adb_pb2.AdbRequest.GenericRequest( args=[ 'shell', 'svc', 'wifi', 'enable', ] ) ) ) time.sleep(1.0) def _configure_grpc(self) -> None: """Configure networking and set the gRPC ip and port on AVD or device.""" if self._disable_other_network_traffic: self.execute_adb_call( adb_pb2.AdbRequest( generic=adb_pb2.AdbRequest.GenericRequest( args=[ 'shell', 'su', '0', 'iptables', '-A', 'OUTPUT', '-p', 'tcp', '-d', self._grpc_server_ip, '--dport', str(self._port), '-j', 'ACCEPT', ] ) ) ) time.sleep(3.0) self.execute_adb_call( adb_pb2.AdbRequest( generic=adb_pb2.AdbRequest.GenericRequest( args=[ 'shell', 'su', '0', 'iptables', '-A', 'OUTPUT', '-j', 'DROP', ] ) ) ) time.sleep(3.0) self.execute_adb_call( adb_pb2.AdbRequest( settings=adb_pb2.AdbRequest.SettingsRequest( name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL, put=adb_pb2.AdbRequest.SettingsRequest.Put( key='no_proxy', value=f'{self._grpc_server_ip}:{self._port}' ), ) ) ) self.attempt_enable_networking() self.execute_adb_call( adb_pb2.AdbRequest( send_broadcast=adb_pb2.AdbRequest.SendBroadcast( action=( 'accessibility_forwarder.intent.action.SET_GRPC --ei' f' "port" {self._port} --es "host" {self._grpc_server_ip}' ), component=( 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver' ), ) ) ) def _accumulate_and_return_a11y_info( self, timer: float | None = None, get_env_observation: bool = True ) -> dict[str, Any]: """Accumulates and returns the latest a11y tree info and observation. Args: timer: If larger than 0, the system will wait this long for a11y info to accumulate before it returns a value. get_env_observation: If False, the corresponding observation is not introduced here. Returns: a dict with a11y forest under key 'a11y_forest'. All other fields will provide the observation, if requested. """ timer = timer or 0.0 if timer > 0.0: time.sleep(timer) if get_env_observation: # Fetch observation. new_ts = self._env.step({ 'action_type': np.array( android_action_type_lib.ActionType.REPEAT, dtype=self._parent_action_spec['action_type'].dtype, ), }) observation = new_ts.observation else: observation = {} extras = self.accumulate_new_extras() forests = a11y_forests.extract_forests_from_task_extras(extras) if forests: observation['a11y_forest'] = forests[-1] else: observation['a11y_forest'] = None return observation def _fetch_task_extras_and_update_observation( self, observation: dict[str, Any], timeout: float = 0.0 ) -> dict[str, Any]: if timeout > 0.0: observation = self._accumulate_and_return_a11y_info( timeout, get_env_observation=True ) if not self._add_latest_a11y_info_to_obs: observation.pop('a11y_forest') else: new_obs = self._accumulate_and_return_a11y_info(get_env_observation=False) if self._add_latest_a11y_info_to_obs: observation.update(new_obs) return observation def reset(self) -> dm_env.TimeStep: self._reset_enable_networking_attempts() self._servicer.pause_and_clear() timestep = self._env.reset() self._servicer.resume() if self._env.stats()['relaunch_count'] > self._relaunch_count: self._configure_grpc() self._relaunch_count = self._env.stats()['relaunch_count'] self._accumulated_extras = {} timeout = self._a11y_info_timeout or 0.0 new_observation = self._fetch_task_extras_and_update_observation( timestep.observation, timeout ) timestep = timestep._replace(observation=new_observation) return timestep def step(self, action: Any) -> dm_env.TimeStep: timeout = float(action.pop('wait_time', self._a11y_info_timeout or 0.0)) timestep = self._env.step(action) new_observation = self._fetch_task_extras_and_update_observation( timestep.observation, timeout=timeout ) timestep = timestep._replace(observation=new_observation) return timestep def accumulate_new_extras(self) -> dict[str, Any]: new_extras = self._fetch_task_extras() if self._should_accumulate: for key in new_extras: if key in self._accumulated_extras: self._accumulated_extras[key] = np.concatenate( (self._accumulated_extras[key], new_extras[key]), axis=0 ) else: self._accumulated_extras[key] = new_extras[key] else: self._accumulated_extras = new_extras self._should_accumulate = True return self._accumulated_extras def _fetch_task_extras(self) -> dict[str, Any]: """Fetches task_extras from the services. NOTE: If you want to access the latest a11y information, please use accumulate_and_return_a11y_info instead. This function has the side effect of clearing the content from the servicer, hence all the a11y info returned here won't be accumulated. Returns: A dict with the corresponding task_extras. Raises: EnableNetworkingError: after a fixed number of attempts to revive the a11y services by re-enabling the network connection. """ base_extras = self._env.task_extras(latest_only=False).copy() # If the previous future is done, reset it to the initial state. if ( self._enabling_networking_future is not None and self._enabling_networking_future.done() ): self._enabling_networking_future = None self._enable_networking_attempts_left -= 1 logging.info('Finished enabling networking.') if ( self._enabling_networking_future is None and 'exception' in base_extras and base_extras['exception'].shape[0] ): self._a11y_exception = base_extras['exception'] logging.warning( 'AccessibilityForwarder logged exceptions: %s', self._a11y_exception ) if self._enable_networking_attempts_left > 0: logging.warning( 'Attempting to enable networking. %s attempts left.', self._enable_networking_attempts_left - 1, ) executor = futures.ThreadPoolExecutor(max_workers=1) self._enabling_networking_future = executor.submit( self.attempt_enable_networking ) else: raise EnableNetworkingError( 'A11y service failed multiple times with' f' exception.{self._a11y_exception}.' ) forests = self._servicer.gather_forests() if forests: base_extras.update(a11y_forests.package_forests_to_task_extras(forests)) self._reset_enable_networking_attempts() events = self._servicer.gather_events() if events: base_extras.update(a11y_events.package_events_to_task_extras(events)) self._reset_enable_networking_attempts() return base_extras def task_extras(self, latest_only: bool = False) -> dict[str, Any]: if self._accumulated_extras is None: raise RuntimeError('You must call .reset() before calling .task_extras()') self._should_accumulate = False extras = self._accumulated_extras.copy() if latest_only: a11y_events.keep_latest_event_only(extras) a11y_forests.keep_latest_forest_only(extras) return extras ================================================ FILE: android_env/wrappers/a11y_grpc_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for a11y_grpc_wrapper.""" import time from unittest import mock from absl.testing import absltest from absl.testing import parameterized from android_env import env_interface from android_env.proto import adb_pb2 from android_env.proto.a11y import a11y_pb2 from android_env.proto.a11y import a11y_pb2_grpc from android_env.proto.a11y import android_accessibility_forest_pb2 from android_env.wrappers import a11y_grpc_wrapper import dm_env import grpc import numpy as np def empty_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): return android_accessibility_forest_pb2.AndroidAccessibilityForest() def one_empty_window_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() _ = forest.windows.add() return forest def one_window_one_node_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() window = forest.windows.add() node = window.tree.nodes.add() node.class_name = 'foo' node.is_clickable = True node.hint_text = 'Foo hint' return forest def one_window_two_nodes_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() window = forest.windows.add() node = window.tree.nodes.add() node.class_name = 'bar' node.is_clickable = True node.hint_text = 'Bar hint' node = window.tree.nodes.add() node.class_name = 'bar' node.is_clickable = False node.hint_text = 'Bar hint 2' return forest def three_windows_forest() -> ( android_accessibility_forest_pb2.AndroidAccessibilityForest ): forest = android_accessibility_forest_pb2.AndroidAccessibilityForest() _ = forest.windows.add() window = forest.windows.add() node = window.tree.nodes.add() node.class_name = 'foo' node.is_clickable = True node.hint_text = 'hint' window = forest.windows.add() node = window.tree.nodes.add() node.class_name = 'baz' node.is_clickable = True node.hint_text = 'hint' node = window.tree.nodes.add() node.class_name = 'foobar' node.is_clickable = False node.hint_text = 'hint' return forest def empty_dict() -> dict[str, str]: return {} def single_item_dict() -> dict[str, str]: return {'foo': 'bar'} def several_long_items_dict() -> dict[str, str]: return { 'first_key': 'Lorem ipsum ' * 100, 'second_key': 'the beginning is the end is' * 100, } def single_item_dict_with_special_chars() -> dict[str, str]: return {'foo': 'bar\r\t\nbaz'} def _ok_response(): return adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK) class A11yGrpcWrapperTest(parameterized.TestCase): def test_server(self): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.task_extras.return_value = {} base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) wrapped_env.reset() channel_creds = grpc.local_channel_credentials() with grpc.secure_channel( f'[::]:{wrapped_env.get_port()}', channel_creds ) as channel: grpc.channel_ready_future(channel).result() stub = a11y_pb2_grpc.A11yServiceStub(channel) stub.SendForest(one_window_one_node_forest()) stub.SendForest(one_window_two_nodes_forest()) wrapped_env.step({}) extras = wrapped_env.task_extras(latest_only=False) self.assertIn('accessibility_tree', extras) self.assertEqual(extras['accessibility_tree'].shape[0], 2) # tests of fetch_task_extras: # exception occurs (ensure attempt to enable networking) and recovers # exception occurs and enable networking doesn't help # exception occurs twice but with a forest sent between @parameterized.named_parameters( ('no_events_or_forests', [], []), ( 'no_events', [], [one_window_one_node_forest(), one_window_two_nodes_forest()], ), ('no_forests', [empty_dict(), single_item_dict()], []), ( 'events_and_forests', [empty_dict(), single_item_dict()], [one_window_one_node_forest(), one_window_two_nodes_forest()], ), ) @mock.patch.object(time, 'sleep', autospec=True) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_fetch_task_extras( self, received_events, received_forests, mock_server, mock_add_servicer, mock_sleep, ): del mock_server, mock_add_servicer, mock_sleep mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.task_extras.return_value = { 'foo': np.array(['bar', 'baz'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) wrapped_env.reset() for forest in received_forests: wrapped_env._servicer.SendForest(forest, mock_context) for event in received_events: wrapped_env._servicer.SendEvent( a11y_pb2.EventRequest(event=event), mock_context ) with mock.patch.object( wrapped_env, 'attempt_enable_networking' ) as mock_attempt_enable_networking: extras = wrapped_env._fetch_task_extras() mock_attempt_enable_networking.assert_not_called() self.assertIn('foo', extras) np.testing.assert_array_equal(extras['foo'], ['bar', 'baz']) self.assertIn('some_key', extras) np.testing.assert_array_equal(extras['some_key'], ['some_value']) if received_events: self.assertIn('full_event', extras) self.assertLen(extras['full_event'], len(received_events)) for i, event in enumerate(received_events): event = a11y_pb2.EventRequest(event=event) np.testing.assert_array_equal(extras['full_event'][i], event) else: self.assertNotIn('full_event', extras) if received_forests: self.assertIn('accessibility_tree', extras) self.assertLen(extras['accessibility_tree'], len(received_forests)) for i, forest in enumerate(received_forests): np.testing.assert_array_equal(extras['accessibility_tree'][i], forest) else: self.assertNotIn('accessibility_tree', extras) @mock.patch.object(time, 'sleep', autospec=True) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_fetch_task_extras_enable_networking( self, mock_server, mock_add_servicer, mock_sleep, ): del mock_server, mock_add_servicer, mock_sleep base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), 'exception': np.array(['fake exception'], dtype='U'), } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) with mock.patch.object( wrapped_env, 'attempt_enable_networking' ) as mock_attempt_enable_networking: extras = wrapped_env._fetch_task_extras() self.assertNotIn('accessibility_tree', extras) self.assertNotIn('full_event', extras) future = wrapped_env._enabling_networking_future if future is not None: future.result() mock_attempt_enable_networking.assert_called_once() @mock.patch.object(time, 'sleep', autospec=True) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_fetch_task_extras_enable_networking_twice( self, mock_server, mock_add_servicer, mock_sleep, ): del mock_server, mock_add_servicer, mock_sleep mock_context = mock.create_autospec(grpc.ServicerContext, instance=True) base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) wrapped_env.reset() base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), 'exception': np.array(['fake exception'], dtype='U'), } with mock.patch.object( wrapped_env, 'attempt_enable_networking' ) as mock_attempt_enable_networking: extras = wrapped_env._fetch_task_extras() self.assertNotIn('accessibility_tree', extras) self.assertNotIn('full_event', extras) future = wrapped_env._enabling_networking_future if future is not None: future.result() mock_attempt_enable_networking.assert_called_once() # Fixed networking; send a forest so the wrapper knows it worked. wrapped_env._servicer.SendForest(one_window_one_node_forest(), mock_context) base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), } extras = wrapped_env._fetch_task_extras() self.assertIn('accessibility_tree', extras) self.assertEqual(extras['accessibility_tree'].shape[0], 1) self.assertEqual( extras['accessibility_tree'][0], one_window_one_node_forest() ) base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), 'exception': np.array(['fake exception'], dtype='U'), } with mock.patch.object( wrapped_env, 'attempt_enable_networking' ) as mock_attempt_enable_networking: extras = wrapped_env._fetch_task_extras() self.assertNotIn('accessibility_tree', extras) self.assertNotIn('full_event', extras) future = wrapped_env._enabling_networking_future if future is not None: future.result() mock_attempt_enable_networking.assert_called_once() @mock.patch.object(time, 'sleep', autospec=True) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_task_extras_raises_a11y_info_exception( self, mock_sleep, mock_add_servicer, mock_server ): del mock_server, mock_add_servicer, mock_sleep base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) base_env.step.return_value = dm_env.transition( observation={'dummy': 42}, reward=0.0 ) wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, add_latest_a11y_info_to_obs=True, max_enable_networking_attempts=1, ) wrapped_env.reset() base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), 'exception': np.array(['fake exception'], dtype='U'), } with mock.patch.object( wrapped_env, 'attempt_enable_networking' ) as mock_attempt_enable_networking: extras = wrapped_env._fetch_task_extras() self.assertNotIn('accessibility_tree', extras) self.assertNotIn('full_event', extras) # Wait for the the attempt to finish. future = wrapped_env._enabling_networking_future if future is not None: future.result() mock_attempt_enable_networking.assert_called_once() # The _fetch_task_extras() call inside the next step will force a restart self.assertRaises( a11y_grpc_wrapper.EnableNetworkingError, wrapped_env.step, {} ) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_configure_grpc( self, mock_server, mock_add_servicer, ): del mock_server, mock_add_servicer base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.task_extras.return_value = { 'foo': np.array(['bar'], dtype='U'), 'some_key': np.array(['some_value'], dtype='U'), } base_env.stats.return_value = {'relaunch_count': 1} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) with mock.patch.object( wrapped_env, '_configure_grpc' ) as mock_configure_grpc: wrapped_env.reset() mock_configure_grpc.assert_called_once() @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_task_extras_raises_before_reset( self, unused_mock_server, unused_mock_add_servicer ): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) with self.assertRaisesRegex( RuntimeError, r'You must call \.reset\(\) before calling \.task_extras\(\)', ): wrapped_env.task_extras(latest_only=False) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_extras_accumulate_between_steps( self, mock_server, mock_add_servicer ): del mock_server, mock_add_servicer base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) base_env.step.return_value = dm_env.transition( observation={'dummy': 42}, reward=0.0 ) wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, add_latest_a11y_info_to_obs=True ) with mock.patch.object(wrapped_env, '_fetch_task_extras'): wrapped_env._fetch_task_extras.return_value = { 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), } timestep = wrapped_env.reset() self.assertIn('a11y_forest', timestep.observation) self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) wrapped_env._fetch_task_extras.return_value = { 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), 'accessibility_tree': np.array( one_window_two_nodes_forest(), ndmin=1, dtype=object ), } timestep = wrapped_env.step({}) self.assertIn('a11y_forest', timestep.observation) self.assertEqual( timestep.observation['a11y_forest'], one_window_two_nodes_forest() ) timestep = wrapped_env.step({}) self.assertIn('a11y_forest', timestep.observation) self.assertEqual( timestep.observation['a11y_forest'], one_window_two_nodes_forest() ) wrapped_env._fetch_task_extras.return_value = { 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), } timestep = wrapped_env.step({}) self.assertIn('a11y_forest', timestep.observation) self.assertEqual( timestep.observation['a11y_forest'], one_window_two_nodes_forest() ) expected_task_extras = { 'full_event': np.array( [ single_item_dict(), empty_dict(), empty_dict(), single_item_dict(), ], dtype=object, ), 'accessibility_tree': np.array( [ empty_forest(), one_window_two_nodes_forest(), one_window_two_nodes_forest(), ], dtype=object, ), } expected_task_extras_latest = { 'full_event': np.array([single_item_dict()], dtype=object), 'accessibility_tree': np.array( [one_window_two_nodes_forest()], dtype=object ), } task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) task_extras = wrapped_env.task_extras(latest_only=True) np.testing.assert_equal( task_extras['full_event'], expected_task_extras_latest['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras_latest['accessibility_tree'], ) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_a11y_info_disabled( self, unused_mock_server, unused_mock_add_servicer, ): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.action_spec.return_value = { 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) base_env.step.return_value = dm_env.transition( observation={'dummy': 42}, reward=0.0 ) wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, add_latest_a11y_info_to_obs=False, a11y_info_timeout=1.0 ) with mock.patch.object(wrapped_env, '_fetch_task_extras'): wrapped_env._fetch_task_extras.return_value = { 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), } timestep = wrapped_env.reset() self.assertNotIn('a11y_forest', timestep.observation) timestep = wrapped_env.step({}) self.assertNotIn('a11y_forest', timestep.observation) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_a11y_info_with_timer_info_present( self, unused_mock_server, unused_mock_add_servicer, ): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.action_spec.return_value = { 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) base_env.step.return_value = dm_env.transition( observation={'dummy': 42}, reward=0.0 ) wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0 ) with mock.patch.object(wrapped_env, '_fetch_task_extras'): wrapped_env._fetch_task_extras.side_effect = [{ 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), }] timestep = wrapped_env.reset() self.assertIn('a11y_forest', timestep.observation) self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) @mock.patch.object(time, 'sleep', autospec=True) def test_a11y_info_with_timer_task_extra_returned( self, unused_mock_server, unused_mock_add_servicer, unused_mock_sleep ): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.action_spec.return_value = { 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) base_env.step.return_value = dm_env.transition( observation={'dummy': 42}, reward=0.0 ) wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0 ) with mock.patch.object(wrapped_env, '_fetch_task_extras'): wrapped_env._fetch_task_extras.side_effect = [ { 'accessibility_tree': np.array( empty_forest(), ndmin=1, dtype=object ), }, ] timestep = wrapped_env.reset() self.assertIn('a11y_forest', timestep.observation) self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) @mock.patch.object(time, 'sleep', autospec=True) def test_a11y_info_with_timer_from_action( self, unused_mock_server, unused_mock_add_servicer, mock_sleep ): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.action_spec.return_value = { 'action_type': dm_env.specs.Array(shape=(), dtype=np.int32) } base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() base_env.reset.return_value = dm_env.restart(observation={'dummy': 42}) base_env.step.return_value = dm_env.transition( observation={'dummy': 42}, reward=0.0 ) wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=0.0 ) with mock.patch.object(wrapped_env, '_fetch_task_extras'): wrapped_env._fetch_task_extras.side_effect = [ { 'accessibility_tree': np.array( empty_forest(), ndmin=1, dtype=object ), }, ] timestep = wrapped_env.step(action={'wait_time': 1.0}) self.assertIn('a11y_forest', timestep.observation) mock_sleep.assert_called_once() self.assertEqual(timestep.observation['a11y_forest'], empty_forest()) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_task_extras_same_between_calls(self, mock_server, mock_add_servicer): del mock_server, mock_add_servicer base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) expected_task_extras = { 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), } with mock.patch.object(wrapped_env, '_fetch_task_extras'): wrapped_env._fetch_task_extras.return_value = expected_task_extras wrapped_env.reset() task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) expected_task_extras = { 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), 'accessibility_tree': np.array( one_window_two_nodes_forest(), ndmin=1, dtype=object ), } with mock.patch.object(wrapped_env, '_fetch_task_extras'): wrapped_env._fetch_task_extras.return_value = expected_task_extras wrapped_env.step({}) task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) @mock.patch.object( a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True ) @mock.patch.object(grpc, 'server', autospec=True) def test_task_extras_clear_if_called_between_step( self, mock_server, mock_add_servicer ): del mock_server, mock_add_servicer base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.stats.return_value = {'relaunch_count': 0} base_env.execute_adb_call.return_value = _ok_response() wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env) with mock.patch.object(wrapped_env, '_fetch_task_extras'): expected_task_extras = { 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), } wrapped_env._fetch_task_extras.return_value = expected_task_extras wrapped_env.reset() task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) expected_task_extras = { 'full_event': np.array(single_item_dict(), ndmin=1, dtype=object), 'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object), } wrapped_env._fetch_task_extras.return_value = expected_task_extras wrapped_env.step({}) task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) expected_task_extras = { 'full_event': np.array(empty_dict(), ndmin=1, dtype=object), 'accessibility_tree': np.array( one_window_two_nodes_forest(), ndmin=1, dtype=object ), } wrapped_env._fetch_task_extras.return_value = expected_task_extras wrapped_env.step({}) task_extras = wrapped_env.task_extras(latest_only=False) np.testing.assert_equal( task_extras['full_event'], expected_task_extras['full_event'] ) np.testing.assert_equal( task_extras['accessibility_tree'], expected_task_extras['accessibility_tree'], ) @parameterized.named_parameters( ('none_true', False, False, False, 0), ('only_install', True, False, False, 1), ('only_start', False, True, False, 1), ('only_enable_a11y_tree', False, False, True, 1), ('install_and_start_no_a11y_tree', True, True, False, 2), ('install_and_a11y_tree', True, False, True, 2), ('start_and_a11y_tree', False, True, True, 2), ('all_true', True, True, True, 3), ) @mock.patch.object(time, 'sleep', autospec=True) def test_apk_install_and_start( self, install_a11y_forwarding: bool, start_a11y_service: bool, enable_a11y_tree_logs: bool, expected_adb_calls: int, unused_mock_sleep, ): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) side_effects = [] if install_a11y_forwarding: side_effects.append(_ok_response()) # install response if start_a11y_service: side_effects.append(_ok_response()) # start service response if enable_a11y_tree_logs: side_effects.append(_ok_response()) # enable_tree_request base_env.execute_adb_call.side_effect = side_effects _ = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, install_a11y_forwarding=install_a11y_forwarding, start_a11y_service=start_a11y_service, enable_a11y_tree_info=enable_a11y_tree_logs, ) self.assertEqual(base_env.execute_adb_call.call_count, expected_adb_calls) @mock.patch.object(time, 'sleep', autospec=True) def test_component_and_start(self, unused_mock_sleep): base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) side_effects = [] side_effects.append(_ok_response()) # install response side_effects.append(_ok_response()) # start service response side_effects.append(_ok_response()) # enable_tree_request base_env.execute_adb_call.side_effect = side_effects _ = a11y_grpc_wrapper.A11yGrpcWrapper( base_env, install_a11y_forwarding=True, start_a11y_service=True, enable_a11y_tree_info=True, ) # call_args returns a tuple of which the first member is a tuple containing # the most recent args the mock was called with, and execute_adb_call only # has one arg (so [0][0] to access the AdbRequest). self.assertEqual( base_env.execute_adb_call.call_args[0][0].send_broadcast.component, 'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver', ) def test_broadcast_sent_default_grpc_server_ip(self): """Tests that the broadcast sets the default grpc server ip.""" base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.execute_adb_call.return_value = _ok_response() base_env.stats.return_value = {'relaunch_count': 1} wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( env=base_env, disable_other_network_traffic=False, install_a11y_forwarding=False, start_a11y_service=False, enable_a11y_tree_info=False, ) wrapped_env.reset() self.assertStartsWith( base_env.execute_adb_call.call_args[0][0].send_broadcast.action, 'accessibility_forwarder.intent.action.SET_GRPC', ) self.assertIn( '--es "host" 10.0.2.2', base_env.execute_adb_call.call_args[0][0].send_broadcast.action, ) @parameterized.parameters(('127.0.0.1',), ('1.2.3.4',), 'localhost') def test_broadcast_sent_custom_grpc_server_ip(self, grpc_server_ip): """Tests that the broadcast sets the custom grpc server ip.""" base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.execute_adb_call.return_value = _ok_response() base_env.stats.return_value = {'relaunch_count': 1} wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( env=base_env, disable_other_network_traffic=False, install_a11y_forwarding=False, start_a11y_service=False, enable_a11y_tree_info=False, grpc_server_ip=grpc_server_ip, ) wrapped_env.reset() self.assertStartsWith( base_env.execute_adb_call.call_args[0][0].send_broadcast.action, 'accessibility_forwarder.intent.action.SET_GRPC', ) self.assertIn( f'--es "host" {grpc_server_ip}', base_env.execute_adb_call.call_args[0][0].send_broadcast.action, ) def test_broadcast_sent_port(self): """Tests that the broadcast sets the correct port.""" base_env = mock.create_autospec( env_interface.AndroidEnvInterface, instance=True ) base_env.execute_adb_call.return_value = _ok_response() base_env.stats.return_value = {'relaunch_count': 1} wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper( env=base_env, disable_other_network_traffic=False, install_a11y_forwarding=False, start_a11y_service=False, enable_a11y_tree_info=False, ) wrapped_env.reset() self.assertStartsWith( base_env.execute_adb_call.call_args[0][0].send_broadcast.action, 'accessibility_forwarder.intent.action.SET_GRPC', ) self.assertIn( f'--ei "port" {wrapped_env.get_port()}', base_env.execute_adb_call.call_args[0][0].send_broadcast.action, ) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/base_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Base class for AndroidEnv wrappers.""" from typing import Any from absl import logging from android_env import env_interface from android_env.proto import adb_pb2 from android_env.proto import state_pb2 import dm_env from dm_env import specs import numpy as np class BaseWrapper(env_interface.AndroidEnvInterface): """AndroidEnv wrapper.""" def __init__(self, env: env_interface.AndroidEnvInterface) -> None: self._env = env logging.info('Wrapping with %s', self.__class__.__name__) def reset(self) -> dm_env.TimeStep: self._reset_state() timestep = self._process_timestep(self._env.reset()) return timestep def step(self, action: Any) -> dm_env.TimeStep: action = self._process_action(action) return self._process_timestep(self._env.step(action)) def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]: return self._env.task_extras(latest_only=latest_only) def _reset_state(self): pass def _process_action(self, action: Any) -> Any: return action def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: return timestep def observation_spec(self) -> dict[str, specs.Array]: return self._env.observation_spec() def action_spec(self) -> dict[str, specs.Array]: return self._env.action_spec() def reward_spec(self) -> specs.Array: return self._env.reward_spec() def discount_spec(self) -> specs.Array: return self._env.discount_spec() def _wrapper_stats(self) -> dict[str, Any]: """Add wrapper specific logging here.""" return {} def stats(self) -> dict[str, Any]: info = self._env.stats() info.update(self._wrapper_stats()) return info def load_state( self, request: state_pb2.LoadStateRequest ) -> state_pb2.LoadStateResponse: """Loads a state.""" return self._env.load_state(request) def save_state( self, request: state_pb2.SaveStateRequest ) -> state_pb2.SaveStateResponse: """Saves a state. Args: request: A `SaveStateRequest` containing any parameters necessary to specify how/what state to save. Returns: A `SaveStateResponse` containing the status, error message (if applicable), and any other relevant information. """ return self._env.save_state(request) def execute_adb_call( self, adb_call: adb_pb2.AdbRequest ) -> adb_pb2.AdbResponse: return self._env.execute_adb_call(adb_call) @property def raw_action(self) -> Any: return self._env.raw_action @property def raw_observation(self) -> Any: return self._env.raw_observation @property def raw_env(self) -> env_interface.AndroidEnvInterface: """Recursively unwrap until we reach the true 'raw' env.""" wrapped = self._env if hasattr(wrapped, 'raw_env'): return wrapped.raw_env return wrapped def __getattr__(self, attr) -> Any: """Delegate attribute access to underlying environment.""" return getattr(self._env, attr) def close(self) -> None: self._env.close() ================================================ FILE: android_env/wrappers/base_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.wrappers.base_wrapper.""" from unittest import mock from absl import logging from absl.testing import absltest from android_env import env_interface from android_env.proto import state_pb2 from android_env.wrappers import base_wrapper class BaseWrapperTest(absltest.TestCase): @mock.patch.object(logging, 'info') def test_base_function_forwarding(self, mock_info): base_env = mock.create_autospec(env_interface.AndroidEnvInterface) wrapped_env = base_wrapper.BaseWrapper(base_env) mock_info.assert_called_with('Wrapping with %s', 'BaseWrapper') fake_ts = 'fake_ts' base_env.reset.return_value = fake_ts self.assertEqual(fake_ts, wrapped_env.reset()) base_env.reset.assert_called_once() fake_ts = 'fake_ts' fake_action = 'fake_action' base_env.step.return_value = fake_ts self.assertEqual(fake_ts, wrapped_env.step(fake_action)) base_env.step.assert_called_once_with(fake_action) fake_extras = 'fake_task_extras' base_env.task_extras.return_value = fake_extras self.assertEqual(fake_extras, wrapped_env.task_extras(latest_only=True)) base_env.task_extras.assert_called_once_with(latest_only=True) fake_obs_spec = 'fake_obs_spec' base_env.observation_spec.return_value = fake_obs_spec self.assertEqual(fake_obs_spec, wrapped_env.observation_spec()) base_env.observation_spec.assert_called_once() fake_action_spec = 'fake_action_spec' base_env.action_spec.return_value = fake_action_spec self.assertEqual(fake_action_spec, wrapped_env.action_spec()) base_env.action_spec.assert_called_once() fake_raw_action = 'fake_raw_action' type(base_env).raw_action = mock.PropertyMock(return_value=fake_raw_action) self.assertEqual(fake_raw_action, wrapped_env.raw_action) fake_raw_observation = 'fake_raw_observation' type(base_env).raw_observation = mock.PropertyMock( return_value=fake_raw_observation) self.assertEqual(fake_raw_observation, wrapped_env.raw_observation) load_request = state_pb2.LoadStateRequest(args={}) expected_response = state_pb2.LoadStateResponse( status=state_pb2.LoadStateResponse.Status.OK ) base_env.load_state.return_value = expected_response self.assertEqual(wrapped_env.load_state(load_request), expected_response) base_env.load_state.assert_called_once_with(load_request) save_request = state_pb2.SaveStateRequest(args={}) expected_response = state_pb2.SaveStateResponse( status=state_pb2.SaveStateResponse.Status.OK ) base_env.save_state.return_value = expected_response self.assertEqual(wrapped_env.save_state(save_request), expected_response) base_env.save_state.assert_called_once_with(save_request) wrapped_env.close() base_env.close.assert_called_once() fake_return_value = 'fake' # AndroidEnv::some_random_function() does not exist and calling it should # raise an AttributeError. with self.assertRaises(AttributeError): base_env.some_random_function.return_value = fake_return_value def test_multiple_wrappers(self): base_env = mock.create_autospec(env_interface.AndroidEnvInterface) wrapped_env_1 = base_wrapper.BaseWrapper(base_env) wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1) wrapped_env_2.close() base_env.close.assert_called_once() def test_raw_env(self): base_env = mock.create_autospec(env_interface.AndroidEnvInterface) wrapped_env_1 = base_wrapper.BaseWrapper(base_env) wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1) self.assertEqual(base_env, wrapped_env_2.raw_env) def test_stats(self): base_env = mock.create_autospec(env_interface.AndroidEnvInterface) wrapped_env = base_wrapper.BaseWrapper(base_env) base_stats = {'base': 'stats'} base_env.stats.return_value = base_stats self.assertEqual(base_stats, wrapped_env.stats()) @mock.patch.object(logging, 'info') def test_wrapped_stats(self, mock_info): base_env = mock.create_autospec(env_interface.AndroidEnvInterface) class LoggingWrapper1(base_wrapper.BaseWrapper): def _wrapper_stats(self): return { 'wrapper1': 'stats', 'shared': 1, } class LoggingWrapper2(base_wrapper.BaseWrapper): def _wrapper_stats(self): return { 'wrapper2': 'stats', 'shared': 2, } wrapped_env = LoggingWrapper2(LoggingWrapper1(base_env)) mock_info.assert_has_calls([ mock.call('Wrapping with %s', 'LoggingWrapper1'), mock.call('Wrapping with %s', 'LoggingWrapper2'), ]) base_stats = {'base': 'stats'} base_env.stats.return_value = base_stats expected_stats = { 'base': 'stats', 'wrapper1': 'stats', 'wrapper2': 'stats', 'shared': 2, } self.assertEqual(expected_stats, wrapped_env.stats()) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/discrete_action_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wraps the AndroidEnv environment to provide discrete actions.""" from collections.abc import Sequence from typing import cast from android_env import env_interface from android_env.components import action_type from android_env.wrappers import base_wrapper import dm_env from dm_env import specs import numpy as np _NOISE_CLIP_VALUE = 0.4999 class DiscreteActionWrapper(base_wrapper.BaseWrapper): """AndroidEnv with discrete actions.""" def __init__( self, env: env_interface.AndroidEnvInterface, action_grid: Sequence[int] = (10, 10), redundant_actions: bool = True, noise: float = 0.1, ) -> None: super().__init__(env) self._parent_action_spec = self._env.action_spec() self._assert_base_env() self._action_grid = action_grid # [height, width] self._grid_size = np.prod(self._action_grid) action_types = cast( specs.DiscreteArray, self._parent_action_spec['action_type'] ) self._num_action_types = action_types.num_values self._redundant_actions = redundant_actions self._noise = noise def _assert_base_env(self) -> None: """Checks that the wrapped env has the right action spec format.""" assert len(self._parent_action_spec) == 2 assert not self._parent_action_spec['action_type'].shape assert self._parent_action_spec['touch_position'].shape == (2,) @property def num_actions(self) -> int: """Number of discrete actions.""" if self._redundant_actions: return self._grid_size * self._num_action_types else: return self._grid_size + self._num_action_types - 1 def step(self, action: dict[str, int]) -> dm_env.TimeStep: """Take a step in the base environment.""" return self._env.step(self._process_action(action)) def _process_action(self, action: dict[str, int]) -> dict[str, np.ndarray]: """Transforms action so that it agrees with AndroidEnv's action spec.""" return { 'action_type': np.array(self._get_action_type(action['action_id']), dtype=self._parent_action_spec['action_type'].dtype), 'touch_position': np.array(self._get_touch_position(action['action_id']), dtype=self._parent_action_spec['touch_position'].dtype) } def _get_action_type(self, action_id: int) -> action_type.ActionType: """Compute action type corresponding to the given action_id. When `self._redundant_actions` == True the `grid_size` is "broadcast" over all the possible actions so you end up with `grid_size` discrete actions of type 0, `grid_size` discrete actions of type 1, etc. for all action types. When `self._redundant_actions` == False the first `grid_size` actions are reserved for "touch" and the rest are just added (NOT multiplied) to the total number of discrete actions (exactly one of LIFT and REPEAT). Args: action_id: A discrete action. Returns: action_type: The action_type of the action. """ if self._redundant_actions: assert action_id < self._num_action_types * self._grid_size return action_id // self._grid_size else: assert action_id <= self._grid_size + 1 if action_id < self._grid_size: return action_type.ActionType.TOUCH elif action_id == self._grid_size: return action_type.ActionType.LIFT else: return action_type.ActionType.REPEAT def _get_touch_position(self, action_id: int) -> Sequence[float]: """Compute the position corresponding to the given action_id. Note: in the touch_position (x, y) of an action, x corresponds to the horizontal axis (width), and y corresponds to the vertical axis (height) of the screen. BUT, the screen has dimensions (height, width), i.e. the first coordinate corresponds to y, and the second coordinate corresponds to x. Pay attention to this mismatch in the calculations below. Args: action_id: A discrete action. Returns: touch_position: The [0,1]x[0,1] coordinate of the action. """ position_idx = action_id % self._grid_size x_pos_grid = position_idx % self._action_grid[1] # WIDTH y_pos_grid = position_idx // self._action_grid[1] # HEIGHT noise_x = np.random.normal(loc=0.0, scale=self._noise) noise_y = np.random.normal(loc=0.0, scale=self._noise) # Noise is clipped so that the action will strictly stay in the cell. noise_x = max(min(noise_x, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE) noise_y = max(min(noise_y, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE) x_pos = (x_pos_grid + 0.5 + noise_x) / self._action_grid[1] # WIDTH y_pos = (y_pos_grid + 0.5 + noise_y) / self._action_grid[0] # HEIGHT # Project action space to action_spec ranges. For the default case of # minimum = [0, 0] and maximum = [1, 1], this will not do anything. x_min, y_min = cast( specs.BoundedArray, self._parent_action_spec['touch_position'] ).minimum x_max, y_max = cast( specs.BoundedArray, self._parent_action_spec['touch_position'] ).maximum x_pos = x_min + x_pos * (x_max - x_min) y_pos = y_min + y_pos * (y_max - y_min) return [x_pos, y_pos] def action_spec(self) -> dict[str, specs.Array]: """Action spec of the wrapped environment.""" return { 'action_id': specs.DiscreteArray( num_values=self.num_actions, name='action_id') } ================================================ FILE: android_env/wrappers/discrete_action_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.wrappers.discrete_action_wrapper.""" from unittest import mock from absl.testing import absltest from android_env import env_interface from android_env.components import action_type as action_type_lib from android_env.wrappers import discrete_action_wrapper from dm_env import specs import numpy as np ActionType = action_type_lib.ActionType def _make_array_spec(shape, dtype, name): assert len(shape) == 1 return specs.BoundedArray( name=name, shape=shape, dtype=dtype, minimum=np.zeros(shape), maximum=(shape[0] - 1) * np.ones(shape), # maximum is inclusive. ) def _valid_shape(action): assert len(action) == 2, action assert not action['action_type'].shape, ( 'action: %r, shape: %r' % (action['action_type'], action['action_type'].shape)) assert action['touch_position'].shape == ( 2,), ('action: %r, shape: %r' % (action['touch_position'], action['touch_position'].shape)) def _valid_types(action, types): for a, t in zip(action.values(), types): assert a.dtype == t, '%r is not of dtype %r' % (a, t) class DiscreteActionWrapperTest(absltest.TestCase): def setUp(self): super().setUp() self._num_action_types = 3 # Only TOUCH, LIFT, REPEAT. self._base_action_spec = { 'action_type': specs.DiscreteArray( num_values=self._num_action_types, name='action_type'), 'touch_position': _make_array_spec( shape=(2,), dtype=np.float32, name='touch_position'), } self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface) self.base_env.action_spec.return_value = self._base_action_spec def test_num_actions(self): wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(3, 3), redundant_actions=True) # 27 = 3 * 3 * 2 (H * W * self._num_action_types). self.assertEqual(27, wrapped_env.num_actions) def test_num_actions_non_redundant(self): # Check that with `redundant_actions`==False we get an additive term instead # of a multiplier in the number of actions. non_redudant_wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(3, 3), redundant_actions=False) # 11 = 3 * 3 + 2 (H * W + (self._num_action_types - 1)). self.assertEqual(11, non_redudant_wrapped_env.num_actions) def test_reset(self): wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, redundant_actions=True) fake_timestep = 'ts' self.base_env.reset.return_value = fake_timestep ts = wrapped_env.reset() self.base_env.reset.assert_called_once() self.assertEqual(fake_timestep, ts) def test_step_no_noise(self): height = 4 width = 3 wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(height, width), noise=0.0, redundant_actions=True) self.assertEqual(height * width * self._num_action_types, wrapped_env.num_actions) vertical_half_step = 1. / float(height) / 2. horizontal_half_step = 1. / float(width) / 2. delta = 0.0001 # Testing the four corners with each finger position def get_verifier(expected_action_type, lower_x, lower_y): def verifier(x): _valid_shape(x) _valid_types(x, [np.int32, np.float32]) self.assertEqual( expected_action_type, x['action_type']) if lower_y: self.assertAlmostEqual( vertical_half_step, x['touch_position'][1], delta=delta) else: self.assertAlmostEqual( 1 - vertical_half_step, x['touch_position'][1], delta=delta) if lower_x: self.assertAlmostEqual( horizontal_half_step, x['touch_position'][0], delta=delta) else: self.assertAlmostEqual( 1 - horizontal_half_step, x['touch_position'][0], delta=delta) return True return verifier action_tests = { 0: get_verifier(0, lower_x=True, lower_y=True), 2: get_verifier(0, lower_x=False, lower_y=True), 9: get_verifier(0, lower_x=True, lower_y=False), 11: get_verifier(0, lower_x=False, lower_y=False), 12: get_verifier(1, lower_x=True, lower_y=True), 14: get_verifier(1, lower_x=False, lower_y=True), 21: get_verifier(1, lower_x=True, lower_y=False), 23: get_verifier(1, lower_x=False, lower_y=False), 24: get_verifier(2, lower_x=True, lower_y=True), 26: get_verifier(2, lower_x=False, lower_y=True), 33: get_verifier(2, lower_x=True, lower_y=False), 35: get_verifier(2, lower_x=False, lower_y=False), } fake_timestep = 'ts' self.base_env.step.return_value = fake_timestep for action_id, verifier in action_tests.items(): ts = wrapped_env.step({'action_id': action_id}) verifier(self.base_env.step.call_args[0][0]) self.assertEqual(fake_timestep, ts) def test_step_redundant_actions_invalid_action_id(self): wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(4, 3), noise=0.0, redundant_actions=True) with self.assertRaises(AssertionError): _ = wrapped_env.step({'action_id': 36}) def test_step_no_noise_no_redudant_actions(self): height = 4 width = 3 wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(height, width), noise=0.0, redundant_actions=False) self.assertEqual(height * width + (self._num_action_types - 1), wrapped_env.num_actions) vertical_half_step = 1. / float(height) / 2. horizontal_half_step = 1. / float(width) / 2. delta = 0.0001 # Testing the four corners with each finger position def get_verifier(expected_action_type, lower_x, lower_y): def verifier(x): _valid_shape(x) _valid_types(x, [np.int32, np.float32]) self.assertEqual(expected_action_type, x['action_type']) # If the action type == TOUCH, then check the coordinate values. if x['action_type'] == ActionType.TOUCH: if lower_y: self.assertAlmostEqual( vertical_half_step, x['touch_position'][1], delta=delta) else: self.assertAlmostEqual( 1 - vertical_half_step, x['touch_position'][1], delta=delta) if lower_x: self.assertAlmostEqual( horizontal_half_step, x['touch_position'][0], delta=delta) else: self.assertAlmostEqual( 1 - horizontal_half_step, x['touch_position'][0], delta=delta) return True return verifier action_tests = { # Touch type actions 0: get_verifier(0, lower_x=True, lower_y=True), 2: get_verifier(0, lower_x=False, lower_y=True), 9: get_verifier(0, lower_x=True, lower_y=False), 11: get_verifier(0, lower_x=False, lower_y=False), # Actions > grid_size return non-touch actions with (0,0) coordinates. 12: get_verifier(1, lower_x=False, lower_y=False), 13: get_verifier(2, lower_x=False, lower_y=False), } fake_timestep = 'ts' self.base_env.step.return_value = fake_timestep for action_id, verifier in action_tests.items(): ts = wrapped_env.step({'action_id': action_id}) verifier(self.base_env.step.call_args[0][0]) self.assertEqual(fake_timestep, ts) def test_step_no_redundant_actions_invalid_action_id(self): wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(4, 3), noise=0.0, redundant_actions=False) with self.assertRaises(AssertionError): _ = wrapped_env.step({'action_id': 14}) def test_step_with_noise(self): height = 4 width = 3 wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(height, width), noise=1.0) self.assertEqual(height * width * self._num_action_types, wrapped_env.num_actions) vertical_grid_step = 1. / float(height) horizontal_grid_step = 1. / float(width) # Testing the four corners with each finger position def get_verifier(expected_up_down, lower_x, lower_y): def verifier(x): _valid_shape(x) _valid_types(x, [np.int32, np.float32]) self.assertEqual(expected_up_down, x['action_type']) if lower_y: self.assertGreater(vertical_grid_step, x['touch_position'][1]) else: self.assertLess(1 - vertical_grid_step, x['touch_position'][1]) if lower_x: self.assertGreater(horizontal_grid_step, x['touch_position'][0]) else: self.assertLess(1 - horizontal_grid_step, x['touch_position'][0]) return True return verifier action_tests = { 0: get_verifier(0, lower_x=True, lower_y=True), 2: get_verifier(0, lower_x=False, lower_y=True), 9: get_verifier(0, lower_x=True, lower_y=False), 11: get_verifier(0, lower_x=False, lower_y=False), 12: get_verifier(1, lower_x=True, lower_y=True), 14: get_verifier(1, lower_x=False, lower_y=True), 21: get_verifier(1, lower_x=True, lower_y=False), 23: get_verifier(1, lower_x=False, lower_y=False), 24: get_verifier(2, lower_x=True, lower_y=True), 26: get_verifier(2, lower_x=False, lower_y=True), 33: get_verifier(2, lower_x=True, lower_y=False), 35: get_verifier(2, lower_x=False, lower_y=False), } fake_timestep = 'ts' self.base_env.step.return_value = fake_timestep for action_id, verifier in action_tests.items(): ts = wrapped_env.step({'action_id': action_id}) verifier(self.base_env.step.call_args[0][0]) self.assertEqual(fake_timestep, ts) def test_parent_spec_type(self): base_action_spec = { 'action_type': specs.DiscreteArray( num_values=self._num_action_types, name='action_type'), 'touch_position': _make_array_spec( shape=(2,), dtype=np.float64, name='touch_position'), } base_env = mock.create_autospec(env_interface.AndroidEnvInterface) base_env.action_spec.return_value = base_action_spec wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( base_env, noise=0.0) fake_timestep = 'ts' base_env.step.return_value = fake_timestep def verifier(x): _valid_types(x, [np.int32, np.float64]) return True ts = wrapped_env.step({'action_id': 1}) verifier(base_env.step.call_args[0][0]) self.assertEqual(fake_timestep, ts) def test_observation_spec(self): wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env) fake_obs_spec = 'fake_obs_spec' self.base_env.observation_spec.return_value = fake_obs_spec observation_spec = wrapped_env.observation_spec() self.base_env.observation_spec.assert_called_once() self.assertEqual(fake_obs_spec, observation_spec) def test_action_spec(self): wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(4, 5), redundant_actions=True) expected_action_spec = { 'action_id': specs.DiscreteArray( num_values=4 * 5 * self._num_action_types, name='action_type') } self.assertEqual(expected_action_spec, wrapped_env.action_spec()) def test_action_spec_non_redundant(self): wrapped_env = discrete_action_wrapper.DiscreteActionWrapper( self.base_env, action_grid=(4, 5), redundant_actions=False) num_non_touch_actions = self._num_action_types - 1 expected_action_spec = { 'action_id': specs.DiscreteArray( num_values=4 * 5 + num_non_touch_actions, name='action_type') } self.assertEqual(expected_action_spec, wrapped_env.action_spec()) def test_assert_base_env_action_spec_too_short(self): self.base_env.action_spec.return_value = { 'action_type': specs.DiscreteArray( num_values=self._num_action_types, name='action_type'), } with self.assertRaises(AssertionError): _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) def test_assert_base_env_action_spec_too_long(self): self.base_env.action_spec.return_value = { 'action_type': specs.DiscreteArray( num_values=self._num_action_types, name='action_type'), 'touch_position': _make_array_spec( shape=(2,), dtype=np.float32, name='touch_position'), 'too_long': _make_array_spec( shape=(1,), dtype=np.float32, name='too_long'), } with self.assertRaises(AssertionError): _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) def test_assert_base_env_action_spec_wrong_shapes(self): self.base_env.action_spec.return_value = { 'action_type': _make_array_spec( shape=(2,), dtype=np.float32, name='action_type'), 'touch_position': _make_array_spec( shape=(1,), dtype=np.float32, name='touch_position') } with self.assertRaises(AssertionError): _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) def test_assert_base_env_ok(self): self.base_env.action_spec.return_value = { 'action_type': specs.DiscreteArray( num_values=self._num_action_types, name='action_type'), 'touch_position': _make_array_spec( shape=(2,), dtype=np.float32, name='touch_position'), } _ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/flat_interface_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wraps the AndroidEnv environment to make its interface flat.""" from typing import Any, cast from android_env import env_interface from android_env.wrappers import base_wrapper import dm_env from dm_env import specs import numpy as np RGB_CHANNELS = (0, 1, 2) def _extract_screen_pixels(obs: np.ndarray) -> np.ndarray: """Get only screen pixels by removing previous action layer.""" is_grayscale_image = obs.shape[-1] == 2 if is_grayscale_image: return np.expand_dims(obs[..., 0], -1) return obs[..., RGB_CHANNELS] def _get_no_action_observation_spec( obs_spec: specs.BoundedArray, ) -> specs.BoundedArray: """Create an observation spec without the action layer.""" shape = np.array(obs_spec.shape) shape[2] -= 1 minimum = obs_spec.minimum maximum = obs_spec.maximum is_scalar = lambda x: np.isscalar(x) or np.ndim(x) == 0 if not is_scalar(minimum): minimum = _extract_screen_pixels(minimum) if not is_scalar(maximum): maximum = _extract_screen_pixels(maximum) return obs_spec.replace(shape=shape, minimum=minimum, maximum=maximum) class FlatInterfaceWrapper(base_wrapper.BaseWrapper): """Simple interface for AndroidEnv. Removes the structure from observations and actions, keeping only the pixel observations. Also exposes action as an int32 scalar, making it easier to use with conventional discrete agents. This wrapper expects a discretized action space. """ def __init__( self, env: env_interface.AndroidEnvInterface, flat_actions: bool = True, flat_observations: bool = True, keep_action_layer: bool = True, ) -> None: super().__init__(env) self._flat_actions = flat_actions self._flat_observations = flat_observations self._keep_action_layer = keep_action_layer self._action_name = list(self._env.action_spec())[0] self._assert_base_env() def _assert_base_env(self) -> None: base_action_spec = self._env.action_spec() assert len(base_action_spec) == 1, self._env.action_spec() assert isinstance(base_action_spec, dict) assert isinstance(base_action_spec[self._action_name], specs.BoundedArray) def _process_action( self, action: int | np.ndarray | dict[str, Any] ) -> int | np.ndarray | dict[str, Any]: if self._flat_actions: return {self._action_name: action} else: return action def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: if self._flat_observations: step_type, reward, discount, observation = timestep # Keep only the pixels. pixels = observation['pixels'] pixels = ( pixels if self._keep_action_layer else _extract_screen_pixels(pixels) ) return dm_env.TimeStep( step_type=step_type, reward=reward, discount=discount, observation=pixels, ) else: return timestep def reset(self) -> dm_env.TimeStep: timestep = self._env.reset() return self._process_timestep(timestep) def step(self, action: int) -> dm_env.TimeStep: timestep = self._env.step(self._process_action(action)) return self._process_timestep(timestep) def observation_spec(self) -> specs.Array | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks if self._flat_observations: pixels_spec = cast( specs.BoundedArray, self._env.observation_spec()['pixels'] ) if not self._keep_action_layer: return _get_no_action_observation_spec(pixels_spec) return pixels_spec else: return self._env.observation_spec() def action_spec(self) -> specs.BoundedArray | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks if self._flat_actions: return self._env.action_spec()[self._action_name] # pytype: disable=bad-return-type else: return self._env.action_spec() ================================================ FILE: android_env/wrappers/flat_interface_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.wrappers.flat_interface_wrapper.""" from typing import cast from unittest import mock from absl.testing import absltest from android_env.wrappers import flat_interface_wrapper import dm_env from dm_env import specs import numpy as np def _make_array_spec(shape, dtype=np.float32, name=None, maximum=3, minimum=0): return specs.BoundedArray( shape=shape, dtype=dtype, name=name, maximum=np.ones(shape) * maximum, minimum=np.ones(shape) * minimum) def _make_timestep(observation): return dm_env.TimeStep( step_type='fake_step_type', reward='fake_reward', discount='fake_discount', observation=observation, ) class FlatInterfaceWrapperTest(absltest.TestCase): def setUp(self): super().setUp() self.action_shape = (1,) self.base_action_spec: dict[str, specs.DiscreteArray] = { 'action_id': specs.DiscreteArray(name='action_id', num_values=4) } self.int_obs_shape = (3, 4, 2) self.float_obs_shape = (2,) self.base_observation_spec = { 'pixels': _make_array_spec( shape=self.int_obs_shape, dtype=np.uint8, name='pixels'), 'obs1': _make_array_spec( shape=self.float_obs_shape, dtype=np.float32, name='obs1'), } # Expected. self.expected_observation_spec = _make_array_spec( shape=self.int_obs_shape, dtype=np.uint8, name='pixels') self.image_obs = np.ones(self.int_obs_shape, dtype=np.uint8) self.expected_timestep = _make_timestep(self.image_obs) # Expected for no new action layer shape. expected_new_shape_no_action_layer = (3, 4, 1) self.expected_observation_spec_no_action_layer = _make_array_spec( shape=expected_new_shape_no_action_layer, dtype=np.uint8, name='pixels') self.expected_timestep_no_action_layer = _make_timestep( np.ones(expected_new_shape_no_action_layer, dtype=np.uint8)) # Base environment. self.other_obs = np.ones(self.float_obs_shape, dtype=np.float32) self.base_timestep = _make_timestep({ 'pixels': self.image_obs, 'obs1': self.other_obs}) self.base_env = mock.create_autospec(dm_env.Environment) self.base_env.action_spec.return_value = self.base_action_spec self.base_env.observation_spec.return_value = self.base_observation_spec self.base_env.reset.return_value = self.base_timestep self.base_env.step.return_value = self.base_timestep def test_reset(self): wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) ts = wrapped_env.reset() self.base_env.reset.assert_called_once() self.assertEqual(self.expected_timestep, ts) def test_reset_no_action_layer(self): wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper( self.base_env, keep_action_layer=False) ts = wrapped_env.reset() self.base_env.reset.assert_called_once() self.assertEqual( self.expected_timestep_no_action_layer.observation.tolist(), ts.observation.tolist()) def test_step(self): wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) action = 2 ts = wrapped_env.step(action) def verifier(x): self.assertIsInstance(x, dict) self.assertIsInstance(x['action_id'], int) self.assertEqual(x['action_id'], action) return True verifier(self.base_env.step.call_args[0][0]) self.assertEqual(self.expected_timestep, ts) def test_step_no_action_layer(self): wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper( self.base_env, keep_action_layer=False) action = 2 ts = wrapped_env.step(action) def verifier(x): self.assertIsInstance(x, dict) self.assertIsInstance(x['action_id'], int) self.assertEqual(x['action_id'], action) return True verifier(self.base_env.step.call_args[0][0]) self.assertEqual( self.expected_timestep_no_action_layer.observation.tolist(), ts.observation.tolist()) def test_observation_spec(self): wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) observation_spec = wrapped_env.observation_spec() self.base_env.observation_spec.assert_called_once() self.assertEqual(self.expected_observation_spec, observation_spec) def test_observation_spec_no_action_layer(self): wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper( self.base_env, keep_action_layer=False) observation_spec = wrapped_env.observation_spec() self.base_env.observation_spec.assert_called_once() self.assertEqual(self.expected_observation_spec_no_action_layer, observation_spec) def test_action_spec(self): wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env) action_spec = cast(specs.BoundedArray, wrapped_env.action_spec()) parent_action_spec = self.base_action_spec['action_id'] self.assertEqual(parent_action_spec.name, action_spec.name) self.assertEqual((), action_spec.shape) self.assertEqual(np.int32, action_spec.dtype) self.assertEqual(0, action_spec.minimum) def test_bad_action_spec_structured_action(self): bad_base_env = mock.create_autospec(dm_env.Environment) bad_base_env.action_spec.return_value = { 'action_id': _make_array_spec((1,)), 'too_many': _make_array_spec((1,)) } with self.assertRaises(AssertionError): _ = flat_interface_wrapper.FlatInterfaceWrapper(bad_base_env) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/float_pixels_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Converts pixel observation to from int to float32 between 0.0 and 1.0.""" from android_env import env_interface from android_env.components import pixel_fns from android_env.wrappers import base_wrapper import dm_env from dm_env import specs import numpy as np class FloatPixelsWrapper(base_wrapper.BaseWrapper): """Wraps AndroidEnv for Panultimate agent.""" def __init__(self, env: env_interface.AndroidEnvInterface) -> None: super().__init__(env) self._input_spec = self._env.observation_spec()['pixels'] self._should_convert_int_to_float = np.issubdtype(self._input_spec.dtype, np.integer) def _process_observation( self, observation: dict[str, np.ndarray] ) -> dict[str, np.ndarray]: if self._should_convert_int_to_float: float_pixels = pixel_fns.convert_int_to_float( observation['pixels'], self._input_spec ) observation['pixels'] = float_pixels return observation def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: step_type, reward, discount, observation = timestep return dm_env.TimeStep( step_type=step_type, reward=reward, discount=discount, observation=self._process_observation(observation)) def reset(self) -> dm_env.TimeStep: return self._process_timestep(self._env.reset()) def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: return self._process_timestep(self._env.step(action)) def observation_spec(self) -> dict[str, specs.Array]: if self._should_convert_int_to_float: observation_spec = self._env.observation_spec() observation_spec['pixels'] = specs.BoundedArray( shape=self._env.observation_spec()['pixels'].shape, dtype=np.float32, minimum=0.0, maximum=1.0, name=self._env.observation_spec()['pixels'].name) return observation_spec return self._env.observation_spec() ================================================ FILE: android_env/wrappers/float_pixels_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.wrappers.float_pixels_wrapper.""" from unittest import mock from absl.testing import absltest from android_env.wrappers import float_pixels_wrapper import dm_env from dm_env import specs import numpy as np def _make_array_spec(shape, dtype=np.float32, name=None): return specs.Array( shape=shape, dtype=dtype, name=name, ) def _make_bounded_array_spec( shape, dtype=np.float32, name=None, maximum=1.0, minimum=0.0): return specs.BoundedArray( shape=shape, dtype=dtype, name=name, maximum=maximum, minimum=minimum, ) def _simple_timestep(obs_shape, obs_type): return dm_env.TimeStep( step_type=dm_env.StepType.MID, reward=3.14, discount=0.9, observation=(np.ones(shape=obs_shape, dtype=obs_type),), ) class FloatPixelsWrapperTest(absltest.TestCase): def setUp(self): super().setUp() self.pixels_shape = (3, 4) base_pixel_spec = _make_array_spec( shape=self.pixels_shape, dtype=np.uint8, name='pixels') self.other_obs_spec = _make_array_spec( shape=(1,), dtype=np.float32, name='other_obs') base_observation_spec = { 'pixels': base_pixel_spec, 'other_obs': self.other_obs_spec } self.base_env = mock.create_autospec(dm_env.Environment) self.base_env.observation_spec.return_value = base_observation_spec self.base_timestep = dm_env.TimeStep( step_type=dm_env.StepType.MID, reward=3.14, discount=0.9, observation={ 'pixels': np.ones(shape=self.pixels_shape, dtype=np.uint8), 'other_obs': [42.2]}) self.base_env.step.return_value = self.base_timestep self.base_env.reset.return_value = self.base_timestep def test_float_pixels_wrapper_spec(self): expected_pixel_spec = _make_bounded_array_spec( shape=self.pixels_shape, dtype=np.float32, name='pixels', minimum=0.0, maximum=1.0) wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env) self.assertLen(wrapped_env.observation_spec(), 2) self.assertEqual(expected_pixel_spec, wrapped_env.observation_spec()['pixels']) self.assertEqual(self.other_obs_spec, wrapped_env.observation_spec()['other_obs']) def test_float_pixels_wrapper_step(self): wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env) ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])}) self.assertEqual(self.base_timestep.step_type, ts.step_type) self.assertEqual(self.base_timestep.reward, ts.reward) self.assertEqual(self.base_timestep.discount, ts.discount) self.assertEqual(self.base_timestep.observation['other_obs'], ts.observation['other_obs']) expected_pixel_value = 1. / 255. # original values are unit8 expected_pixels = np.ones( self.pixels_shape, dtype=np.float32) * expected_pixel_value np.testing.assert_equal(expected_pixels, ts.observation['pixels']) def test_float_pixels_wrapper_reset(self): wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env) ts = wrapped_env.reset() self.assertEqual(self.base_timestep.step_type, ts.step_type) self.assertEqual(self.base_timestep.reward, ts.reward) self.assertEqual(self.base_timestep.discount, ts.discount) self.assertEqual(self.base_timestep.observation['other_obs'], ts.observation['other_obs']) expected_pixel_value = 1. / 255. # original values are unit8 expected_pixels = np.ones( self.pixels_shape, dtype=np.float32) * expected_pixel_value np.testing.assert_equal(expected_pixels, ts.observation['pixels']) def test_float_pixels_wrapper_already_float(self): base_pixel_spec = _make_array_spec( shape=self.pixels_shape, dtype=np.float64, name='pixels') base_observation_spec = { 'pixels': base_pixel_spec, 'other_obs': self.other_obs_spec } base_env = mock.create_autospec(dm_env.Environment) base_env.observation_spec.return_value = base_observation_spec wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(base_env) # If the pixels are already float values, then obs_spec does not change. self.assertEqual(base_env.observation_spec(), wrapped_env.observation_spec()) # The wrapper should not touch the timestep in this case. fake_timestep = ('step_type', 'reward', 'discount', 'obs') base_env.step.return_value = fake_timestep ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])}) self.assertEqual(fake_timestep, ts) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/gym_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wraps the AndroidEnv to expose an OpenAI Gym interface.""" from typing import Any from android_env.wrappers import base_wrapper import dm_env from dm_env import specs import gym from gym import spaces import numpy as np class GymInterfaceWrapper(gym.Env): """AndroidEnv with OpenAI Gym interface.""" def __init__(self, env: dm_env.Environment): self._env = env self.spec = None self.action_space = self._spec_to_space(self._env.action_spec()) self.observation_space = self._spec_to_space(self._env.observation_spec()) self.metadata = {'render.modes': ['rgb_array']} self._latest_observation = None def _spec_to_space(self, spec: specs.Array) -> spaces.Space: """Converts dm_env specs to OpenAI Gym spaces.""" if isinstance(spec, list): return spaces.Tuple([self._spec_to_space(s) for s in spec]) if isinstance(spec, dict): return spaces.Dict( {name: self._spec_to_space(s) for name, s in spec.items()} ) if isinstance(spec, specs.DiscreteArray): return spaces.Box( shape=(), dtype=spec.dtype, low=0, high=spec.num_values-1) if isinstance(spec, specs.BoundedArray): return spaces.Box( shape=spec.shape, dtype=spec.dtype, low=spec.minimum, high=spec.maximum) if isinstance(spec, specs.Array): if spec.dtype == np.uint8: low = 0 high = 255 else: low = -np.inf high = np.inf return spaces.Box(shape=spec.shape, dtype=spec.dtype, low=low, high=high) raise ValueError('Unknown type for specs: {}'.format(spec)) def render(self, mode='rgb_array'): """Renders the environment.""" if mode == 'rgb_array': if self._latest_observation is None: return return self._latest_observation['pixels'] else: raise ValueError('Only supported render mode is rgb_array.') def reset(self) -> np.ndarray: self._latest_observation = None timestep = self._env.reset() return timestep.observation def step(self, action: dict[str, int]) -> tuple[Any, ...]: """Take a step in the base environment.""" timestep = self._env.step(action) observation = timestep.observation self._latest_observation = observation reward = timestep.reward done = timestep.step_type == dm_env.StepType.LAST info = {'discount': timestep.discount} return observation, reward, done, info ================================================ FILE: android_env/wrappers/gym_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.wrappers.gym_wrapper.""" from unittest import mock from absl.testing import absltest from android_env import env_interface from android_env.wrappers import gym_wrapper import dm_env from dm_env import specs from gym import spaces import numpy as np class GymInterfaceWrapperTest(absltest.TestCase): def setUp(self): super().setUp() self._base_env = mock.create_autospec(env_interface.AndroidEnvInterface) self._base_env.action_spec.return_value = { 'action_type': specs.DiscreteArray( num_values=3, name='action_type'), 'touch_position': specs.BoundedArray( shape=(2,), dtype=np.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0], name='touch_position'), } self._base_env.observation_spec.return_value = { 'pixels': specs.Array( shape=(480, 320, 3), dtype=np.uint8, name='pixels'), 'timedelta': specs.Array(shape=(), dtype=np.int64, name='timedelta'), 'orientation': specs.Array( shape=np.array([4]), dtype=np.uint8, name='orientation'), } self._wrapped_env = gym_wrapper.GymInterfaceWrapper(self._base_env) self._fake_ts = dm_env.TimeStep( step_type=dm_env.StepType.MID, observation={'pixels': np.ones(shape=(2, 3))}, reward=10.0, discount=1.0) def test_render(self): self._base_env.step.return_value = self._fake_ts _ = self._wrapped_env.step(action=np.zeros(shape=(1,))) image = self._wrapped_env.render(mode='rgb_array') self.assertTrue(np.array_equal(image, np.ones(shape=(2, 3)))) def test_render_error(self): with self.assertRaises(ValueError): _ = self._wrapped_env.render(mode='human') def test_reset(self): self._base_env.reset.return_value = dm_env.TimeStep( step_type=dm_env.StepType.FIRST, observation={'pixels': np.ones(shape=(2, 3))}, reward=10.0, discount=1.0) obs = self._wrapped_env.reset() self._base_env.reset.assert_called_once() self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3)))) def test_step(self): self._base_env.step.return_value = self._fake_ts obs, _, _, _ = self._wrapped_env.step(action=np.zeros(shape=(1,))) self._base_env.step.assert_called_once() self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3)))) def test_spec_to_space(self): spec = specs.Array( shape=(2, 3), dtype=np.float32) space = self._wrapped_env._spec_to_space(spec) self.assertEqual(space, spaces.Box( low=-np.inf, high=np.inf, shape=spec.shape, dtype=spec.dtype)) spec = specs.BoundedArray( shape=(), dtype=np.float32, minimum=4, maximum=5) space = self._wrapped_env._spec_to_space(spec) self.assertEqual(space, spaces.Box( low=4, high=5, shape=spec.shape, dtype=spec.dtype)) spec = specs.DiscreteArray(num_values=4) space = self._wrapped_env._spec_to_space(spec) self.assertEqual(space, spaces.Box( low=0, high=3, shape=(), dtype=np.int32)) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/image_rescale_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wraps the AndroidEnv environment to rescale the observations.""" from collections.abc import Sequence from typing import cast from android_env import env_interface from android_env.wrappers import base_wrapper import dm_env from dm_env import specs import numpy as np from PIL import Image # Taken from https://pillow.readthedocs.io/en/3.2.x/reference/Image.html#PIL.Image.Image.convert # # This array maps an RGB image to a grayscale image using the ITU-R 709 # specification which is good for computer displays and HDTV. RGB_TO_GRAYSCALE_COEFFICIENTS = [0.2126, 0.7152, 0.0722] class ImageRescaleWrapper(base_wrapper.BaseWrapper): """AndroidEnv with rescaled observations.""" def __init__( self, env: env_interface.AndroidEnvInterface, zoom_factors: Sequence[float] | None = (0.5, 0.5), grayscale: bool = False, ) -> None: super().__init__(env) assert 'pixels' in self._env.observation_spec() assert self._env.observation_spec()['pixels'].shape[-1] in [ 1, 3, ], 'Number of pixel channels should be 1 or 3.' self._grayscale = grayscale if zoom_factors is None: zoom_factors = (1.0, 1.0) # We only zoom the width and height of each layer, and we explicitly do not # want to zoom the number of channels so we just multiply it by 1.0. self._zoom_factors = tuple(zoom_factors) + (1.0,) def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: observation = timestep.observation processed_observation = observation.copy() processed_observation['pixels'] = self._process_pixels( observation['pixels'] ) return timestep._replace(observation=processed_observation) def _process_pixels(self, raw_observation: np.ndarray) -> np.ndarray: # We expect `raw_observation` to have shape (W, H, 3) - 3 for RGB new_shape = np.array( self._zoom_factors[0:2] * np.array(raw_observation.shape[0:2]), dtype=np.int32, )[::-1] if self._grayscale: # When self._grayscale == True, we squash the RGB into a single layer image = np.dot(raw_observation, RGB_TO_GRAYSCALE_COEFFICIENTS) else: image = raw_observation return self._resize_image_array(image, new_shape) def _resize_image_array( self, grayscale_or_rbg_array: np.ndarray, new_shape: np.ndarray ) -> np.ndarray: """Resize color or grayscale/action_layer array to new_shape.""" assert new_shape.ndim == 1 assert len(new_shape) == 2 resized_array = np.array( Image.fromarray(grayscale_or_rbg_array.astype('uint8')).resize( tuple(new_shape) ) ) if resized_array.ndim == 2: return np.expand_dims(resized_array, axis=-1) return resized_array def reset(self) -> dm_env.TimeStep: timestep = self._env.reset() return self._process_timestep(timestep) def step(self, action) -> dm_env.TimeStep: timestep = self._env.step(action) return self._process_timestep(timestep) def observation_spec(self) -> dict[str, specs.Array]: parent_spec = self._env.observation_spec().copy() parent_pixels = cast(specs.BoundedArray, parent_spec['pixels']) out_shape = np.multiply(parent_pixels.shape, self._zoom_factors).astype( np.int32 ) if self._grayscale: # In grayscale mode we want the output shape to be [W, H, 1] out_shape[-1] = 1 parent_spec['pixels'] = specs.BoundedArray( shape=out_shape, dtype=parent_pixels.dtype, name=parent_pixels.name, minimum=parent_pixels.minimum, maximum=parent_pixels.maximum, ) return parent_spec ================================================ FILE: android_env/wrappers/image_rescale_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.wrappers.image_rescale_wrapper.""" from typing import Any from unittest import mock from absl.testing import absltest from android_env import env_interface from android_env.wrappers import image_rescale_wrapper import dm_env from dm_env import specs import numpy as np def _simple_spec(): return specs.BoundedArray( shape=np.array([300, 300, 3]), dtype=np.uint8, name='pixels', minimum=0, maximum=255) def _simple_timestep(): observation = np.ones(shape=[300, 300, 3]) return dm_env.TimeStep( step_type=dm_env.StepType.MID, reward=3.14, discount=0.9, observation={'pixels': observation}) class ImageRescaleWrapperTest(absltest.TestCase): def test_100x50_grayscale(self): fake_timestep = _simple_timestep() fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) fake_env.observation_spec.return_value = {'pixels': _simple_spec()} fake_env.reset.return_value = fake_timestep fake_env.step.return_value = fake_timestep wrapper = image_rescale_wrapper.ImageRescaleWrapper( fake_env, zoom_factors=(1.0 / 3, 1.0 / 6.0), grayscale=True) self.assertIsNotNone(wrapper) self.assertEqual(wrapper.observation_spec()['pixels'].shape, (100, 50, 1)) reset_timestep = wrapper.reset() reset_image = reset_timestep.observation['pixels'] self.assertEqual(reset_image.shape, (100, 50, 1)) step_timestep = wrapper.step(action='fake_action') step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (100, 50, 1)) def test_150x60_full_channels(self): fake_timestep = _simple_timestep() fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) fake_env.observation_spec.return_value = {'pixels': _simple_spec()} fake_env.reset.return_value = fake_timestep fake_env.step.return_value = fake_timestep wrapper = image_rescale_wrapper.ImageRescaleWrapper( fake_env, zoom_factors=(1.0 / 2.0, 1.0 / 5.0)) self.assertIsNotNone(wrapper) self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3)) reset_timestep = wrapper.reset() reset_image = reset_timestep.observation['pixels'] self.assertEqual(reset_image.shape, (150, 60, 3)) step_timestep = wrapper.step(action='fake_action') step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (150, 60, 3)) def test_list_zoom_factor(self): fake_timestep = _simple_timestep() fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) fake_env.observation_spec.return_value = {'pixels': _simple_spec()} fake_env.reset.return_value = fake_timestep fake_env.step.return_value = fake_timestep wrapper = image_rescale_wrapper.ImageRescaleWrapper( fake_env, zoom_factors=[0.5, 0.2]) self.assertIsNotNone(wrapper) self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3)) reset_timestep = wrapper.reset() reset_image = reset_timestep.observation['pixels'] self.assertEqual(reset_image.shape, (150, 60, 3)) step_timestep = wrapper.step(action='fake_action') step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (150, 60, 3)) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/last_action_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Extends Android observation with the latest action taken.""" from typing import cast from android_env import env_interface from android_env.components import action_type from android_env.components import pixel_fns from android_env.wrappers import base_wrapper import dm_env from dm_env import specs import numpy as np class LastActionWrapper(base_wrapper.BaseWrapper): """Extends Android observations with information about the last action taken. The position of the last action is denoted by a single white pixel (with a value of 255) in a channel of all black pixels (with a value of 0). As this wrapper makes use of temporarily stored information about the last action taken, it is important to apply on the environment side rather than the agent side. Recommended not to apply before an ImageRescaleWrapper, to avoid distortion of the single pixel denoting the action position. """ def __init__(self, env: env_interface.AndroidEnvInterface, concat_to_pixels: bool = True) -> None: """Initializes the internal state of this wrapper. Args: env: the environment to wrap. concat_to_pixels: If True, will add a channel to the pixel observation. If False, will pass the action as an extra observation. """ super().__init__(env) self._concat_to_pixels = concat_to_pixels self._screen_dimensions = self._env.observation_spec()['pixels'].shape[:2] def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep: observation = timestep.observation.copy() processed_observation = self._process_observation(observation) return timestep._replace(observation=processed_observation) def _process_observation( self, observation: dict[str, np.ndarray] ) -> dict[str, np.ndarray]: """Extends observation with last_action data.""" processed_observation = observation.copy() last_action_layer = self._get_last_action_layer(observation['pixels']) if self._concat_to_pixels: pixels = observation['pixels'].copy() processed_pixels = np.dstack((pixels, last_action_layer)) processed_observation['pixels'] = processed_pixels else: processed_observation['last_action'] = last_action_layer return processed_observation def _get_last_action_layer(self, pixels: np.ndarray) -> np.ndarray: """Makes sure the rescaling doesn't distort the last_action layer.""" last_action = self._env.raw_action last_action_layer = np.zeros(self._screen_dimensions, dtype=pixels.dtype) if ('action_type' in last_action and last_action['action_type'] == action_type.ActionType.TOUCH): touch_position = last_action['touch_position'] x, y = pixel_fns.touch_position_to_pixel_position( touch_position, width_height=self._screen_dimensions[::-1] ) last_action_layer[y, x] = 255 return last_action_layer def reset(self) -> dm_env.TimeStep: timestep = self._env.reset() return self._process_timestep(timestep) def step(self, action) -> dm_env.TimeStep: timestep = self._env.step(action) return self._process_timestep(timestep) def observation_spec(self) -> dict[str, specs.Array]: parent_spec = self._env.observation_spec().copy() parent_pixels = cast(specs.BoundedArray, parent_spec['pixels']) shape = parent_pixels.shape if self._concat_to_pixels: parent_spec['pixels'] = specs.BoundedArray( shape=(shape[0], shape[1], shape[2] + 1), dtype=parent_pixels.dtype, name=parent_pixels.name, minimum=parent_pixels.minimum, maximum=parent_pixels.maximum) else: parent_spec.update({ 'last_action': specs.BoundedArray( shape=(shape[0], shape[1]), dtype=parent_pixels.dtype, name='last_action', minimum=parent_pixels.minimum, maximum=parent_pixels.maximum) }) return parent_spec ================================================ FILE: android_env/wrappers/last_action_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for android_env.wrappers.last_action_wrapper.""" from typing import Any from unittest import mock from absl.testing import absltest from android_env import env_interface from android_env.components import action_type from android_env.wrappers import last_action_wrapper import dm_env from dm_env import specs import numpy as np def _simple_spec(): return specs.BoundedArray( shape=np.array([120, 80, 3]), dtype=np.uint8, name='pixels', minimum=0, maximum=255) def _simple_timestep(): observation = np.ones(shape=[120, 80, 3]) return dm_env.TimeStep( step_type=dm_env.StepType.MID, reward=3.14, discount=0.9, observation={'pixels': observation}) class LastActionWrapperTest(absltest.TestCase): def test_concat_to_pixels(self): fake_timestep = _simple_timestep() fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) fake_env.observation_spec.return_value = {'pixels': _simple_spec()} fake_env.reset.return_value = fake_timestep fake_env.step.return_value = fake_timestep wrapper = last_action_wrapper.LastActionWrapper( fake_env, concat_to_pixels=True) self.assertIsNotNone(wrapper) self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 4)) reset_timestep = wrapper.reset() reset_image = reset_timestep.observation['pixels'] self.assertEqual(reset_image.shape, (120, 80, 4)) last_action_layer = reset_image[:, :, -1] self.assertEqual(np.sum(last_action_layer), 0) action1 = { 'action_type': action_type.ActionType.TOUCH, 'touch_position': np.array([0.25, 0.75], dtype=np.float32), # (W x H) } type(fake_env).raw_action = mock.PropertyMock(return_value=action1) step_timestep = wrapper.step(action=action1) step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (120, 80, 4)) # (H x W) last_action_layer = step_image[:, :, -1] self.assertEqual(np.sum(last_action_layer), 255) y, x = np.where(last_action_layer == 255) self.assertEqual((y.item(), x.item()), (90, 20)) action2 = { 'action_type': action_type.ActionType.LIFT, 'touch_position': np.array([0.25, 0.75], dtype=np.float32), } type(fake_env).raw_action = mock.PropertyMock(return_value=action2) step_timestep = wrapper.step(action=action2) step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (120, 80, 4)) last_action_layer = step_image[:, :, -1] self.assertEqual(np.sum(last_action_layer), 0) action3 = { 'action_type': action_type.ActionType.TOUCH, 'touch_position': np.array([0.25, 1.0], dtype=np.float32), } type(fake_env).raw_action = mock.PropertyMock(return_value=action3) step_timestep = wrapper.step(action=action3) step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (120, 80, 4)) last_action_layer = step_image[:, :, -1] self.assertEqual(np.sum(last_action_layer), 255) y, x = np.where(last_action_layer == 255) self.assertEqual((y.item(), x.item()), (119, 20)) def test_no_concat_to_pixels(self): fake_timestep = _simple_timestep() fake_env = mock.create_autospec(env_interface.AndroidEnvInterface) fake_env.observation_spec.return_value = {'pixels': _simple_spec()} fake_env.reset.return_value = fake_timestep fake_env.step.return_value = fake_timestep wrapper = last_action_wrapper.LastActionWrapper( fake_env, concat_to_pixels=False) self.assertIsNotNone(wrapper) self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 3)) self.assertEqual(wrapper.observation_spec()['last_action'].shape, (120, 80)) reset_timestep = wrapper.reset() reset_image = reset_timestep.observation['pixels'] self.assertEqual(reset_image.shape, (120, 80, 3)) last_action_layer = reset_timestep.observation['last_action'] self.assertEqual(np.sum(last_action_layer), 0) action1 = { 'action_type': action_type.ActionType.TOUCH, 'touch_position': np.array([0.25, 0.75], dtype=np.float32), } type(fake_env).raw_action = mock.PropertyMock(return_value=action1) step_timestep = wrapper.step(action=action1) step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (120, 80, 3)) last_action_layer = step_timestep.observation['last_action'] self.assertEqual(np.sum(last_action_layer), 255) y, x = np.where(last_action_layer == 255) self.assertEqual((y.item(), x.item()), (90, 20)) action2 = { 'action_type': action_type.ActionType.LIFT, 'touch_position': np.array([0.25, 0.75], dtype=np.float32), } type(fake_env).raw_action = mock.PropertyMock(return_value=action2) step_timestep = wrapper.step(action=action2) step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (120, 80, 3)) last_action_layer = step_timestep.observation['last_action'] self.assertEqual(np.sum(last_action_layer), 0) action3 = { 'action_type': action_type.ActionType.TOUCH, 'touch_position': np.array([1.0, 0.75], dtype=np.float32), } type(fake_env).raw_action = mock.PropertyMock(return_value=action3) step_timestep = wrapper.step(action=action3) step_image = step_timestep.observation['pixels'] self.assertEqual(step_image.shape, (120, 80, 3)) last_action_layer = step_timestep.observation['last_action'] self.assertEqual(np.sum(last_action_layer), 255) y, x = np.where(last_action_layer == 255) self.assertEqual((y.item(), x.item()), (90, 79)) if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/rate_limit_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Limits interactions with the environment to a given rate.""" import enum import time from android_env import env_interface from android_env.components import action_type from android_env.wrappers import base_wrapper import dm_env import numpy as np class RateLimitWrapper(base_wrapper.BaseWrapper): """Limits interactions with the environment to a given rate.""" class SleepType(enum.IntEnum): """Determines how the wrapper interacts with the underlying environment.""" # The wrapper sleeps before calling `step()` on the underlying environment. BEFORE = 0 # The wrapper sleeps after calling `step()` on the underlying environment. AFTER = 1 # The wrapper first calls `step()`, obtaining a TimeStep which is ignored, # then it sleeps, and then it calls `step(REPEAT)` to obtain a TimeStep # that's as fresh as possible. # # Note that for both BEFORE and AFTER_WITH_REPEAT, the _total_ amount of # time inside this wrapper may go beyond the rate specified in `rate` # because the sleep does not account for the time taken by step(). AFTER_WITH_REPEAT = 2 def __init__(self, env: env_interface.AndroidEnvInterface, rate: float, sleep_type: SleepType = SleepType.AFTER_WITH_REPEAT): """Initializes this wrapper. Args: env: The underlying environment to which this wrapper is applied. rate: The desired rate in Hz to interact with the environment. If <=0.0, this wrapper will be disabled. sleep_type: This determines how the wrapper will interact with the underlying AndroidEnv environment. """ super().__init__(env) self._assert_base_env() self._last_step_time = None self._max_wait = 1.0 / rate if rate > 0.0 else 0.0 self._sleep_type = sleep_type def _assert_base_env(self): """Checks that the wrapped env has the right action spec format.""" parent_action_spec = self._env.action_spec() assert len(parent_action_spec) == 2 assert not parent_action_spec['action_type'].shape assert parent_action_spec['touch_position'].shape == (2,) def reset(self): timestep = self._env.reset() self._last_step_time = time.time() return timestep def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: """Takes a step while maintaining a steady interaction rate.""" # If max_wait is non-positive, the wrapper has no effect. if self._max_wait <= 0.0: return self._env.step(action) if self._sleep_type == RateLimitWrapper.SleepType.BEFORE: self._wait() timestep = self._env.step(action) if timestep.last(): return timestep if self._sleep_type == RateLimitWrapper.SleepType.AFTER_WITH_REPEAT: for k in action.keys(): if k.startswith('action_type'): action[k] = np.array(action_type.ActionType.REPEAT, dtype=np.uint8) self._wait() first_reward = timestep.reward or 0.0 timestep = self._env.step(action) second_reward = timestep.reward or 0.0 # Accumulate rewards over the two steps taken. timestep = timestep._replace(reward=first_reward + second_reward) elif self._sleep_type == RateLimitWrapper.SleepType.AFTER: self._wait() self._last_step_time = time.time() return timestep def _wait(self) -> None: if self._max_wait > 0.0 and self._last_step_time is not None: time_since_step = time.time() - self._last_step_time sec_to_wait = self._max_wait - time_since_step if sec_to_wait > 0.0: time.sleep(sec_to_wait) ================================================ FILE: android_env/wrappers/rate_limit_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for rate_limit_wrapper.""" import time from typing import Any, Protocol from unittest import mock from absl.testing import absltest from absl.testing import parameterized from android_env import env_interface from android_env.components import action_type from android_env.wrappers import rate_limit_wrapper import dm_env from dm_env import specs import numpy as np def _get_base_env(): env = mock.create_autospec(env_interface.AndroidEnvInterface) env.action_spec.return_value = { 'action_type': specs.DiscreteArray( num_values=len(action_type.ActionType), name='action_type'), 'touch_position': specs.BoundedArray( shape=(2,), dtype=np.float32, minimum=[0.0, 0.0], maximum=[1.0, 1.0], name='touch_position'), } return env class _FnWithTimestamps(Protocol): """A function with `timestamp` and `timestamps` attributes.""" timestamp: float timestamps: list[float] def _with_timestamp(fn: Any) -> _FnWithTimestamps: return fn class RateLimitWrapperTest(parameterized.TestCase): @parameterized.named_parameters( ('zero_rate', 0), ('negative_rate', -50), ) @mock.patch.object(time, 'sleep', autospec=True) def test_disabled(self, rate, mock_sleep): """With a non-positive rate, this wrapper should do nothing.""" env = _get_base_env() wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=rate) _ = wrapper.reset() mock_sleep.assert_not_called() _ = wrapper.step({ 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), 'touch_position': np.array([0.123, 0.456]) }) mock_sleep.assert_not_called() # When the wrapper is disabled, base step should only be called once. env.step.assert_called_once() @mock.patch.object(time, 'sleep', autospec=True) def test_enabled(self, mock_sleep): """When enabled, the wrapper should sleep for a period in [0, 1/rate].""" env = _get_base_env() env.step.return_value = dm_env.transition(reward=None, observation=None) wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=1/33.33) _ = wrapper.reset() mock_sleep.assert_not_called() # It should never sleep during reset(). # Step for 100 steps. for _ in range(100): _ = wrapper.step({ 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), 'touch_position': np.array([0.123, 0.456]) }) # Check that there are 100 calls and that they're all within [0, 1/rate]. self.assertLen(mock_sleep.call_args_list, 100) for call in mock_sleep.call_args_list: args, unused_kwargs = call sleep_time = args[0] self.assertBetween(sleep_time, 0.0, 33.33) @mock.patch.object(time, 'sleep', autospec=True) def test_enabled_sleep_type_before(self, mock_sleep): """When sleep_type==BEFORE, sleep should come before step().""" env = _get_base_env() wrapper = rate_limit_wrapper.RateLimitWrapper( env, rate=1/33.33, sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.BEFORE) _ = wrapper.reset() mock_sleep.assert_not_called() # It should never sleep during reset(). @_with_timestamp def _sleep_fn(sleep_time): _sleep_fn.timestamp = time.time() self.assertBetween(sleep_time, 0.0, 33.33) mock_sleep.side_effect = _sleep_fn def _step_fn(action): self.assertEqual( action['action_type'], np.array(action_type.ActionType.LIFT, dtype=np.uint8)) _step_fn.timestamps.append(time.time()) return dm_env.transition(reward=None, observation=None) _step_fn.timestamps = [] env.step.side_effect = _step_fn _ = wrapper.step({ 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), 'touch_position': np.array([0.123, 0.456]) }) self.assertLen(_step_fn.timestamps, 1) # We expect sleep to have been executed BEFORE a single `step()`. self.assertGreaterEqual(_step_fn.timestamps[0], _sleep_fn.timestamp) @mock.patch.object(time, 'sleep', autospec=True) def test_enabled_sleep_type_after(self, mock_sleep): """When sleep_type==AFTER, sleep should come after step().""" env = _get_base_env() wrapper = rate_limit_wrapper.RateLimitWrapper( env, rate=1/33.33, sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.AFTER) _ = wrapper.reset() mock_sleep.assert_not_called() # It should never sleep during reset(). @_with_timestamp def _sleep_fn(sleep_time): _sleep_fn.timestamp = time.time() self.assertBetween(sleep_time, 0.0, 33.33) mock_sleep.side_effect = _sleep_fn def _step_fn(action): self.assertEqual( action['action_type'], np.array(action_type.ActionType.LIFT, dtype=np.uint8)) _step_fn.timestamps.append(time.time()) return dm_env.transition(reward=None, observation=None) _step_fn.timestamps = [] env.step.side_effect = _step_fn _ = wrapper.step({ 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), 'touch_position': np.array([0.123, 0.456]) }) # We expect sleep to have been executed AFTER a single `step()`. self.assertLen(_step_fn.timestamps, 1) self.assertLessEqual(_step_fn.timestamps[0], _sleep_fn.timestamp) @mock.patch.object(time, 'sleep', autospec=True) def test_enabled_sleep_type_after_with_repeat(self, mock_sleep): """When sleep_type==AFTER_WITH_REPEAT, sleep should be between 2 steps().""" env = _get_base_env() wrapper = rate_limit_wrapper.RateLimitWrapper( env, rate=1/33.33, sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType .AFTER_WITH_REPEAT) _ = wrapper.reset() mock_sleep.assert_not_called() # It should never sleep during reset(). @_with_timestamp def _sleep_fn(sleep_time): _sleep_fn.timestamp = time.time() self.assertBetween(sleep_time, 0.0, 33.33) mock_sleep.side_effect = _sleep_fn @_with_timestamp def _step_fn(action): # On even calls the action should be the actual agent action, but on odd # calls they should be REPEATs. if len(_step_fn.timestamps) % 2 == 0: self.assertEqual( action['action_type'], np.array(action_type.ActionType.LIFT, dtype=np.uint8)) else: self.assertEqual( action['action_type'], np.array(action_type.ActionType.REPEAT, dtype=np.uint8)) _step_fn.timestamps.append(time.time()) return dm_env.transition(reward=1.0, observation=None) _step_fn.timestamps = [] env.step.side_effect = _step_fn timestep = wrapper.step({ 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), 'touch_position': np.array([0.123, 0.456]) }) # When the wrapper is enabled, base step should be called twice. self.assertEqual(env.step.call_count, 2) # `step()` should be called twice: before `sleep()` and after it. self.assertLen(_step_fn.timestamps, 2) self.assertGreaterEqual(_sleep_fn.timestamp, _step_fn.timestamps[0]) self.assertLessEqual(_sleep_fn.timestamp, _step_fn.timestamps[1]) # Rewards should accumulate over the two step() calls self.assertEqual(timestep.reward, 2.0) @mock.patch.object(time, 'sleep', autospec=True) def test_enabled_sleep_type_after_with_repeat_last(self, mock_sleep): """If the first step is a LAST, second step should not be taken.""" env = _get_base_env() wrapper = rate_limit_wrapper.RateLimitWrapper( env, rate=1/33.33, sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType .AFTER_WITH_REPEAT) _ = wrapper.reset() mock_sleep.assert_not_called() # It should never sleep during reset(). env.step.return_value = dm_env.termination(reward=None, observation=None) _ = wrapper.step({ 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8), 'touch_position': np.array([0.123, 0.456]) }) # Second step call should be skipped. env.step.assert_called_once() mock_sleep.assert_not_called() if __name__ == '__main__': absltest.main() ================================================ FILE: android_env/wrappers/tap_action_wrapper.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Wraps the AndroidEnv environment to provide tap actions of a given duration.""" from collections.abc import Sequence from typing import Any from android_env import env_interface from android_env.components import action_type from android_env.wrappers import base_wrapper import dm_env import numpy as np class TapActionWrapper(base_wrapper.BaseWrapper): """AndroidEnv with tap actions.""" def __init__(self, env: env_interface.AndroidEnvInterface, num_frames: int = 5, touch_only: bool = False) -> None: super().__init__(env) assert 'action_type' in env.action_spec() self._touch_only = touch_only self._num_frames = num_frames self._env_steps = 0 def stats(self) -> dict[str, Any]: """Returns a dictionary of metrics logged by the environment.""" logs = self._env.stats() logs.update({'env_steps': self._env_steps}) return logs def _process_action( self, action: dict[str, np.ndarray] ) -> Sequence[dict[str, np.ndarray]]: if self._touch_only: assert action['action_type'] == 0 touch_action = action.copy() touch_action['action_type'] = np.array( action_type.ActionType.TOUCH ).astype(self.action_spec()['action_type'].dtype) actions = [touch_action] * self._num_frames lift_action = action.copy() lift_action['action_type'] = np.array(action_type.ActionType.LIFT).astype( self.action_spec()['action_type'].dtype ) actions.append(lift_action) else: if action['action_type'] == action_type.ActionType.TOUCH: actions = [action] * self._num_frames lift_action = action.copy() lift_action['action_type'] = np.array( action_type.ActionType.LIFT ).astype(self.action_spec()['action_type'].dtype) actions.append(lift_action) else: actions = [action] * (self._num_frames + 1) return actions def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep: """Takes a step in the environment.""" self._env_steps += self._num_frames + 1 actions = self._process_action(action) total_reward = 0.0 for idx in range(len(actions)): step_type, reward, discount, observation = self._env.step(actions[idx]) if reward: total_reward += reward if step_type == dm_env.StepType.LAST: return dm_env.TimeStep( step_type=step_type, reward=total_reward, discount=discount, observation=observation) return dm_env.TimeStep( step_type=step_type, reward=total_reward, discount=discount, observation=observation) def action_spec(self) -> dict[str, dm_env.specs.Array]: if self._touch_only: return { 'action_type': dm_env.specs.DiscreteArray(num_values=1, name='action_type'), 'touch_position': self._env.action_spec()['touch_position'], } else: return self._env.action_spec() ================================================ FILE: android_env/wrappers/tap_action_wrapper_test.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for tap_action_wrapper.""" from unittest import mock from absl.testing import absltest from android_env import env_interface from android_env.components import action_type from android_env.wrappers import tap_action_wrapper import dm_env from dm_env import specs import numpy as np def _make_array_spec(shape, dtype, name): return specs.BoundedArray( name=name, shape=shape, dtype=dtype, minimum=np.zeros(shape), maximum=np.ones(shape), # maximum is inclusive. ) class TapActionWrapperTest(absltest.TestCase): def setUp(self): super().setUp() self._base_action_spec = { 'action_type': specs.DiscreteArray( num_values=3, name='action_type'), 'touch_position': _make_array_spec( shape=(2,), dtype=np.float32, name='touch_position'), } self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface) self.base_env.action_spec.return_value = self._base_action_spec def test_process_action_repeat(self): wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=3) action = { 'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32), 'touch_position': np.array([0.5, 0.5], dtype=np.float32), } actions = wrapped_env._process_action(action) self.assertLen(actions, wrapped_env._num_frames + 1) self.assertEqual(action, actions[-1]) def test_process_action_lift(self): wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=3) action = { 'action_type': np.array(action_type.ActionType.LIFT, dtype=np.int32), 'touch_position': np.array([0.5, 0.5], dtype=np.float32), } actions = wrapped_env._process_action(action) self.assertLen(actions, wrapped_env._num_frames + 1) self.assertEqual(action, actions[-1]) def test_process_action_touch(self): wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=3) action = { 'action_type': np.array(action_type.ActionType.TOUCH, dtype=np.int32), 'touch_position': np.array([0.5, 0.5], dtype=np.float32), } actions = wrapped_env._process_action(action) self.assertLen(actions, wrapped_env._num_frames + 1) self.assertEqual( actions[-1]['action_type'], np.array(action_type.ActionType.LIFT) ) def test_reset(self): wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=5) fake_timestep = 'ts' self.base_env.reset.return_value = fake_timestep ts = wrapped_env.reset() self.base_env.reset.assert_called_once() self.assertEqual(fake_timestep, ts) def test_step(self): # Arrange. wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=5) fake_timestep = dm_env.TimeStep( step_type='fake_type', reward=0.0, discount=1.0, observation='fake_obs') self.base_env.step.return_value = fake_timestep self.base_env.stats.return_value = {} # Act. ts = wrapped_env.step({ 'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32), 'touch_position': np.array([0.5, 0.5], dtype=np.float32), }) stats = wrapped_env.stats() # Assert. self.assertEqual(wrapped_env._num_frames+1, self.base_env.step.call_count) self.assertIsInstance(ts, dm_env.TimeStep) self.assertIsInstance(stats, dict) self.assertIn('env_steps', stats) self.assertEqual(stats['env_steps'], 6) def test_observation_spec(self): wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=5) fake_obs_spec = 'fake_obs_spec' self.base_env.observation_spec.return_value = fake_obs_spec observation_spec = wrapped_env.observation_spec() self.base_env.observation_spec.assert_called_once() self.assertEqual(fake_obs_spec, observation_spec) def test_action_spec(self): wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=5) self.base_env.action_spec.return_value = self._base_action_spec action_spec = wrapped_env.action_spec() self.base_env.action_spec.assert_called() self.assertEqual(self.base_env.action_spec(), action_spec) def test_stats(self): """Checks that returned stats have expected properties.""" # Arrange. self.base_env.stats.return_value = { 'some_key': 12345, 'another_key': 5.4321, } wrapped_env = tap_action_wrapper.TapActionWrapper( self.base_env, num_frames=5 ) # Act. stats = wrapped_env.stats() # Assert. self.assertIsInstance(stats, dict) # Original entries should still be present. self.assertIn('some_key', stats) self.assertEqual(stats['some_key'], 12345) self.assertIn('another_key', stats) self.assertEqual(stats['another_key'], 5.4321) # TapActionWrapper inserts its own `env_steps`. self.assertIn('env_steps', stats) self.assertEqual(stats['env_steps'], 0) if __name__ == '__main__': absltest.main() ================================================ FILE: docs/emulator_guide.md ================================================ # AndroidEnv - Emulator Setup Guide In this document we provide a step-by-step guide for creating a virtual Android device with Android Studio. After creating an AVD ([Android Virtual Device](https://developer.android.com/studio/run/managing-avds)) you will be able to connect it to an AndroidEnv instance and you're ready to go. To get started, you will need to download [Android Studio](https://developer.android.com/studio) - an IDE widely used by Android developers. ## Install an SDK Platform Package Android Studio comes with the Android Software Development Toolkit (SDK) which, among others, allows you to install different versions of Android. Click on **Tools** > **SDK Manager** and select the SDK version that you would like to use. ![Screenshot of 'android_studio_2'](images/android_studio/android_studio_2.png) We recommend that you set the `Android SDK Location` to be in your home directory (for example, on Linux the default one is `~/Android/Sdk`, while on macOS - `~/Library/Android/sdk`). You can always find the SDK location in Android Studio under **Preferences** > **Appearance & Behavior** > **System Settings** > **Android SDKs** > _Android SDK Location_. If you set the the custom `Android SDK Location`, make note of it - you will need it for connecting AndroidEnv to your AVD. ![Screenshot of 'android_studio_0'](images/android_studio/android_studio_0.png) ## Create an AVD Now it is time to create a virtual device (AVD). Go to **Tools** > **AVD Manager**. ![Screenshot of 'android_studio_1'](images/android_studio/android_studio_1.png) In the pop-up window you will find an option to **Create Virtual Device**. ![Screenshot of 'android_studio_3'](images/android_studio/android_studio_3.png) Configure the virtual device. You can select the model or choose from more advanced settings (refer to the [Android docs](https://developer.android.com/studio/run/managing-avds) for step-by-step instructions). ![Screenshot of 'android_studio_4'](images/android_studio/android_studio_4.png) Name your AVD and take note of this value. It will be neccessary for connecting AndroidEnv to this virtual device. ![Screenshot of 'android_studio_6'](images/android_studio/android_studio_6.png) Once you are done, you will see the new AVD show up in the **AVD Manager**. Click on **View details** to inspect some of its properties. ![Screenshot of 'android_studio_8'](images/android_studio/android_studio_8.png) Take note of the `AVD Path`. This value will be neccessary for connecting AndroidEnv to this device. We recommend that you set this to be your home directory (for instance, on Linux or macOS it may be `~/.android/avd`). ![Screenshot of 'android_studio_9'](images/android_studio/android_studio_9.png) ## Ready to use With SDK and AVD both set up, you are now ready to use this emulated device with AndroidEnv. Don't forget to take note of the following three values: your AVD name, the AVD path, and the SDK path. For example, on Linux they may be: ``` --avd_name=my_avd --avd_package_path=~/.android/avd --android_sdk_root=~/Android/Sdk ``` Next, once you have set up the AVD, follow the [Task steps](instructions.md#the-task) in the [Running the environment guide](instructions.md) to finish setting up AndroidEnv. However, if you want to just interact with the newly created device, click on the run button next to your AVD in the **AVD Manager** in Android Studio (this step is optional). ![Screenshot of 'android_studio_7'](images/android_studio/android_studio_7.png) You will see an emulator window pop up. You can interact with it by clicking on the screen. ![Screenshot of 'android_studio_10'](images/android_studio/android_studio_10.png) There are many other features in Android Studio that let you customize your device. For example, you can create custom images with pre-installed applications or configured settings. ================================================ FILE: docs/environment.md ================================================ # AndroidEnv - Environment features AndroidEnv is a complex environment that, while offering an almost endless range of possibilites for RL research and investigation, poses multiple kinds of challenges simultaneously. In this document we outline AndroidEnv's main features that render it such a unique learning environment. ## Real-time environment AndroidEnv is built on top of an emulated Android device, allowing the agent to communicate with the emulator through touch actions. Android emulators are created independently from our environment implementation and simulate real Android devices in the most realistic manner. This simulation runs real-time, independently of the agent, meaning that the simulaton will not wait for agent input between frames. This aspect of the environment renders the RL setup similar to a robotics problem, where the challenges of real-time interaction and consequent noise in observations have to be overcome. Please note there is currently no straightforward way to slow down the simulation either. ## Action space Perhaps one of the most interesting features of AndroidEnv is its large and complex action interface. The raw action space of the environment consists of a tuple `(x,y) in [0,1]x[0,1]` determining the location of the action on the Android screen, and a discrete value `ActionType in {LIFT, TOUCH, REPEAT}` indicating whether the agent wants to touch the screen at this chosen location or not. This action space is uniform across all tasks/apps. **Gestures.** The complexity of the interface arises from the fact that individual raw actions on their own do not neccessarily trigger a meaningful change in the environment. Most Android applications are designed such that they can be controlled/navigated through common touchscreen gestures such as pressing buttons, swiping, scrolling, pinching, drag and drop etc. Each of these can be thought of as particular sequences of raw actions: for example, *touching* the screen at a particular location, then immediately *lifting* the imaginary finger might be interpreted as a *press of a button*; while a sequence of *touches* aligned in a vertical line might be interpreted as *scrolling*. We note that AndroidEnv does not support multitouch actions at the moment, but it is a possible feature to add. It is important to point out that it is out of the environment's control to determine how particular sequences of raw actions get interpreted by the Android simulator - much like when humans interact with physical devices, a certain gesture on the screen might be interpreted differently if it is performed at a slightly different angle or speed. Tap | Double Tap | Touch & Hold | Flick Left | Flick Right | Scroll (H) | Scroll (V) | Drag & Drop -------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------------------------------------------------------ | ---------------------------------------------------------------------- | ------------------------------------------------------------------------ | ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------- | ----------- ![Screenshot of 'tap'](images/gestures/1-Finger-Tap.gif) | ![Screenshot of 'double_tap'](images/gestures/1-Finger-Double-Tap.gif) | ![Screenshot of 'touch_hold'](images/gestures/1-Finger-Touch-&-Hold.gif) | ![Screenshot of 'flick_left'](images/gestures/1-Finger-Flick-Left.gif) | ![Screenshot of 'flick_right'](images/gestures/1-Finger-Flick-Right.gif) | ![Screenshot of 'horizontal_scroll'](images/gestures/1-Finger-Horizontal-Scroll.gif) | ![Screenshot of 'vertical_scroll'](images/gestures/1-Finger-Vertical-Scroll.gif) | ![Screenshot of 'move'](images/gestures/1-Finger-Move.gif) **Wrappers.** It is possible to alter the raw action space of the environment by applying [wrappers](#wrappers). For example one might discretize the action space by splitting the screen up into a grid of a desired size; restrict the ActionType to *touch* only; or fix certain gesture skills. We note here that these wrappers, again, will not alter how the particular sequence of performed raw actions gets interpreted by the Android simulator. ## Observation space The observation space of AndroidEnv consists of three main components: (`pixels`, `timedelta`, `orientation`), the most notable of these being `pixels`. The original screen size will depend on the type of emulator used, but given that it will correspond to real device screen sizes, this will usually be quite large (of course, this can be scaled down, e.g. with wrappers). The `timedelta` component captures the amount of time passed since the last observation was fetched. The `orientation`, even though it does not affect the layout of the RGB image in the observation, might carry relevant information for the agent. For example, if there is text on the screen, it is important to know how it is oriented. Again, a benefit of this observation space is that it is uniform across all tasks. As mentioned above, observations often carry spatial cues and are suggestive of the kind of actions/gestures that are meaningful to perform in a given state. ## Task extras On top of the default observations (`pixels`, `timedelta`, `orientation`), some tasks might expose additional structured observations after each step. An *extra* in AndroidEnv is any information that an app may send to aid the understanding of the task. The type of information sent through this channel is usually something difficult to obtain from raw pixels and may include meaningful information such as: * The current board configuration (e.g. of a chess game or of a tetris game) in matrix or string form. * The position of the avatar in a map. * Events such as whether a button was pressed or whether a checkpoint was achieved. Note that these are entirely optional and may not be available at all. To request extras from the environment, you can call `env.task_extras()` after each `env.step()`, which will return a dictionary of all the extra observations observed during the previous step (or an empty dict is there's none available). For example: ```python for _ in range(FLAGS.n_steps): action = agent.select_action(timestep.observation) timestep = env.step(action) logging.info('observation: %s', timestep.observation) logging.info('extra observations: %s', env.task_extras()) ``` Please note however that the env might not return extras at every timestep, only when something meaningful happened (e.g. only when a button was pressed, or when the state of the board has changed). When integrating your own APK as a new task for the environment, you can define your own extras by following the instructions [here](tasks_guide.md#log-messages-and-custom-apks). ## Wrappers AndroidEnv's action- and observation spaces can be altered by applying suitable wrappers. While custom wrappers can be built easily, we have provided a number of useful wrappers that demonstrate their usage: * `discrete_action_wrapper`: Discretizes the action space into an `nxk` grid. * `flat_interface_wrapper`: Removes the dictionary structure from the observation and action specs. * `float_pixels_wrapper`: Projects the pixel RGB values from the integer range `[0, 255]` to the float range `[0, 1]`. * `image_rescale_wrapper`: Resizes the pixel observations by the selected ratio. * `gym_wrapper`: Changes the environment interface from [dm_env](https://github.com/deepmind/dm_env) to [OpenAI](https://gym.openai.com/) gym interface. * `last_action_wrapper`: Extends the observation with a one-hot encoded location of the previously taken action, in order to aid agents without built-in memory. ## Internal structure of AndroidEnv The chart below gives an overview of the internal workings of the system, illustrating how different classes interact with each other and what their individual roles are. See the source code for more details. ![Components Chart](images/misc/components_chart.svg) ================================================ FILE: docs/example_tasks.md ================================================ # AndroidEnv - Available tasks This page gives a detailed overview of the example tasks provided with AndroidEnv. The purpose is to give researchers an idea of the different kinds of challenges that AndroidEnv poses. To use any of these tasks in your own experiments, click on **Download** to download a ZIP file containing textprotos and the corresponding APKs. After downloading, move the `.apk` and `.textproto` files to a directory of your choice and take note of their path. This information is needed for [running](instructions.md#create-the-env) an AndroidEnv instance with the given task. | App / Game | Interface | Time reactive | Multi-level | Rewards | Extras | Download | | --------------------------------------------------- | -------------------- | -------------- | ----------------------------- | ---------- | ------------ | -------- | | [Vokram (MDP)](#vokram) | Tapping (buttons) | No | Yes (4 variants) | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/mdp.tar.gz) | | [Apple Flinger](#apple-flinger) | Drag & drop | No | Yes (6 variants) | Dense | No | [Download](https://storage.googleapis.com/android_env-tasks/apple_flinger.tar.gz) | | [Blockinger](#blockinger) | Tapping (buttons) | Yes | No | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/blockinger.tar.gz) | | [Catch](#catch) | Touch | Yes | No | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/catch_the_ball.tar.gz) | | [Classic 2048](#classic-2048) | Swiping | No | No | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/classic_2048.tar.gz) | | [Dodge](#dodge) | Tapping | Yes | No | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/dodge.tar.gz) | | [DroidFish (Chess)](#droidfish) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/droidfish.tar.gz) | | [FlappyDroid](#flappydroid) | Tapping | Yes | Yes (2 levels) | Dense | No | [Download](https://storage.googleapis.com/android_env-tasks/systemui_egg_land.tar.gz) | | [FloodIt](#floodit) | Tapping (buttons) | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/floodit.tar.gz) | | [Frozen Bubble](#frozen-bubble) | Dragging, tapping | No | No | Sparse | No | [Download](https://storage.googleapis.com/android_env-tasks/frozen_bubble.tar.gz) | | [Memory Game](#memory-game) | Tapping | No | Yes (6 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/memory_game.tar.gz) | | [Minesweeper](#minesweeper) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/minesweeper.tar.gz) | | [Nostalgic Racer](#nostalgic-racer) | Touch | Yes | Yes (2 variants) | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/nostalgic_racer.tar.gz) | | [Open Sudoku](#open-sudoku) | Tapping (buttons) | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/open_sudoku.tar.gz) | | [Perfection](#perfection) | Drag & drop | No | Yes (3 game types) | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/perfection.tar.gz) | | [Rocket Sleigh](#rocket-sleigh) | Tapping | Yes | No | Dense | No | [Download](https://storage.googleapis.com/android_env-tasks/rocket_sleigh.tar.gz) | | [Pong](#pong) | Drag | Yes | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/pong.tar.gz) | | [SGT Puzzles - Blackbox](#sgt-puzzles-blackbox) | Tapping | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Bridge](#sgt-puzzles-bridge) | Drag & drop | No | Yes (5 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Cube](#sgt-puzzles-cube) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Dominosa](#sgt-puzzles-dominosa) | Tapping | No | Yes (5 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Fifteen](#sgt-puzzles-fifteen) | Tapping | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Flip](#sgt-puzzles-flip) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/apple_flinger.tar.gz) | | [SGT Puzzles - Flood](#sgt-puzzles-flood) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Galaxies](#sgt-puzzles-galaxies) | Tapping | No | Yes (6 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Guess](#sgt-puzzles-guess) | Tapping | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Inertia](#sgt-puzzles-inertia) | Tapping | No | Yes (2 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Light Up](#sgt-puzzles-light-up) | Tapping | No | Yes (5 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Loopy](#sgt-puzzles-loopy) | Tapping | No | Yes (3 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [SGT Puzzles - Net](#sgt-puzzles-net) | Tapping | No | Yes (5 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) | | [Shattered Pixel Dungeon](#shattered-pixel-dungeon) | Tapping | Yes | Yes (4 variants) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/shattered_pixel_dungeon.tar.gz) | | [Simple Solitaire](#simple-solitaire) | Drag & drop | No | Yes (19 tasks) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/simple_solitaire.tar.gz) | | [Snake](#snake) | Tapping (buttons) | Yes | No | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/aosp_samples_snake.tar.gz) | | [Vector Pinball](#vector-pinball) | Tapping | Yes | Yes (5 variants) | Sparse | No | [Download](https://storage.googleapis.com/android_env-tasks/vector_pinball.tar.gz) | ## Vokram Vokram is our in-house implementation of an Android app that displays a Markov-Decision-Process (MDP) graph as buttons on the screen which the agent must use to select its actions. The observation is simply the color of the background, and the actions are the buttons themselves which are presented in different colors. * **mdp_0000**: This is a task that presents the agent with two colored, but unlabeled buttons on the screen. Pressing one of the buttons gives the agent a reward of `-1` and redraws the buttons on the screen. The other button gives a reward of `+1` and terminates the episode. The color of the buttons is the same throughout the episode. The sizes of the buttons are randomized at each screen draw. Pressing anywhere else on the screen gives a reward of zero. The task lasts up to 60 seconds, at which point the episode is restarted. The underlying dynamics governing the buttons is a simple 2-state 2-action Markov Decision Process (MDP). The MDP is an intentionally simple environment that can be used to debug agents. * **mdp_0001**: This is similar to `mdp_0000` but it's even simpler. It presents the agent with a single button which gives a reward of `+1` and terminates the episode when pressed. This task can be used for example to train agents to click buttons. * **mdp_0002**: In this task there are two buttons, pressing either of which will terminate the episode with a return of `+1`. * **mdp_0003**: An equivalent of `mdp_0000` with rewards reversed: the episode ends when the wrong button is clicked, and carries on with a new set of buttons when the correct one is clicked.
Extras returned * `actions`: - Set of all buttons present, e.g. `['A', 'B']`. - Returned when any button is pressed. - Has `shape=[2], dtype=STRING_U1`. * `clicks`: - Character representing the button pressed. - Returned when any button is pressed. - Has `shape=[1], dtype=STRING_U1`. * `buttons`: - Coordinates of the top left and bottom right corners of each button, e.g `[[x_a_0, y_a_0, x_a_1, y_a_1], [x_b_0, y_a_0, x_b_1, y_b_1]]`. - Returned when any button is pressed. - Has `shape=[2, 4], dtype=INT32`.
**mdp_0000** | **mdp_0001** | **mdp_0002** | **mdp_0003** ------------------------------------------------ | ------------------------------------------------ | ------------------------------------------------ | ------------ ![Screenshot of 'mdp_0000'](images/mdp_0000.gif) | ![Screenshot of 'mdp_0001'](images/mdp_0001.gif) | ![Screenshot of 'mdp_0002'](images/mdp_0002.gif) | ![Screenshot of 'mdp_0003'](images/mdp_0003.gif) ## Apple Flinger A clone of Angry Birds. Even though the game offers many levels, we currently expose six levels. See the original github repo for more info: https://gitlab.com/ar-/apple-flinger.
Extras returned Returns no extras.
**apple_flinger_M_1_1** | **apple_flinger_M_1_2** | **apple_flinger_M_1_18** ---------------------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------ ![Screenshot of 'apple_flinger_M_1_1'](images/apple_flinger_M_1_1.gif) | ![Screenshot of 'apple_flinger_M_1_2'](images/apple_flinger_M_1_2.gif) | ![Screenshot of 'apple_flinger_M_1_18'](images/apple_flinger_M_1_18.gif) **apple_flinger_M_2_1** | **apple_flinger_M_2_2** | **apple_flinger_M_2_18** ----------------------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------ ![Screenshot of 'apple_flinger_M_2_1'](images/apple_flinger_M_2_1.gif)) | ![Screenshot of 'apple_flinger_M_2_2'](images/apple_flinger_M_2_2.gif) | ![Screenshot of 'apple_flinger_M_2_18'](images/apple_flinger_M_2_18.gif) ## Blockinger This is a Tetris clone implemented with on-screen controls. See the original github repo for more info: https://github.com/tasioleiva/Blockinger.git.
Extras returned * `down_pressed`, `left_pressed`, `right_pressed`, `rotate_right_pressed`, `drop_pressed`: - Indicates that said button has been pressed. - Returned when said button has been pressed. - Has `shape=[1], dtype=INT32`. * `current_board`: - One-hot encoded state of the board. - Has `shape=[18, 10], dtype=INT32`. * `current_line`, `cleared_lines`: - Index of the relevant line. - Has `shape=[1], dtype=INT32`. * `current_piece`, `next_piece`: - Index representing the type of piece. - Has `shape=[1], dtype=INT32`.
![Screenshot of 'blockinger'](images/blockinger.gif) ## Catch Classic Catch game.
Extras returned * `ball`: - `x, y` coordinates of the ball. - Returned every timestep. - Has `shape=[2], dtype=INT32`. * `paddle`: - `x, y` coordinates of the paddle. - Returned every timestep. - Has `shape=[2], dtype=INT32`. * `paddle_width`: - Width of the paddle. - Returned every timestep. - Has `shape=[1], dtype=INT32`. * `lives`: - Number of lives left. - Returned every timestep. - Has `shape=[1], dtype=INT32`.
![Screenshot of 'catch_the_ball_default'](images/catch_the_ball_default.gif) ## Classic 2048 This is an Android implementation of a popular game in the 2010s. See the original github repo for more info: https://github.com/tpcstld/2048.git.
Extras returned * `grid`: - State of the board. - Returned when the board changes. - Has `shape=[4, 4], dtype=INT32`. * `direction`: - Index representing the direction of the last swipe (between 0-3). - Returned when the swipe prompted a board change. - Has `shape=[1], dtype=INT32`.
![Screenshot of 'classic_2048'](images/classic_2048.gif) ## Dodge Guide the ball from the red line to the green line without getting hit by the floating dots.
Extras returned * `lives`: - Number of lives left. - Returned when its value changes. - Has `shape=[1], dtype=INT32`. * `level`: - Current level. - Returned when its value changes. - Has `shape=[1], dtype=INT32`.
![Screenshot of 'dodge_default'](images/dodge_default.gif) ## DroidFish Standard chess game. You can choose whether to play as a specific player (black/white), or have the player colour randomly assigned at the beginning of each episode. The numbers 1, 10 and 100 indicate the level of difficulty. Take a look at a few sample moves below to get an idea of roughly how well the bot plays for each level of difficulty. You can see that the 1% and 10% bots often make very obvious mistakes. See the original github repo for more info: https://github.com/peterosterlund2/droidfish.
Extras returned * `board`: - State of the board, representing pieces by indices. No piece - 0 - White pieces - 1: king, 2: queen, 3: rook, 4: bishop, 5: knight, 6: pawn - Black pieces - 7: king, 8: queen, 9: rook, 10: bishop, 11: knight, 12: pawn - Returned when the board changes. - Has `shape=[8, 8], dtype=INT32`. * `selection`: - Coordinate of selected piece (between 0-64, -1 if selection is removed) - Returned when a piece is selected (or unselected). - Has `shape=[1], dtype=INT32`. * `moved`: - Coordinates "from" and "to" cells when a piece is moved (between 0-64) - Returned when a piece is moved. - Has `shape=[2], dtype=INT32`. * `invalid`: - Coordinates "from" and "to" cells of an invalid move attempt (between 0-64) - Returned upon invalid move request - Has `shape=[2], dtype=INT32`.
**droidfish_black_1** | **droidfish_black_10** | **droidfish_black_100** ------------------------------------------------------------------ | -------------------------------------------------------------------- | ----------------------- ![Screenshot of 'droidfish_black_1'](images/droidfish_black_1.gif) | ![Screenshot of 'droidfish_black_10'](images/droidfish_black_10.gif) | ![Screenshot of 'droidfish_black_100'](images/droidfish_black_100.gif) **droidfish_white_1** | **droidfish_white_10** | **droidfish_white_100** ------------------------------------------------------------------ | -------------------------------------------------------------------- | ----------------------- ![Screenshot of 'droidfish_white_1'](images/droidfish_white_1.gif) | ![Screenshot of 'droidfish_white_10'](images/droidfish_white_10.gif) | ![Screenshot of 'droidfish_white_100'](images/droidfish_white_100.gif) **droidfish_random_1** | **droidfish_random_10** | **droidfish_random_100** -------------------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------ ![Screenshot of 'droidfish_random_1'](images/droidfish_random_1.gif) | ![Screenshot of 'droidfish_random_10'](images/droidfish_random_10.gif) | ![Screenshot of 'droidfish_random_100'](images/droidfish_random_100.gif) ## FlappyDroid A clone of the well-known game Flappy Birds.
Extras returned Returns no extras.
**systemui_egg_land_default** | **systemui_egg_land_half_speed** ---------------------------------------------------------------------------------- | -------------------------------- ![Screenshot of 'systemui_egg_land_default'](images/systemui_egg_land_default.gif) | ![Screenshot of 'systemui_egg_land_half_speed'](images/systemui_egg_land_half_speed.gif) ## FloodIt FloodIt is a game where the player needs to fill the board with a single color. The dynamics of the game are driven by a few colorful buttons at the bottom of the screen, which when pressed cause the currently active region to change its color to the color of the pressed button. When this active region changes color it absorbs neighboring squares that have the same color, thus expanding the active region. The active region starts as a single square at the top-left corner of the board. The game gives a single reward at the end of the game if the player manages to fill the entire board with the same color within the maximum number steps, otherwise the reward is just zero. This is a very hard-exploration game because the game does not give intermediate rewards until the very end of the game, and the number of possible moves is incredibly big. You can find another implementation of this game in the task [SGT Puzzles - Flood](#sgt-puzzles-flood).
Extras returned * `board`: - State of the board, representing colours by their indices. - 0: purple, 1: blue, 2: green, 3: yellow, 4: red, 5: pink - Returned when the board changes. - Has `shape=[board_size, board_size], dtype=INT32`. * `clicked`: - Index of the colour clicked (between 0-5) - Returned when said colour is clicked. - Has `shape=[1], dtype=INT32`. * `flipped`: - The number of new cells that just got merged into the big blob (0 or more) - Returned when the board state changes - Has `shape=[1], dtype=INT32`.
**floodit_easy** | **floodit_medium** | **floodit_hard** -------------------------------------------------------- | ------------------------------------------------------------ | ---------------- ![Screenshot of 'floodit_easy'](images/floodit_easy.gif) | ![Screenshot of 'floodit_medium'](images/floodit_medium.gif) | ![Screenshot of 'floodit_hard'](images/floodit_hard.gif) ### Task `mdp_flood_it` Custom task created for pretraining agents to locate and press FloodIt buttons on the screen. ![Screenshot of 'mdp_flood_it'](images/mdp_flood_it.gif) ## Frozen Bubble Shoot the coloured bubbles in a direction of your choice. Groups of bubbles with the same colour will drop. Remove all bubbles from the board before the time runs out. See the original github repo for more info: https://github.com/robinst/frozen-bubble-android.git.
Extras returned Returns no extras.
![Screenshot of 'frozen_bubble'](images/frozen_bubble.gif) ## Memory Game Classic memory game. Find the pairs of images. See the original github repo for more info: https://github.com/sromku/memory-game/.
Extras returned * `flip`: - Index of the card flipped - Returned when a card is clicked. - Has `shape=[1], dtype=INT32`. * `cards`: - Number of cards still on the board. - Returned upon finding a pair. - Has `shape=[1], dtype=INT32`. * `remained`: - Number of cards remaining at the end of the episode. - Returned when an episode is over. - Has `shape=[1], dtype=INT32`. * `stars`: - Number of stars achieved at the end of the game. - Returned upon finishing the game. - Has `shape=[1], dtype=INT32`. * `achieved`: - Score obtained by the end of the episode. - Returned when an episode is over. - Has `shape=[1], dtype=INT32`.
**memory_game_animals_beginner** | **memory_game_animals_easy** | **memory_game_monsters_medium** ---------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------- | ------------------------------- ![Screenshot of 'memory_game_animals_beginner'](images/memory_game_animals_beginner.gif) | ![Screenshot of 'memory_game_animals_easy'](images/memory_game_animals_easy.gif) | ![Screenshot of 'memory_game_monsters_medium'](images/memory_game_monsters_medium.gif) **memory_game_monsters_hard** | **memory_game_emojis_hardest** | **memory_game_emojis_master** ---------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | ----------------------------- ![Screenshot of 'memory_game_monsters_hard'](images/memory_game_monsters_hard.gif) | ![Screenshot of 'memory_game_emojis_hardest'](images/memory_game_emojis_hardest.gif) | ![Screenshot of 'memory_game_emojis_master'](images/memory_game_emojis_master.gif) ## Minesweeper This is an Android implementation of a popular game on Desktop in the 1990s. See the original github repo for more info: https://gitlab.com/ar-/apple-flinger.
Extras returned * `hidden`: - Number of hidden cells. - Returned whenever the board changes. - Has `shape=[1], dtype=INT32`. * `revealed`: - Number of revealed cells. - Returned whenever the board changes. - Has `shape=[1], dtype=INT32`. * `bombs`: - Number of bombs in the game. - Returned whenever the board changes. - Has `shape=[1], dtype=INT32`. * `click`: - Coordinates of the cell clicked (row, column). - Returned whenever the board changes. - Has `shape=[2], dtype=INT32`. * `grid`: - State of the board. - -1 = hidden, -2 = marked, 9 = bomb, 0-8 = number of nearby bombs - Returned whenever the board changes. - Has `shape=[grid_height, grid_width], dtype=INT32`.
**minesweeper_easy** | **minesweeper_medium** | **minesweeper_hard** ---------------------------------------------------------------- | -------------------------------------------------------------------- | -------------------- ![Screenshot of 'minesweeper_easy'](images/minesweeper_easy.gif) | ![Screenshot of 'minesweeper_medium'](images/minesweeper_medium.gif) | ![Screenshot of 'minesweeper_hard'](images/minesweeper_hard.gif) ## Nostalgic Racer NostalgicRacer is a racing game that offers Atari-like graphics and controls. The objective is to maximize the score which increases as the car moves forward and by collecting coins and speed-ups.
Extras returned Returns no extras.
### Task `nostalgic_racer` The player can only control whether the car should move left, move right or stay put by touching on the screen. If the touch is on the right pane the car moves right, if the touch is on the left pane the car moves left and no touches leaves the car in the same position. Pressing for too little time moves the car by miniscule amounts, with effects similar to staying put. ### Task `nostalgic_racer_2d` This is the same underlying game as NostalgicRacer with the same objective. However, the interface is very different. The observation is given as a 2D view from the top with no perspective and the touchscreen determines the position that the car should move to (sideways). **nostaligc_racer** | **nostalgic_racer_2d** -------------------------------------------------------------- | ---------------------- ![Screenshot of 'nostalgic_racer'](images/nostalgic_racer.gif) | ![Screenshot of 'nostalgic_racer_2d'](images/nostalgic_racer_2d.gif) ## Open Sudoku Classic Sudoku game with different levels of difficulty. The board is randomised over a set of 30 boards for each level. See the original github repo for more info: https://github.com/ogarcia/opensudoku.
Extras returned * `value`: - Number pressed (between 1-9, 0 if the "delete" button is pressed). - Returned upon clicking said button. - Has `shape=[1], dtype=INT32`.
**open_sudoku_easy** | **open_sudoku_medium** | **open_sudoku_hard** ---------------------------------------------------------------- | -------------------------------------------------------------------- | -------------------- ![Screenshot of 'open_sudoku_easy'](images/open_sudoku_easy.gif) | ![Screenshot of 'open_sudoku_medium'](images/open_sudoku_medium.gif) | ![Screenshot of 'open_sudoku_hard'](images/open_sudoku_hard.gif) ## Perfection Drag the items corresponding to the targets with the same shape.
Extras returned * `moving`: - The ID of the piece being dragged on the screen or 0. - Returned when its value changes. - Has `shape=[1], dtype=INT32`. * `todo`: - Number of pieces yet to be moved to a hole. - Returned when its value changes. - Has `shape=[1], dtype=INT32`. * `done`: - Number of pieces correctly moved to a hole. - Returned when its value changes. - Has `shape=[1], dtype=INT32`.
**perfection_1_circle_static** | **perfection_1_cube_static** | **perfection_1_plus_static** | **perfection_1_triangle_static** ------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------- | -------------------------------- ![Screenshot of 'perfection_1_circle_static'](images/perfection_1_circle_static.gif) | ![Screenshot of 'perfection_1_square_static'](images/perfection_1_square_static.gif) | ![Screenshot of 'perfection_1_plus_static'](images/perfection_1_plus_static.gif) | ![Screenshot of 'perfection_1_triangle_static'](images/perfection_1_triangle_static.gif) **perfection_default** | **perfection_4_colors_square_static** | **perfection_4_pieces_static** -------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------- | ------------------------------ ![Screenshot of 'perfection_default'](images/perfection_default.gif) | ![Screenshot of 'perfection_4_colors_square_static'](images/perfection_4_colors_square_static.gif) | ![Screenshot of 'perfection_4_pieces_static'](images/perfection_4_pieces_static.gif) ## Rocket Sleigh A Flappy Bird-like game where you have to collect christmas presents while avoiding trees. The sleigh is powered by a rocket that needs to recharge over time after you use up its fuel.
Extras returned Returns no extras.
![Screenshot of 'rocket_sleigh_default'](images/rocket_sleigh.gif) ## Pong Classic Pong game.
Extras returned * `ball`: - The ball coordinates: [left, top, right, bottom]. - Returned when its value changes. - Has `shape=[4], dtype=INT32`. * `computer`: - The computer paddle coordinates: [left, top, right, bottom]. - Returned when its value changes. - Has `shape=[4], dtype=INT32`. * `human`: - The human paddle coordinates: [left, top, right, bottom]. - Returned when its value changes. - Has `shape=[4], dtype=INT32`. * `collision`: - Indicates collision of paddle and ball: (0=no collision, 1=collision). - Returned when its value changes. - Has `shape=[1], dtype=INT32`. * `state`: - The current state of the game: (0=pause, 1=ready, 2=running, 3=lose, 4=win). - Returned when its value changes. - Has `shape=[1], dtype=INT32`.
**pong_easy** | **pong_default** | **pong_hard** -------------------------------------------------- | -------------------------------------------------------- | ------------- ![Screenshot of 'pong_easy'](images/pong_easy.gif) | ![Screenshot of 'pong_default'](images/pong_default.gif) | ![Screenshot of 'pong_hard'](images/pong_hard.gif) ## SGT Puzzles - Blackbox There's an invisible laser beam originating from each of the cells at the edge of the grid. There are also a given number of balls inside the grid, hidden from the player whose aim is to guess where those balls are. The player can figure out where those balls might be by looking at how they *deflect* the laser beams. Clicking on an edge cell the player can reveal information about how that particular laser beam travels. Click on a cell; if the cell reveals an `H`, it means the straight laser beam leaving this cell hits a ball frontally. If the cell reveals a *number* along with another cell with the same number, that means the laser beam originating in the first cell ends up getting absorbed in the corresponding pair cell. If the cell reveals an `R`, that means the laser beam was *reflected*: either its origin and the cell it gets absorbed in is the *same*, or the beam gets bent before entering the grid. See the description below. The balls affect the travel of the laser beam in the following way: * If a laser beam hits it straight, it gets absorbed. This is denoted by the letter `H`. * If a laser beam hits *its corner*, the beam gets deflected by 90 degrees. * If a laser beam hits a balls corner right at the edge of the grid, i.e. before it enters the grid, it is considered *reflected*. * If a laser beam enters the same cell that it originally left, it is considered *reflected* too. Once the player has placed the given number of balls on the screen, a green dot appears that allows the player to check if their solution was correct. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `balls`: - The number of balls in the game arena. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `guesses`: - The number of guessed balls made by the agent. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `wrong`: - 1 if the guesses are wrong, 0 otherwise. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `lasers`: - The number of lasers in the grid - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `grid`: - Representation of the grid cells: - In the arena: `G`=guessed ball `' '`=empty - In the range: `[0-9]`=the number of lasers, `H`=beam hit, `R`=beam reflected, - `?` unknown, `' '` for the corners - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**blackbox_3x3_1_ball** | **blackbox_5x5_3_balls** | **blackbox_8x8_5_balls** | **blackbox_10x10_5_balls** --------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- | -------------------------- ![Screenshot of 'blackbox_3x3_1_ball'](images/sgtpuzzles_blackbox_3x3_1_ball.gif) | ![Screenshot of 'blackbox_5x5_3_balls'](images/sgtpuzzles_blackbox_5x5_3_balls.gif) | ![Screenshot of 'blackbox_8x8_5_balls'](images/sgtpuzzles_blackbox_8x8_5_balls.gif) | ![Screenshot of 'blackbox_10x10_5_balls'](images/sgtpuzzles_blackbox_10x10_5_balls.gif) ## SGT Puzzles - Bridge Connect nodes on the board so that each number denotes the degree of the given vertex. Edges are not allowed to cross each other and the graph has to be connected. Edges have to be horizontal or vertical, and there can be at most two parallel bridges between any pair of nodes. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `islands`: - Number of nodes on the board. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `grid`: - Representation of the current state of the board. - `[0-9]=island, ' '=empty` - `'|'=vertical line, '"'=double vertical line, '!'=wrong vertical line` - `'-'=horizontal line, '='=double horizontal line, '~'=wrong horizontal line` - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**bridge_7x7_easy** | **bridge_7x7_medium** | **bridge_7x7_hard** | **bridge_10x10_medium** | **bridge_15x15_medium** ------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------- | ----------------------- ![Screenshot of 'sgtpuzzles_bridge_7x7_easy'](images/sgtpuzzles_bridge_7x7_easy.gif) | ![Screenshot of 'sgtpuzzles_bridge_7x7_medium'](images/sgtpuzzles_bridge_7x7_medium.gif) | ![Screenshot of 'sgtpuzzles_bridge_7x7_hard'](images/sgtpuzzles_bridge_7x7_hard.gif) | ![Screenshot of 'sgtpuzzles_bridge_10x10_medium'](images/sgtpuzzles_bridge_10x10_medium.gif) | ![Screenshot of 'sgtpuzzles_bridge_15x15_medium'](images/sgtpuzzles_bridge_15x15_medium.gif) ## SGT Puzzles - Cube There are six coloured squares that you have to collect with a moving cube. If the cube rolls on top of a coloured cell, the colour gets attached to that side of the cube; and if a coloured side of the cube rolls on an empty cell, the colour is removed from the cube. The goal is to have all six sides of the cube coloured. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `current`: - Index of the current grid cell the cube is on. - Returned whenever the cube moves. - Has `shape=[1], dtype=INT32`. * `previous`: - Index of the previous grid cell the cube was on. - Returned whenever the cube moves. - Has `shape=[1], dtype=INT32`. * `grid`: - The grid state (0 = dark, 1 = blue) - Returned whenever the cube moves. - Has `shape=[grid_size, grid_size], dtype=INT32`. * `face_count`: - The number of dark faces on the cube. - Returned whenever the cube moves. - Has `shape=[1], dtype=INT32`. * `face_colour_count`: - The number of blue faces on the cube. - Returned whenever the cube moves. - Has `shape=[1], dtype=INT32`. * `faces`: - The cube faces (0 = dark, 1 = blue) - Returned whenever the cube moves. - Has `shape=[6], dtype=INT32`.
**cube_c3x3** | **cube_c4x4** | **cube_c8x8** ------------------------------------------------------------------------ | ------------------------------------------------------------------------ | ------------- ![Screenshot of 'sgtpuzzles_cube_c3x3'](images/sgtpuzzles_cube_c3x3.gif) | ![Screenshot of 'sgtpuzzles_cube_c4x4'](images/sgtpuzzles_cube_c4x4.gif) | ![Screenshot of 'sgtpuzzles_cube_c8x8'](images/sgtpuzzles_cube_c8x8.gif) ## SGT Puzzles - Dominosa Place 2x1 size dominoes on the board such that the full board is covered, making sure that no two dominoes have the same pair of numbers on them. There needs to be exactly one of (0, 0), (0, 1) (0, 2), ... (1, 1), (1, 2) etc. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `numbers`: - Numbers as they appear in the grid. - Returned whenever the grid changes. - Has `shape=[height, width], dtype=INT32`. * `grid`: - Letters representing the dominoes currently placed on the board. - 'R=right, L=left, T=top, B=bottom' - Returned whenever the grid changes. - Has `shape=[height, with], dtype=INT32`. * `clash`: - Represents clashes on the board (i.e. if two dominoes have the same pair) - '1=clash, 0=no clash' - Returned whenever the grid changes. - Has `shape=[height, width], dtype=INT32`.
**dominosa_1** | **dominosa_3** | **dominosa_3a** | **dominosa_6** | **dominosa_9** -------------------------------------------------------------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------------- | -------------- ![Screenshot of 'sgtpuzzles_dominosa_1'](images/sgtpuzzles_dominosa_1.gif) | ![Screenshot of 'sgtpuzzles_dominosa_3'](images/sgtpuzzles_dominosa_3.gif) | ![Screenshot of 'sgtpuzzles_dominosa_3a'](images/sgtpuzzles_dominosa_3a.gif) | ![Screenshot of 'sgtpuzzles_dominosa_6'](images/sgtpuzzles_dominosa_6.gif) | ![Screenshot of 'sgtpuzzles_dominosa_9'](images/sgtpuzzles_dominosa_9.gif) ## SGT Puzzles - Fifteen Order the tiles in increasing order, starting from the top left corner. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `grid`: - Current state of the grid. - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=INT32`. * `empty`: - Index of the single empty cell in the grid. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `movecount`: - Number of moves made so far. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`.
**fifteen_2x2** | **fifteen_3x3** | **fifteen_4x4** | **fifteen_6x6** ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | --------------- ![Screenshot of 'sgtpuzzles_fifteen_2x2'](images/sgtpuzzles_fifteen_2x2.gif) | ![Screenshot of 'sgtpuzzles_fifteen_3x3'](images/sgtpuzzles_fifteen_3x3.gif) | ![Screenshot of 'sgtpuzzles_fifteen_4x4'](images/sgtpuzzles_fifteen_4x4.gif) | ![Screenshot of 'sgtpuzzles_fifteen_6x6'](images/sgtpuzzles_fifteen_6x6.gif) ## SGT Puzzles - Flip Clicking on a cell will flip the colour of some of its neighbours, which are determined by the symbol in the cell. The goal is to make all the cells have the same colour. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `light`: - The number of light cells. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `dark`: - The number of dark cells - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `moves`: - The number of moves made by the player. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `grid`: - State of the board (0 = dark, 1 = light). - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=INT32`. * `gridMatrix`: - The grid matrix of square neighbours (-1 = outside, 1 = neighbour, 0 = not neighbour) - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size, 3, 3s], dtype=INT32`.
**flip_3x3c** | **flip_4x4c** | **flip_5x5r** ------------------------------------------------------------------------ | ------------------------------------------------------------------------ | ------------- ![Screenshot of 'sgtpuzzles_flip_3x3c'](images/sgtpuzzles_flip_3x3c.gif) | ![Screenshot of 'sgtpuzzles_flip_4x4c'](images/sgtpuzzles_flip_4x4c.gif) | ![Screenshot of 'sgtpuzzles_flip_5x5r'](images/sgtpuzzles_flip_5x5r.gif) ## SGT Puzzles - Flood FloodIt is a game where the player needs to fill the board with a single color. The dynamics of the game are driven by colored areas of the board, which when pressed cause the currently active region to change its color to the color of the pressed button. When this active region changes color it absorbs neighboring squares that have the same color, thus expanding the active region. The active region starts as a single square at the top-left corner of the board. The game gives a single reward at the end of the game if the player manages to fill the entire board with the same color within the maximum number steps, otherwise the reward is just zero. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `board`: - State of the board, representing colours by their indices. - 0: red, 1: yellow, 2: green, 3: blue, 4: orange, 5: purple, - 6: brown, 7: light blue, 8: light green, 9: pink - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=INT32`.
**sgtpuzzles_flood_3x3_easy** | **sgtpuzzles_flood_12x12_medium** | **sgtpuzzles_flood_16x16_hard** ---------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | ------------------------------- ![Screenshot of 'sgtpuzzles_flood_3x3_easy'](images/sgtpuzzles_flood_3x3_easy.gif) | ![Screenshot of 'sgtpuzzles_flood_12x12_medium'](images/sgtpuzzles_flood_12x12_medium.gif) | ![Screenshot of 'sgtpuzzles_flood_16x16_hard'](images/sgtpuzzles_flood_16x16_hard.gif) ## SGT Puzzles - Galaxies Split the grid up into centrally symmetric areas. The centre of symmetry for each area is denoted by a dot on the grid. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `dot`: - Number of dots on the board. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `grid`: - String representation of the board: - `o`=dot, `' '`=empty, `+`, `-`, `|` = cell corners. - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**galaxies_3x3_normal** | **galaxies_5x5_normal** | **galaxies_7x7_normal** --------------------------------------------------------------------------------- | --------------------------------------------------------------------------------- | ----------------------- ![Screenshot of 'galaxies_3x3_normal'](images/sgtpuzzles_galaxies_3x3_normal.gif) | ![Screenshot of 'galaxies_5x5_normal'](images/sgtpuzzles_galaxies_5x5_normal.gif) | ![Screenshot of 'galaxies_7x7_normal'](images/sgtpuzzles_galaxies_7x7_normal.gif) **galaxies_7x7_unreasonable** | **galaxies_10x10_normal** | **galaxies_15x15_normal** --------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- | ------------------------- ![Screenshot of 'galaxies_7x7_normal'](images/sgtpuzzles_galaxies_7x7_normal.gif) | ![Screenshot of 'galaxies_10x10_normal'](images/sgtpuzzles_galaxies_10x10_normal.gif) | ![Screenshot of 'galaxies_15x15_normal'](images/sgtpuzzles_galaxies_15x15_normal.gif) ## SGT Puzzles - Guess The computer has thought of a sequence of colours that you have to guess. Fill the top row with colours of your choice, and wait for the computer to give you feedback about your sequence. It will show a black dot for each colour that is placed in the correct position, and a white dot for each that is present in the hidden sequence, but not at the position your guess. Try to figure out the hidden sequence before you run out of guesses! See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `peg`: - Indices representing the colours selected in the latest row. - Returned after the row is completed and evaluated. - Has `shape=[row_length], dtype=INT32`. * `feedback`: - Evaluation of the latest guess (0: incorrect, 1: correct place, 2: correct colour) - Returned after the row is completed and evaluated. - Has `shape=[row_length], dtype=INT32`.
**guess_basic** | **guess_quick** | **guess_standard** | **guess_super** ---------------------------------------------------------------------------- | ----------------------------------------------------------------- | ----------------------------------------------------------------------- | --------------- ![Screenshot of 'sgtpuzzles_guess_basic'](images/sgtpuzzles_guess_basic.gif) | ![Screenshot of 'guess_quick'](images/sgtpuzzles_guess_quick.gif) | ![Screenshot of 'guess_standard'](images/sgtpuzzles_guess_standard.gif) | ![Screenshot of 'guess_super'](images/sgtpuzzles_guess_super.gif) ## SGT Puzzles - Inertia Collect all the blue diamonds on the board without colliding into a bomb. You can move the ball in the 8 main directions (including the diagonals). The ball will keep on moving in that direction until it hits a wall, a bomb, a diamond or a circle. Circles and diamonds have grip, i.e. it will stop the ball from continuing to move in the direction it was going towards. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `gems`: - Current number of gems still on the board. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `distancemoved`: - Number of cells just moved. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `grid`: - Symbols of grid cells (b=blank, g=gem, m=mine, s=stop, w=wall) - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=INT32`.
**inertia_5x5** | **inertia_10x10** ----------------------------------------------------------------- | ----------------- ![Screenshot of 'inertia_5x5'](images/sgtpuzzles_inertia_5x5.gif) | ![Screenshot of 'inertia_10x10'](images/sgtpuzzles_inertia_10x10.gif) ## SGT Puzzles - Light Up You have a grid of squares. Some are empty (black) and some are *walls* (grey); some of the walls are numbered. Your aim is to *light up* all the empty squares by placing light bulbs in some of them. The numbers denote how many bulbs' light hits these directly in a straight sight. Meanwhile, no two bulbs should light up each other (i.e. only one bulb allowed in straight sight). See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `grid`: - String representation of the board: - '#' = blocked, ' ' = empty/black, 'L' = light, 'l' = illuminated, - 'X' = impossible, 'number' = number of bulbs hitting this cell. - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**light_up_3x3_easy** | **light_up_5x5_easy** | **light_up_7x7_easy** | **light_up_10x10_tricky** | **light_up_14x14_easy** ---------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- | ----------------------- ![Screenshot of 'sgtpuzzles_light_up_3x3_easy'](images/sgtpuzzles_light_up_3x3_easy.gif) | ![Screenshot of 'sgtpuzzles_light_up_5x5_easy'](images/sgtpuzzles_light_up_5x5_easy.gif) | ![Screenshot of 'sgtpuzzles_light_up_7x7_easy'](images/sgtpuzzles_light_up_7x7_easy.gif) | ![Screenshot of 'light_up_10x10_tricky'](images/sgtpuzzles_light_up_10x10_tricky.gif) | ![Screenshot of 'sgtpuzzles_light_up_14x14_easy'](images/sgtpuzzles_light_up_14x14_easy.gif) ## SGT Puzzles - Loopy Draw a closed loop along the edges of the grid. A number in a cell denotes the number of edges adjacent to that cell. The loop cannot intersect itself. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `grid`: - String representation of the board: - The grid lines and cells: - `.` = dots (cell corners) - `0-9` = number on cell face or ` ` for empty face - `?` = unknown (default),`x` = no line,`-` = `|` = line,`~` = `/` = error - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**loopy_3x3_easy** | **loopy_5x5_easy** | **loopy_7x7_easy** | **loopy_7x7_normal** | **loopy_7x7_hard** ---------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- | ------------------ ![Screenshot of 'sgtpuzzles_loopy_3x3_easy'](images/sgtpuzzles_loopy_3x3_easy.gif) | ![Screenshot of 'sgtpuzzles_loopy_5x5_easy'](images/sgtpuzzles_loopy_5x5_easy.gif) | ![Screenshot of 'sgtpuzzles_loopy_7x7_easy'](images/sgtpuzzles_loopy_7x7_easy.gif) | ![Screenshot of 'sgtpuzzles_loopy_7x7_normal'](images/sgtpuzzles_loopy_7x7_normal.gif) | ![Screenshot of 'sgtpuzzles_loopy_7x7_hard'](images/sgtpuzzles_loopy_7x7_hard.gif) ## SGT Puzzles - Net There are a number of light bulbs, wires and a single light source in the middle. Connect the wires such that all bulbs are lit up, without loose ends or loops in the wiring. You can rotate the tiles by clicking on them. If the task name has a *w* suffix in it then it is allowed to connect wires on opposing edges of the grid. See the original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned * `active`: - The number of active/completed cells. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `total`: - The total number of cells - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `grid`: - The grid cells represented by numbers. - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=INT32`. * `gridCompleted`: - The grid cells' active/completed status (0 = false, 1 = true). - Returned whenever the grid changes. - Has `shape=[grid_size, grid_size], dtype=INT32`.
**net_3x3** | **net_5x5** | **net_7x7w** | **net_9x9** | **net_11x11w** -------------------------------------------------------------------- | -------------------------------------------------------------------- | ---------------------------------------------------------------------- | -------------------------------------------------------------------- | -------------- ![Screenshot of 'sgtpuzzles_net_3x3'](images/sgtpuzzles_net_3x3.gif) | ![Screenshot of 'sgtpuzzles_net_5x5'](images/sgtpuzzles_net_5x5.gif) | ![Screenshot of 'sgtpuzzles_net_7x7w'](images/sgtpuzzles_net_7x7w.gif) | ![Screenshot of 'sgtpuzzles_net_9x9'](images/sgtpuzzles_net_9x9.gif) | ![Screenshot of 'sgtpuzzles_net_11x11w'](images/sgtpuzzles_net_11x11w.gif) ## Shattered Pixel Dungeon Shattered Pixel Dungeon is a Roguelike RPG, with pixel art graphics and lots of variety and replayability. Every game is unique, with four different playable characters, randomized levels and enemies, and over 150 items to collect and use. The game is simple to get into, but has lots of depth. Strategy is required if you want to win! See the original github repo for more info: https://github.com/00-Evan/shattered-pixel-dungeon.git.
Extras returned * `action`: - The action just completed by the hero. - Returned whenever an action is taken. - Has `shape=[1], dtype=STRING_U25`. * `dst`: - TThe destination of the action. - Returned whenever an action is taken. - Has `shape=[1], dtype=INT32`. * `level`: - The level reached. - Returned whenever a new level is reached. - Has `shape=[1], dtype=INT32`. * `depth`: - The depth of the level reached. - Returned whenever a new floor is reached. - Has `shape=[1], dtype=INT32`. * `deepestFloor`: - The deepest reached floor statistic. - Returned whenever a new floor is reached. - Has `shape=[1], dtype=INT32`. * `gold`: - The gold level reached. - Returned whenever gold is acquired. - Has `shape=[1], dtype=INT32`. * `totalGold`: - The gold collected statistic. - Returned whenever gold is acquired. - Has `shape=[1], dtype=INT32`. * `addedGold`: - The gold acquired at the most recent step. - Returned whenever gold is acquired. - Has `shape=[1], dtype=INT32`. * `newlyVisited`: - Number of new squares uncovered. - Returned whenever new squares are uncovered. - Has `shape=[1], dtype=INT32`. * `damageDealt`: - Damage dealt by hero. - Returned whenever the hero deals damage. - Has `shape=[1], dtype=INT32`. * `damageTaken`: - Damage taken by hero. - Returned whenever the hero takes damage. - Has `shape=[1], dtype=INT32`. * `HP`: - Hit Points (Health) of hero. - Returned whenever its value changes. - Has `shape=[1], dtype=INT32`. * `exp`: - Experience gained by hero. - Returned whenever the hero gains experience points. - Has `shape=[1], dtype=INT32`. * `dew`: - How many dew (an immediately consumed item) the user picks up. - Returned whenever a dew is picked up. - Has `shape=[1], dtype=INT32`. * `heal`: - Amount the hero is healed. - Returned whenever the hero is healed. - Has `shape=[1], dtype=INT32`. * `itemPickup`: - The name of an item that is picked up. - Returned whenever an item is picked up. - Has `shape=[1], dtype=STRING_U25`. * `itemDrop`: - The name of an item that is dropped. - Returned whenever an item is dropped. - Has `shape=[1], dtype=STRING_U25`. * `destroy`: - The name of the enemy that is killed. - Returned whenever an enemy is killed. - Has `shape=[1], dtype=STRING_U25`. * `search`: - Whether the user clicked the "Search" button. - Returned whenever the button is clicked. - Has `shape=[1], dtype=INT32`. * `wait`: - Whether the user clicked the "Wait" button. - Returned whenever the button is clicked. - Has `shape=[1], dtype=INT32`. * `openedInventory`: - Whether the user clicked the "Inventory" button. - Returned whenever the button is clicked. - Has `shape=[1], dtype=INT32`.
huntress | mage | rogue | warrior ------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | ------- ![Screenshot of 'shattered_pixel_dungeon_huntress'](images/shattered_pixel_dungeon_huntress.gif) | ![Screenshot of 'shattered_pixel_dungeon_mage'](images/shattered_pixel_dungeon_mage.gif) | ![Screenshot of 'shattered_pixel_dungeon_rogue'](images/shattered_pixel_dungeon_rogue.gif) | ![Screenshot of 'shattered_pixel_dungeon_warrior'](images/shattered_pixel_dungeon_warrior.gif) ## Simple Solitaire This is an Android implementation of [Solitaire](https://en.wikipedia.org/wiki/Solitaire) card games. We currently support 19 variants listed here in alphabetical order. Note that the full task_IDs take the form `simple_solitaire_aces_up`, `simple_solitaire_calculation` etc. See the original github repo for more info: https://github.com/TobiasBielefeld/Simple-Solitaire.git.
Extras returned * `card`: - The new visible card `[kind, suit]`: - kind: `a=ace, k=king, q=queen, j=jack, x=10, 2-9=digit`. - suit: `c=clubs, d=diamonds, h=hearts, s=spades`. - Returned when a card is moved. - Has `shape=[2], dtype=STRING_U1`. * `stack_i`: - A non-empty stack of visible cards `[kind, suit]`. - `i` different extras (`stack_0`, `stack_1`, `...`), one corresponding to each stack. - Returned when a card is moved. - Has `shape=[52, 2], dtype=STRING_U1`.
**aces_up** | **calculation** | **canfield** | **forty_eight** | **freecell** -------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ------------ ![Screenshot of 'simple_solitaire_aces_up'](images/simple_solitaire_aces_up.gif) | ![Screenshot of 'simple_solitaire_calculation'](images/simple_solitaire_calculation.gif) | ![Screenshot of 'simple_solitaire_canfield'](images/simple_solitaire_canfield.gif) | ![Screenshot of 'simple_solitaire_forty_eight'](images/simple_solitaire_forty_eight.gif) | ![Screenshot of 'simple_solitaire_freecell'](images/simple_solitaire_freecell.gif) **golf** | **grandfathers_clock** | **gypsy** | **klondike** | **maze** -------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | -------- ![Screenshot of 'simple_solitaire_golf'](images/simple_solitaire_golf.gif) | ![Screenshot of 'simple_solitaire_grandfathers_clock'](images/simple_solitaire_grandfathers_clock.gif) | ![Screenshot of 'simple_solitaire_gypsy'](images/simple_solitaire_gypsy.gif) | ![Screenshot of 'simple_solitaire_klondike'](images/simple_solitaire_klondike.gif) | ![Screenshot of 'simple_solitaire_maze'](images/simple_solitaire_maze.gif) **mod3** | **napoleons_tomb** | **pyramid** | **simple_simon** | **spider** -------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | ---------- ![Screenshot of 'simple_solitaire_mod3'](images/simple_solitaire_mod3.gif) | ![Screenshot of 'simple_solitaire_napoleons_tomb'](images/simple_solitaire_napoleons_tomb.gif) | ![Screenshot of 'simple_solitaire_pyramid'](images/simple_solitaire_pyramid.gif) | ![Screenshot of 'simple_solitaire_simple_simon'](images/simple_solitaire_simple_simon.gif) | ![Screenshot of 'simple_solitaire_spider'](images/simple_solitaire_spider.gif) **spiderette** | **tri_peaks** | **vegas** | **yukon** -------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | --------- ![Screenshot of 'simple_solitaire_spiderette'](images/simple_solitaire_spiderette.gif) | ![Screenshot of 'simple_solitaire_tri_peaks'](images/simple_solitaire_tri_peaks.gif) | ![Screenshot of 'simple_solitaire_vegas'](images/simple_solitaire_vegas.gif) | ![Screenshot of 'simple_solitaire_yukon'](images/simple_solitaire_yukon.gif) ## Snake Classic Snake game.
Extras returned * `move`: - The desired direction of movement: `0`=left, `1`=up, `2`=down, `3`=right. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `direction`: - The direction of the snake: `1`=north, `2`=south, `3`=east, `4`=west. - Returned whenever the grid changes. - Has `shape=[1], dtype=INT32`. * `grid`: - The grid cells: `x`=border, `' '`=empty, `s`=snake, `a`=apple. - Returned whenever the grid changes. - Has `shape=[13, 19], dtype=STRING_U1`.
![Screenshot of 'aosp_samples_snake_default'](images/aosp_samples_snake_default.gif) ## Vector Pinball A simple vector-based Pinball game with realistic physics. See the original github repo for more info: https://github.com/dozingcat/Vector-Pinball.
Extras returned Returns no extras.
**vector_pinball_table_1** | **vector_pinball_table_2** | **vector_pinball_table_3** ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | -------------------------- ![Screenshot of 'vector_pinball_table_1'](images/vector_pinball_table_1.gif) | ![Screenshot of 'vector_pinball_table_2'](images/vector_pinball_table_2.gif) | ![Screenshot of 'vector_pinball_table_3'](images/vector_pinball_table_3.gif) **vector_pinball_table_4** | **vector_pinball_table_5** ---------------------------------------------------------------------------- | -------------------------- ![Screenshot of 'vector_pinball_table_4'](images/vector_pinball_table_4.gif) | ![Screenshot of 'vector_pinball_table_5'](images/vector_pinball_table_5.gif) ================================================ FILE: docs/instructions.md ================================================ # AndroidEnv - Running the environment In order to create an AndroidEnv instance you will need to provide two main components: a [simulator](#the-simulator) and a [task](#the-task). In the following sections you will learn how you can create them. ### The simulator First, you will need to provide your Android virtual device (AVD) that the environment (and through it, the agent) can communicate with. While this can also be a physical device, in most cases you will need a virtual emulated device. There are many ways to emulate an AVD - in our examples, we will use [Android Studio](https://developer.android.com/studio) to create one. 1. In Android Studio, create a virtual device by following this step-by-step [guide](emulator_guide.md). 2. Follow the steps below to attach the AVD to your environment. ### The task and examples with games and other apps A `task` is a particular definition of an RL problem that the agent will be interacting with. A `task` may include critical RL information, such as what the rewards are, when the episodes are supposed to terminate, and what the reset procedures are that the environment should perform upon episode termination (e.g. start or relaunch an app, clear cache, etc.). This information is packaged into a `Task()` proto message, which gets passed passed to AndroidEnv. * For ready-made example tasks provided with AndroidEnv, check out the [Available tasks](example_tasks.md), featuring Vokram (with Markov Decision Process (MDP)), Pong, DroidFish (a chess clone), Blockinger (a tetris clone), and more. * See the [Tasks guide](tasks_guide.md) for details on features and capabilities of tasks, as well as how to create custom ones. ### Create the environment After setting up the simulator and creating a task, you may find the [`loader.load()`](https://github.com/deepmind/android_env/blob/main/android_env/loader.py) function handy for creating an environment instance by providing relevant arguments, such as: * `task_path`: the path pointing to the `.textproto` file describing the desired task. * `avd_name`: the name of the AVD specified when your created it in Android Studio. * `android_avd_home` (Optional): the path to where the AVD is installed. (default value: `~/.android/avd`). * `android_sdk_root` (Optional): the root directory of the Android SDK. (default value: `~/Android/Sdk`). * `emulator_path` (Optional): the path to the emulator binary. (default: `~/Android/Sdk/emulator/emulator`). * `adb_path` (Optional): the path to the ADB ([Android Debug Bridge](https://developer.android.com/studio/command-line/adb)). (default value: `~/Android/Sdk/platform-tools/adb`). For the AVD name and path, in Android Studio, go to **Tools** > **AVD Manager**, right click on your virtual device, and select **View Details**, where you will find the `avd_name` and its path. For the Android SDK location, in Android Studio, go to **Preferences** > **Appearance & Behavior** > **System Settings** > **Android SDKs** and note the _Android SDK Location_. In the SDK folder, you will find `/emulator/emulator` as well as the ADB path (`/platform-tools/adb`). Your example configuration may look like this, depending on how you set up your emulator: ```python from android_env import loader env = loader.load( avd_name='my_avd', android_avd_home='/Users/username/.android/avd', android_sdk_root='/Users/username/Library/Android/sdk', emulator_path='/Users/username/Library/Android/sdk/emulator/emulator', adb_path='/Users/username/Library/Android/sdk/platform-tools/adb', task_path='/Users/username/android_env/my_tasks/my_task.textproto', ) ``` ## Example RL agent scripts The `examples` directory contains a few simple example agent setups, such as: * [`run_random_agent.py`](https://github.com/deepmind/android_env/blob/main/examples/run_random_agent.py): Runs a simple loop performing randomly selected actions in the environment. * [`run_acme_agent.py`](https://github.com/deepmind/android_env/blob/main/examples/run_acme_agent.py): Runs a training loop with an [Acme](https://deepmind.com/research/publications/Acme) DQN agent, implemented in the popular DeepMind RL framework. This will require to install the [`acme`](https://github.com/deepmind/acme) dependency. * [`run_human_agent.py`](https://github.com/deepmind/android_env/blob/main/examples/run_human_agent.py): Creates a [`pygame`](https://www.pygame.org) instance that lets a human user interact with the environment and observe environment mechanics, such as rewards or task extras. You will need to install the [PyGame] dependency. For instance, here is how you can run [`run_random_agent.py`](https://github.com/8bitmp3/android_env/blob/main/examples/run_random_agent.py) in a folder where you have your APK file, such as [Apple Flinger](https://github.com/deepmind/android_env/blob/main/docs/example_tasks.md#apple-flinger) from [Example tasks](https://github.com/deepmind/android_env/blob/main/docs/example_tasks.md). (The downloaded TAR file contains the APK file and `.textproto` definitions.) ```shell python3 run_random_agent.py \ --avd_name='my_avd' \ --android_avd_home=/Users/username/.android/avd \ --android_sdk_root=/Users/username/Library/Android/sdk \ --emulator_path=/Users/username/Library/Android/sdk/emulator/emulator \ --adb_path=/Users/username/Library/Android/sdk/platform-tools/adb \ --num_steps=1000 \ --task_path=/Users/username//apple_flinger_M_1_1.textproto ``` ================================================ FILE: docs/tasks_guide.md ================================================ # AndroidEnv - Tasks With AndroidEnv we provide a mechanism for easily defining RL tasks for the agent to learn. This includes various types of information such as what app/game it should train on, what rewards the environment returns, or the start state distribution and the episode end criteria. ## Task structure A *task* definition is captured in the form of a `Task()` proto message. These are most easily created by writing a `.textproto` file, then parsing it into a proto message. In this section you can find a detailed description about the types of information that make up a task, and an example demonstrating exactly how to put these into code.
Expand this tab to view the main types of information captured in these messages: * `id`: An ID used to identify the task. * `setup_steps`: These are steps the environment will perform right after launching the simulator. Possible steps include: * `install_apk`: Installs an application from a specified path to the APK file. * `start_activity`: Launches the requested app/activity. * `rotate`: Sets the orientation of the device (landscape/portrait). * `reset_steps`: These are steps the environment will perform right at the beginning of a new RL episode. Possible steps include: * `force_stop`: Stops a given app. * `start_activity`: Launches the requested app/activity. * `start_screen_pinning`: Restricts the agent's interaction to a particular activity through [screen pinning](https://support.google.com/android/answer/9455138?hl=en), meaning the agent will not be able to quit the given app. * `clear_cache`: Clears the cache of a given app. * `success_conditions`: For each success condition defined, the environment will make sure that these conditions were met after finishing `setup_steps` and `reset_steps`. They might include conditions such as: * `check_install`: Makes sure that the request app was successfully installed. * `wait_for_app_screen`: Waits until the request app was successfully launched. * `expected_app_screen`: If this value is set to a particular activity, the environment will periodically check if the agent is still interacting with said activity, making sure it has not accidentally quit the application we want it to be training on. * `max_episode_sec`: Puts a time limit on the episodes, triggering an episode reset if the current episode has lasted too long. * `max_duration_steps`: Puts a step limit on the episodes, triggering an episode reset once the agent has reached the specified limit. * `log_parsing_config`: If the environment is parsing logcat messages, this field will determine what information it should listen for using regular expressions. * `filters`: The environment filters log messages for these labels which signify that such messages were meant to be parsed by AndroidEnv. * `log_regexps`: Once a log message was identified as relevant using the filters, the environment parses its contents using these regular expressions. For example, an application might be sending log messages of the form `reward: 1.0`, then the task will capture this info using the regexp `^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$`.
Expand this tab to see what an example `.textproto` file might look like in practice: ```python id: "classic_2048" name: "Classic 2048 - Default" description: "Slide numbered tiles on a grid to combine them to create a tile with the number 2048" package_name: "com.tpcstld.twozerogame" full_activity_name: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity" # Perform these upon launching the environment setup_steps: [ { # Install the 2048 app adb_call: { install_apk: { filesystem: { path: path/to/classic_2048.apk } } } # Check if it was installed correctly success_condition: { check_install: { package_name: "com.tpcstld.twozerogame" timeout_sec: 10.0 } } }, # Orient the screen in portait mode { adb_call: { rotate: { orientation: PORTRAIT_0 } } } ] # Perform these upon episode resets reset_steps: [ # Stop the 2048 app { adb_call: { force_stop: { package_name: "com.tpcstld.twozerogame" } } }, { adb_call: { clear_cache: { package_name: "com.tpcstld.twozerogame" } } }, # Start the 2048 app { adb_call: { start_activity: { full_activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity" extra_args: [ "--ez", '"RL_TASK_ENABLED"', '"true"', "--es", '"RL_TASK_GAME_CONFIG"', '"{}"' ] } } # Wait until the app has launched successfully success_condition: { wait_for_app_screen: { app_screen: { activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity" view_hierarchy_path: [ ] } timeout_sec: 10.0 } num_retries: 10 } }, # Make sure the agent cannot quit the 2048 app { adb_call: { start_screen_pinning: { full_activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity" } } } ] # Periodically check if the agent has accidentally quit the app expected_app_screen: { activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity" view_hierarchy_path: [] } max_episode_steps: 500 # Capture expected format of log messages log_parsing_config: { filters: ["AndroidRLTask:V"] log_regexps: { score: "^[Ss]core: ([-+]?[0-9]*\\.?[0-9]*)$" reward: "^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$" episode_end: "^episode[ _]end$" extra: "^extra: (?P[^ ]*)[ ]?(?P.*)$" json_extra: "^json_extra: (?P.*)$" } } ```
## Log messages and custom APKs You might have noticed that tasks often rely on log messages exposed by the Android system, which AndroidEnv can intercept and translate into items such as rewards, episode end signals or task extras. One way to define rewards is by using `log_parsing_config.LogRegexps.RewardEvent` messages in the task proto. These consist of a regular expression and a numeric value indicating the intended reward. If the regexp is matched in any of the lines of the logcat stream, the agent will receive the given reward. It is also possible to have multiple of these RewardEvents, allowing us to give rewards for different log messages. The same applies for episode end signals: logcat messages that match the regexps defined in `log_parsing_config.LogRegexps.episode_end` will trigger an episode reset. Of course, applications might not send suitable messages by default, so in order to have access to such messages, we often add them to the apps' source code to match our expectations. For example, in the case of the 2048 app, we find in the game's source code the exact lines where the score is computed, and add a line to log this value in the format that is expected by the textproto (or conversely, make sure the textproto matches the format you specified here). For example: ```java // Make sure thet LOG_FILTER matches 'filters' in the textproto public static final String LOG_FILTER = "AndroidRLTask"; // Make sure that the corresponding part of 'log_regexps' will match this string Log.i(LOG_FILTER, String.format(Locale.ENGLISH, "reward: %r", reward_value)) ``` You can take a look at example APKs extended with log messages in the example tasks (see the section below). ## Example tasks Along with the environment implementation we provide a set of example task definitions. These were chosen so that they would demonstrate the large variety of different challenges (e.g. app navigtion, puzzle games, time-reactive games, adventure games, card games...) and corresponding interfaces (e.g. button pressing, swiping, drag-and-drop...) available in AndroidEnv. You can find a list and detailed description of each of these tasks in [example_tasks.md](example_tasks.md). ================================================ FILE: examples/__init__.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. ================================================ FILE: examples/run_acme_agent.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Acme DQN agent interacting with AndroidEnv.""" from absl import app from absl import flags from absl import logging import acme from acme import specs from acme import wrappers as acme_wrappers from acme.agents.tf import dqn from acme.tf import networks from android_env import loader from android_env.components import config_classes from android_env.wrappers import discrete_action_wrapper from android_env.wrappers import float_pixels_wrapper from android_env.wrappers import image_rescale_wrapper # Simulator args flags.DEFINE_string('avd_name', None, 'Name of AVD to use.') flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.') flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.') flags.DEFINE_string('emulator_path', '~/Android/Sdk/emulator/emulator', 'Path to emulator.') flags.DEFINE_string('adb_path', '~/Android/Sdk/platform-tools/adb', 'Path to ADB.') # Environment args flags.DEFINE_string('task_path', None, 'Path to task textproto file.') # Experiment args flags.DEFINE_integer('num_episodes', 100, 'Number of episodes.') FLAGS = flags.FLAGS def apply_wrappers(env): """Applies a series of wrappers to the environment.""" env = discrete_action_wrapper.DiscreteActionWrapper(env, action_grid=(10, 10)) env = image_rescale_wrapper.ImageRescaleWrapper( env, zoom_factors=(0.25, 0.25)) env = float_pixels_wrapper.FloatPixelsWrapper(env) env = acme_wrappers.SinglePrecisionWrapper(env) return env def main(_): config = config_classes.AndroidEnvConfig( task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), simulator=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( emulator_path=FLAGS.emulator_path, android_sdk_root=FLAGS.android_sdk_root, android_avd_home=FLAGS.android_avd_home, avd_name=FLAGS.avd_name, run_headless=FLAGS.run_headless, ), adb_controller=config_classes.AdbControllerConfig( adb_path=FLAGS.adb_path ), ), ) with loader.load(config) as env: env = apply_wrappers(env) env_spec = specs.make_environment_spec(env) agent = dqn.DQN( environment_spec=env_spec, network=networks.DQNAtariNetwork( num_actions=env_spec.actions.num_values), batch_size=10, samples_per_insert=2, min_replay_size=10) loop = acme.EnvironmentLoop(env, agent) loop.run(num_episodes=FLAGS.num_episodes) if __name__ == '__main__': logging.set_verbosity('info') logging.set_stderrthreshold('info') flags.mark_flags_as_required(['task_path', 'avd_name']) app.run(main) ================================================ FILE: examples/run_human_agent.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Loads an interactive session where a human acts on behalf of an agent.""" import time from typing import Any from absl import app from absl import flags from absl import logging from android_env import loader from android_env.components import action_type from android_env.components import config_classes from android_env.components import pixel_fns import dm_env import numpy as np import pygame # Simulator args. flags.DEFINE_string('avd_name', None, 'Name of AVD to use.') flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.') flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.') flags.DEFINE_string('emulator_path', '~/Android/Sdk/emulator/emulator', 'Path to emulator.') flags.DEFINE_string('adb_path', '~/Android/Sdk/platform-tools/adb', 'Path to ADB.') flags.DEFINE_boolean('run_headless', True, 'Optionally turn off display.') # Environment args. flags.DEFINE_string('task_path', None, 'Path to task textproto file.') # Pygame args. flags.DEFINE_list('screen_size', '480,720', 'Screen width, height in pixels.') flags.DEFINE_float('frame_rate', 1.0/30.0, 'Frame rate in seconds.') FLAGS = flags.FLAGS def _get_action_from_event( event: pygame.event.Event, screen: pygame.Surface, orientation: int ) -> dict[str, Any]: """Returns the current action by reading data from a pygame Event object.""" act_type = action_type.ActionType.LIFT if event.type == pygame.MOUSEBUTTONDOWN: act_type = action_type.ActionType.TOUCH return { 'action_type': np.array(act_type, dtype=np.int32), 'touch_position': _scale_position(event.pos, screen, orientation), } def _get_action_from_mouse( screen: pygame.Surface, orientation: int ) -> dict[str, Any]: """Returns the current action by reading data from the mouse.""" act_type = action_type.ActionType.LIFT if pygame.mouse.get_pressed()[0]: act_type = action_type.ActionType.TOUCH return { 'action_type': np.array(act_type, dtype=np.int32), 'touch_position': _scale_position(pygame.mouse.get_pos(), screen, orientation), } def _scale_position(position: np.ndarray, screen: pygame.Surface, orientation: int) -> np.ndarray: """AndroidEnv accepts mouse inputs as floats so we need to scale it.""" scaled_pos = np.divide(position, screen.get_size(), dtype=np.float32) if orientation == 1: # LANDSCAPE_90 scaled_pos = scaled_pos[::-1] scaled_pos[0] = 1 - scaled_pos[0] return scaled_pos def _accumulate_reward( timestep: dm_env.TimeStep, episode_return: float) -> float: """Accumulates rewards collected over the course of an episode.""" if timestep.reward and timestep.reward != 0: logging.info('Reward: %s', timestep.reward) episode_return += timestep.reward if timestep.first(): episode_return = 0 elif timestep.last(): logging.info('Episode return: %s', episode_return) return episode_return def _render_pygame_frame(surface: pygame.Surface, screen: pygame.Surface, orientation: int, timestep: dm_env.TimeStep) -> None: """Displays latest observation on pygame surface.""" frame = timestep.observation['pixels'][:, :, :3] # (H x W x C) (RGB) frame = pixel_fns.transpose_pixels(frame) # (W x H x C) frame = pixel_fns.orient_pixels(frame, orientation) pygame.surfarray.blit_array(surface, frame) pygame.transform.smoothscale(surface, screen.get_size(), screen) pygame.display.flip() def main(_): pygame.init() pygame.display.set_caption('android_human_agent') config = config_classes.AndroidEnvConfig( task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), simulator=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( emulator_path=FLAGS.emulator_path, android_sdk_root=FLAGS.android_sdk_root, android_avd_home=FLAGS.android_avd_home, avd_name=FLAGS.avd_name, run_headless=FLAGS.run_headless, ), adb_controller=config_classes.AdbControllerConfig( adb_path=FLAGS.adb_path ), ), ) with loader.load(config) as env: # Reset environment. first_timestep = env.reset() orientation = np.argmax(first_timestep.observation['orientation']) # Create pygame canvas. screen_size = list(map(int, FLAGS.screen_size)) # (W x H) obs_shape = env.observation_spec()['pixels'].shape[:2] # (H x W) if (orientation == 1 or orientation == 3): # LANDSCAPE_90 | LANDSCAPE_270 screen_size = screen_size[::-1] obs_shape = obs_shape[::-1] screen = pygame.display.set_mode(screen_size) # takes (W x H) surface = pygame.Surface(obs_shape[::-1]) # takes (W x H) # Start game loop. prev_frame = time.time() episode_return = 0 while True: if pygame.key.get_pressed()[pygame.K_ESCAPE]: return all_events = pygame.event.get() for event in all_events: if event.type == pygame.QUIT: return # Filter event queue for mouse click events. mouse_click_events = [ event for event in all_events if event.type in [pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP] ] # Process all mouse click events. for event in mouse_click_events: action = _get_action_from_event(event, screen, orientation) timestep = env.step(action) episode_return = _accumulate_reward(timestep, episode_return) _render_pygame_frame(surface, screen, orientation, timestep) # Sample the current position of the mouse either way. action = _get_action_from_mouse(screen, orientation) timestep = env.step(action) episode_return = _accumulate_reward(timestep, episode_return) _render_pygame_frame(surface, screen, orientation, timestep) # Limit framerate. now = time.time() frame_time = now - prev_frame if frame_time < FLAGS.frame_rate: time.sleep(FLAGS.frame_rate - frame_time) prev_frame = now if __name__ == '__main__': logging.set_verbosity('info') logging.set_stderrthreshold('info') flags.mark_flags_as_required(['avd_name', 'task_path']) app.run(main) ================================================ FILE: examples/run_random_agent.py ================================================ # coding=utf-8 # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Example script demonstrating usage of AndroidEnv.""" from absl import app from absl import flags from absl import logging from android_env import loader from android_env.components import config_classes from dm_env import specs import numpy as np FLAGS = flags.FLAGS # Simulator args. flags.DEFINE_string('avd_name', None, 'Name of AVD to use.') flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.') flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.') flags.DEFINE_string('emulator_path', '~/Android/Sdk/emulator/emulator', 'Path to emulator.') flags.DEFINE_string('adb_path', '~/Android/Sdk/platform-tools/adb', 'Path to ADB.') flags.DEFINE_bool('run_headless', False, 'Whether to display the emulator window.') # Environment args. flags.DEFINE_string('task_path', None, 'Path to task textproto file.') # Experiment args. flags.DEFINE_integer('num_steps', 1000, 'Number of steps to take.') def main(_): config = config_classes.AndroidEnvConfig( task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path), simulator=config_classes.EmulatorConfig( emulator_launcher=config_classes.EmulatorLauncherConfig( emulator_path=FLAGS.emulator_path, android_sdk_root=FLAGS.android_sdk_root, android_avd_home=FLAGS.android_avd_home, avd_name=FLAGS.avd_name, run_headless=FLAGS.run_headless, ), adb_controller=config_classes.AdbControllerConfig( adb_path=FLAGS.adb_path ), ), ) with loader.load(config) as env: action_spec = env.action_spec() def get_random_action() -> dict[str, np.ndarray]: """Returns a random AndroidEnv action.""" action = {} for k, v in action_spec.items(): if isinstance(v, specs.DiscreteArray): action[k] = np.random.randint(low=0, high=v.num_values, dtype=v.dtype) else: action[k] = np.random.random(size=v.shape).astype(v.dtype) return action _ = env.reset() for step in range(FLAGS.num_steps): action = get_random_action() timestep = env.step(action=action) reward = timestep.reward logging.info('Step %r, action: %r, reward: %r', step, action, reward) if __name__ == '__main__': logging.set_verbosity('info') logging.set_stderrthreshold('info') flags.mark_flags_as_required(['avd_name', 'task_path']) app.run(main) ================================================ FILE: pyproject.toml ================================================ [build-system] requires = [ "setuptools", "wheel" ] build-backend = "setuptools.build_meta" [project] name = "android-env" version = "1.2.2" description = "AndroidEnv environment and library for training agents." authors = [{name = "DeepMind"}] license = {file = "LICENSE"} readme = {text = "Read the README at https://github.com/deepmind/android_env for more information.", content-type = "text/plain"} keywords = ["Android", "OS", "reinforcement-learning"] requires-python = ">=3.10" dependencies = [ "absl-py>=0.1.0", "dm_env", "grpcio", "numpy>=1.21", "portpicker>=1.2.0", "protobuf>=2.6", "pygame", ] [project.optional-dependencies] testing = [ "gym", "pillow", "pytype", ] acme = ["dm-acme"] gym = ["gym"] [project.urls] repository = "https://github.com/deepmind/android_env" deepmind = "https://www.deepmind.com/publications/androidenv-the-android-learning-environment" arxiv = "https://arxiv.org/abs/2105.13231" ================================================ FILE: setup.py ================================================ # Copyright 2026 DeepMind Technologies Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Simple package definition for using with `pip`.""" import importlib import os import setuptools from setuptools import find_packages from setuptools import setup from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py _ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) # Tuple of proto message definitions to build Python bindings for. Paths must # be relative to root directory. _ANDROID_ENV_PROTOS = ( 'android_env/proto/adb.proto', 'android_env/proto/emulator_controller.proto', 'android_env/proto/snapshot.proto', 'android_env/proto/snapshot_service.proto', 'android_env/proto/state.proto', 'android_env/proto/task.proto', 'android_env/proto/a11y/a11y.proto', 'android_env/proto/a11y/android_accessibility_action.proto', 'android_env/proto/a11y/android_accessibility_forest.proto', 'android_env/proto/a11y/android_accessibility_node_info.proto', 'android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto', 'android_env/proto/a11y/android_accessibility_tree.proto', 'android_env/proto/a11y/android_accessibility_window_info.proto', 'android_env/proto/a11y/rect.proto', ) class _GenerateProtoFiles(setuptools.Command): """Command to generate protobuf bindings for AndroidEnv protos.""" descriptions = 'Generates Python protobuf bindings for AndroidEnv protos.' user_options = [] def initialize_options(self): pass def finalize_options(self): pass def run(self): # Import grpc_tools here, after setuptools has installed setup_requires # dependencies. from grpc_tools import protoc # pylint: disable=g-import-not-at-top with importlib.resources.as_file( importlib.resources.files('grpc_tools').joinpath('_proto') ) as path: grpc_protos_include = str(path) for proto_path in _ANDROID_ENV_PROTOS: proto_args = [ 'grpc_tools.protoc', '--proto_path={}'.format(grpc_protos_include), '--proto_path={}'.format(_ROOT_DIR), '--python_out={}'.format(_ROOT_DIR), '--pyi_out={}'.format(_ROOT_DIR), '--grpc_python_out={}'.format(_ROOT_DIR), os.path.join(_ROOT_DIR, proto_path), ] if protoc.main(proto_args) != 0: raise RuntimeError('ERROR: {}'.format(proto_args)) class _BuildExt(build_ext): """Generate protobuf bindings in build_ext stage.""" def run(self): self.run_command('generate_protos') build_ext.run(self) class _BuildPy(build_py): """Generate protobuf bindings in build_py stage.""" def run(self): self.run_command('generate_protos') build_py.run(self) setup( packages=find_packages(exclude=['examples']), package_data={'': ['proto/*.proto']}, # Copy protobuf files. include_package_data=True, setup_requires=['grpcio-tools'], cmdclass={ 'build_ext': _BuildExt, 'build_py': _BuildPy, 'generate_protos': _GenerateProtoFiles, }, )