Repository: google-deepmind/android_env
Branch: main
Commit: 0cdf2711c4e9
Files: 148
Total size: 908.6 KB
Directory structure:
gitextract_70x6n_qo/
├── .github/
│ └── workflows/
│ └── tests.yml
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── android_env/
│ ├── __init__.py
│ ├── apps/
│ │ ├── MODULE.bazel
│ │ ├── java/
│ │ │ └── com/
│ │ │ └── google/
│ │ │ └── androidenv/
│ │ │ ├── accessibilityforwarder/
│ │ │ │ ├── AccessibilityForwarder.kt
│ │ │ │ ├── AccessibilityForwarderTest.kt
│ │ │ │ ├── AccessibilityTreeCreator.kt
│ │ │ │ ├── AccessibilityTreeCreatorTest.kt
│ │ │ │ ├── AndroidManifest.xml
│ │ │ │ ├── AndroidManifest_lite.xml
│ │ │ │ ├── FlagsBroadcastReceiver.kt
│ │ │ │ ├── FlagsBroadcastReceiverTest.kt
│ │ │ │ ├── LogFlags.kt
│ │ │ │ ├── ParentChildNodePair.kt
│ │ │ │ ├── UniqueIdsGenerator.kt
│ │ │ │ └── res/
│ │ │ │ └── xml/
│ │ │ │ └── accessibility_forwarder_service.xml
│ │ │ └── catch/
│ │ │ ├── AndroidManifest.xml
│ │ │ ├── BUILD.bazel
│ │ │ ├── GameLogic.kt
│ │ │ ├── GameLogicThread.kt
│ │ │ ├── MainActivity.kt
│ │ │ ├── RenderThread.kt
│ │ │ ├── res/
│ │ │ │ ├── layout/
│ │ │ │ │ └── main.xml
│ │ │ │ └── values/
│ │ │ │ └── strings.xml
│ │ │ └── sprite/
│ │ │ ├── BUILD.bazel
│ │ │ ├── Background.kt
│ │ │ ├── Ball.kt
│ │ │ ├── LineSegment.kt
│ │ │ ├── Paddle.kt
│ │ │ ├── Point.kt
│ │ │ └── Sprite.kt
│ │ └── javatests/
│ │ └── com/
│ │ └── google/
│ │ └── androidenv/
│ │ └── catch/
│ │ ├── AndroidManifest.xml
│ │ ├── BUILD.bazel
│ │ ├── GameLogicTest.kt
│ │ ├── GameLogicThreadTest.kt
│ │ ├── MainActivityTest.kt
│ │ ├── RenderThreadTest.kt
│ │ └── sprite/
│ │ ├── BUILD.bazel
│ │ ├── BackgroundTest.kt
│ │ ├── BallTest.kt
│ │ ├── PaddleTest.kt
│ │ └── SpriteTest.kt
│ ├── components/
│ │ ├── __init__.py
│ │ ├── action_fns.py
│ │ ├── action_fns_test.py
│ │ ├── action_type.py
│ │ ├── adb_call_parser.py
│ │ ├── adb_call_parser_test.py
│ │ ├── adb_controller.py
│ │ ├── adb_controller_test.py
│ │ ├── adb_log_stream.py
│ │ ├── adb_log_stream_test.py
│ │ ├── app_screen_checker.py
│ │ ├── app_screen_checker_test.py
│ │ ├── config_classes.py
│ │ ├── coordinator.py
│ │ ├── coordinator_test.py
│ │ ├── device_settings.py
│ │ ├── device_settings_test.py
│ │ ├── dumpsys_thread.py
│ │ ├── dumpsys_thread_test.py
│ │ ├── errors.py
│ │ ├── errors_test.py
│ │ ├── log_stream.py
│ │ ├── log_stream_test.py
│ │ ├── logcat_thread.py
│ │ ├── logcat_thread_test.py
│ │ ├── pixel_fns.py
│ │ ├── pixel_fns_test.py
│ │ ├── setup_step_interpreter.py
│ │ ├── setup_step_interpreter_test.py
│ │ ├── simulators/
│ │ │ ├── __init__.py
│ │ │ ├── base_simulator.py
│ │ │ ├── base_simulator_test.py
│ │ │ ├── emulator/
│ │ │ │ ├── __init__.py
│ │ │ │ ├── emulator_launcher.py
│ │ │ │ ├── emulator_launcher_test.py
│ │ │ │ ├── emulator_simulator.py
│ │ │ │ └── emulator_simulator_test.py
│ │ │ └── fake/
│ │ │ ├── __init__.py
│ │ │ ├── fake_simulator.py
│ │ │ └── fake_simulator_test.py
│ │ ├── specs.py
│ │ ├── specs_test.py
│ │ ├── task_manager.py
│ │ └── task_manager_test.py
│ ├── env_interface.py
│ ├── environment.py
│ ├── environment_test.py
│ ├── loader.py
│ ├── loader_test.py
│ ├── proto/
│ │ ├── __init__.py
│ │ ├── a11y/
│ │ │ ├── __init__.py
│ │ │ ├── a11y.proto
│ │ │ ├── android_accessibility_action.proto
│ │ │ ├── android_accessibility_forest.proto
│ │ │ ├── android_accessibility_node_info.proto
│ │ │ ├── android_accessibility_node_info_clickable_span.proto
│ │ │ ├── android_accessibility_tree.proto
│ │ │ ├── android_accessibility_window_info.proto
│ │ │ └── rect.proto
│ │ ├── adb.proto
│ │ ├── emulator_controller.proto
│ │ ├── snapshot.proto
│ │ ├── snapshot_service.proto
│ │ ├── state.proto
│ │ └── task.proto
│ └── wrappers/
│ ├── __init__.py
│ ├── a11y/
│ │ ├── __init__.py
│ │ ├── a11y_events.py
│ │ ├── a11y_events_test.py
│ │ ├── a11y_forests.py
│ │ ├── a11y_forests_test.py
│ │ ├── a11y_servicer.py
│ │ └── a11y_servicer_test.py
│ ├── a11y_grpc_wrapper.py
│ ├── a11y_grpc_wrapper_test.py
│ ├── base_wrapper.py
│ ├── base_wrapper_test.py
│ ├── discrete_action_wrapper.py
│ ├── discrete_action_wrapper_test.py
│ ├── flat_interface_wrapper.py
│ ├── flat_interface_wrapper_test.py
│ ├── float_pixels_wrapper.py
│ ├── float_pixels_wrapper_test.py
│ ├── gym_wrapper.py
│ ├── gym_wrapper_test.py
│ ├── image_rescale_wrapper.py
│ ├── image_rescale_wrapper_test.py
│ ├── last_action_wrapper.py
│ ├── last_action_wrapper_test.py
│ ├── rate_limit_wrapper.py
│ ├── rate_limit_wrapper_test.py
│ ├── tap_action_wrapper.py
│ └── tap_action_wrapper_test.py
├── docs/
│ ├── emulator_guide.md
│ ├── environment.md
│ ├── example_tasks.md
│ ├── instructions.md
│ └── tasks_guide.md
├── examples/
│ ├── __init__.py
│ ├── run_acme_agent.py
│ ├── run_human_agent.py
│ └── run_random_agent.py
├── pyproject.toml
└── setup.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/workflows/tests.yml
================================================
name: tests
on:
push:
branches: [ main ]
pull_request:
branches: [ main ]
workflow_dispatch:
inputs:
git-ref:
description: Git Ref (Optional)
required: false
jobs:
build:
runs-on: ubuntu-latest
env:
TEST_TMPDIR: '/tmp'
strategy:
matrix:
python-version: ["3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v6
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
pip install --upgrade pip setuptools
python setup.py install
pip install .[testing]
- name: Run tests
run: |
# Find all test files, print their names and execute them in parallel
# with a maximum of 20 proccesses.
find . -type f -name "*_test.py" -print0 | xargs -t -0 -n1 -P 20 python3
================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute
# Pull Requests
Please send in fixes or feature additions through Pull Requests.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement. You (or your employer) retain the copyright to your contribution,
this simply gives us permission to use and redistribute your contributions as
part of the project. Head over to to see
your current agreements on file or to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# AndroidEnv - The Android Learning Environment
[AndroidEnv](https://github.com/deepmind/android_env) is a Python library that
exposes an [Android](https://www.android.com/) device as a Reinforcement
Learning (RL) environment. The library provides a flexible platform for defining
custom tasks on top of the Android Operating System, including any Android
application. Agents interact with the device through a universal action
interface - the touchscreen - by sending localized touch and lift events to the
system. The library processes these events and returns pixel observations and
rewards as provided by specific [task definitions](docs/tasks_guide.md). For
example, rewards might be given for events such as successfully scrolling down a
page, sending an email, or achieving some score in a game, depending on the
research purpose and how the user configures the task.
[](https://github.com/deepmind/android_env/actions/workflows/tests.yml)
[](https://badge.fury.io/py/android-env)
[](https://pepy.tech/project/android-env)
## Index
* [Environment details](docs/environment.md)
* [Running AndroidEnv](docs/instructions.md)
* [Setting up a virtual Android device](docs/emulator_guide.md)
* [Defining a task in AndroidEnv](docs/tasks_guide.md)
* [Example tasks available for download](docs/example_tasks.md)
## Environment features
There are a number of aspects that make AndroidEnv a challenging yet suitable
environment for Reinforcement Learning research:
* Allowing agents to interact with a system used daily by billions of users
around the world, AndroidEnv offers a platform for RL agents to navigate,
learn tasks and have direct impact in **real-world contexts**. The
environment wraps a simulated Android device, which runs independently from
the environment, completely unaltered, and works in exactly the same way as
the devices that humans use, exposing exactly the same features and
services.
* The platform offers a virtually infinite **range of possible tasks**, all
sharing a common action interface. The library facilitates the design of
Reinforcement Learning tasks for any existing or custom built Android
application. For example, it exposes the broad world of Android games,
ranging from card games, puzzle games, time reactive games, all requiring a
diverse set of action combinations and interaction types.
* The environment runs on top of a **real-time simulation** of an Android
device. In other words, the environment dynamics does not wait for the agent
to deliberate, and the speed of the simulation cannot be increased.
* The observation is a collection of **RGB values** corresponding to the
displayed pixels on the screen. The exact screen resolution depends on the
simulated device, but in general it will be considered relatively large in
an RL context. However, users have the option of downsampling each
observation.
* The learning environment has an interesting, **complex action space** unique
to the touchscreen interface of Android.
* The raw, **hybrid action space** consists of a continuous tuple
signifying the action location, and a discrete signal determining
whether the agent wants to touch the screen or lift its virtual finger.
* Raw actions are highly **composable**: the Android UI and most
applications were designed so that they could be intuitively navigated
via common
[touchscreen gestures](https://developer.android.com/training/gestures/detector)
such as tapping, scrolling, swiping, pinching, drag & drop etc. This is
still the case in AndroidEnv: to trigger meaningful changes in the
environment, the agent often has to perform carefully timed and
positioned sequences of raw actions. For example, in order to navigate
to the next image in a photo gallery, the agent would have to perform a
*swipe*, touching the screen multiple times, gradually shifting the
actions' positions to the right. Thus, in most contexts raw actions do
not trigger changes in the state of the environment unless correctly
chained together to make up a human gesture.
* The action interface is **closely related to the observation space**, as
meaningful touch and lift events are often either co-localized or
strongly correlated to the location or movement of salient objects in
the observation. For example, the position of a button on the screen
aligns with the location of the actions that trigger the button press.
* The library provides tools for flexibly **altering the action
interface** if needed for particular studies, such as discretization or
hard-coding gesture skills. Still, we believe that the real challenge
remains in devising agents that are capable of dealing with a large
suite of diverse tasks, through acting and learning in the complex
unifying action interface.
# Getting started
### Installation
The easiest way to get AndroidEnv is with pip:
```shell
$ python3 -m pip install android-env
```
Please note that `/examples` are not included in this package.
Alternatively, you can clone the repository from git's `main` branch:
```shell
$ git clone https://github.com/deepmind/android_env/
$ cd android_env
$ python3 setup.py install
```
Update: the environment now runs on Windows, but please keep in mind that this
option is not well-maintained or widely supported, as Unix-based systems are the
primary target platforms of this project.
### Create a simulator
Before running the environment, you will need access to an emulated Android
device. For instructions on creating a virtual Android device, see the
[Emulator guide](docs/emulator_guide.md).
### Define a task
Then, you will want to define what the agent's *task* is. At this point, the
agent will be able to communicate with the emulated device, but it will not yet
have an objective, or access to signals such as rewards or RL episode ends.
Learn [how to define an RL task](docs/tasks_guide.md) of your own, or use one of
the [existing task definitions](docs/example_tasks.md) for training.
### Load and run
To find out how to run and train agents on AndroidEnv, see these
[detailed instructions](docs/instructions.md). Here you can also find example
scripts demonstrating how to run a random agent, an
[acme](https://github.com/deepmind/acme) agent, or a human agent on AndroidEnv.
## About
This library is developed and maintained by [DeepMind](http://deepmind.com). \
You can find the [technical report](https://arxiv.org/abs/2105.13231) on Arxiv,
as well as an introductory
[blog
post](https://www.deepmind.com/publications/androidenv-the-android-learning-environment)
on DeepMind's website.
If you use AndroidEnv in your research, you can cite the paper using the
following BibTeX:
```
@article{ToyamaEtAl2021AndroidEnv,
title = {{AndroidEnv}: A Reinforcement Learning Platform for Android},
author = {Daniel Toyama and Philippe Hamel and Anita Gergely and
Gheorghe Comanici and Amelia Glaese and Zafarali Ahmed and Tyler
Jackson and Shibl Mourad and Doina Precup},
year = {2021},
eprint = {2105.13231},
archivePrefix = {arXiv},
primaryClass = {cs.LG},
volume = {abs/2105.13231},
url = {http://arxiv.org/abs/2105.13231},
}
```
Disclaimer: This is not an official Google product.
================================================
FILE: android_env/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/apps/MODULE.bazel
================================================
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Bazel dependencies for building Catch.
module(
name = "catch",
version = "1.0",
)
bazel_dep(name = "rules_android", version = "0.6.6")
bazel_dep(name = "rules_kotlin", version = "2.1.8")
bazel_dep(name = "rules_jvm_external", version = "6.7")
bazel_dep(name = "rules_robolectric", version = "4.16", repo_name = "robolectric")
bazel_dep(name = "rules_java", version = "9.0.3")
bazel_dep(name = "protobuf", version = "30.0")
# To avoid conflict with different protobuf versions.
single_version_override(
module_name = "protobuf",
version = "30.0",
)
maven = use_extension("@rules_jvm_external//:extensions.bzl", "maven")
# Need to set testonly = True because the package depends on testonly targets.
maven.artifact(
testonly = True,
artifact = "runner",
group = "androidx.test",
version = "1.7.0",
)
maven.artifact(
testonly = True,
artifact = "junit",
group = "androidx.test.ext",
version = "1.3.0",
)
maven.artifact(
testonly = True,
artifact = "mockito-kotlin",
group = "org.mockito.kotlin",
version = "6.1.0",
)
maven.install(
artifacts = [
"androidx.test.ext:junit:1.3.0",
"androidx.test:runner:1.7.0",
"com.google.guava:guava:32.0.1-jre",
"com.google.truth:truth:1.4.0",
"org.mockito.kotlin:mockito-kotlin:6.1.0",
"org.mockito:mockito-core:5.20.0",
"org.robolectric:robolectric:4.16",
"org.yaml:snakeyaml:2.5",
],
repositories = [
"https://maven.google.com",
"https://repo1.maven.org/maven2",
],
)
use_repo(maven, "maven")
remote_android_extensions = use_extension(
"@rules_android//bzlmod_extensions:android_extensions.bzl",
"remote_android_tools_extensions",
)
use_repo(remote_android_extensions, "android_tools")
android_sdk_repository_extension = use_extension("@rules_android//rules/android_sdk_repository:rule.bzl", "android_sdk_repository_extension")
use_repo(android_sdk_repository_extension, "androidsdk")
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarder.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import android.accessibilityservice.AccessibilityService
import android.util.Log
import android.view.accessibility.AccessibilityEvent
import android.view.accessibility.AccessibilityNodeInfo
import android.view.accessibility.AccessibilityWindowInfo
import com.google.androidenv.accessibilityforwarder.A11yServiceGrpcKt.A11yServiceCoroutineStub
import io.grpc.ManagedChannel
import io.grpc.ManagedChannelBuilder
import io.grpc.ProxyDetector
import io.grpc.StatusException
import kotlinx.coroutines.TimeoutCancellationException
import kotlinx.coroutines.runBlocking
import kotlinx.coroutines.withTimeout
/**
* An Android service that listens to accessibility events and forwards them via gRPC.
*
* This service also logs the accessibility tree if [LogFlags.logAccessibilityTree] is set and if
* [LogFlags.grpcPort] is positive.
*
* Please see
* https://developer.android.com/reference/android/view/accessibility/AccessibilityEvent#getEventType()
* for a comprehensive list of events emitted by Android.
*/
class AccessibilityForwarder(
private val channelFactory: (host: String, port: Int) -> ManagedChannel = { host, port ->
ManagedChannelBuilder.forAddress(host, port)
.proxyDetector(ProxyDetector { _ -> null })
.usePlaintext()
.build()
}
) : AccessibilityService() {
init {
// Spawn long-running thread for periodically logging the tree.
Thread(
Runnable {
while (LogFlags.a11yTreePeriodMs > 0) {
try {
logAccessibilityTree()
} catch (e: ConcurrentModificationException) {
continue
}
Thread.sleep(/* millis= */ LogFlags.a11yTreePeriodMs)
}
}
)
.start()
}
// grpcStub has a backing property that can be reset to null.
private var _grpcStub: A11yServiceCoroutineStub? = null
val grpcStub: A11yServiceCoroutineStub
get() {
if (_grpcStub == null) {
Log.i(TAG, "Building channel on ${LogFlags.grpcHost}:${LogFlags.grpcPort}.")
_grpcStub = A11yServiceCoroutineStub(channelFactory(LogFlags.grpcHost, LogFlags.grpcPort))
}
return _grpcStub!!
}
private fun resetGrpcStub() {
_grpcStub = null
}
override fun onInterrupt() {
LogFlags.a11yTreePeriodMs = 0 // Turn off periodic tree forwarding.
}
override fun onAccessibilityEvent(event: AccessibilityEvent?) {
if (event == null) {
Log.i(TAG, "`event` is null.")
return
}
logExtrasForEvent(event)
val eventType = event.eventType
val eventTypeStr: String = AccessibilityEvent.eventTypeToString(eventType)
if (eventTypeStr.isNotEmpty()) {
Log.i(TAG, eventTypeStr)
}
}
private fun logAccessibilityTree() {
if (!LogFlags.logAccessibilityTree) {
Log.i(TAG, "Not logging accessibility tree")
return
}
val windows = getWindowsOrNull()
if (windows == null) {
Log.i(TAG, "windows is null.")
return
}
// Check gRPC port before actually building the forest.
if (LogFlags.grpcPort <= 0) {
Log.w(TAG, "Can't log accessibility tree because gRPC port has not been set.")
return
}
val forest = creator.buildForest(windows)
try {
val grpcTimeoutMillis = 1000L
val response: ForestResponse =
with(grpcStub) {
Log.i(TAG, "sending (blocking) gRPC request for tree.")
runBlocking { withTimeout(grpcTimeoutMillis) { sendForest(forest) } }
}
if (response.error.isNotEmpty()) {
Log.w(TAG, "gRPC response.error: ${response.error}")
} else {
Log.i(TAG, "gRPC request for tree succeeded.")
}
} catch (e: StatusException) {
Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?")
Log.i(TAG, "extra: exception ['$e']")
resetGrpcStub()
} catch (e: TimeoutCancellationException) {
Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?")
Log.i(TAG, "extra: exception ['$e']")
resetGrpcStub()
}
}
private fun getWindowsOrNull(): List? =
try {
windows
} catch (e: NullPointerException) {
null
}
/** Logs extras for all event types. */
private fun logExtrasForEvent(event: AccessibilityEvent) {
val events: MutableMap = mutableMapOf()
val sourceDescription = event.source?.contentDescription()
if (!sourceDescription.isNullOrEmpty()) {
events.put("source_content_description", sourceDescription)
}
// Output the event text.
val eventText = event.text.joinToString(", ")
if (eventText.isNotEmpty()) {
events.put("event_text", eventText)
}
// Output the source text.
val sourceText = event.source?.text?.toString()
if (!sourceText.isNullOrEmpty()) {
events.put("source_text", sourceText)
}
val eventTypeStr: String = AccessibilityEvent.eventTypeToString(event.eventType)
if (eventTypeStr.isNotEmpty()) {
events.put("event_type", eventTypeStr)
}
val className = event.source?.className?.toString()
if (!className.isNullOrEmpty()) {
events.put("source_class_name", className)
}
val packageName = event.packageName?.toString()
if (!packageName.isNullOrEmpty()) {
events.put("event_package_name", packageName)
}
// Text editing properties.
val beforeText = event.beforeText?.toString()
if (!beforeText.isNullOrEmpty()) {
events.put("before_text", beforeText)
}
val fromIndex = event.fromIndex
if (fromIndex != -1) {
events.put("from_index", fromIndex.toString())
}
val toIndex = event.toIndex
if (toIndex != -1) {
events.put("to_index", toIndex.toString())
}
val addedCount = event.addedCount
if (addedCount != -1) {
events.put("added_count", addedCount.toString())
}
val removedCount = event.removedCount
if (removedCount != -1) {
events.put("removed_count", removedCount.toString())
}
// Text traversal properties
val movementGranularity = event.movementGranularity
if (movementGranularity != 0) {
events.put("movement_granularity", movementGranularity.toString())
}
val action = event.action
if (action != 0) {
events.put("action", action.toString())
}
// Scrolling properties.
if (eventTypeStr == "TYPE_VIEW_SCROLLED") {
events.put("scroll_delta_x", event.scrollDeltaX.toString())
events.put("scroll_delta_y", event.scrollDeltaY.toString())
}
// Report viewID so we know exactly where the event came from.
val viewId = event.source?.viewIdResourceName?.toString()
if (!viewId.isNullOrEmpty()) {
events.put("view_id", viewId)
}
// Format [events] as a Python dict.
if (events.isNotEmpty()) {
events.put("event_timestamp_ms", event.eventTime.toString(10))
// Check if we want to use gRPC.
if (LogFlags.grpcPort > 0) {
try {
val grpcTimeoutMillis = 1000L
val request = eventRequest { this.event.putAll(events) }
val response: EventResponse =
with(grpcStub) {
Log.i(TAG, "sending (blocking) gRPC request for event.")
runBlocking { withTimeout(grpcTimeoutMillis) { sendEvent(request) } }
}
if (response.error.isNotEmpty()) {
Log.w(TAG, "gRPC response.error: ${response.error}")
} else {
Log.i(TAG, "gRPC request for event succeeded.")
}
} catch (e: StatusException) {
Log.w(TAG, "gRPC StatusException; are you sure networking is turned on?")
Log.i(TAG, "extra: exception ['$e']")
resetGrpcStub()
} catch (e: TimeoutCancellationException) {
Log.w(TAG, "gRPC TimeoutCancellationException; are you sure networking is turned on?")
Log.i(TAG, "extra: exception ['$e']")
resetGrpcStub()
}
} else {
Log.w(TAG, "Can't log accessibility event because gRPC port has not been set.")
}
}
}
/** Recursively climbs the accessibility tree until the root, collecting descriptions. */
private fun AccessibilityNodeInfo?.contentDescription(): String {
if (this == null) {
return ""
}
val descriptions = mutableListOf()
var current: AccessibilityNodeInfo? = this
while (current != null) {
val description = current.contentDescription
if (description != null) {
descriptions.add(description.toString())
}
current = current.parent
}
return descriptions.joinToString(", ")
}
companion object {
private const val TAG = "AndroidRLTask"
private val creator = AccessibilityTreeCreator()
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityForwarderTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import android.view.accessibility.AccessibilityEvent
import android.view.accessibility.AccessibilityNodeInfo
import android.view.accessibility.AccessibilityWindowInfo
import com.google.common.truth.Truth.assertThat
import io.grpc.Status
import io.grpc.StatusException
import io.grpc.inprocess.InProcessChannelBuilder
import io.grpc.inprocess.InProcessServerBuilder
import io.grpc.testing.GrpcCleanupRule
import org.junit.Assert.assertFalse
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestParameterInjector
import org.robolectric.Shadows.shadowOf
@RunWith(RobolectricTestParameterInjector::class)
class AccessibilityForwarderTest {
@get:Rule(order = 1) val cleanupRule = GrpcCleanupRule()
class FakeAccessibilityService : A11yServiceGrpcKt.A11yServiceCoroutineImplBase() {
var sendForestChecker: (AndroidAccessibilityForest) -> String = { _ -> "" }
var sendEventChecker: (EventRequest) -> String = { _ -> "" }
override suspend fun sendForest(request: AndroidAccessibilityForest) = forestResponse {
error = sendForestChecker(request)
}
override suspend fun sendEvent(request: EventRequest) = eventResponse {
error = sendEventChecker(request)
}
}
protected lateinit var forwarder: AccessibilityForwarder
protected val fakeA11yService = FakeAccessibilityService()
protected val channel by lazy {
val serverName: String = InProcessServerBuilder.generateName()
cleanupRule.register(
InProcessServerBuilder.forName(serverName)
.directExecutor()
.addService(fakeA11yService)
.build()
.start()
)
cleanupRule.register(InProcessChannelBuilder.forName(serverName).directExecutor().build())
}
/** Initializes [forwarder] and [LogFlags] from the given args. */
fun createForwarder(
logAccessibilityTree: Boolean = false,
a11yTreePeriodMs: Long = 0,
grpcHost: String = "10.0.2.2",
grpcPort: Int = 0,
a11yWindows: MutableList? = null,
) {
LogFlags.logAccessibilityTree = logAccessibilityTree
LogFlags.a11yTreePeriodMs = a11yTreePeriodMs
LogFlags.grpcHost = grpcHost
LogFlags.grpcPort = grpcPort
forwarder = AccessibilityForwarder({ _, _ -> channel })
if (a11yWindows == null) {
shadowOf(forwarder).setWindows(mutableListOf(AccessibilityWindowInfo.obtain()))
} else {
shadowOf(forwarder).setWindows(a11yWindows)
}
}
@Test
fun onInterrupt_doesNotCrash() {
// Arrange.
createForwarder(logAccessibilityTree = false)
fakeA11yService.sendEventChecker = { _: EventRequest ->
assertFalse(true) // This should not be called.
"" // This should be unreachable
}
// Act.
forwarder.onInterrupt()
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_nullEventShouldBeIgnored() {
// Arrange.
createForwarder(logAccessibilityTree = false)
fakeA11yService.sendEventChecker = { _: EventRequest ->
assertFalse(true) // This should not be called.
"" // This should be unreachable
}
// Act.
forwarder.onAccessibilityEvent(null)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_knownEventWithNoInformationShouldNotBeEmitted() {
// Arrange.
createForwarder(logAccessibilityTree = false)
var nodeInfo = AccessibilityNodeInfo()
nodeInfo.setContentDescription("")
var event = AccessibilityEvent()
shadowOf(event).setSourceNode(nodeInfo)
fakeA11yService.sendEventChecker = { _: EventRequest ->
assertFalse(true) // This should not be called.
"" // This should be unreachable
}
// Act.
forwarder.onAccessibilityEvent(event)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_typeViewClicked_sendEventViaGrpc() {
// Arrange.
createForwarder(logAccessibilityTree = false, grpcPort = 1234)
forwarder = AccessibilityForwarder({ _, _ -> channel })
var nodeInfo = AccessibilityNodeInfo()
nodeInfo.setContentDescription("My Content Description")
nodeInfo.setText("My Source Text")
nodeInfo.setClassName("AwesomeClass")
var event = AccessibilityEvent()
event.setEventTime(1357924680)
event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED)
event.getText().add("Some text!")
event.setPackageName("some.loooong.package.name")
shadowOf(event).setSourceNode(nodeInfo)
fakeA11yService.sendEventChecker = { request: EventRequest ->
// Check that all fields are consistent with how they were set above.
assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_CLICKED")
assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
assertThat(request.eventMap.get("source_content_description"))
.isEqualTo("My Content Description")
assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
// No error message
""
}
// Act.
forwarder.onAccessibilityEvent(event)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_typeViewTextChanged_ensureAllFieldsForwarded() {
// Arrange.
createForwarder(logAccessibilityTree = false, grpcPort = 1234)
var nodeInfo = AccessibilityNodeInfo()
nodeInfo.setContentDescription("My Content Description")
nodeInfo.setText("My Source Text")
nodeInfo.setClassName("AwesomeClass")
var event = AccessibilityEvent()
event.setEventTime(1357924680)
event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_CHANGED)
event.getText().add("Some text!")
event.fromIndex = 7
event.beforeText = "Old words"
event.addedCount = 12
event.removedCount = 9
event.setPackageName("some.loooong.package.name")
shadowOf(event).setSourceNode(nodeInfo)
fakeA11yService.sendEventChecker = { request: EventRequest ->
// Check that all fields are consistent with how they were set above.
assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_TEXT_CHANGED")
assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
assertThat(request.eventMap.get("source_content_description"))
.isEqualTo("My Content Description")
assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
assertThat(request.eventMap.get("from_index")).isEqualTo("7")
assertThat(request.eventMap.get("before_text")).isEqualTo("Old words")
assertThat(request.eventMap.get("added_count")).isEqualTo("12")
assertThat(request.eventMap.get("removed_count")).isEqualTo("9")
assertFalse(request.eventMap.containsKey("to_index"))
assertFalse(request.eventMap.containsKey("view_id"))
assertFalse(request.eventMap.containsKey("action"))
assertFalse(request.eventMap.containsKey("movement_granularity"))
assertFalse(request.eventMap.containsKey("scroll_delta_x"))
assertFalse(request.eventMap.containsKey("scroll_delta_y"))
// No error message
""
}
// Act.
forwarder.onAccessibilityEvent(event)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_typeViewScrolled_ensureAllFieldsForwarded() {
// Arrange.
createForwarder(logAccessibilityTree = false, grpcPort = 1234)
var nodeInfo = AccessibilityNodeInfo()
nodeInfo.setContentDescription("My Content Description")
nodeInfo.setText("My Source Text")
nodeInfo.setClassName("AwesomeClass")
var event = AccessibilityEvent()
event.setEventTime(1357924680)
event.setEventType(AccessibilityEvent.TYPE_VIEW_SCROLLED)
event.getText().add("Some text!")
event.scrollDeltaX = 13
event.scrollDeltaY = 27
event.setPackageName("some.loooong.package.name")
shadowOf(event).setSourceNode(nodeInfo)
fakeA11yService.sendEventChecker = { request: EventRequest ->
// Check that all fields are consistent with how they were set above.
assertThat(request.eventMap.get("event_type")).isEqualTo("TYPE_VIEW_SCROLLED")
assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
assertThat(request.eventMap.get("source_content_description"))
.isEqualTo("My Content Description")
assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
assertThat(request.eventMap.get("scroll_delta_x")).isEqualTo("13")
assertThat(request.eventMap.get("scroll_delta_y")).isEqualTo("27")
assertFalse(request.eventMap.containsKey("from_index"))
assertFalse(request.eventMap.containsKey("to_index"))
assertFalse(request.eventMap.containsKey("before_text"))
assertFalse(request.eventMap.containsKey("added_count"))
assertFalse(request.eventMap.containsKey("removed_count"))
// No error message
""
}
// Act.
forwarder.onAccessibilityEvent(event)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_typeViewTextTraversedAtMovementGranularity_ensureAllFieldsForwarded() {
// Arrange.
createForwarder(logAccessibilityTree = false, grpcPort = 1234)
var nodeInfo = AccessibilityNodeInfo()
nodeInfo.setContentDescription("My Content Description")
nodeInfo.setText("My Source Text")
nodeInfo.setClassName("AwesomeClass")
nodeInfo.viewIdResourceName = "this.big.old.view.id"
var event = AccessibilityEvent()
event.setEventTime(1357924680)
event.setEventType(AccessibilityEvent.TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY)
event.getText().add("Some text!")
event.setPackageName("some.loooong.package.name")
event.movementGranularity = 5
event.fromIndex = 6
event.toIndex = 8
event.action = 23
shadowOf(event).setSourceNode(nodeInfo)
fakeA11yService.sendEventChecker = { request: EventRequest ->
// Check that all fields are consistent with how they were set above.
assertThat(request.eventMap.get("event_type"))
.isEqualTo("TYPE_VIEW_TEXT_TRAVERSED_AT_MOVEMENT_GRANULARITY")
assertThat(request.eventMap.get("event_package_name")).isEqualTo("some.loooong.package.name")
assertThat(request.eventMap.get("source_content_description"))
.isEqualTo("My Content Description")
assertThat(request.eventMap.get("source_text")).isEqualTo("My Source Text")
assertThat(request.eventMap.get("source_class_name")).isEqualTo("AwesomeClass")
assertThat(request.eventMap.get("event_text")).isEqualTo("Some text!")
assertThat(request.eventMap.get("event_timestamp_ms")).isEqualTo("1357924680")
assertThat(request.eventMap.get("movement_granularity")).isEqualTo("5")
assertThat(request.eventMap.get("from_index")).isEqualTo("6")
assertThat(request.eventMap.get("to_index")).isEqualTo("8")
assertThat(request.eventMap.get("view_id")).isEqualTo("this.big.old.view.id")
assertThat(request.eventMap.get("action")).isEqualTo("23")
// No error message
""
}
// Act.
forwarder.onAccessibilityEvent(event)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_sendingevent_grpcTimeout() {
// Arrange.
createForwarder(
logAccessibilityTree = false,
a11yTreePeriodMs = 0,
grpcHost = "amazing.host",
grpcPort = 4321,
)
var nodeInfo = AccessibilityNodeInfo()
nodeInfo.setContentDescription("My Content Description")
nodeInfo.setText("My Source Text")
nodeInfo.setClassName("AwesomeClass")
var event = AccessibilityEvent()
event.setEventTime(1357924680)
event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED)
event.getText().add("Some text!")
event.setPackageName("some.loooong.package.name")
shadowOf(event).setSourceNode(nodeInfo)
fakeA11yService.sendEventChecker = { _ ->
// Delay the request to prompt a timeout
Thread.sleep(1500L)
"" // Return no error.
}
// Act.
forwarder.onAccessibilityEvent(event)
// Run a second request to ensure that the channel gets rebuilt.
fakeA11yService.sendEventChecker = { _ -> "" }
forwarder.onAccessibilityEvent(event)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun onAccessibilityEvent_sendingevent_grpcStatusException() {
// Arrange.
createForwarder(logAccessibilityTree = false, grpcHost = "amazing.host", grpcPort = 4321)
var nodeInfo = AccessibilityNodeInfo()
nodeInfo.setContentDescription("My Content Description")
nodeInfo.setText("My Source Text")
nodeInfo.setClassName("AwesomeClass")
var event = AccessibilityEvent()
event.setEventTime(1357924680)
event.setEventType(AccessibilityEvent.TYPE_VIEW_CLICKED)
event.getText().add("Some text!")
event.setPackageName("some.loooong.package.name")
shadowOf(event).setSourceNode(nodeInfo)
fakeA11yService.sendEventChecker = { _ -> throw StatusException(Status.UNAVAILABLE) }
// Act.
forwarder.onAccessibilityEvent(event)
// Run a second request to ensure that the channel gets rebuilt.
fakeA11yService.sendEventChecker = { _ -> "" }
forwarder.onAccessibilityEvent(event)
// Assert.
// See `sendEventChecker` above.
}
@Test
fun logAccessibilityTreeFalse_doesNotLogAccessibilityTree() {
// Arrange.
createForwarder(logAccessibilityTree = false, a11yTreePeriodMs = 10, grpcPort = 13579)
fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest ->
assertFalse(true) // This should not be called.
"" // This should be unreachable
}
// Act.
Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
// Assert.
// See `sendForestChecker` above.
}
@Test
fun grpcPortZero_doesNotSendTree() {
// Arrange.
createForwarder(logAccessibilityTree = true, a11yTreePeriodMs = 10, grpcPort = 0)
fakeA11yService.sendForestChecker = { _: AndroidAccessibilityForest ->
assertFalse(true) // This should not be called.
"" // This should be unreachable
}
// Act.
Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
// Assert.
// See `sendForestChecker` above.
}
@Test
fun grpcPortPositive_shouldSendTreeViaGrpc() {
// Arrange.
val window = AccessibilityWindowInfo()
shadowOf(window).setType(AccessibilityWindowInfo.TYPE_SYSTEM)
createForwarder(
logAccessibilityTree = true,
a11yTreePeriodMs = 10,
grpcPort = 1234,
a11yWindows = mutableListOf(window),
)
fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest ->
// Check that we get only a single window.
assertThat(request.windowsList.size).isEqualTo(1)
// And that its type is what we set above.
assertThat(request.windowsList[0].windowType)
.isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_SYSTEM)
// The error message
"Something went wrong!"
}
// Act.
Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
// Assert.
// See `sendForestChecker` above.
}
@Test
fun grpcPortPositiveAndHost_shouldSendTreeViaGrpc() {
// Arrange.
fakeA11yService.sendForestChecker = { request: AndroidAccessibilityForest ->
// Check that we get only a single window.
assertThat(request.windowsList.size).isEqualTo(1)
// And that its type is what we set above.
assertThat(request.windowsList[0].windowType)
.isEqualTo(AndroidAccessibilityWindowInfo.WindowType.TYPE_ACCESSIBILITY_OVERLAY)
"" // Return no error.
}
val window = AccessibilityWindowInfo()
shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
createForwarder(
logAccessibilityTree = true,
a11yTreePeriodMs = 500,
grpcHost = "amazing.host",
grpcPort = 4321,
a11yWindows = mutableListOf(window),
)
// Act.
Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
// Assert.
// See `sendForestChecker` above.
}
@Test
fun sendingForest_grpcTimeout() {
// Arrange.
fakeA11yService.sendForestChecker = { _ ->
// Delay the request to prompt a timeout
Thread.sleep(1500L)
"" // Return no error.
}
val window = AccessibilityWindowInfo()
shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
createForwarder(
logAccessibilityTree = true,
a11yTreePeriodMs = 10,
grpcHost = "amazing.host",
grpcPort = 4321,
a11yWindows = mutableListOf(window),
)
// Act.
Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function.
// Run a second request to ensure that the channel gets rebuilt.
fakeA11yService.sendForestChecker = { _ -> "" }
Thread.sleep(2000) // Sleep a bit to give time to trigger the tree logging function.
// Assert.
// See `sendForestChecker` above.
}
@Test
fun sendingForest_grpcStatusException() {
// Arrange.
val window = AccessibilityWindowInfo()
shadowOf(window).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
createForwarder(
logAccessibilityTree = true,
a11yTreePeriodMs = 10,
grpcHost = "amazing.host",
grpcPort = 4321,
a11yWindows = mutableListOf(window),
)
fakeA11yService.sendForestChecker = { _ -> throw StatusException(Status.UNAVAILABLE) }
// Act.
Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
// Run a second request to ensure that the channel gets rebuilt.
fakeA11yService.sendForestChecker = { _ -> "" }
Thread.sleep(1000) // Sleep a bit to give time to trigger the tree logging function.
// Assert.
// See `sendForestChecker` above.
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreator.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import android.graphics.Rect
import android.util.Log
import android.view.accessibility.AccessibilityNodeInfo
import android.view.accessibility.AccessibilityWindowInfo
import com.google.androidenv.accessibilityforwarder.AndroidAccessibilityWindowInfo.WindowType
import java.util.concurrent.ConcurrentHashMap
import java.util.stream.Collectors
import kotlin.collections.mutableListOf
import kotlinx.coroutines.Deferred
import kotlinx.coroutines.async
import kotlinx.coroutines.awaitAll
import kotlinx.coroutines.runBlocking
/** Helper methods for creating the android accessibility info extra. */
class AccessibilityTreeCreator() {
/** Creates an accessibility forest proto. */
fun buildForest(windowInfos: List): AndroidAccessibilityForest {
val sourcesMap: ConcurrentHashMap =
ConcurrentHashMap()
val windows: List =
processWindowsAndBlock(windowInfos, sourcesMap)
return androidAccessibilityForest { this.windows += windows }
}
private fun processWindowsAndBlock(
windowInfos: List,
sourcesMap: ConcurrentHashMap,
): List {
val windows: List
runBlocking { windows = processWindows(windowInfos, sourcesMap) }
return windows
}
private suspend fun processWindows(
windowInfos: List,
sourcesMap: ConcurrentHashMap,
): List {
var windowInfoProtos = mutableListOf()
for (i in windowInfos.size - 1 downTo 0) {
val windowInfoProto = processWindow(windowInfos.get(i), sourcesMap)
windowInfoProto?.let { windowInfoProtos.add(windowInfoProto) }
}
return windowInfoProtos.toList()
}
private suspend fun processWindow(
windowInfo: AccessibilityWindowInfo,
sources: ConcurrentHashMap,
): AndroidAccessibilityWindowInfo? {
val bounds = Rect()
windowInfo.getBoundsInScreen(bounds)
val root: AccessibilityNodeInfo? = windowInfo.root
if (root == null) {
Log.i(TAG, "window root is null")
return androidAccessibilityWindowInfo {
this.tree = androidAccessibilityTree {}
this.isActive = windowInfo.isActive
this.id = windowInfo.id
this.layer = windowInfo.layer
this.isAccessibilityFocused = windowInfo.isAccessibilityFocused
this.isFocused = windowInfo.isFocused
this.boundsInScreen = convertToRectProto(bounds)
this.windowType = toWindowType(windowInfo.type)
}
}
val treeDeferred: Deferred
runBlocking { treeDeferred = async { processNodesInWindow(root, sources) } }
return androidAccessibilityWindowInfo {
this.tree = treeDeferred.await()
this.isActive = windowInfo.isActive
this.id = windowInfo.id
this.layer = windowInfo.layer
this.isAccessibilityFocused = windowInfo.isAccessibilityFocused
this.isFocused = windowInfo.isFocused
this.boundsInScreen = convertToRectProto(bounds)
this.windowType = toWindowType(windowInfo.type)
}
}
private suspend fun processNodesInWindow(
root: AccessibilityNodeInfo,
sources: ConcurrentHashMap,
): AndroidAccessibilityTree {
Log.d(TAG, "processNodesInWindow()")
val traversalQueue = ArrayDeque()
traversalQueue.add(ParentChildNodePair.builder().child(root).build())
val uniqueIdsCache: UniqueIdsGenerator = UniqueIdsGenerator()
var currentDepth = 0
val nodesDeferred = mutableListOf>()
val seenNodes: HashSet = HashSet()
seenNodes.add(root)
runBlocking {
while (!traversalQueue.isEmpty()) {
// Traverse the tree layer-by-layer.
// The first layer has only the root and depth 0.
// The second layer has all the root's children and depth 1.
for (nodesAtCurrentDepth in traversalQueue.size downTo 1) {
val nodePair: ParentChildNodePair = traversalQueue.removeFirst()
for (i in 0 until nodePair.child().childCount) {
val childNode: AccessibilityNodeInfo? = nodePair.child().getChild(i)
if (childNode != null && !seenNodes.contains(childNode)) {
traversalQueue.add(
ParentChildNodePair.builder().child(childNode).parent(nodePair.child()).build()
)
seenNodes.add(childNode)
}
}
val thisDepth = currentDepth
var deferred = async { processNode(nodePair, sources, uniqueIdsCache, thisDepth) }
nodesDeferred.add(deferred)
}
currentDepth++
}
}
return androidAccessibilityTree { this.nodes += nodesDeferred.awaitAll() }
}
companion object {
private const val TAG = "AndroidRLTask"
}
}
private fun processNode(
nodePair: ParentChildNodePair,
sourceBuilder: ConcurrentHashMap,
uniqueIdsCache: UniqueIdsGenerator,
nodeDepth: Int,
): AndroidAccessibilityNodeInfo {
val node: AccessibilityNodeInfo = nodePair.child()
val immutableNode: AndroidAccessibilityNodeInfo =
createAndroidAccessibilityNode(
node,
uniqueIdsCache.getUniqueId(node),
nodeDepth,
getChildUniqueIds(node, uniqueIdsCache),
)
sourceBuilder.put(immutableNode, node)
return immutableNode
}
private fun createAndroidAccessibilityNode(
node: AccessibilityNodeInfo,
nodeId: Int,
depth: Int,
childIds: List,
): AndroidAccessibilityNodeInfo {
val bounds = Rect()
node.getBoundsInScreen(bounds)
val actions = node.getActionList().stream().map(::createAction).collect(Collectors.toList())
return androidAccessibilityNodeInfo {
this.actions += actions
this.boundsInScreen = convertToRectProto(bounds)
this.isCheckable = node.isCheckable
this.isChecked = node.isChecked
this.className = stringFromNullableCharSequence(node.getClassName())
this.isClickable = node.isClickable
this.contentDescription = stringFromNullableCharSequence(node.getContentDescription())
this.isEditable = node.isEditable
this.isEnabled = node.isEnabled
this.isFocusable = node.isFocusable
this.hintText = stringFromNullableCharSequence(node.getHintText())
this.isLongClickable = node.isLongClickable
this.packageName = stringFromNullableCharSequence(node.getPackageName())
this.isPassword = node.isPassword
this.isScrollable = node.isScrollable
this.isSelected = node.isSelected
this.text = stringFromNullableCharSequence(node.getText())
this.textSelectionEnd = node.getTextSelectionEnd().toLong()
this.textSelectionStart = node.getTextSelectionStart().toLong()
this.viewIdResourceName = node.getViewIdResourceName() ?: ""
this.isVisibleToUser = node.isVisibleToUser
this.windowId = node.windowId
this.uniqueId = nodeId
this.childIds += childIds
this.drawingOrder = node.drawingOrder
this.tooltipText = stringFromNullableCharSequence(node.getTooltipText())
this.depth = depth
}
}
private fun createAction(
action: AccessibilityNodeInfo.AccessibilityAction
): AndroidAccessibilityAction =
AndroidAccessibilityAction.newBuilder()
.setId(action.id)
.setLabel(stringFromNullableCharSequence(action.label))
.build()
private fun getChildUniqueIds(
node: AccessibilityNodeInfo,
uniqueIdsCache: UniqueIdsGenerator,
): List {
val ids = mutableListOf()
for (childId in 0 until node.getChildCount()) {
val child: AccessibilityNodeInfo = node.getChild(childId) ?: continue
ids.add(uniqueIdsCache.getUniqueId(child))
}
return ids.toList()
}
fun stringFromNullableCharSequence(cs: CharSequence?): String = cs?.toString() ?: ""
fun convertToRectProto(rect: Rect) = protoRect {
left = rect.left
top = rect.top
right = rect.right
bottom = rect.bottom
}
private fun toWindowType(type: Int): WindowType =
when (type) {
AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY -> WindowType.TYPE_ACCESSIBILITY_OVERLAY
AccessibilityWindowInfo.TYPE_APPLICATION -> WindowType.TYPE_APPLICATION
AccessibilityWindowInfo.TYPE_INPUT_METHOD -> WindowType.TYPE_INPUT_METHOD
AccessibilityWindowInfo.TYPE_SYSTEM -> WindowType.TYPE_SYSTEM
AccessibilityWindowInfo.TYPE_SPLIT_SCREEN_DIVIDER -> WindowType.TYPE_SPLIT_SCREEN_DIVIDER
else -> WindowType.UNKNOWN_TYPE
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AccessibilityTreeCreatorTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import android.view.accessibility.AccessibilityNodeInfo
import android.view.accessibility.AccessibilityWindowInfo
import kotlin.test.assertEquals
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
import org.robolectric.Shadows.shadowOf
@RunWith(RobolectricTestRunner::class)
class AccessibilityTreeCreatorTest {
@Test
fun buildForest_buildsAccessibilityForestCorrectly() {
val creator = AccessibilityTreeCreator()
val forest = creator.buildForest(mutableListOf(createWindowInfo()))
assertEquals(forest.windowsCount, 1)
assertEquals(forest.getWindows(0).tree.nodesCount, 3)
var rootNode: AndroidAccessibilityNodeInfo? = null
var checkableNode: AndroidAccessibilityNodeInfo? = null
val nodes = forest.getWindows(0).tree.nodesList
for (i in nodes.size - 1 downTo 0) {
if (nodes[i].text == "root node") {
rootNode = nodes[i]
}
if (nodes[i].isCheckable == true) {
checkableNode = nodes[i]
}
}
assertEquals(rootNode?.childIdsCount, 2)
assertEquals(checkableNode?.text, "Check box")
}
@Test
fun buildForest_noRootInWindow_returnsEmptyTree() {
val creator = AccessibilityTreeCreator()
val windowInfo = AccessibilityWindowInfo.obtain()
shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
val forest = creator.buildForest(mutableListOf(windowInfo))
assertEquals(0, forest.getWindows(0).tree.nodesList.size)
}
private fun createAccessibilityNodeInfo(): AccessibilityNodeInfo {
val root = AccessibilityNodeInfo.obtain()
root.text = "root node"
root.isClickable = true
val accessibilityNodeInfo = AccessibilityNodeInfo.obtain()
accessibilityNodeInfo.viewIdResourceName = "test"
accessibilityNodeInfo.isClickable = true
accessibilityNodeInfo.isEditable = true
accessibilityNodeInfo.hintText = "Please enter your address"
shadowOf(root).addChild(accessibilityNodeInfo)
val anotherChildNode = AccessibilityNodeInfo.obtain()
anotherChildNode.isCheckable = true
anotherChildNode.text = "Check box"
shadowOf(root).addChild(anotherChildNode)
return root
}
private fun createWindowInfo(): AccessibilityWindowInfo {
val windowInfo = AccessibilityWindowInfo.obtain()
shadowOf(windowInfo).setType(AccessibilityWindowInfo.TYPE_ACCESSIBILITY_OVERLAY)
shadowOf(windowInfo).setRoot(createAccessibilityNodeInfo())
return windowInfo
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest.xml
================================================
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/AndroidManifest_lite.xml
================================================
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiver.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import android.content.BroadcastReceiver
import android.content.Context
import android.content.Intent
import android.util.Log
/** Broadcast receiver responsible for enabling or disabling flags. */
class FlagsBroadcastReceiver() : BroadcastReceiver() {
override fun onReceive(context: Context?, intent: Intent?) {
val action = intent?.action
Log.i(TAG, "Received broadcast intent with action: " + action)
when (action) {
ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS -> {
Log.i(TAG, "Enabling Accessibility Tree logging.")
LogFlags.logAccessibilityTree = true
}
ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS -> {
Log.i(TAG, "Disabling Accessibility Tree logging.")
LogFlags.logAccessibilityTree = false
}
ACTION_SET_GRPC -> {
// The Android Emulator uses 10.0.2.2 as a redirect to the workstation's IP. Most often the
// gRPC server will be running locally so it makes sense to use this as the default value.
// See https://developer.android.com/studio/run/emulator-networking#networkaddresses.
val host = intent.getStringExtra("host") ?: "10.0.2.2"
// The TCP port to connect. If <=0 gRPC is disabled.
val port = intent.getIntExtra("port", 0)
Log.i(TAG, "Setting gRPC endpoint to ${host}:${port}.")
LogFlags.grpcHost = host
LogFlags.grpcPort = port
}
else -> Log.w(TAG, "Unknown action: ${action}")
}
}
companion object {
private const val TAG = "FlagsBroadcastReceiver"
private const val ACTION_ENABLE_ACCESSIBILITY_TREE_LOGS =
"accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS"
private const val ACTION_DISABLE_ACCESSIBILITY_TREE_LOGS =
"accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS"
private const val ACTION_SET_GRPC = "accessibility_forwarder.intent.action.SET_GRPC"
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/FlagsBroadcastReceiverTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import android.content.Intent
import com.google.common.truth.Truth.assertThat
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
@RunWith(RobolectricTestRunner::class)
class FlagsBroadcastReceiverTest {
@Test
fun onReceive_nullIntent_shouldNotLogAnything() {
// Arrange.
LogFlags.logAccessibilityTree = false
val receiver = FlagsBroadcastReceiver()
// Act.
receiver.onReceive(context = null, intent = null)
// Assert.
assertThat(LogFlags.logAccessibilityTree).isFalse()
}
@Test
fun onReceive_nullIntent_actionShouldNotLogAnything() {
// Arrange.
LogFlags.logAccessibilityTree = false
val receiver = FlagsBroadcastReceiver()
val intent = Intent()
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.logAccessibilityTree).isFalse()
}
@Test
fun onReceive_unknownIntent_actionShouldIssueWarning() {
// Arrange.
LogFlags.logAccessibilityTree = false
val receiver = FlagsBroadcastReceiver()
val intent = Intent("SOME_WEIRD_ACTION")
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.logAccessibilityTree).isFalse()
}
@Test
fun onReceive_intentWithDisableAction_shouldDisableTreeLogging() {
// Arrange.
LogFlags.logAccessibilityTree = true
val receiver = FlagsBroadcastReceiver()
val intent = Intent("accessibility_forwarder.intent.action.DISABLE_ACCESSIBILITY_TREE_LOGS")
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.logAccessibilityTree).isFalse()
}
@Test
fun onReceive_intentWithEnableAction_shouldEnableTreeLogging() {
// Arrange.
LogFlags.logAccessibilityTree = false
val receiver = FlagsBroadcastReceiver()
val intent = Intent("accessibility_forwarder.intent.action.ENABLE_ACCESSIBILITY_TREE_LOGS")
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.logAccessibilityTree).isTrue()
}
@Test
fun onReceive_intentWithSetGrpcActionNoArgs_shouldDefaultToEmuIpAndPortZero() {
// Arrange.
LogFlags.grpcHost = "some_host"
LogFlags.grpcPort = 9999
val receiver = FlagsBroadcastReceiver()
val intent = Intent("accessibility_forwarder.intent.action.SET_GRPC")
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2")
assertThat(LogFlags.grpcPort).isEqualTo(0)
}
@Test
fun onReceive_intentWithSetGrpcActionWithHostNoPort_shouldDefaultPortToZero() {
// Arrange.
LogFlags.grpcHost = "some_host"
LogFlags.grpcPort = 9999
val receiver = FlagsBroadcastReceiver()
val intent =
Intent("accessibility_forwarder.intent.action.SET_GRPC").apply {
putExtra("host", "awesome.server.ca")
}
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.grpcHost).isEqualTo("awesome.server.ca")
assertThat(LogFlags.grpcPort).isEqualTo(0)
}
@Test
fun onReceive_intentWithSetGrpcActionWithPortNoHost_shouldDefaultHostToEmuIp() {
// Arrange.
LogFlags.grpcHost = "some_host"
LogFlags.grpcPort = 9999
val receiver = FlagsBroadcastReceiver()
val intent =
Intent("accessibility_forwarder.intent.action.SET_GRPC").apply { putExtra("port", 54321) }
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.grpcHost).isEqualTo("10.0.2.2")
assertThat(LogFlags.grpcPort).isEqualTo(54321)
}
@Test
fun onReceive_intentWithSetGrpcActionWithHostAndPort_shouldSetBoth() {
// Arrange.
LogFlags.grpcHost = "some_host"
LogFlags.grpcPort = 9999
val receiver = FlagsBroadcastReceiver()
val intent =
Intent("accessibility_forwarder.intent.action.SET_GRPC").apply {
putExtra("host", "grpc.ca")
putExtra("port", 54321)
}
// Act.
receiver.onReceive(context = null, intent = intent)
// Assert.
assertThat(LogFlags.grpcHost).isEqualTo("grpc.ca")
assertThat(LogFlags.grpcPort).isEqualTo(54321)
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/LogFlags.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
/**
* Controls global settings in AccessibilityForwarder.
*
* Please note that this class is not thread safe.
*/
object LogFlags {
// Whether to log the accessibility tree.
var logAccessibilityTree: Boolean = false
// How frequent to emit a11y trees (in milliseconds).
var a11yTreePeriodMs: Long = 100
// The gRPC server to connect to. (Only available if grpcPort>0).
var grpcHost: String = ""
// If >0 this represents the gRPC port number to connect to.
var grpcPort: Int = 0
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/ParentChildNodePair.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import android.view.accessibility.AccessibilityNodeInfo
import com.google.auto.value.AutoValue
/** Parent and child [AccessibilityNodeInfo] relationship. */
@AutoValue
internal abstract class ParentChildNodePair {
abstract fun parent(): AccessibilityNodeInfo?
abstract fun child(): AccessibilityNodeInfo
/** [ParentChildNodePair] builder. */
@AutoValue.Builder
abstract class Builder {
abstract fun parent(parent: AccessibilityNodeInfo?): Builder
abstract fun child(child: AccessibilityNodeInfo): Builder
abstract fun build(): ParentChildNodePair
}
companion object {
@JvmStatic fun builder(): Builder = AutoValue_ParentChildNodePair.Builder()
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/UniqueIdsGenerator.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.accessibilityforwarder
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.AtomicInteger
import java.util.function.Function
/** Thread-safe helper class for assigning a unique ID to an object. */
internal class UniqueIdsGenerator {
private val nextId = AtomicInteger(0)
private val uniqueIdsByNode = ConcurrentHashMap()
fun getUniqueId(a: A): Int {
return uniqueIdsByNode.computeIfAbsent(a, Function { _: A -> nextId.getAndIncrement() })!!
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/accessibilityforwarder/res/xml/accessibility_forwarder_service.xml
================================================
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/AndroidManifest.xml
================================================
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/BUILD.bazel
================================================
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Classic RL task implemented as an Android app.
load("@rules_android//rules:rules.bzl", "android_binary")
load("@rules_kotlin//kotlin:android.bzl", "kt_android_library")
package(
default_visibility = [":catch_packages"],
)
package_group(
name = "catch_packages",
packages = [
"//java/com/google/androidenv/catch/...",
"//javatests/com/google/androidenv/catch/...",
],
)
licenses(["notice"])
android_binary(
name = "app",
manifest = "AndroidManifest.xml",
multidex = "native",
deps = [":MainActivity"],
)
kt_android_library(
name = "GameLogic",
srcs = ["GameLogic.kt"],
deps = [
"//java/com/google/androidenv/catch/sprite:Background",
"//java/com/google/androidenv/catch/sprite:Ball",
"//java/com/google/androidenv/catch/sprite:LineSegment",
"//java/com/google/androidenv/catch/sprite:Paddle",
],
)
kt_android_library(
name = "GameLogicThread",
srcs = ["GameLogicThread.kt"],
deps = [
":GameLogic",
],
)
kt_android_library(
name = "MainActivity",
srcs = ["MainActivity.kt"],
manifest = "AndroidManifest.xml",
resource_files = glob(["res/**"]),
deps = [
":GameLogic",
":GameLogicThread",
":RenderThread",
"//java/com/google/androidenv/catch/sprite:Background",
"//java/com/google/androidenv/catch/sprite:Ball",
"//java/com/google/androidenv/catch/sprite:Paddle",
],
)
kt_android_library(
name = "RenderThread",
srcs = ["RenderThread.kt"],
deps = [
":GameLogic",
],
)
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/GameLogic.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.graphics.Canvas
import android.view.MotionEvent
import com.google.androidenv.catch.sprite.Background
import com.google.androidenv.catch.sprite.Ball
import com.google.androidenv.catch.sprite.LineSegment
import com.google.androidenv.catch.sprite.Paddle
import java.time.Duration
import java.time.Instant
import kotlin.random.Random
/** The class that contains the game logic. */
open class GameLogic(
// Expected number of frames per second.
fps: Int = 60,
// Pseudo random number generator.
private val rand: Random = Random.Default,
// Width and height of the game in pixels.
private val width: Int,
private val height: Int,
// UI objects in the game.
private var background: Background = Background(),
private var ball: Ball = Ball(maxX = width, maxY = height, rand = rand),
private var paddle: Paddle = Paddle(maxX = width, y = height),
) {
private val sleepTime: Duration = Duration.ofMillis((1000.0 / fps).toLong())
/** Reinitializes the state of the game. */
// Need to make this open to allow for testing.
open fun reset() {
this.ball.reset()
}
/** Runs one "throw" of a [ball] that needs to be caught by the [paddle]. */
// Need to make this open to allow for testing.
open fun run(): Boolean {
var lastTimestamp = Instant.now()
do {
Thread.sleep(sleepTime.toMillis())
val now = Instant.now()
val interval = Duration.between(lastTimestamp, now)
lastTimestamp = now
ball.update(interval)
} while (!ball.isOutOfBounds())
return ball.intersects(LineSegment(paddle.topLeft(), paddle.topRight()))
}
/** Processes a user event (e.g. a touchscreen event) and updates the [paddle] accordingly. */
fun handleTouch(event: MotionEvent) {
paddle.x = event.x.toInt()
}
/** Renders the game on [c]. */
open fun render(c: Canvas) {
background.draw(c)
ball.draw(c)
paddle.draw(c)
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/GameLogicThread.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.util.Log
/** A thread that continuously runs the game logic, resetting after each internal [run()]. */
class GameLogicThread(private val game: GameLogic, private val loggingTag: String) : Thread() {
/** Whether this thread should continuously run. */
private var shouldRun: Boolean = true
/** A counter of game runs. */
private var counter: Int = 0
/**
* Lets the current [run()] iteration complete then break exit this [Thread].
*
* Notice that [shouldRun] cannot have a private getter with a public setter (please see
* https://youtrack.jetbrains.com/issue/KT-3110 for details), hence this public function. Also
* notice that we cannot call this function [stop()] since it would shadow [Thread.stop()].
*/
public fun finish() {
shouldRun = false
}
/** Continuously runs the [game] until [finish()] is called. */
public override fun run() {
while (shouldRun) {
game.reset()
Log.i(loggingTag, "${counter++} - ${game.run()}")
}
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/MainActivity.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.app.Activity
import android.content.Intent
import android.graphics.Color
import android.os.Bundle
import android.util.Log
import android.view.SurfaceHolder
import android.view.SurfaceView
import android.view.View
import android.view.Window
import com.google.androidenv.catch.sprite.Background
import com.google.androidenv.catch.sprite.Ball
import com.google.androidenv.catch.sprite.Paddle
/** The activity that allows users to play the RL game of Catch. */
class MainActivity() : Activity(), SurfaceHolder.Callback {
private var surfaceView: SurfaceView? = null
private var renderThread: RenderThread? = null
private var gameLogicThread: GameLogicThread? = null
private val fps: Int = 60
private var gameCounter: Int = 0
private var width: Int = -1
private var height: Int = -1
private var extras: Bundle? = null
// [Activity] overrides.
/** Initializes the Android [View] and sets up callbacks. */
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
Log.i(TAG, "MainActivity::onCreate()")
requestWindowFeature(Window.FEATURE_NO_TITLE)
setContentView(R.layout.main)
val surface: SurfaceView? = findViewById(R.id.surfaceView)
if (surface == null) throw Exception("Could not create SurfaceView. Aborting...")
surface.visibility = View.VISIBLE
surface.holder.addCallback(this)
surfaceView = surface
extras = intent?.extras
}
override fun onNewIntent(intent: Intent?) {
super.onNewIntent(intent)
Log.i(TAG, "MainActivity::onNewIntent()")
extras = intent?.extras
startGame()
}
// [SurfaceHolder.Callback] overrides.
override fun surfaceCreated(holder: SurfaceHolder) {
Log.i(TAG, "MainActivity::surfaceCreated()")
renderThread = RenderThread(surfaceHolder = holder, fps = fps).also { it.start() }
}
override fun surfaceChanged(holder: SurfaceHolder, format: Int, width: Int, height: Int) {
Log.i(TAG, "MainActivity::surfaceChanged()")
this.width = width
this.height = height
startGame()
}
override fun surfaceDestroyed(holder: SurfaceHolder) {
Log.i(TAG, "MainActivity::surfaceDestroyed()")
renderThread?.finish()
renderThread?.join()
gameLogicThread?.finish()
gameLogicThread?.join()
}
private fun startGame() {
Log.i(TAG, "MainActivity::startGame()")
if (width <= 0 || height <= 0) {
Log.e(TAG, "MainActivity::startGame() - Width or height not initialized yet.")
return
}
val backgroundColor = Color.parseColor(extras?.getString("backgroundColor") ?: "BLACK")
val ballColor = Color.parseColor(extras?.getString("ballColor") ?: "WHITE")
val ballRadius = extras?.getFloat("ballRadius", 10.0f) ?: 10.0f
val ballSpeed = extras?.getFloat("ballSpeed", 0.2f) ?: 0.2f
val paddleColor = Color.parseColor(extras?.getString("paddleColor") ?: "WHITE")
val paddleWidth = extras?.getInt("paddleWidth", 80) ?: 80
val paddleHeight = extras?.getInt("paddleHeight", 10) ?: 10
Log.i(TAG, "MainActivity::startGame() - extras bundle: $extras")
val game =
GameLogic(
width = width,
height = height,
fps = fps,
background = Background(color = backgroundColor),
ball =
Ball(
maxX = width,
maxY = height,
color = ballColor,
radius = ballRadius,
speed = ballSpeed,
),
paddle =
Paddle(
color = paddleColor,
width = paddleWidth,
height = paddleHeight,
maxX = width,
y = (height - paddleHeight / 2),
),
)
// Stop the previous game logic thread if it's running.
gameLogicThread?.finish()
gameLogicThread?.join()
// Create and start the new GameLogicThread, passing the game instance.
gameLogicThread = GameLogicThread(game, TAG).also { it.start() }
// Pass the same game instance to the render thread.
renderThread?.game = game
surfaceView?.setOnTouchListener(
// Suppress warning for ClickableViewAccessibility since click handling
// is not within an OnTouchListener.
@SuppressWarnings("ClickableViewAccessibility")
View.OnTouchListener { _, motionEvent ->
game.handleTouch(motionEvent)
true
}
)
}
companion object {
private const val TAG = "AndroidRLTask"
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/RenderThread.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.graphics.Canvas
import android.view.SurfaceHolder
import java.time.Duration
/** A thread that continuously renders the game logic onto a surface. */
class RenderThread(private val surfaceHolder: SurfaceHolder, private val fps: Int = 60) : Thread() {
/** Whether this thread should continuously run. */
private var shouldRun: Boolean = true
/** How long to sleep at each [run()] iteration. */
private val sleepTime: Duration = Duration.ofMillis((1000.0 / fps).toLong())
/** The class responsible for issuing rendering commands to the canvas. */
var game: GameLogic? = null
/**
* Runs the current game logic [run()] to completion.
*
* Notice that [shouldRun] cannot have a private getter with a public setter (please see
* https://youtrack.jetbrains.com/issue/KT-3110 for details), hence this public function. Also
* notice that we cannot call this function [stop()] since it would shadow [Thread.stop()].
*/
public fun finish() {
shouldRun = false
}
/** Continuously renders the [game] onto [surfaceHolder]. */
public override fun run() {
while (shouldRun) {
if (surfaceHolder.surface?.isValid() ?: false) {
val c: Canvas = surfaceHolder.lockCanvas()
game?.render(c)
surfaceHolder.unlockCanvasAndPost(c)
}
Thread.sleep(sleepTime.toMillis())
}
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/res/layout/main.xml
================================================
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/res/values/strings.xml
================================================
Catch
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/sprite/BUILD.bazel
================================================
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Sprites for the app.
load("@rules_kotlin//kotlin:android.bzl", "kt_android_library")
package(
default_visibility = ["//java/com/google/androidenv/catch:catch_packages"],
)
licenses(["notice"])
kt_android_library(
name = "Background",
srcs = ["Background.kt"],
deps = [":Sprite"],
)
kt_android_library(
name = "Ball",
srcs = ["Ball.kt"],
deps = [
":LineSegment",
":Point",
":Sprite",
],
)
kt_android_library(
name = "LineSegment",
srcs = ["LineSegment.kt"],
deps = [":Point"],
)
kt_android_library(
name = "Paddle",
srcs = ["Paddle.kt"],
deps = [
":Point",
":Sprite",
],
)
kt_android_library(
name = "Point",
srcs = ["Point.kt"],
)
kt_android_library(
name = "Sprite",
srcs = ["Sprite.kt"],
)
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Background.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
import android.graphics.Color
/** Represents the static background behind all objects. */
open class Background(private val color: Int = Color.BLACK) : Sprite() {
/** Paints the canvas with the color given in the constructor. */
override fun draw(c: Canvas) {
c.drawColor(color)
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Ball.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import java.time.Duration
import kotlin.math.ceil
import kotlin.math.sqrt
import kotlin.random.Random
/** Represents a ball that travels down in space with constant speed. */
open class Ball(
private val maxX: Int,
private val maxY: Int,
private val color: Int = Color.WHITE,
private val radius: Float = 10.0f,
// `speed`'s unit is in pixels/ms.
private val speed: Float = 1.0f,
private val rand: Random = Random.Default,
) : Sprite() {
// `x` and `y` represent the position of the center of this ball.
//
// Valid range [0, maxX]. 0==left, maxX==right.
private var x: Int = rand.nextInt(maxX)
// Valid range [0, maxY]. 0==top, maxY==bottom.
private var y: Int = ceil(radius).toInt()
private val paint: Paint =
Paint(Paint.ANTI_ALIAS_FLAG).apply {
style = Paint.Style.FILL
color = (this@Ball).color
}
/** Returns `true` if this ball intersects the given line [segment]. */
fun intersects(segment: LineSegment): Boolean {
/** A vector with two components. */
data class Vector2D(val u: Int, val v: Int) {
/** Returns the dot product between two 2D vectors. */
fun dot(other: Vector2D): Int = u * other.u + v * other.v
}
/** Returns the vector representing [p] minus [q]. */
fun pointDiff(p: Point, q: Point): Vector2D = Vector2D(p.x - q.x, p.y - q.y)
val direction = pointDiff(segment.p1, segment.p0) // p0 -> p1.
val centerToP = pointDiff(segment.p0, Point(x, y)) // Ball center -> p0.
// The `(centerToP + m * direction)` function models all the points in the line segment where
// the independent variable `m` is a real number in [0,1]. Putting this function into the
// formula for the circle (x ^ 2 + y ^ 2 = radius ^ 2) gives a quadratic equation
// (am^2 + bm + c = 0) where:
// [a] = direction · direction
// [b] = 2 centerToP · direction
// [c] = centerToP · centerToP - radius ^ 2
val a = direction.dot(direction)
val b = 2 * centerToP.dot(direction)
val c = centerToP.dot(centerToP) - radius * radius
val delta = b * b - 4 * a * c
if (delta < 0)
return false // No real roots means the (infinite) line does not intersect the ball.
val d = sqrt(delta)
val m1 = (-b - d) / (2 * a)
val m2 = (-b + d) / (2 * a)
// If a root is in [0,1], the line segment intersects the circumference.
// If [m1] < 0 and [m2] > 1, the line segment is "within" the circle meaning the circle
// intersects the infinite line, but not the line segment. In this case, we consider that it
// touched the ball.
return (m1 >= 0 && m1 <= 1) || (m2 >= 0 && m2 <= 1) || (m1 < 0 && m2 > 1)
}
/** Places the ball at the top of the screen at a random x-coordinate. */
fun reset() {
x = rand.nextInt(maxX)
y = ceil(radius).toInt()
}
/** Moves the ball down by [timeDeltaMs]. */
open fun update(timeDelta: Duration) {
y += (speed * timeDelta.toMillis()).toInt()
}
/** Returns whether the ball is over [maxY]. */
fun isOutOfBounds(): Boolean = y + radius > maxY || y - radius < 0
/** Draws this ball in `c`. */
override fun draw(c: Canvas) {
c.drawCircle(x.toFloat(), y.toFloat(), radius, paint)
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/sprite/LineSegment.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
/** Represents a finite line segment in 2D connected by two points [p0] and [p1]. */
data class LineSegment(val p0: Point, val p1: Point)
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Paddle.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.graphics.Rect
import kotlin.ranges.coerceIn
/** Represents a paddle to hit/catch a falling ball. */
open class Paddle(
private val color: Int = Color.WHITE,
// Width and height in pixels.
private val width: Int = 80,
private val height: Int = 10,
// maxX is the maximum X value for the center of the paddle.
private val maxX: Int = 100,
// The vertical position of the center of this paddle in pixels.
val y: Int = 100,
) : Sprite() {
// Memoize a few things to make [draw()] a bit faster.
private val halfH = height / 2
private val halfW = width / 2
private val paint =
Paint(Paint.ANTI_ALIAS_FLAG).apply {
style = Paint.Style.FILL
color = (this@Paddle).color
}
// The horizontal center of the paddle.
var x: Int = maxX / 2 // Start in the middle.
set(value) {
field = value.coerceIn(0, maxX)
}
/** Returns the (x,y) coordinates of the top-left corner. */
fun topLeft(): Point = Point(x - halfW, y - halfH)
/** Returns the (x,y) coordinates of the top-right corner. */
fun topRight(): Point = Point(x + halfW, y - halfH)
fun move(deltaX: Int) {
x += deltaX
}
override fun draw(c: Canvas) {
val rect =
Rect().apply {
bottom = y + halfH
top = y - halfH
left = x - halfW
right = x + halfW
}
c.drawRect(rect, paint)
}
}
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Point.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
/** Represents a cartesian point in 2D. */
data class Point(val x: Int, val y: Int)
================================================
FILE: android_env/apps/java/com/google/androidenv/catch/sprite/Sprite.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
/** Represents something that can be drawn on the screen. */
open class Sprite {
/** Draws the Sprite in the given canvas. */
open fun draw(c: Canvas) {}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/AndroidManifest.xml
================================================
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/BUILD.bazel
================================================
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Tests for the Android version of the RL Catch game.
load("@rules_kotlin//kotlin:android.bzl", "kt_android_local_test")
load("@rules_kotlin//kotlin:core.bzl", "kt_kotlinc_options")
kt_kotlinc_options(
name = "kt_kotlinc_options",
jvm_target = "11", # Need to override default 1.8.
x_no_param_assertions = True,
)
kt_android_local_test(
name = "GameLogicTest",
srcs = ["GameLogicTest.kt"],
kotlinc_opts = ":kt_kotlinc_options",
deps = [
"//java/com/google/androidenv/catch:GameLogic",
"//java/com/google/androidenv/catch/sprite:Background",
"//java/com/google/androidenv/catch/sprite:Ball",
"//java/com/google/androidenv/catch/sprite:Paddle",
"@maven//:androidx_test_ext_junit",
"@maven//:androidx_test_runner",
"@maven//:com_google_truth_truth",
"@maven//:org_mockito_kotlin_mockito_kotlin",
"@maven//:org_robolectric_robolectric",
"@robolectric//bazel:android-all",
],
)
kt_android_local_test(
name = "GameLogicThreadTest",
srcs = ["GameLogicThreadTest.kt"],
kotlinc_opts = ":kt_kotlinc_options",
deps = [
"//java/com/google/androidenv/catch:GameLogic",
"//java/com/google/androidenv/catch:GameLogicThread",
"@maven//:androidx_test_ext_junit",
"@maven//:com_google_truth_truth",
"@maven//:org_mockito_kotlin_mockito_kotlin",
"@maven//:org_robolectric_robolectric",
"@robolectric//bazel:android-all",
],
)
kt_android_local_test(
name = "MainActivityTest",
srcs = [
"MainActivityTest.kt",
],
kotlinc_opts = ":kt_kotlinc_options",
manifest = "AndroidManifest.xml",
deps = [
"//java/com/google/androidenv/catch:MainActivity",
"@maven//:androidx_test_ext_junit",
"@maven//:junit_junit",
"@maven//:org_robolectric_robolectric",
"@robolectric//bazel:android-all",
],
)
kt_android_local_test(
name = "RenderThreadTest",
srcs = ["RenderThreadTest.kt"],
kotlinc_opts = ":kt_kotlinc_options",
deps = [
"//java/com/google/androidenv/catch:GameLogic",
"//java/com/google/androidenv/catch:RenderThread",
"@maven//:androidx_test_ext_junit",
"@maven//:org_mockito_kotlin_mockito_kotlin",
"@maven//:org_mockito_mockito_core",
"@maven//:org_robolectric_robolectric",
"@robolectric//bazel:android-all",
],
)
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/GameLogicTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.graphics.Canvas
import androidx.test.core.view.MotionEventBuilder
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.androidenv.catch.sprite.Background
import com.google.androidenv.catch.sprite.Ball
import com.google.androidenv.catch.sprite.Paddle
import com.google.common.truth.Truth.assertThat
import java.time.Duration
import java.time.Instant
import kotlin.random.Random
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.kotlin.any
import org.mockito.kotlin.atLeast
import org.mockito.kotlin.atMost
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.spy
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
@RunWith(AndroidJUnit4::class)
class GameLogicTest {
@Test
fun run_ballIsMissed() {
// Arrange.
val width = 123
val height = 33
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 37 }
val game =
GameLogic(
rand = mockRandom,
width = width,
height = height,
ball = Ball(maxX = width, maxY = height, radius = 5.0f, rand = mockRandom),
paddle = Paddle(maxX = width, y = height, width = 3, height = 2),
)
game.reset()
game.handleTouch(
MotionEventBuilder.newBuilder().setPointer(/* x= */ 12.0f, /* y= */ 31.0f).build()
)
// Act.
val outcome = game.run() // Ball falls at x==37, ev.x==12 so ball is missed.
// Assert.
assertThat(outcome).isEqualTo(false)
}
@Test
fun run_ballIsCaught() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 53 }
val game = GameLogic(rand = mockRandom, width = 321, height = 47)
game.reset()
game.handleTouch(
MotionEventBuilder.newBuilder().setPointer(/* x= */ 53.0f, /* y= */ 43.0f).build()
)
// Act.
val outcome = game.run() // Ball falls at x==53, ev.x==53 so ball is caught.
// Assert.
assertThat(outcome).isEqualTo(true)
}
@Test
fun run_resetAllowsMultipleGamesToBePlayedWithASingleObjectAndDoesNotHang() {
// Arrange.
val mockRandom: Random = mock()
val game = GameLogic(width = 101, height = 59, rand = mockRandom)
// Act.
repeat(17) {
game.reset()
val unused = game.run() // Ignore the outcome since we only care about run() terminating.
}
// Assert.
// [rand.nextInt()] should be called once at construction and then 17 times for [reset()].
verify(mockRandom, times(18)).nextInt(any())
}
@Test
fun run_inASeparateThread() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 23 }
val game = GameLogic(rand = mockRandom, width = 321, height = 89)
game.reset()
game.handleTouch(
MotionEventBuilder.newBuilder().setPointer(/* x= */ 23.0f, /* y= */ 29.0f).build()
)
var outcome: Boolean = false
class MyThread(val g: GameLogic, var outcome: Boolean) : Thread() {
public override fun run() {
outcome = g.run()
}
}
val someThread = MyThread(game, outcome)
// Act.
someThread.start() // Ball falls at x==23, ev.x==23 so ball is caught.
someThread.join()
// Assert.
assertThat(outcome).isEqualTo(true)
}
@Test
fun run_fpsLeadstoApproximatelyNumberOfElapsedTimeAndUpdateCalls() {
// Arrange.
val width = 123
val height = 300
val ball = spy(Ball(maxX = width, maxY = height, speed = 2.0f, radius = 1.0f))
val game = GameLogic(fps = 100, width = width, height = height, ball = ball)
game.reset()
// Act.
val start = Instant.now()
val unused = game.run()
val end = Instant.now()
// Assert.
val elapsed = Duration.between(start, end)
// The ball should take around `height / speed = 150` milliseconds to reach the bottom. Due to
// timing non-determinism, we accept values between 100 and 200.
assertThat(elapsed.toMillis()).isAtLeast(100L)
assertThat(elapsed.toMillis()).isAtMost(200L)
// At fps==100, we expect [update()] to be called every `1000 / 100 = 10` milliseconds. We
// expect [elapsed] to be around 150ms (checked above) which should be around `150 / 10 = 15`
// calls, so to account for timing non-determinism we accept between 5 and 25 calls.
verify(ball, atLeast(5)).update(any())
verify(ball, atMost(25)).update(any())
}
@Test
fun render_drawCanBeCalledMultipleTimesWithinASingleRun() {
// Arrange.
val width = 321
val height = 89
val mockCanvas: Canvas = mock()
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 23 }
val background = spy(Background())
val paddle = spy(Paddle())
val ball = spy(Ball(maxX = width, maxY = height))
val game =
GameLogic(
rand = mockRandom,
width = width,
height = height,
background = background,
ball = ball,
paddle = paddle,
)
game.reset()
game.handleTouch(
MotionEventBuilder.newBuilder().setPointer(/* x= */ 23.0f, /* y= */ 29.0f).build()
)
class MyThread(val g: GameLogic) : Thread() {
public override fun run() {
val unused = g.run()
}
}
val someThread = MyThread(game)
// Act.
someThread.start()
repeat(11) { game.render(mockCanvas) }
someThread.join()
// Assert.
verify(background, times(11)).draw(mockCanvas)
verify(ball, times(11)).draw(mockCanvas)
verify(paddle, times(11)).draw(mockCanvas)
}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/GameLogicThreadTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.util.Log
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.common.truth.Truth.assertThat
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.kotlin.atLeastOnce
import org.mockito.kotlin.mock
import org.mockito.kotlin.verify
import org.robolectric.junit.rules.ExpectedLogMessagesRule
@RunWith(AndroidJUnit4::class)
class GameLogicThreadTest {
// Rule to assert log messages, taken as a reference from MainActivityTest.kt
@get:Rule val expectedLogMessagesRule = ExpectedLogMessagesRule()
private val mockGame: GameLogic = mock()
private val testTag = "TestAndroidRLTask"
@Test
fun run_iteratesGameAndLogs() {
// Arrange
val gameLogicThread = GameLogicThread(mockGame, testTag)
// Act
gameLogicThread.start()
Thread.sleep(100) // Allow time for the thread to execute at least once.
gameLogicThread.finish()
gameLogicThread.join() // Wait for the thread to terminate.
// Assert
// Verify that the game's core methods were called at least once.
verify(mockGame, atLeastOnce()).reset()
verify(mockGame, atLeastOnce()).run()
// Expect the log message from the run() loop.
// The mock 'game.run()' returns false by default.
expectedLogMessagesRule.expectLogMessage(Log.INFO, testTag, "0 - false")
}
@Test
fun finish_stopsTheThread() {
// Arrange
val gameLogicThread = GameLogicThread(mockGame, testTag)
// Act
gameLogicThread.start()
// Let it run for a moment before stopping it.
Thread.sleep(50)
gameLogicThread.finish()
gameLogicThread.join()
// Assert
assertThat(gameLogicThread.isAlive).isFalse()
}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/MainActivityTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.content.Intent
import android.util.Log
import androidx.test.ext.junit.rules.ActivityScenarioRule
import androidx.test.ext.junit.runners.AndroidJUnit4
import java.lang.reflect.Method
import org.junit.Before
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.junit.rules.ExpectedLogMessagesRule
@RunWith(AndroidJUnit4::class)
class MainActivityTest {
@get:Rule(order = 0) val activityScenarioRule = ActivityScenarioRule(MainActivity::class.java)
@get:Rule(order = 1) val expectedLogMessagesRule = ExpectedLogMessagesRule()
@Before
fun setUp() {
expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::onCreate()")
}
@Test
fun surfaceChanged_logsStartsGame() {
activityScenarioRule.scenario.onActivity { activity ->
// Arrange.
val surfaceView = activity.findViewById(R.id.surfaceView)
val surfaceHolder = surfaceView.holder
// Act - Trigger the surfaceChanged callback with positive width and height.
activity.surfaceChanged(surfaceHolder, 0, 100, 200)
// Assert.
expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::surfaceChanged()")
expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::startGame()")
}
}
@Test
fun onNewIntent_logsStartsGame_errorsOnUninitializedWidthOrHeight() {
// Arrange.
val newIntent = Intent()
// Find the onNewIntent method using reflection
val onNewIntentMethod: Method =
MainActivity::class.java.getDeclaredMethod("onNewIntent", Intent::class.java)
// Enable access to protected method
onNewIntentMethod.isAccessible = true
activityScenarioRule.scenario.onActivity { activity ->
// Act - Invoke the onNewIntent method using reflection.
onNewIntentMethod.invoke(activity, newIntent)
// Assert.
expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::onNewIntent()")
expectedLogMessagesRule.expectLogMessage(Log.INFO, TAG, "MainActivity::startGame()")
// In this test case where we don't call surfaceChanged(), default width and height
// are -1 and should trigger this error to prevent Ball from initializing
// with invalid negative values, since nextInt() expects a positive number.
expectedLogMessagesRule.expectLogMessage(
Log.ERROR,
TAG,
"MainActivity::startGame() - Width or height not initialized yet.",
)
}
}
companion object {
private const val TAG = "AndroidRLTask"
}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/RenderThreadTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch
import android.graphics.Canvas
import android.view.Surface
import android.view.SurfaceHolder
import androidx.test.ext.junit.runners.AndroidJUnit4
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.verifyNoInteractions
import org.mockito.kotlin.any
import org.mockito.kotlin.atLeast
import org.mockito.kotlin.atMost
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.mock
import org.mockito.kotlin.verify
@RunWith(AndroidJUnit4::class)
class RenderThreadTest {
@Test
fun run_finishBeforeStartResultsInNoRendering() {
// Arrange.
val surfaceHolder: SurfaceHolder = mock()
val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 1000)
val game: GameLogic = mock()
renderThread.game = game
// Act.
renderThread.finish()
renderThread.start()
// Assert.
verifyNoInteractions(game)
verifyNoInteractions(surfaceHolder)
}
@Test
fun run_startResultsInSomeRendering() {
// Arrange.
val canvas: Canvas = mock()
val surface: Surface = mock() { on { isValid() } doReturn true }
val surfaceHolder: SurfaceHolder =
mock() {
on { getSurface() } doReturn surface
on { lockCanvas() } doReturn canvas
}
val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 1000)
val game: GameLogic = mock()
renderThread.game = game
// Act.
renderThread.start()
Thread.sleep(/* millis= */ 500) // Sleep for at least one loop iteration.
renderThread.finish()
// Assert.
verify(surfaceHolder, atLeast(1)).surface
verify(surfaceHolder, atLeast(1)).lockCanvas()
verify(surfaceHolder, atLeast(1)).unlockCanvasAndPost(any())
verify(game, atLeast(1)).render(canvas)
}
@Test
fun run_finishStopsRendering() {
// Arrange.
val canvas: Canvas = mock()
val surface: Surface = mock() { on { isValid() } doReturn true }
val surfaceHolder: SurfaceHolder =
mock() {
on { getSurface() } doReturn surface
on { lockCanvas() } doReturn canvas
}
val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 20)
val game: GameLogic = mock()
renderThread.game = game
// Act.
renderThread.start()
Thread.sleep(/* millis= */ 500) // Sleep for around 10 iterations
renderThread.finish()
Thread.sleep(/* millis= */ 500) // Sleep some more to ensure nothing runs after.
// Assert.
verify(surfaceHolder, atLeast(1)).surface
verify(surfaceHolder, atLeast(1)).lockCanvas()
verify(surfaceHolder, atLeast(1)).unlockCanvasAndPost(any())
// We expect [game.render()] to be executed for around 500 / (1000 / 20 = 50) = 10 times. To
// allow for some timing non-determinism we allow it to execute up to 15 times, but not more
// than that since [renderThread.finish()] should stop the thread from calling it.
verify(game, atLeast(1)).render(canvas)
verify(game, atMost(15)).render(canvas)
}
@Test
fun run_expectedFramesPerSecond() {
// Arrange.
val canvas: Canvas = mock()
val surface: Surface = mock() { on { isValid() } doReturn true }
val surfaceHolder: SurfaceHolder =
mock() {
on { getSurface() } doReturn surface
on { lockCanvas() } doReturn canvas
}
val renderThread = RenderThread(surfaceHolder = surfaceHolder, fps = 5)
val game: GameLogic = mock()
renderThread.game = game
// Act.
renderThread.start()
Thread.sleep(/* millis= */ 2000) // Sleep for around 10 loop iterations.
renderThread.finish()
// Assert.
verify(surfaceHolder, atLeast(1)).surface
verify(surfaceHolder, atLeast(1)).lockCanvas()
verify(surfaceHolder, atLeast(1)).unlockCanvasAndPost(any())
// We expect [game.render()] to be called around 2000ms / 5fps = 10 times but to account for
// timing non-determinism we allow ±4 iterations.
verify(game, atLeast(6)).render(canvas)
verify(game, atMost(14)).render(canvas)
}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/BUILD.bazel
================================================
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Unit tests for Sprites in Catch.
load("@rules_kotlin//kotlin:android.bzl", "kt_android_local_test")
load("@rules_kotlin//kotlin:core.bzl", "kt_kotlinc_options")
kt_kotlinc_options(
name = "kt_kotlinc_options",
jvm_target = "11", # Need to override default 1.8.
x_no_param_assertions = True,
)
kt_android_local_test(
name = "BackgroundTest",
srcs = ["BackgroundTest.kt"],
kotlinc_opts = ":kt_kotlinc_options",
deps = [
"//java/com/google/androidenv/catch/sprite:Background",
"@maven//:com_google_guava_guava",
"@maven//:com_google_testparameterinjector_test_parameter_injector",
"@maven//:org_mockito_kotlin_mockito_kotlin",
"@maven//:org_yaml_snakeyaml",
],
)
kt_android_local_test(
name = "BallTest",
srcs = ["BallTest.kt"],
kotlinc_opts = ":kt_kotlinc_options",
tags = ["robolectric"],
deps = [
"//java/com/google/androidenv/catch/sprite:Ball",
"//java/com/google/androidenv/catch/sprite:LineSegment",
"//java/com/google/androidenv/catch/sprite:Point",
"@maven//:androidx_test_ext_junit",
"@maven//:com_google_guava_guava",
"@maven//:com_google_truth_truth",
"@maven//:org_mockito_kotlin_mockito_kotlin",
"@maven//:org_robolectric_robolectric",
"@robolectric//bazel:android-all",
],
)
kt_android_local_test(
name = "PaddleTest",
srcs = ["PaddleTest.kt"],
kotlinc_opts = ":kt_kotlinc_options",
tags = ["robolectric"],
deps = [
"//java/com/google/androidenv/catch/sprite:Paddle",
"//java/com/google/androidenv/catch/sprite:Point",
"@maven//:androidx_test_ext_junit",
"@maven//:com_google_guava_guava",
"@maven//:com_google_truth_truth",
"@maven//:org_mockito_kotlin_mockito_kotlin",
"@maven//:org_robolectric_robolectric",
"@robolectric//bazel:android-all",
],
)
kt_android_local_test(
name = "SpriteTest",
srcs = ["SpriteTest.kt"],
kotlinc_opts = ":kt_kotlinc_options",
deps = [
"//java/com/google/androidenv/catch/sprite:Sprite",
"@maven//:org_mockito_kotlin_mockito_kotlin",
"@maven//:org_mockito_mockito_core",
],
)
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/BackgroundTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
import android.graphics.Color
import com.google.testing.junit.testparameterinjector.KotlinTestParameters.testValues
import com.google.testing.junit.testparameterinjector.TestParameter
import com.google.testing.junit.testparameterinjector.TestParameterInjector
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.kotlin.mock
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
@RunWith(TestParameterInjector::class)
class BackgroundTest {
@Test
fun draw_defaultConstructorIsBlack() {
// Arrange.
val mockCanvas: Canvas = mock()
val background: Background = Background()
// Act.
background.draw(mockCanvas)
// Assert.
verify(mockCanvas, times(1)).drawColor(Color.BLACK)
}
@Test
fun draw_customColors(
@TestParameter colorInt: Int = testValues(0, 255, 13_579, 2_468, 12_384_173)
) {
// Arrange.
val mockCanvas: Canvas = mock()
val background: Background = Background(color = colorInt)
// Act.
background.draw(mockCanvas)
// Assert.
verify(mockCanvas, times(1)).drawColor(colorInt)
}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/BallTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.common.truth.Truth.assertThat
import java.time.Duration
import kotlin.random.Random
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Suite
import org.mockito.kotlin.any
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.doReturn
import org.mockito.kotlin.eq
import org.mockito.kotlin.mock
import org.mockito.kotlin.verify
import org.robolectric.ParameterizedRobolectricTestRunner
@RunWith(Suite::class)
@Suite.SuiteClasses(
BallTest.UpdateAndResetTests::class,
BallTest.ColorIntTest::class,
BallTest.CheckBoundsTest::class,
BallTest.IntersectsTest::class,
)
class BallTest {
@RunWith(AndroidJUnit4::class)
class UpdateAndResetTests() {
@Test
fun isOutOfBounds_initialState_isFalse() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) {
assertThat(isOutOfBounds()).isEqualTo(false)
}
}
@Test
fun isOutOfBounds_initialState_isTrueIfRadiusExceedsMaxY() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 100, maxY = 10, radius = 11.0f, speed = 1.0f, rand = mockRandom)) {
assertThat(isOutOfBounds()).isEqualTo(true)
}
}
@Test
fun isOutOfBounds_initialState_isFalseIfRadiusExceedsOnlyMaxX() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 10, maxY = 100, radius = 11.0f, speed = 1.0f, rand = mockRandom)) {
assertThat(isOutOfBounds()).isEqualTo(false)
}
}
@Test
fun update_zeroDurationDoesNotMove_withinBounds() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) {
// Act.
update(Duration.ofMillis(0)) // The ball should not move.
// Assert.
assertThat(isOutOfBounds()).isEqualTo(false) // It should still be within the bounds.
}
}
@Test
fun update_zeroDurationDoesNotMove_outOfBounds() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) {
update(Duration.ofMillis(110)) // Place the ball out of bounds.
assertThat(isOutOfBounds()).isEqualTo(true)
// Act.
update(Duration.ofMillis(0)) // The ball should not move.
// Assert.
assertThat(isOutOfBounds()).isEqualTo(true) // It should still be out of bounds.
}
}
@Test
fun update_negativeDurationsMovesUp() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) {
update(Duration.ofMillis(30)) // Move the ball down 30 pixels.
assertThat(isOutOfBounds()).isEqualTo(false)
// Act.
update(Duration.ofMillis(-50)) // Move the ball _up_ 50 pixels.
// Assert.
assertThat(isOutOfBounds()).isEqualTo(true) // Now it should be out-of-bounds.
}
}
@Test
fun update_singleThrow() {
// Ensures that a complete throw of a ball with radius==3.0f and maxY=100 behaves as expected.
// [isOutOfBounds()] should return [false] for the first (100-3.0f-3.0f)=94 [update()] calls,
// but [true] afterwards.
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) {
// Act.
repeat(94) {
update(Duration.ofMillis(1))
assertThat(isOutOfBounds()).isEqualTo(false)
}
update(Duration.ofMillis(1))
// Assert.
assertThat(isOutOfBounds()).isEqualTo(true)
}
}
@Test
fun intersects_afterUpdate() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
// Act & Assert.
with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) {
assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))).isEqualTo(true)
update(Duration.ofMillis(1))
assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))).isEqualTo(false)
}
}
@Test
fun reset_intersectsInitialPositionShouldBeTrue() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
with(Ball(maxX = 100, maxY = 100, radius = 3.0f, speed = 1.0f, rand = mockRandom)) {
// Act.
assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0)))).isEqualTo(true)
update(Duration.ofMillis(1)) // Move the ball 1 pixels down.
assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0))))
.isEqualTo(false) // Segment is now outside of the ball.
reset() // Resetting should move the ball up again.
// Assert.
assertThat(intersects(LineSegment(Point(40, 0), Point(60, 0))))
.isEqualTo(true) // Segment is now inside of the ball.
}
}
@Test
fun reset_differentInitialXCoordinates() {
// Arrange.
val ball: Ball = Ball(maxX = 100, maxY = 100, radius = 3.0f)
// Act.
var pointInside: Boolean = false
var pointOutside: Boolean = false
while (!pointInside || !pointOutside) {
if (ball.intersects(LineSegment(Point(45, 0), Point(55, 0)))) {
pointInside = true
} else {
pointOutside = true
}
ball.reset() // Sample a new initial position for the ball.
}
// Assert.
// Eventually after many initial positions the ball should satisfy both conditions.
assertThat(pointInside).isEqualTo(true)
assertThat(pointOutside).isEqualTo(true)
}
}
@RunWith(ParameterizedRobolectricTestRunner::class)
class ColorIntTest(private val c: Int) {
@Test
fun draw_customBallColors() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 37 }
val mockCanvas: Canvas = mock()
val paintCaptor = argumentCaptor()
val ball: Ball = Ball(maxX = 50, maxY = 80, radius = 1.23f, color = c, rand = mockRandom)
// Act.
ball.draw(mockCanvas)
// Assert.
verify(mockCanvas).drawCircle(eq(37.0f), eq(2.0f), eq(1.23f), paintCaptor.capture())
with(paintCaptor.lastValue) {
assertThat(color).isEqualTo(c)
assertThat(style).isEqualTo(Paint.Style.FILL)
}
}
companion object {
@JvmStatic
@ParameterizedRobolectricTestRunner.Parameters(name = "color = {0}")
fun parameters() = listOf(0, 255, -1, 13579, 2468, 12384173, Color.WHITE, Color.BLUE)
}
}
@RunWith(ParameterizedRobolectricTestRunner::class)
class CheckBoundsTest(private val p: ParamPack) {
@Test
fun intersects_checkBounds() {
// Arrange.
val mockRandom: Random =
mock() { on { nextInt(any()) } doReturn p.maxX / 2 } // Horizontal middle.
// Act.
val ball: Ball = Ball(maxX = p.maxX, maxY = p.maxY, radius = p.radius, rand = mockRandom)
// Assert.
assertThat(ball.intersects(LineSegment(Point(p.x - 1, p.y), Point(p.x + 1, p.y))))
.isEqualTo(p.expected)
}
data class ParamPack(
val maxX: Int,
val maxY: Int,
val radius: Float,
val x: Int,
val y: Int,
val expected: Boolean,
)
companion object {
@JvmStatic
@ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}")
fun parameters() =
listOf(
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 0,
y = 0,
expected = false,
), // Ball to the right of `x`.
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 39,
y = 0,
expected = false,
), // Ball to the right of `x`.
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 40,
y = 10,
expected = true,
), // Ball contains `x`.
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 50,
y = 0,
expected = true,
), // Ball contains `x`.
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 60,
y = 10,
expected = true,
), // Ball contains `x`.
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 61,
y = 0,
expected = false,
), // Ball to the left of `x`.
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 100,
y = 0,
expected = false,
), // Ball to the left of `x`.
ParamPack(
maxX = 100,
maxY = 100,
radius = 10.0f,
x = 50,
y = 21,
expected = false,
), // Ball above `y`.
)
}
}
@RunWith(ParameterizedRobolectricTestRunner::class)
class IntersectsTest(private val p: ParamPack) {
@Test
fun intersects_ballAtx50y10radius10() {
// Arrange.
val mockRandom: Random = mock() { on { nextInt(any()) } doReturn 50 } // Horizontal middle.
// Act.
val ball: Ball = Ball(maxX = 100, maxY = 100, radius = 10.0f, rand = mockRandom)
// Assert.
assertThat(ball.intersects(p.segment)).isEqualTo(p.expected)
}
data class ParamPack(val segment: LineSegment, val expected: Boolean)
companion object {
@JvmStatic
@ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}")
fun parameters() =
listOf(
ParamPack(
segment = LineSegment(Point(50, 10), Point(80, 40)),
expected = true,
), // Segment that starts at the center of the ball so it should always intersect.
ParamPack(
segment = LineSegment(Point(49, 0), Point(51, 0)),
expected = true,
), // Tangential segment that touches the bottom of the ball.
ParamPack(
segment = LineSegment(Point(40, 5), Point(65, 7)),
expected = true,
), // Segment longer than diameter, touching the circumference twice.
ParamPack(
segment = LineSegment(Point(42, 2), Point(58, 1)),
expected = true,
), // Segment shorter than diameter, touching the circumference twice.
ParamPack(
segment = LineSegment(Point(44, 4), Point(54, 3)),
expected = true,
), // Segment shorter than diameter, fully inside the circle, not touching the
// circumference.
ParamPack(
segment = LineSegment(Point(35, 4), Point(54, 3)),
expected = true,
), // Segment that touches the circumference once "from the left".
ParamPack(
segment = LineSegment(Point(54, 7), Point(67, 13)),
expected = true,
), // Segment that touches the circumference once "from the right".
ParamPack(
segment = LineSegment(Point(36, 7), Point(45, 0)),
expected = false,
), // Segment "to the left of the ball". No intersection.
ParamPack(
segment = LineSegment(Point(58, -3), Point(60, 3)),
expected = false,
), // Segment "to the right of the ball". No intersection.
)
}
}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/PaddleTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
import android.graphics.Color
import android.graphics.Paint
import android.graphics.Rect
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.common.truth.Truth.assertThat
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.Suite
import org.mockito.kotlin.argumentCaptor
import org.mockito.kotlin.mock
import org.mockito.kotlin.verify
import org.robolectric.ParameterizedRobolectricTestRunner
@RunWith(Suite::class)
@Suite.SuiteClasses(
PaddleTest.ConstructorTests::class,
PaddleTest.MoveTests::class,
PaddleTest.XSetterTests::class,
PaddleTest.DrawTests::class,
)
class PaddleTest {
@RunWith(AndroidJUnit4::class)
class ConstructorTests() {
@Test
fun x_initialValueShouldBeAtCenter() {
with(Paddle(maxX = 30)) { assertThat(x).isEqualTo(15) }
with(Paddle(maxX = 31)) { assertThat(x).isEqualTo(15) }
}
@Test
fun topLeft_correspondsToGivenValues() {
with(Paddle(width = 10, height = 6, maxX = 40, y = 33)) {
assertThat(topLeft()).isEqualTo(Point(x = 15, y = 30))
}
}
@Test
fun topRight_correspondsToGivenValues() {
with(Paddle(width = 10, height = 6, maxX = 40, y = 33)) {
assertThat(topRight()).isEqualTo(Point(x = 25, y = 30))
}
}
}
@RunWith(ParameterizedRobolectricTestRunner::class)
class MoveTests(private val p: ParamPack) {
@Test
fun move_expectedDestination() {
// Arrange.
with(Paddle(maxX = 50)) {
// Act.
move(deltaX = p.displacement)
// Assert.
assertThat(x).isEqualTo(p.expectedX)
}
}
data class ParamPack(val displacement: Int, val expectedX: Int)
companion object {
@JvmStatic
@ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}")
fun parameters() =
listOf(
// Initial position is x==25.
ParamPack(displacement = 10, expectedX = 35),
ParamPack(displacement = -10, expectedX = 15),
ParamPack(displacement = 0, expectedX = 25),
// Going beyond the left and right walls should clamp the values to 0 and 50.
ParamPack(displacement = -26, expectedX = 0),
ParamPack(displacement = 26, expectedX = 50),
)
}
}
@RunWith(ParameterizedRobolectricTestRunner::class)
class XSetterTests(private val p: ParamPack) {
@Test
fun xSetter_expectedDestination() {
// Arrange.
with(Paddle(maxX = 50)) {
// Act.
x = p.target
// Assert.
assertThat(x).isEqualTo(p.expectedX)
}
}
data class ParamPack(val target: Int, val expectedX: Int)
companion object {
@JvmStatic
@ParameterizedRobolectricTestRunner.Parameters(name = "param = {0}")
fun parameters() =
listOf(
// Initial position is x==25.
ParamPack(target = 0, expectedX = 0),
ParamPack(target = 15, expectedX = 15),
ParamPack(target = 25, expectedX = 25),
ParamPack(target = 35, expectedX = 35),
ParamPack(target = 50, expectedX = 50),
// Going beyond the left and right walls should clamp the values to 0 and 50.
ParamPack(target = -1, expectedX = 0),
ParamPack(target = 51, expectedX = 50),
)
}
}
@RunWith(AndroidJUnit4::class)
class DrawTests() {
@Test
fun draw_initialPosition() {
// Arrange.
val mockCanvas: Canvas = mock()
val rectCaptor = argumentCaptor()
val paintCaptor = argumentCaptor()
with(Paddle(color = Color.RED, width = 100, height = 20, maxX = 300, y = 400)) {
// Act.
draw(mockCanvas)
// Assert.
assertThat(x).isEqualTo(150)
verify(mockCanvas).drawRect(rectCaptor.capture(), paintCaptor.capture())
with(rectCaptor.lastValue) {
assertThat(bottom).isEqualTo(400 + 10)
assertThat(top).isEqualTo(400 - 10)
assertThat(left).isEqualTo(150 - 50)
assertThat(right).isEqualTo(150 + 50)
}
}
}
@Test
fun draw_afterMove() {
// Arrange.
val mockCanvas: Canvas = mock()
val rectCaptor = argumentCaptor()
val paintCaptor = argumentCaptor()
with(Paddle(color = Color.RED, width = 100, height = 20, maxX = 300, y = 400)) {
// Act.
move(50)
draw(mockCanvas)
// Assert.
assertThat(x).isEqualTo(200)
verify(mockCanvas).drawRect(rectCaptor.capture(), paintCaptor.capture())
with(rectCaptor.lastValue) {
assertThat(bottom).isEqualTo(400 + 10)
assertThat(top).isEqualTo(400 - 10)
assertThat(left).isEqualTo(200 - 50)
assertThat(right).isEqualTo(200 + 50)
}
with(paintCaptor.lastValue) {
assertThat(color).isEqualTo(Color.RED)
assertThat(style).isEqualTo(Paint.Style.FILL)
}
}
}
}
}
================================================
FILE: android_env/apps/javatests/com/google/androidenv/catch/sprite/SpriteTest.kt
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.androidenv.catch.sprite
import android.graphics.Canvas
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.mockito.Mockito.verifyNoInteractions
import org.mockito.kotlin.mock
import org.mockito.kotlin.times
import org.mockito.kotlin.verify
/** Trivial tests to ensure the types in the API are correct. */
@RunWith(JUnit4::class)
class SpriteTest {
@Test
fun defaultImplementationDoesNothing() {
// Arrange.
val mockCanvas: Canvas = mock()
val sprite = Sprite()
// Act.
sprite.draw(mockCanvas)
// Assert.
verifyNoInteractions(mockCanvas) // No methods should be called on the canvas.
}
@Test
fun draw_argumentsAreForwarded() {
// Arrange.
val mockSprite: Sprite = mock()
val mockCanvas: Canvas = mock()
// Act.
mockSprite.draw(mockCanvas)
// Assert.
verify(mockSprite, times(1)).draw(mockCanvas)
}
}
================================================
FILE: android_env/components/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/components/action_fns.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions to convert actions between different components' formats."""
from absl import logging
from android_env.components import action_type as action_type_lib
from android_env.components import errors
from android_env.components import pixel_fns
from android_env.components.simulators import base_simulator
import numpy as np
def send_action_to_simulator(
action: dict[str, np.ndarray],
simulator: base_simulator.BaseSimulator,
screen_width: int,
screen_height: int,
num_fingers: int,
) -> bool:
"""Sends the selected action to the given simulator.
The simulator will interpret the action according to `action["action_type"]`.
The effect this action triggers in the Android OS will be determined by the
currently running application.
Args:
action: action which will get interpreted as a touchscreen event.
simulator: The simulator that will receive the action.
screen_width: The width of the touchscreen in pixels.
screen_height: The height of the touchscreen in pixels.
num_fingers: The number of fingers used in this simulator.
"""
try:
match action['action_type']:
# If the action is a TOUCH or LIFT, send a touch event to the simulator.
case action_type_lib.ActionType.TOUCH | action_type_lib.ActionType.LIFT:
prepared_action = _prepare_touch_action(
action, screen_width, screen_height, num_fingers
)
simulator.send_touch(prepared_action)
# If the action is a key event, send a key event to the simulator.
case action_type_lib.ActionType.KEYDOWN:
simulator.send_key(action['keycode'].item(0), event_type='keydown')
case action_type_lib.ActionType.KEYUP:
simulator.send_key(action['keycode'].item(0), event_type='keyup')
case action_type_lib.ActionType.KEYPRESS:
simulator.send_key(action['keycode'].item(0), event_type='keypress')
except errors.SendActionError:
logging.exception('Unable to execute action: %r', action)
return False
return True
def _prepare_touch_action(
action: dict[str, np.ndarray],
screen_width: int,
screen_height: int,
num_fingers: int,
) -> list[tuple[int, int, bool, int]]:
"""Turns an AndroidEnv action into values that the simulator can interpret.
Converts float-valued 'touch_position' to integer coordinates corresponding
to specific pixels, and 'action_type' to booleans indicating whether the
screen is touched at said location or not. The result of this function can
be sent directly to the underlying simulator (e.g. the Android Emulator,
virtual machine, or a phone).
Args:
action: An action containing 'action_type' and 'touch_position'.
Returns:
A tuple with the format (x: int, y: int, down/up: bool, finger_index: int).
"""
touch_events = []
for i, finger_action in enumerate(_split_touch_action(action, num_fingers)):
is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH
touch_position = finger_action['touch_position']
touch_pixels = pixel_fns.touch_position_to_pixel_position(
touch_position, width_height=(screen_width, screen_height)
)
touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i))
return touch_events
def _split_touch_action(
action: dict[str, np.ndarray], num_fingers: int
) -> list[dict[str, np.ndarray]]:
"""Splits a multitouch action into a list of single-touch actions."""
single_touch_actions = [{
'action_type': action['action_type'],
'touch_position': action['touch_position'],
}]
for i in range(2, num_fingers + 1):
single_touch_actions.append({
'action_type': action[f'action_type_{i}'],
'touch_position': action[f'touch_position_{i}'],
})
return single_touch_actions
def lift_all_fingers_action(num_fingers: int) -> dict[str, np.ndarray]:
"""A lift action with each finger."""
# There's always at least one finger.
lift_action = {
'action_type': np.array(action_type_lib.ActionType.LIFT),
'touch_position': np.array([0, 0]),
}
# Subsequent fingers have separate dict entries.
for i in range(2, num_fingers + 1):
lift_action |= {
f'action_type_{i}': np.array(action_type_lib.ActionType.LIFT),
f'touch_position_{i}': np.array([0, 0]),
}
return lift_action
================================================
FILE: android_env/components/action_fns_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import action_fns
from android_env.components import action_type as action_type_lib
from android_env.components import errors
from android_env.components.simulators import base_simulator
import numpy as np
class ActionFnsTest(parameterized.TestCase):
def test_send_action_to_simulator_missing_action_type(self):
"""A `KeyError` should be raised if the action is missing "action_type"."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
action = {'some_key': np.array(123, np.int32)}
# Act & Assert.
self.assertRaises(
KeyError,
action_fns.send_action_to_simulator,
action,
simulator,
800,
600,
1,
)
def test_send_action_to_simulator_sendactionerror(self):
"""Returns `False` if the simulator raises a SendActionError."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
simulator.send_touch.side_effect = errors.SendActionError('oops!')
action = {
'action_type': action_type_lib.ActionType.TOUCH,
'touch_position': np.array([0.3, 0.5], np.float32),
}
# Act.
output = action_fns.send_action_to_simulator(
action,
simulator,
800,
600,
1,
)
# Assert.
self.assertFalse(output)
simulator.send_touch.assert_called_once()
def test_send_action_to_simulator_touch_success_one_finger(self):
"""Returns `True` with a proper 1-finger touch action."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
action = {
'action_type': action_type_lib.ActionType.TOUCH,
'touch_position': np.array([0.2, 0.5], np.float32),
}
# Act.
output = action_fns.send_action_to_simulator(
action,
simulator,
800,
600,
1,
)
# Assert.
self.assertTrue(output)
simulator.send_touch.assert_called_once_with(
[(np.int32(160), np.int32(300), True, 0)]
)
def test_send_action_to_simulator_touch_success_multiple_finger(self):
"""Returns `True` with a proper 3-finger touch action."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
action = {
'action_type': action_type_lib.ActionType.TOUCH,
'touch_position': np.array([0.2, 0.5], np.float32),
'action_type_2': action_type_lib.ActionType.LIFT,
'touch_position_2': np.array([0.1, 0.2], np.float32),
'action_type_3': action_type_lib.ActionType.TOUCH,
'touch_position_3': np.array([0.5, 0.2], np.float32),
}
# Act.
output = action_fns.send_action_to_simulator(
action,
simulator,
800,
600,
3,
)
# Assert.
self.assertTrue(output)
simulator.send_touch.assert_called_once_with([
(np.int32(160), np.int32(300), True, 0),
(np.int32(80), np.int32(120), False, 1),
(np.int32(400), np.int32(120), True, 2),
])
def test_send_action_to_simulator_keydown_success(self):
"""Returns `True` with a proper keydown action."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
action = {
'action_type': action_type_lib.ActionType.KEYDOWN,
'keycode': np.array([21], np.int32),
}
# Act.
output = action_fns.send_action_to_simulator(
action,
simulator,
800,
600,
1,
)
# Assert.
self.assertTrue(output)
simulator.send_key.assert_called_once_with(21, event_type='keydown')
def test_send_action_to_simulator_keyup_success(self):
"""Returns `True` with a proper keyup action."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
action = {
'action_type': action_type_lib.ActionType.KEYUP,
'keycode': np.array([42], np.int32),
}
# Act.
output = action_fns.send_action_to_simulator(
action,
simulator,
800,
600,
1,
)
# Assert.
self.assertTrue(output)
simulator.send_key.assert_called_once_with(42, event_type='keyup')
def test_send_action_to_simulator_keypress_success(self):
"""Returns `True` with a proper keypress action."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
action = {
'action_type': action_type_lib.ActionType.KEYPRESS,
'keycode': np.array([96], np.int32),
}
# Act.
output = action_fns.send_action_to_simulator(
action,
simulator,
800,
600,
1,
)
# Assert.
self.assertTrue(output)
simulator.send_key.assert_called_once_with(96, event_type='keypress')
@parameterized.named_parameters(
(
'one_finger',
1,
{
'action_type': np.array(action_type_lib.ActionType.LIFT),
'touch_position': np.array([0, 0]),
},
),
(
'two_fingers',
2,
{
'action_type': np.array(action_type_lib.ActionType.LIFT),
'touch_position': np.array([0, 0]),
'action_type_2': np.array(action_type_lib.ActionType.LIFT),
'touch_position_2': np.array([0, 0]),
},
),
)
def test_lift_all_fingers_action(
self, num_fingers: int, expected_action: dict[str, np.ndarray]
):
"""Returns the expected action."""
output = action_fns.lift_all_fingers_action(num_fingers)
for k, v in expected_action.items():
np.testing.assert_array_equal(v, output[k])
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/action_type.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The different kinds of actions that AndroidEnv supports.
The native action space of AndroidEnv consists of a tuple consisting of
- A position (x, y) ∈ [0, 1] x [0, 1], determining the location of the action on
the screen, and
- A discrete value, indicating the action type, which is in this file.
See https://arxiv.org/abs/2105.13231, section 2.2 for details.
"""
import enum
@enum.unique
class ActionType(enum.IntEnum):
"""Integer values to describe each supported action in AndroidEnv.
Note for KEY* types:
- Only meaningful if connected to a _physical_ keyboard, _not_ virtual
keyboard.
- Added afterwards so they did not appear in the paper.
Attributes:
TOUCH: Touching the screen at a location.
LIFE: Lifting the (imaginary) pointer from the screen at a location.
REPEAT: Repeating the last chosen action.
KEYDOWN: Sending a key down event.
KEYUP: Sending a key up event.
KEYPRESS: Sending a key down event, immediately followed by a key up event.
"""
TOUCH = 0
LIFT = 1
REPEAT = 2
KEYDOWN = 3
KEYUP = 4
KEYPRESS = 5
================================================
FILE: android_env/components/adb_call_parser.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Processes adb_pb2.AdbRequest commands."""
import os
import re
import subprocess
import sys
import tempfile
from absl import logging
from android_env.components import adb_controller as adb_control
from android_env.proto import adb_pb2
# A mapping from a Button enum to keycode strings.
#
# Please see https://developer.android.com/reference/android/view/KeyEvent
#
# We currently only accept the following entries:
_BUTTON_TO_KEYCODE = {
adb_pb2.AdbRequest.PressButton.Button.HOME: 'KEYCODE_HOME',
adb_pb2.AdbRequest.PressButton.Button.BACK: 'KEYCODE_BACK',
adb_pb2.AdbRequest.PressButton.Button.ENTER: 'KEYCODE_ENTER',
}
class AdbCallParser:
"""Parses AdbRequest messages and executes corresponding adb commands."""
def __init__(self, adb_controller: adb_control.AdbController):
self._adb_controller = adb_controller
self._handlers = {
'install_apk': self._install_apk,
'start_activity': self._start_activity,
'force_stop': self._force_stop,
'tap': self._tap,
'press_button': self._press_button,
'start_screen_pinning': self._start_screen_pinning,
'send_broadcast': self._send_broadcast,
'uninstall_package': self._handle_uninstall_package,
'get_current_activity': self._get_current_activity,
'get_orientation': self._get_orientation,
'push': self._push,
'pull': self._pull,
'input_text': self._input_text,
'settings': self._handle_settings,
'generic': self._handle_generic,
'package_manager': self._handle_package_manager,
'dumpsys': self._handle_dumpsys,
}
def _execute_command(
self, command_args: list[str], timeout: float | None
) -> tuple[adb_pb2.AdbResponse, bytes]:
"""Executes the command, catches errors and populates the response status.
Args:
command_args: a list of arguments for the ADB request.
timeout: Timeout in seconds.
Returns:
A tuple of the AdbResponse with the status populated, and the output
bytes from the command.
"""
response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
command_output = b''
try:
command_output = self._adb_controller.execute_command(
command_args, timeout=timeout)
except subprocess.CalledProcessError as adb_error:
if adb_error.stdout is not None:
response.status = adb_pb2.AdbResponse.Status.ADB_ERROR
response.error_message = adb_error.stdout
except subprocess.TimeoutExpired:
response.status = adb_pb2.AdbResponse.Status.TIMEOUT
response.error_message = 'Timeout'
return response, command_output
def parse(self, request: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
"""Executes `request` and returns an appropriate response."""
response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
command_type = request.WhichOneof('command')
logging.debug('AdbRequest command type: %s', command_type)
if command_type is None:
response.status = adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND
response.error_message = 'AdbRequest.command is None.'
return response
if request.timeout_sec < 0:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = ('AdbRequest.timeout_sec cannot be negative. '
f'Got: {request.timeout_sec}')
return response
timeout: float | None = request.timeout_sec or None
return self._handlers[command_type](request, timeout)
def _force_stop(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Stops an application.
Args:
request: The external request containing the package to force stop.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse.
"""
force_stop = request.force_stop
response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
if not force_stop.package_name:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = '`force_stop.package_name` cannot be empty.'
return response
response, _ = self._execute_command(
['shell', 'am', 'force-stop', force_stop.package_name], timeout)
return response
def _fetch_current_task_id(
self, full_activity_name: str, timeout: float | None = None
) -> int:
"""Returns the task ID of the given `full_activity_name`.
Args:
full_activity_name: The full name of the activity whose corresponding
task id we are looking for.
timeout: Optional time limit in seconds.
Returns:
task_id: An integer corresponding to the specified activity.
"""
stack = self._adb_controller.execute_command(
['shell', 'am', 'stack', 'list'], timeout=timeout)
lines = stack.decode('utf-8').splitlines()
regex = re.compile(
r'^\ *taskId=(?P[0-9]*): (?P[^\s]*) .*visible=true'
r'.*topActivity=ComponentInfo{(?P[^\s]*)}$')
for line in lines:
match = regex.search(line)
if match is None:
continue
current_task_id_str = match.group('id')
base_activity = match.group('base_activity')
top_activity = match.group('top_activity')
# If neither of the matched activities equals the activity we are
# looking for, we discard their task id and continue the search.
if full_activity_name not in {base_activity, top_activity}:
logging.info('Full activity %s was not found in current line %s',
full_activity_name, line)
continue
# Otherwise return the integer task id.
try:
return int(current_task_id_str)
except ValueError:
logging.info('Failed to parse task ID [%r].', current_task_id_str)
# At this point if we could not find a task ID, there's nothing we can do.
logging.error('Could not find current activity in stack list: %r', lines)
return -1
def _start_screen_pinning(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Pins an application.
Args:
request: The request containing the activity to pin.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse.
"""
full_activity = request.start_screen_pinning.full_activity
response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
if not full_activity:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
'`start_screen_pinning.full_activity` cannot be empty.')
return response
current_task_id = self._fetch_current_task_id(full_activity, timeout)
if current_task_id == -1:
response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
response.error_message = ('Could not find task ID for activity '
f'[{full_activity}]')
return response
response, _ = self._execute_command(
['shell', 'am', 'task', 'lock',
str(current_task_id)], timeout=timeout)
return response
def _send_broadcast(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Sends a broadcast.
Args:
request: The request with the information for the broadcast event.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse.
"""
send_broadcast = request.send_broadcast
response = adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
if not send_broadcast.action:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = ('`send_broadcast.{action}` cannot be empty.')
return response
if send_broadcast.component:
component_args = ['-n', send_broadcast.component]
else:
component_args = []
response, _ = self._execute_command(
['shell', 'am', 'broadcast', '-a', send_broadcast.action]
+ component_args,
timeout=timeout,
)
return response
def _install_apk(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Installs an app given its local path in the filesystem.
Args:
request: The external request with an install_apk field.
Contains information for the .apk installation.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse.
"""
install_apk = request.install_apk
response = adb_pb2.AdbResponse()
location_type = install_apk.WhichOneof('location')
logging.info('location_type: %s', location_type)
match location_type:
case 'filesystem':
fpath = install_apk.filesystem.path
if not os.path.exists(fpath):
response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
response.error_message = f'Could not find local_apk_path: {fpath}'
return response
response, _ = self._execute_command(
['install', '-r', '-t', '-g', fpath], timeout=timeout
)
case 'blob':
# `delete_on_close` was only added in Python 3.12 so we add a switch
# here to still support previous Python versions.
if sys.version_info >= (3, 12):
kwargs = {'suffix': '.apk', 'delete_on_close': False}
else:
kwargs = {'suffix': '.apk'}
with tempfile.NamedTemporaryFile(**kwargs) as f:
fpath = f.name
f.write(install_apk.blob.contents)
response, _ = self._execute_command(
['install', '-r', '-t', '-g', fpath], timeout=timeout
)
case _:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
f'Unsupported `install_apk.location` type: {location_type}'
)
return response
return response
def _start_activity(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Starts a given activity.
Options for `start_activity`:
`am start` command options:
-D: enable debugging
-W: wait for launch to complete
--start-profiler : start profiler and send results to
-P : like above, but profiling stops when app goes idle
-R: repeat the activity launch times. Prior to each repeat,
the top activity will be finished.
-S: force stop the target app before starting the activity
--opengl-trace: enable tracing of OpenGL functions
Args:
request: The request with information on what activity to start.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse. If successful, StartActivityResponse will contain the
activity name and adb command output.
"""
activity = request.start_activity.full_activity
if not activity:
return adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
error_message='`start_activity.full_activity` cannot be empty.')
force_stop = '-S' if request.start_activity.force_stop else ''
response, command_output = self._execute_command(
['shell', 'am', 'start', force_stop, '-W', '-n', activity] +
list(request.start_activity.extra_args or []),
timeout=timeout)
# Check command output for potential errors.
expected_error = re.compile(r""".*Error.*""", re.VERBOSE)
if expected_error.match(str(command_output)):
return adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.INTERNAL_ERROR,
error_message=f'start_activity failed with error: {command_output}')
response.start_activity.full_activity = activity
response.start_activity.output = command_output
return response
def _press_button(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Presses a keyboard key.
Args:
request: The request with information on what button to press.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse.
"""
button = request.press_button.button
if button not in _BUTTON_TO_KEYCODE:
return adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
error_message=('PressButton.button must be one of '
f'[{_BUTTON_TO_KEYCODE.keys()}]. '
f'Got: {button}. Please see `adb.proto`.'))
keycode = _BUTTON_TO_KEYCODE[button]
response, command_output = self._execute_command(
['shell', 'input', 'keyevent', keycode], timeout=timeout)
response.press_button.output = command_output
return response
def _handle_uninstall_package(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Handles UninstallPackage messages.
Args:
request: The specification of what to uninstall.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse
"""
package_name = request.uninstall_package.package_name
response = adb_pb2.AdbResponse()
# Every UninstallPackage should have a package_name.
if not package_name:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
'`uninstall_package.package_name` cannot be empty.')
return response
# Get list of installed packages and issue an uninstall only if it's
# already installed.
package_response = self._handle_package_manager(
adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
packages=adb_pb2.AdbRequest.PackageManagerRequest.List
.Packages()))))
if package_name in package_response.package_manager.list.items:
response, _ = self._execute_command(['uninstall', package_name], timeout)
else:
msg = (f'Cannot uninstall {package_name} since it is not installed.')
logging.warning(msg)
response.error_message = msg
return response
def _get_current_activity(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Fetches current activity.
Args:
request: The request with the `.get_current_activity` field set. This is
unused, but it's in the signature so that all calls are uniform.
timeout: Optional time limit in seconds.
Returns:
AdbResponse containing the current activity.
"""
del request # Unused.
response, visible_task = self._execute_command(
['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'],
timeout=timeout)
if response.status != adb_pb2.AdbResponse.Status.OK:
return response
if not visible_task:
_, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'],
timeout=timeout)
response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
response.error_message = ('Empty visible_task. `am stack list`: '
f'{am_stack_list}')
return response
visible_task = visible_task.decode('utf-8')
if sys.platform == 'win32':
visible_task_list = re.findall(
r'visible=true topActivity=ComponentInfo{(.+?)}', visible_task)
if not visible_task_list:
visible_task = ''
else:
visible_task = 'ComponentInfo{' + visible_task_list[0] + '}'
p = re.compile(r'.*\{(.*)\}')
matches = p.search(visible_task)
if matches is None:
_, am_stack_list = self._execute_command(['shell', 'am', 'stack', 'list'],
timeout=timeout)
response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
response.error_message = (
'Could not extract current activity. Will return nothing. '
f'`am stack list`: {am_stack_list}')
return response
response.get_current_activity.full_activity = matches.group(1)
return response
def _get_orientation(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Fetches current device orientation.
Args:
request: The request with the `.get_orientation` field set.
timeout: Optional time limit in seconds.
Returns:
AdbResponse containing the current device orientation. This is
unused, but it's in the signature so that all calls are uniform.
"""
del request # Unused.
logging.info('Getting orientation...')
response = self._handle_dumpsys(
adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='input')),
timeout=timeout)
output = response.dumpsys.output
if not output:
logging.error('Empty dumpsys output.')
response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
response.error_message = 'Failed to execute `dumpsys input`'
return response
output = output.decode('utf-8')
lines = output.split('\n') # Split by lines.
skip_next = False
for line in lines:
# There may be multiple devices in output. An invalid device can be
# identified by negative PhysicalWidth.
physical_width = re.match(r'\s+PhysicalWidth:\s+(-?\d+)px', line)
if physical_width:
skip_next = int(physical_width.group(1)) < 0
# Depending on the device type, the orientation could take these forms:
# SurfaceOrientation: 0
# InputDeviceOrientation: Rotation0
surface_orientation = re.match(
r'\s+(SurfaceOrientation|InputDeviceOrientation):\s+.*(\d)', line
)
if surface_orientation is not None:
if skip_next:
continue
if surface_orientation.re.groups < 2:
continue
orientation = surface_orientation.group(2)
logging.info('Done getting orientation: %r', orientation)
response.get_orientation.orientation = int(orientation)
return response
response.status = adb_pb2.AdbResponse.Status.INTERNAL_ERROR
response.error_message = (
'Could not find SurfaceOrientation/InputDeviceOrientation in dumpsys '
'output'
)
return response
def _push(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Uploads contents to the device.
Args:
request: The request with the contents to push to the device.
timeout: Optional time limit in seconds.
Returns:
An empty AdbResponse.
"""
path = request.push.path
if not path:
return adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
error_message='Push.path is empty.')
# Create temporary file with `push` contents.
with tempfile.NamedTemporaryFile(delete=False) as f:
fname = f.name
f.write(request.push.content)
# Issue `adb push` command to upload file.
logging.info('Uploading %r to %r.', fname, path)
response, _ = self._execute_command(['push', fname, path], timeout=timeout)
# Delete it.
os.remove(fname)
return response
def _pull(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Downloads file content from the device.
Args:
request: The request with the information on what to get from the device.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse with the contents of the specified file.
"""
path = request.pull.path
if not path:
return adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
error_message='Pull.path is empty.')
# Issue `adb pull` command to copy it to a temporary file.
with tempfile.NamedTemporaryFile(delete=False) as f:
fname = f.name
logging.debug('Downloading %r to %r.', path, fname)
response, _ = self._execute_command(['pull', path, fname],
timeout=timeout)
# Read the content of the file.
with open(fname, 'rb') as f:
response.pull.content = f.read()
# Delete it.
os.remove(fname)
return response
def _input_text(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Inserts text as keyboard events.
Args:
request: The external request.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse
"""
text = request.input_text.text
if not text:
return adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
error_message='InputText.text is empty.')
response, _ = self._execute_command(['shell', 'input', 'text', text],
timeout=timeout)
return response
def _tap(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Taps the device screen.
Args:
request: The request with information on where to tap the screen.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse
"""
x = request.tap.x
y = request.tap.y
# Check for negative coordinates.
# Notice that zero coordinates are valid coordinates (i.e. the first
# column/row of the screen).
if x < 0 or y < 0:
return adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.FAILED_PRECONDITION,
error_message=(
f'Tap coordinates must be non-negative. Got: {request.tap}.'))
response, _ = self._execute_command(
['shell', 'input', 'tap', str(x),
str(y)], timeout=timeout)
return response
def _handle_settings(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Handles SettingsRequest messages.
Args:
request: The specification of what to do with settings.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse
"""
request = request.settings
response = adb_pb2.AdbResponse()
# Every SettingsRequest should have a namespace.
if (
request.name_space
== adb_pb2.AdbRequest.SettingsRequest.Namespace.UNKNOWN
):
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
f'Unknown SettingsRequest.name_space. Got: {request}.')
return response
namespace = adb_pb2.AdbRequest.SettingsRequest.Namespace.Name(
request.name_space).lower()
match request.WhichOneof('verb'):
case 'get':
get = request.get
if not get.key:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
f'Empty SettingsRequest.get.key. Got: {request}.'
)
return response
response, command_output = self._execute_command(
['shell', 'settings', 'get', namespace, get.key], timeout=timeout
)
response.settings.output = command_output
case 'put':
put = request.put
if not put.key or not put.value:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
f'Empty SettingsRequest.put key or value. Got: {request}.'
)
return response
response, command_output = self._execute_command(
['shell', 'settings', 'put', namespace, put.key, put.value],
timeout=timeout,
)
response.settings.output = command_output
case 'delete_key':
delete = request.delete_key
if not delete.key:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
f'Empty SettingsRequest.delete_key.key. Got: {request}.'
)
return response
response, command_output = self._execute_command(
['shell', 'settings', 'delete', namespace, delete.key],
timeout=timeout,
)
response.settings.output = command_output
case 'reset':
reset = request.reset
# At least one of `package_name` or `mode` should be given.
if (
not reset.package_name
and reset.mode
== adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN
):
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
'At least one of SettingsRequest.reset package_name or mode'
f' should be given. Got: {request}.'
)
return response
mode = adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.Name(
reset.mode
).lower()
arg = reset.package_name or mode
response, command_output = self._execute_command(
['shell', 'settings', 'reset', namespace, arg], timeout=timeout
)
response.settings.output = command_output
case 'list':
response, command_output = self._execute_command(
['shell', 'settings', 'list', namespace], timeout=timeout
)
response.settings.output = command_output
case _:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
f'Unknown SettingsRequest.verb. Got: {request}.'
)
return response
def _handle_generic(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Handles GenericRequest messages.
Args:
request: The request with the `.generic` field set indicating what `adb`
shell command to issue
timeout: Optional time limit in seconds.
Returns:
An AdbResponse
"""
response, command_output = self._execute_command(
list(request.generic.args), timeout)
response.generic.output = command_output
return response
def _handle_package_manager(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Handles PackageManagerRequest messages.
Args:
request: The request with the `.package_manager` field set containing the
sub-commands to issue to `adb pm`.
timeout: Optional time limit in seconds.
Returns:
An AdbResponse.
"""
request = request.package_manager
response = adb_pb2.AdbResponse()
match request.WhichOneof('verb'):
case 'list':
what = request.list.WhichOneof('what')
response, output = self._execute_command(
['shell', 'pm', 'list', what], timeout=timeout
)
if output:
items = output.decode('utf-8').split()
# Remove prefix for each item.
prefix = {
'features': 'feature:',
'libraries': 'library:',
'packages': 'package:',
}[what]
items = [x[len(prefix) :] for x in items if x.startswith(prefix)]
response.package_manager.list.items.extend(items)
response.package_manager.output = output
case 'clear':
package_name = request.clear.package_name
if not package_name:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
f'Empty PackageManagerRequest.clear.package_name. Got: {request}.'
)
return response
args = ['shell', 'pm', 'clear', package_name]
if request.clear.user_id:
args.insert(3, '-f')
args.insert(4, request.clear.user_id)
response, response.package_manager.output = self._execute_command(
args, timeout=timeout
)
case 'grant':
grant = request.grant
if not grant.package_name:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = '`grant.package_name` cannot be empty.'
return response
if not grant.permissions:
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = '`grant.permissions` cannot be empty.'
return response
for permission in grant.permissions:
logging.info('Granting permission: %r', permission)
response, response.package_manager.output = self._execute_command(
['shell', 'pm', 'grant', grant.package_name, permission],
timeout=timeout,
)
return response
def _handle_dumpsys(
self, request: adb_pb2.AdbRequest, timeout: float | None = None
) -> adb_pb2.AdbResponse:
"""Handles DumpsysRequest messages.
Args:
request: The request with the `.dumpsys` field set containing
sub-commands to `adb dumpsys` shell command..
timeout: Optional time limit in seconds.
Returns:
An AdbResponse.
"""
request = request.dumpsys
cmd = ['shell', 'dumpsys']
if request.timeout_sec < 0 or request.timeout_ms < 0:
response = adb_pb2.AdbResponse()
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
'DumpsysRequest.timeout_{sec, ms} should be non-negative. '
f'Got: {request}.')
return response
if request.list_only:
# `-l` cannot be combined with the following options.
if request.service or request.args or request.skip_services:
response = adb_pb2.AdbResponse()
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
'DumpsysRequest.list_only cannot be combined with other options. '
f'Got: {request}.')
return response
cmd.append('-l')
if request.timeout_sec > 0:
cmd.append('-t')
cmd.append(str(request.timeout_sec))
elif request.timeout_ms > 0:
cmd.append('-T')
cmd.append(str(request.timeout_ms))
if (
request.priority
!= adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET
):
cmd.append('--priority')
cmd.append(adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.Name(
request.priority))
if request.skip_services:
if request.service:
response = adb_pb2.AdbResponse()
response.status = adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
response.error_message = (
'DumpsysRequest.skip_services cannot be combined with `service`. '
f'Got: {request}.')
return response
cmd.append('--skip')
cmd.append(','.join(request.skip_services))
if request.service:
cmd.append(request.service)
if request.args:
cmd += list(request.args)
if request.proto:
cmd.append('--proto')
response, response.dumpsys.output = self._execute_command(
cmd, timeout=timeout)
return response
================================================
FILE: android_env/components/adb_call_parser_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import builtins
import os
import subprocess
import sys
import tempfile
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import adb_call_parser
from android_env.components import adb_controller
from android_env.proto import adb_pb2
class AdbCallParserTest(parameterized.TestCase):
def test_unknown_command(self):
"""Gets UNKNOWN_COMMAND for an empty request."""
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
response = parser.parse(request)
self.assertEqual(
response.status, adb_pb2.AdbResponse.Status.UNKNOWN_COMMAND
)
def test_invalid_timeout(self):
"""AdbRequest.timeout_sec must be positive."""
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.tap.x = 123
request.timeout_sec = -5
response = parser.parse(request)
self.assertEqual(
response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
)
@mock.patch.object(os.path, 'exists', autospec=True)
def test_install_apk_file_not_found(self, mock_exists):
"""Should fail installing APK when it is not found."""
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.install_apk.filesystem.path = '/my/home/game.apk'
mock_exists.return_value = False
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
@mock.patch.object(os.path, 'exists', autospec=True)
def test_install_apk_successful(self, mock_exists):
"""Should succeed installing an arbitrary APK."""
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.install_apk.filesystem.path = '/my/home/game.apk'
mock_exists.return_value = True
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['install', '-r', '-t', '-g', '/my/home/game.apk'], None)
@mock.patch.object(tempfile, 'NamedTemporaryFile', autospec=True)
def test_install_apk_from_blob(self, mock_tempfile):
"""Should succeed installing APK from blob."""
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
blob_content = b'A fake blob content'
request.install_apk.blob.contents = blob_content
mock_tempfile.return_value.__enter__.return_value.name = '/my/home/test.apk'
mock_tempfile.return_value.__enter__.return_value.write.return_value = None
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['install', '-r', '-t', '-g', '/my/home/test.apk'], None
)
# pytype: disable=attribute-error
expected_tempfile_kwargs = (
{'suffix': '.apk', 'delete_on_close': False}
if sys.version_info > (3, 12)
else {'suffix': '.apk'}
)
mock_tempfile.assert_has_calls([
mock.call(**expected_tempfile_kwargs), # Constructor
mock.call().__enter__(), # Enter context
mock.call().__enter__().write(blob_content), # Call write function
mock.call().__exit__(None, None, None), # Exit context
])
# pytype: enable=attribute-error
def test_start_activity_empty_full_activity(self):
"""A start_activity command should always have a nonempty activity."""
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_activity.extra_args.extend(['blah'])
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
def test_start_activity_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
command_output = (b'Stopping: my.project.SplashActivity\n'
b'Starting: Intent { cmp=my.project.SplashActivity }\n')
adb.execute_command.return_value = command_output
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_activity.full_activity = 'my.project.SplashActivity'
request.start_activity.extra_args.extend(['blah'])
request.start_activity.force_stop = True
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call([
'shell', 'am', 'start', '-S', '-W', '-n',
'my.project.SplashActivity', 'blah'
],
timeout=None),
])
def test_start_activity_successful_no_force_stop(self):
adb = mock.create_autospec(adb_controller.AdbController)
command_output = (b'Stopping: my.project.SplashActivity\n'
b'Starting: Intent { cmp=my.project.SplashActivity }\n')
adb.execute_command.return_value = command_output
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_activity.full_activity = 'my.project.SplashActivity'
request.start_activity.extra_args.extend(['blah'])
request.start_activity.force_stop = False
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call([
'shell', 'am', 'start', '', '-W', '-n', 'my.project.SplashActivity',
'blah'
],
timeout=None),
])
def test_start_activity_error(self):
adb = mock.create_autospec(adb_controller.AdbController)
command_output = (b'Stopping: my.project.SplashActivity\n'
b'Starting: Intent { cmp=my.project.SplashActivity }\n'
b'Error: Activity not started, unknown error code 101\n')
adb.execute_command.return_value = command_output
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_activity.full_activity = 'my.project.SplashActivity'
request.start_activity.extra_args.extend(['blah'])
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
self.assertEqual(
response.error_message,
f'start_activity failed with error: {str(command_output)}')
def test_force_stop(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.force_stop.package_name = 'my.project'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'am', 'force-stop', 'my.project'], None)
def test_grant_permissions_empty_package_name(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.package_manager.grant.permissions.extend(['perm1', 'perm2'])
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
def test_grant_permissions_empty_permissions(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.package_manager.grant.package_name = 'my.project'
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
def test_grant_permissions_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.package_manager.grant.package_name = 'my.project'
request.package_manager.grant.permissions.extend(['perm1', 'perm2'])
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call(['shell', 'pm', 'grant', 'my.project', 'perm1'], None),
mock.call(['shell', 'pm', 'grant', 'my.project', 'perm2'], None),
])
def test_press_button_invalid_button(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.press_button.button = 99999
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
def test_press_button_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b''
parser = adb_call_parser.AdbCallParser(adb)
# HOME.
request = adb_pb2.AdbRequest()
request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.HOME
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_with(
['shell', 'input', 'keyevent', 'KEYCODE_HOME'], None)
# BACK.
request = adb_pb2.AdbRequest()
request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.BACK
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_with(
['shell', 'input', 'keyevent', 'KEYCODE_BACK'], None)
# ENTER.
request = adb_pb2.AdbRequest()
request.press_button.button = adb_pb2.AdbRequest.PressButton.Button.ENTER
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_with(
['shell', 'input', 'keyevent', 'KEYCODE_ENTER'], None)
def test_start_screen_pinning_package_not_found(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = (
b' taskId=12345: my.project.AnotherActivity visible=true'
b' topActivity=ComponentInfo{my.project.AnotherActivity}')
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_screen_pinning.full_activity = 'my.project.AmazingActivity'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'am', 'stack', 'list'], None)
def test_start_screen_pinning_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = (
b' taskId=12345: my.project.AmazingActivity visible=true'
b' topActivity=ComponentInfo{my.project.AmazingActivity}')
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_screen_pinning.full_activity = 'my.project.AmazingActivity'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call(['shell', 'am', 'stack', 'list'], None),
mock.call(['shell', 'am', 'task', 'lock', '12345'], None),
])
def test_start_screen_pinning_base_activity(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = (
b' taskId=12345: my.project.MainActivity visible=true'
b' topActivity=ComponentInfo{my.project.TopActivity}')
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_screen_pinning.full_activity = 'my.project.MainActivity'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call(['shell', 'am', 'stack', 'list'], None),
mock.call(['shell', 'am', 'task', 'lock', '12345'], None),
])
def test_start_screen_pinning_top_activity(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = (
b' taskId=12345: my.project.MainActivity visible=true'
b' topActivity=ComponentInfo{my.project.TopActivity}')
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.start_screen_pinning.full_activity = 'my.project.TopActivity'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call(['shell', 'am', 'stack', 'list'], None),
mock.call(['shell', 'am', 'task', 'lock', '12345'], None),
])
def test_send_broadcast_empty_action(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
send_broadcast=adb_pb2.AdbRequest.SendBroadcast())
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
def test_send_broadcast_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.send_broadcast.action = 'SOME-ACTION'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
def test_send_broadcast_with_component_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.send_broadcast.action = 'SOME-ACTION'
request.send_broadcast.component = 'SOME-COMPONENT'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
def test_uninstall_package_empty_package_name(self):
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.uninstall_package.package_name = ''
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
def test_uninstall_package_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'package:my.package'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest()
request.uninstall_package.package_name = 'my.package'
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
def test_get_current_activity_no_visible_task(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = None
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call(
['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'],
None),
mock.call(['shell', 'am', 'stack', 'list'], None),
])
def test_get_orientation_empty_dumpsys(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b''
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest())
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'],
None)
def test_get_orientation_invalid_device_no_surface_orientation(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b' PhysicalWidth: -123px'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest())
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'],
None)
@parameterized.named_parameters(
('rotation_0', b""" SurfaceOrientation: 0""", 0),
('rotation_90', b""" SurfaceOrientation: 1""", 1),
('rotation_180', b""" SurfaceOrientation: 2""", 2),
('rotation_270', b""" SurfaceOrientation: 3""", 3),
('rotation_0_new', b""" InputDeviceOrientation: 0""", 0),
('rotation_90_new', b""" InputDeviceOrientation: 1""", 1),
('rotation_180_new', b""" InputDeviceOrientation: 2""", 2),
('rotation_270_new', b""" InputDeviceOrientation: 3""", 3),
)
def test_get_orientation_success(
self, orientation: bytes, expected_orientation: int
):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = (
b"""SomeRandomKey: 12345\n""" + orientation + b"""
MoreRandomStuff: awesome_value
"""
)
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest())
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
self.assertEqual(response.get_orientation.orientation, expected_orientation)
adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'input'],
None)
def test_get_current_activity_no_matches(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())
for platform in ['win32', 'linux']:
with mock.patch.object(
sys, 'platform', autospec=True, return_value=platform):
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.INTERNAL_ERROR)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_has_calls([
mock.call([
'shell', 'am', 'stack', 'list', '|', 'grep', '-E',
'visible=true'
], None),
mock.call(['shell', 'am', 'stack', 'list'], None),
])
def test_get_current_activity_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'{MyAwesomeActivity}'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity())
for platform in ['win32', 'linux']:
with mock.patch.object(
sys, 'platform', autospec=True, return_value=platform):
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
# `execute_command` will be called once for each platform.
adb.execute_command.assert_called_with(
['shell', 'am', 'stack', 'list', '|', 'grep', '-E', 'visible=true'],
None)
self.assertEqual(response.get_current_activity.full_activity,
'MyAwesomeActivity')
def test_push_no_path(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
push=adb_pb2.AdbRequest.Push(content=b'Has content but no path'))
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
def test_push_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
push=adb_pb2.AdbRequest.Push(
content=b'My text.', path='/sdcard/my_file.txt'))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once()
args, kwargs = adb.execute_command.call_args
self.assertLen(args, 1)
cmd_args = args[0]
self.assertLen(cmd_args, 3)
self.assertEqual(cmd_args[0], 'push')
self.assertEqual(cmd_args[2], '/sdcard/my_file.txt')
self.assertIn('timeout', kwargs)
self.assertIsNone(kwargs['timeout'])
def test_pull_no_path(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(pull=adb_pb2.AdbRequest.Pull())
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
@mock.patch.object(builtins, 'open', autospec=True)
def test_pull_successful(self, mock_open):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.read.return_value = b'S3cR3t. dO nOt TeLl ANYONE'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
pull=adb_pb2.AdbRequest.Pull(path='/sdcard/my_file.txt'))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
self.assertEqual(response.pull.content, b'S3cR3t. dO nOt TeLl ANYONE')
adb.execute_command.assert_called_once()
args, kwargs = adb.execute_command.call_args
self.assertLen(args, 1)
cmd_args = args[0]
self.assertLen(cmd_args, 3)
self.assertEqual(cmd_args[0], 'pull')
self.assertEqual(cmd_args[1], '/sdcard/my_file.txt')
self.assertIn('timeout', kwargs)
self.assertIsNone(kwargs['timeout'])
def test_input_text_no_text(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(input_text=adb_pb2.AdbRequest.InputText())
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
def test_input_text_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
input_text=adb_pb2.AdbRequest.InputText(
text='The Greatest Text of All Time'))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'input', 'text', 'The Greatest Text of All Time'], None)
@parameterized.named_parameters(
('negative_x_and_negative_y',
adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=-1))),
('negative_x',
adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=-1, y=123))),
('negative_y',
adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=456, y=-1))),
)
def test_tap_failed(self, request: adb_pb2.AdbRequest):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
def test_tap_successful(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=135, y=246))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'input', 'tap', '135', '246'], None)
@parameterized.named_parameters(
('empty_request', adb_pb2.AdbRequest.SettingsRequest()),
('no_namespace',
adb_pb2.AdbRequest.SettingsRequest(
get=adb_pb2.AdbRequest.SettingsRequest.Get(key='my_key'))),
('get_no_key',
adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
get=adb_pb2.AdbRequest.SettingsRequest.Get())),
('put_no_key',
adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put())),
('put_no_value',
adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(key='another_key'))),
('delete_no_key',
adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete())),
('reset_no_package_name_and_no_mode',
adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
reset=adb_pb2.AdbRequest.SettingsRequest.Reset())),
)
def test_settings_failures(self, request):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(settings=request)
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
def test_settings_success_get(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'here it is!'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
get=adb_pb2.AdbRequest.SettingsRequest.Get(key='some_key'))
request = adb_pb2.AdbRequest(settings=request)
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
self.assertEqual(response.settings.output, b'here it is!')
adb.execute_command.assert_called_once_with(
['shell', 'settings', 'get', 'system', 'some_key'], None)
def test_settings_success_put(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'Done for ya!'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE,
put=adb_pb2.AdbRequest.SettingsRequest.Put(key='key1', value='val2'))
request = adb_pb2.AdbRequest(settings=request)
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
self.assertEqual(response.settings.output, b'Done for ya!')
adb.execute_command.assert_called_once_with(
['shell', 'settings', 'put', 'secure', 'key1', 'val2'], None)
def test_settings_success_delete(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'Key deleted.'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
delete_key=adb_pb2.AdbRequest.SettingsRequest.Delete(key='useless_key'))
request = adb_pb2.AdbRequest(settings=request)
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
self.assertEqual(response.settings.output, b'Key deleted.')
adb.execute_command.assert_called_once_with(
['shell', 'settings', 'delete', 'global', 'useless_key'], None)
@parameterized.named_parameters(
('mode_untrusted_defaults',
adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS, '',
'untrusted_defaults'),
('mode_untrusted_clear',
adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR, '',
'untrusted_clear'),
('mode_trusted_defaults',
adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS, '',
'trusted_defaults'),
# If `package_name` is given, it takes precedence over `mode`.
('mode_unknown_package_given',
adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNKNOWN, 'great.package',
'great.package'),
('mode_untrusted_defaults_package_given',
adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_DEFAULTS,
'great.package', 'great.package'),
('mode_untrusted_clear_package_given',
adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.UNTRUSTED_CLEAR,
'great.package', 'great.package'),
('mode_trusted_defaults_package_given',
adb_pb2.AdbRequest.SettingsRequest.Reset.Mode.TRUSTED_DEFAULTS,
'great.package', 'great.package'),
)
def test_settings_success_reset(self, mode, package_name, expected_arg):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'Pkg reset.'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
reset=adb_pb2.AdbRequest.SettingsRequest.Reset(
package_name=package_name, mode=mode))
request = adb_pb2.AdbRequest(settings=request)
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
self.assertEqual(response.settings.output, b'Pkg reset.')
adb.execute_command.assert_called_once_with(
['shell', 'settings', 'reset', 'global', expected_arg], None)
def test_settings_success_list(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'volume_ring=5\nvolume_system=7'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
list=adb_pb2.AdbRequest.SettingsRequest.List())
request = adb_pb2.AdbRequest(settings=request)
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
self.assertEqual(response.settings.output,
b'volume_ring=5\nvolume_system=7')
adb.execute_command.assert_called_once_with(
['shell', 'settings', 'list', 'system'], None)
def test_generic_command(self):
adb = mock.create_autospec(adb_controller.AdbController)
expected_output = b'generic_output'
args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action']
adb.execute_command.return_value = expected_output
parser = adb_call_parser.AdbCallParser(adb)
generic_request = adb_pb2.AdbRequest.GenericRequest(args=args)
request = adb_pb2.AdbRequest(generic=generic_request)
response = parser.parse(request)
self.assertEqual(adb_pb2.AdbResponse.Status.OK, response.status)
self.assertEmpty(response.error_message)
self.assertEqual(response.generic.output, expected_output)
adb.execute_command.assert_called_once_with(args, None)
def test_generic_command_adb_error(self):
adb = mock.create_autospec(adb_controller.AdbController)
args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action']
adb.execute_command.side_effect = subprocess.CalledProcessError(
cmd='cmd', output='adb_error', returncode=-1)
parser = adb_call_parser.AdbCallParser(adb)
generic_request = adb_pb2.AdbRequest.GenericRequest(args=args)
request = adb_pb2.AdbRequest(generic=generic_request)
response = parser.parse(request)
self.assertEqual(adb_pb2.AdbResponse.Status.ADB_ERROR, response.status)
self.assertEqual('adb_error', response.error_message)
self.assertEmpty(response.generic.output)
adb.execute_command.assert_called_once_with(args, None)
def test_generic_command_timeout(self):
adb = mock.create_autospec(adb_controller.AdbController)
args = ['shell', 'am', 'broadcast', '-n', 'receiver', '-a', 'action']
adb.execute_command.side_effect = subprocess.TimeoutExpired(
cmd='cmd', timeout=10)
parser = adb_call_parser.AdbCallParser(adb)
generic_request = adb_pb2.AdbRequest.GenericRequest(args=args)
request = adb_pb2.AdbRequest(generic=generic_request)
response = parser.parse(request)
self.assertEqual(adb_pb2.AdbResponse.Status.TIMEOUT, response.status)
self.assertEqual('Timeout', response.error_message)
self.assertEmpty(response.generic.output)
adb.execute_command.assert_called_once_with(args, None)
@parameterized.named_parameters(
('features',
adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
features=adb_pb2.AdbRequest.PackageManagerRequest.List
.Features())))),
('libraries',
adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
libraries=adb_pb2.AdbRequest.PackageManagerRequest.List
.Libraries())))),
('packages',
adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
packages=adb_pb2.AdbRequest.PackageManagerRequest.List
.Packages())))),
)
def test_package_manager_list_bad_output(self, request):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b"""Something irrelevant."""
parser = adb_call_parser.AdbCallParser(adb)
response = parser.parse(request)
response.package_manager.output = b"""Something irrelevant."""
self.assertEmpty(response.package_manager.list.items)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once()
def test_package_manager_list_features(self):
adb = mock.create_autospec(adb_controller.AdbController)
output = b"""
feature:android.hardware.audio.output
feature:android.hardware.bluetooth
feature:android.hardware.camera
feature:android.hardware.fingerprint
feature:android.software.autofill
feature:android.software.backup
feature:android.software.webview
"""
adb.execute_command.return_value = output
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
features=adb_pb2.AdbRequest.PackageManagerRequest.List.Features(
))))
response = parser.parse(request)
self.assertEqual(response.package_manager.output, output)
self.assertEqual(response.package_manager.list.items, [
'android.hardware.audio.output',
'android.hardware.bluetooth',
'android.hardware.camera',
'android.hardware.fingerprint',
'android.software.autofill',
'android.software.backup',
'android.software.webview',
])
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'pm', 'list', 'features'], None)
def test_package_manager_list_libraries(self):
adb = mock.create_autospec(adb_controller.AdbController)
output = b"""
library:android.ext.shared
library:android.hidl.base-V1.0-java
library:android.hidl.manager-V1.0-java
library:android.net.ipsec.ike
library:android.test.base
library:android.test.mock
library:android.test.runner
library:androidx.window.sidecar
library:com.android.future.usb.accessory
library:com.android.location.provider
library:com.android.media.remotedisplay
library:com.android.mediadrm.signer
library:com.android.nfc_extras
library:com.google.android.gms
library:com.google.android.trichromelibrary
library:javax.obex
library:org.apache.http.legacy
"""
adb.execute_command.return_value = output
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
libraries=adb_pb2.AdbRequest.PackageManagerRequest.List
.Libraries())))
response = parser.parse(request)
self.assertEqual(response.package_manager.output, output)
self.assertEqual(response.package_manager.list.items, [
'android.ext.shared',
'android.hidl.base-V1.0-java',
'android.hidl.manager-V1.0-java',
'android.net.ipsec.ike',
'android.test.base',
'android.test.mock',
'android.test.runner',
'androidx.window.sidecar',
'com.android.future.usb.accessory',
'com.android.location.provider',
'com.android.media.remotedisplay',
'com.android.mediadrm.signer',
'com.android.nfc_extras',
'com.google.android.gms',
'com.google.android.trichromelibrary',
'javax.obex',
'org.apache.http.legacy',
])
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'pm', 'list', 'libraries'], None)
def test_package_manager_list_packages(self):
adb = mock.create_autospec(adb_controller.AdbController)
output = b"""
package:com.android.phone
package:com.awesome.company
package:com.another.great.thingie
"""
adb.execute_command.return_value = output
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages(
))))
response = parser.parse(request)
self.assertEqual(response.package_manager.output, output)
self.assertEqual(response.package_manager.list.items, [
'com.android.phone',
'com.awesome.company',
'com.another.great.thingie',
])
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'pm', 'list', 'packages'], None)
def test_package_manager_clear_no_package_name(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b"""Something irrelevant."""
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear(
package_name='')))
response = parser.parse(request)
self.assertEmpty(response.package_manager.output)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
def test_package_manager_clear_successful_no_user_id(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b"""Some successful message."""
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear(
package_name='my.package')))
response = parser.parse(request)
self.assertEqual(response.package_manager.output,
b"""Some successful message.""")
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'pm', 'clear', 'my.package'], None)
def test_package_manager_clear_successful_with_user_id(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b"""Some successful message."""
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
clear=adb_pb2.AdbRequest.PackageManagerRequest.Clear(
package_name='my.package', user_id='mrawesome')))
response = parser.parse(request)
self.assertEqual(response.package_manager.output,
b"""Some successful message.""")
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'pm', 'clear', '-f', 'mrawesome', 'my.package'], None)
def test_dumpsys_empty_request(self):
"""An empty `DumpsysRequest` is a valid request."""
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(dumpsys=adb_pb2.AdbRequest.DumpsysRequest())
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(['shell', 'dumpsys'],
timeout=None)
@parameterized.named_parameters(
('negative_timeout_sec',
adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_sec=-1))),
('negative_timeout_ms',
adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(timeout_ms=-2))),
)
def test_dumpsys_negative_timeouts(self, request):
"""`DumpsysRequest.timeout_{sec, ms}` if passed, should be positive."""
adb = mock.create_autospec(adb_controller.AdbController)
parser = adb_call_parser.AdbCallParser(adb)
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
@parameterized.named_parameters(
('both_timeouts_zero', 0, 0, ['shell', 'dumpsys']),
('sec_takes_precedence_zero', 123, 0, ['shell', 'dumpsys', '-t', '123']),
('sec_takes_precedence', 123, 456, ['shell', 'dumpsys', '-t', '123']),
('ms_if_no_sec', 0, 456, ['shell', 'dumpsys', '-T', '456']),
)
def test_dumpsys_timeout_successful(self, timeout_sec, timeout_ms, expected):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
timeout_sec=timeout_sec, timeout_ms=timeout_ms))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(expected, timeout=None)
@parameterized.named_parameters(
('priority_undefined',
adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.UNSET,
['shell', 'dumpsys']),
('priority_normal',
adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.NORMAL,
['shell', 'dumpsys', '--priority', 'NORMAL']),
('priority_high', adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.HIGH,
['shell', 'dumpsys', '--priority', 'HIGH']),
('priority_critical',
adb_pb2.AdbRequest.DumpsysRequest.PriorityLevel.CRITICAL,
['shell', 'dumpsys', '--priority', 'CRITICAL']),
)
def test_dumpsys_priority_timeout_successful(self, priority, expected):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(priority=priority))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(expected, timeout=None)
@parameterized.named_parameters(
(
'window_service',
adb_pb2.AdbRequest.DumpsysRequest(list_only=True, service='window'),
),
(
'arbitrary_args',
adb_pb2.AdbRequest.DumpsysRequest(
list_only=True, args=['myoption', 'anotheroption']
),
),
(
'skip_usb',
adb_pb2.AdbRequest.DumpsysRequest(
list_only=True, skip_services=['usb']
),
),
)
def test_dumpsys_list_only_cannot_be_combined(
self, dumpsys_request: adb_pb2.AdbRequest.DumpsysRequest
):
"""When `list_only==True`, the request cannot contain a few fields."""
# Arrange.
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(dumpsys=dumpsys_request)
# Act.
response = parser.parse(request)
# Assert.
self.assertEqual(
response.status, adb_pb2.AdbResponse.Status.FAILED_PRECONDITION
)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
def test_dumpsys_list_only_success(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(list_only=True))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(['shell', 'dumpsys', '-l'],
timeout=None)
def test_dumpsys_skip_services_cannot_combine_with_service(self):
"""When using `DumpsysRequest.skip_service`, it cannot contain `.service`."""
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
service='wifi', skip_services=['window', 'usb']))
response = parser.parse(request)
self.assertEqual(response.status,
adb_pb2.AdbResponse.Status.FAILED_PRECONDITION)
self.assertNotEmpty(response.error_message)
adb.execute_command.assert_not_called()
def test_dumpsys_skip_services(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
skip_services=['window', 'usb']))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'dumpsys', '--skip', 'window,usb'], timeout=None)
def test_dumpsys_single_service(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window'))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(['shell', 'dumpsys', 'window'],
timeout=None)
def test_dumpsys_single_service_with_args(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'whatever'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
service='window', args=['arg1', 'arg2']))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'dumpsys', 'window', 'arg1', 'arg2'], timeout=None)
def test_dumpsys_single_service_with_proto(self):
adb = mock.create_autospec(adb_controller.AdbController)
adb.execute_command.return_value = b'some binary output'
parser = adb_call_parser.AdbCallParser(adb)
request = adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(service='window', proto=True))
response = parser.parse(request)
self.assertEqual(response.status, adb_pb2.AdbResponse.Status.OK)
self.assertEmpty(response.error_message)
adb.execute_command.assert_called_once_with(
['shell', 'dumpsys', 'window', '--proto'], timeout=None)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/adb_controller.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A class to manage and control an external ADB process."""
import os
import subprocess
import time
from absl import logging
from android_env.components import config_classes
from android_env.components import errors
class AdbController:
"""Manages communication with adb."""
def __init__(self, config: config_classes.AdbControllerConfig):
"""Instantiates an AdbController object."""
self._config = config
logging.info('config: %r', self._config)
if not self._config.use_adb_server_port_from_os_env:
# Unset problematic environment variables. ADB commands will fail if these
# are set. They are normally exported by AndroidStudio.
if 'ANDROID_HOME' in os.environ:
logging.info('Removing ANDROID_HOME from os.environ')
del os.environ['ANDROID_HOME']
if 'ANDROID_ADB_SERVER_PORT' in os.environ:
logging.info('Removing ANDROID_ADB_SERVER_PORT from os.environ')
del os.environ['ANDROID_ADB_SERVER_PORT']
# Explicitly expand the $HOME environment variable.
self._os_env_vars = dict(os.environ).copy()
self._os_env_vars.update(
{'HOME': os.path.expandvars(self._os_env_vars.get('HOME', ''))}
)
logging.info('self._os_env_vars: %r', self._os_env_vars)
def command_prefix(self, include_device_name: bool = True) -> list[str]:
"""The command for instantiating an adb client to this server."""
if self._config.use_adb_server_port_from_os_env:
# When using the adb server port set from the OS environment, we don't
# need to pass the port explicitly.
adb_port_args = []
else:
# When using the adb server port set from the config, we need to pass the
# port explicitly.
adb_port_args = ['-P', str(self._config.adb_server_port)]
command_prefix = [
self._config.adb_path,
*adb_port_args,
]
if include_device_name:
command_prefix.extend(['-s', self._config.device_name])
return command_prefix
def init_server(self, timeout: float | None = None):
"""Initialize the ADB server deamon on the given port.
This function should be called immediately after initializing the first
adb_controller, and before launching the simulator.
Args:
timeout: A timeout to use for this operation. If not set the default
timeout set on the constructor will be used.
"""
# Make an initial device-independent call to ADB to start the deamon.
self.execute_command(['devices'], timeout, device_specific=False)
time.sleep(0.2)
def _restart_server(self, timeout: float | None = None):
"""Kills and restarts the adb server.
Args:
timeout: A timeout to use for this operation. If not set the default
timeout set on the constructor will be used.
"""
logging.info('Restarting adb server.')
self.execute_command(
['kill-server'], timeout=timeout, device_specific=False
)
time.sleep(0.2)
cmd_output = self.execute_command(
['start-server'], timeout=timeout, device_specific=False
)
logging.info('start-server output: %r', cmd_output.decode('utf-8'))
time.sleep(2.0)
self.execute_command(['devices'], timeout=timeout, device_specific=False)
time.sleep(0.2)
def execute_command(
self,
args: list[str],
timeout: float | None = None,
device_specific: bool = True,
) -> bytes:
"""Executes an adb command.
Args:
args: A list of strings representing each adb argument. For example:
['install', '/my/app.apk']
timeout: A timeout to use for this operation. If not set the default
timeout set on the constructor will be used.
device_specific: Whether the call is device-specific or independent.
Returns:
The output of running such command as a binary string.
"""
timeout = self._config.default_timeout if timeout is None else timeout
command = self.command_prefix(include_device_name=device_specific) + args
command_str = 'adb ' + ' '.join(command[1:])
n_tries = 2
latest_error = None
for i in range(n_tries):
try:
logging.info('Executing ADB command: [%s]', command_str)
cmd_output = subprocess.check_output(
command,
stderr=subprocess.STDOUT,
timeout=timeout,
env=self._os_env_vars,
)
logging.debug('ADB command output: %s', cmd_output)
return cmd_output
except (subprocess.CalledProcessError, subprocess.TimeoutExpired) as e:
logging.exception(
'Failed to execute ADB command (try %d of %d): [%s]',
i + 1,
n_tries,
command_str,
)
if e.stdout is not None:
logging.error('**stdout**:')
for line in e.stdout.splitlines():
logging.error(' %s', line)
if e.stderr is not None:
logging.error('**stderr**:')
for line in e.stderr.splitlines():
logging.error(' %s', line)
latest_error = e
if device_specific and i < n_tries - 1:
self._restart_server(timeout=timeout)
raise errors.AdbControllerError(
f'Error executing adb command: [{command_str}]\n'
f'Caused by: {latest_error}'
) from latest_error
================================================
FILE: android_env/components/adb_controller_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import time
from unittest import mock
from absl.testing import absltest
from android_env.components import adb_controller as adb_controller_lib
from android_env.components import config_classes
from android_env.components import errors
# Timeout to be used by default in tests below. Set to a small value to avoid
# hanging on a failed test.
_TIMEOUT = 2
class AdbControllerTest(absltest.TestCase):
def setUp(self):
super().setUp()
# Set env vars.
os.environ['MY_ENV_VAR'] = '/some/path/'
os.environ['HOME'] = '$MY_ENV_VAR'
self._env_before = os.environ.copy()
def tearDown(self):
super().tearDown()
if 'ANDROID_HOME' in os.environ:
del os.environ['ANDROID_HOME']
if 'ANDROID_ADB_SERVER_PORT' in os.environ:
del os.environ['ANDROID_ADB_SERVER_PORT']
@mock.patch.object(subprocess, 'check_output', autospec=True)
@mock.patch.object(time, 'sleep', autospec=True)
def test_init_server(self, mock_sleep, mock_check_output):
"""We expect an `adb devices` call when initializing the server."""
# Arrange.
adb_controller = adb_controller_lib.AdbController(
config_classes.AdbControllerConfig(
adb_path='my_adb',
device_name='awesome_device',
use_adb_server_port_from_os_env=True,
)
)
# Act.
adb_controller.init_server(timeout=_TIMEOUT)
# Assert.
expected_env = self._env_before
expected_env['HOME'] = '/some/path/'
mock_check_output.assert_called_once_with(
['my_adb', 'devices'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
)
mock_sleep.assert_called_once()
@mock.patch.object(subprocess, 'check_output', autospec=True)
@mock.patch.object(time, 'sleep', autospec=True)
def test_init_server_with_adb_server_port_from_os_env(
self, mock_sleep, mock_check_output
):
"""Us OS env vars if `use_adb_server_port_from_os_env` is True."""
# Arrange.
# Set the ADB server port to 1234 in the OS environment.
os.environ['ANDROID_ADB_SERVER_PORT'] = '1234'
os.environ['ANDROID_HOME'] = '/some/path/to/android'
adb_controller = adb_controller_lib.AdbController(
config_classes.AdbControllerConfig(
adb_path='my_adb',
device_name='awesome_device',
adb_server_port=9999,
use_adb_server_port_from_os_env=True,
)
)
# Act.
adb_controller.init_server(timeout=_TIMEOUT)
# Assert.
expected_env = self._env_before
expected_env['HOME'] = '/some/path/'
expected_env['ANDROID_HOME'] = '/some/path/to/android'
expected_env['ANDROID_ADB_SERVER_PORT'] = '1234'
mock_check_output.assert_called_once_with(
['my_adb', 'devices'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
)
mock_sleep.assert_called_once()
@mock.patch.object(subprocess, 'check_output', autospec=True)
@mock.patch.object(time, 'sleep', autospec=True)
def test_restart_server(self, mock_sleep, mock_check_output):
"""When an adb command fails, we expect the server to be restarted."""
# Arrange.
mock_check_output.side_effect = [
subprocess.CalledProcessError(returncode=1, cmd='blah'),
] + ['fake_output'.encode('utf-8')] * 4
adb_controller = adb_controller_lib.AdbController(
config_classes.AdbControllerConfig(
adb_path='my_adb',
device_name='awesome_device',
use_adb_server_port_from_os_env=True,
)
)
# Act.
adb_controller.execute_command(['my_command'], timeout=_TIMEOUT)
# Assert.
expected_env = self._env_before
expected_env['HOME'] = '/some/path/'
mock_check_output.assert_has_calls([
mock.call(
['my_adb', '-s', 'awesome_device', 'my_command'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', 'kill-server'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', 'start-server'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', 'devices'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', '-s', 'awesome_device', 'my_command'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
])
mock_sleep.assert_has_calls(
[mock.call(0.2), mock.call(2.0), mock.call(0.2)]
)
@mock.patch.object(subprocess, 'check_output', autospec=True)
@mock.patch.object(time, 'sleep', autospec=True)
def test_invalid_command(self, mock_sleep, mock_check_output):
"""Restart the server when given an invalid command."""
# Arrange.
restart_sequence = ['fake_output'.encode('utf-8')] * 3
mock_check_output.side_effect = (
[
subprocess.CalledProcessError(returncode=1, cmd='blah'),
]
+ restart_sequence
+ [subprocess.CalledProcessError(returncode=1, cmd='blah')]
# Don't restart if last call fails.
)
adb_controller = adb_controller_lib.AdbController(
config_classes.AdbControllerConfig(
adb_path='my_adb',
device_name='awesome_device',
use_adb_server_port_from_os_env=True,
)
)
# Act.
with self.assertRaises(errors.AdbControllerError):
adb_controller.execute_command(['my_command'], timeout=_TIMEOUT)
# Assert.
expected_env = self._env_before
expected_env['HOME'] = '/some/path/'
mock_check_output.assert_has_calls(
[
mock.call(
['my_adb', '-s', 'awesome_device', 'my_command'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', 'kill-server'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', 'start-server'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', 'devices'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
mock.call(
['my_adb', '-s', 'awesome_device', 'my_command'],
stderr=subprocess.STDOUT,
timeout=_TIMEOUT,
env=expected_env,
),
],
any_order=False,
)
mock_sleep.assert_has_calls(
[mock.call(0.2), mock.call(2.0), mock.call(0.2)]
)
@mock.patch.object(subprocess, 'check_output', autospec=True)
@mock.patch.object(time, 'sleep', autospec=True)
def test_avoid_infinite_recursion(self, mock_sleep, mock_check_output):
"""Raise an error if the command fails even after restarts."""
del mock_sleep
mock_check_output.side_effect = subprocess.CalledProcessError(
returncode=1, cmd='blah'
)
adb_controller = adb_controller_lib.AdbController(
config_classes.AdbControllerConfig(
adb_path='my_adb',
device_name='awesome_device',
use_adb_server_port_from_os_env=True,
)
)
self.assertRaises(
errors.AdbControllerError,
adb_controller.execute_command,
['my_command'],
timeout=_TIMEOUT,
)
class AdbControllerInitTest(absltest.TestCase):
def test_deletes_problem_env_vars(self):
os.environ['ANDROID_HOME'] = '/usr/local/Android/Sdk'
os.environ['ANDROID_ADB_SERVER_PORT'] = '1337'
adb_controller_lib.AdbController(
config_classes.AdbControllerConfig(
adb_path='my_adb',
device_name='awesome_device',
adb_server_port=9999,
default_timeout=_TIMEOUT,
)
)
self.assertNotIn('ANDROID_HOME', os.environ)
self.assertNotIn('ANDROID_ADB_SERVER_PORT', os.environ)
def test_use_adb_server_port_from_os_env_retains_os_env_vars(self):
os.environ['ANDROID_HOME'] = '/usr/local/Android/Sdk'
os.environ['ANDROID_ADB_SERVER_PORT'] = '1337'
adb_controller_lib.AdbController(
config_classes.AdbControllerConfig(
adb_path='my_adb',
device_name='awesome_device',
adb_server_port=9999,
default_timeout=_TIMEOUT,
use_adb_server_port_from_os_env=True,
)
)
self.assertIn('ANDROID_ADB_SERVER_PORT', os.environ)
self.assertEqual(os.environ['ANDROID_ADB_SERVER_PORT'], '1337')
self.assertIn('ANDROID_HOME', os.environ)
self.assertEqual(os.environ['ANDROID_HOME'], '/usr/local/Android/Sdk')
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/adb_log_stream.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Class for a stream of logs output by a locally running emulator."""
import subprocess
from absl import logging
from android_env.components import log_stream
_LOGCAT_COMMAND = ['logcat', '-v', 'epoch']
class AdbLogStream(log_stream.LogStream):
"""Manages adb logcat process for a locally running emulator."""
def __init__(self, adb_command_prefix: list[str], verbose: bool = False):
super().__init__(verbose=verbose)
self._adb_command_prefix = adb_command_prefix
def _get_stream_output(self):
# Before spawning a long-lived process, we issue `logcat -b all -c` to clear
# all buffers to avoid interference from previous runs.
clear_buffer_output = subprocess.check_output(
self._adb_command_prefix + ['logcat', '-b', 'all', '-c'],
stderr=subprocess.STDOUT,
timeout=100)
logging.info('clear_buffer_output: %r', clear_buffer_output)
cmd = self._adb_command_prefix + _LOGCAT_COMMAND + self._filters
self._adb_subprocess = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
bufsize=1,
universal_newlines=True)
return self._adb_subprocess.stdout
def stop_stream(self):
if not hasattr(self, '_adb_subprocess') or self._adb_subprocess is None:
logging.error('`stop_stream()` called before `get_stream_output()`. '
'This violates the `LogStream` API.')
else:
self._adb_subprocess.kill()
================================================
FILE: android_env/components/adb_log_stream_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for adb_log_stream."""
import subprocess
from unittest import mock
from absl.testing import absltest
from android_env.components import adb_log_stream
class FakeAdbSubprocess:
@property
def stdout(self):
return [f'line_{i}' for i in range(100)]
def kill(self):
pass
class AdbLogStreamTest(absltest.TestCase):
@mock.patch.object(subprocess, 'check_output', return_value=b'')
@mock.patch.object(subprocess, 'Popen', return_value=FakeAdbSubprocess())
def test_get_stream_output(self, mock_popen, unused_mock_check_output):
stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo'])
stream.set_log_filters(['bar'])
stream_output = stream.get_stream_output()
for i, line in enumerate(stream_output):
self.assertEqual(line, f'line_{i}')
mock_popen.assert_called_with(
['foo', 'logcat', '-v', 'epoch', 'bar', '*:S'],
stderr=subprocess.STDOUT,
stdout=subprocess.PIPE,
bufsize=1,
universal_newlines=True)
def test_stop_stream_before_get_stream_output(self):
"""Calling `stop_stream()` before `get_stream_output()` should not crash."""
# Arrange.
stream = adb_log_stream.AdbLogStream(adb_command_prefix=['foo'])
# Act.
stream.stop_stream()
# Assert.
# Nothing to assert. The test should just finish without raising an
# exception.
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/app_screen_checker.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Determines if the current app screen matches an expected app screen."""
from collections.abc import Callable, Sequence
import enum
import re
import time
from typing import Self
from absl import logging
from android_env.components import adb_call_parser as adb_call_parser_lib
from android_env.components import errors
from android_env.proto import adb_pb2
from android_env.proto import task_pb2
class _DumpsysNode:
"""A node in a dumpsys tree."""
def __init__(self, data: str):
self._children = []
self._data = data
@property
def data(self) -> str:
return self._data
@property
def children(self) -> list[Self]:
return self._children
def find_child(
self, predicate: Callable[[Self], bool], max_levels: int = 0
) -> Self | None:
"""Returns the first direct child that matches `predicate`, None otherwise.
Args:
predicate: Function-like that accepts a _DumpsysNode and returns boolean.
max_levels: Maximum number of levels down the tree to search for a child.
If non-positive, only direct children will be searched for.
Returns:
A _DumpsysNode or None.
"""
if not self.children:
return None
try:
return next(x for x in self.children if predicate(x))
except StopIteration:
logging.info('Failed to find child. max_levels: %i.', max_levels)
# Search children.
if max_levels:
for child in self.children:
child_result = child.find_child(predicate, max_levels - 1)
if child_result is not None:
return child_result
return None
def __repr__(self):
return self._data
def print_tree(self, indent: int = 2):
"""Prints this tree in logging.info()."""
logging.info(' ' * indent + self.data)
for c in self.children:
c.print_tree(indent + 2)
def build_tree_from_dumpsys_output(dumpsys_output: str) -> _DumpsysNode:
"""Constructs a tree from a dumpsys string output.
Args:
dumpsys_output: string Verbatim output from adb dumpsys. The expected format
is a list where each line is a node and the indentation marks the
relationship with its parent or sibling.
Returns:
_DumpsysNode The root of the tree.
"""
lines = dumpsys_output.split('\n') # Split by lines.
lines = [x.rstrip(' \r') for x in lines]
lines = [x for x in lines if len(x)] # Remove empty lines.
root = _DumpsysNode('___root___') # The root of all nodes.
parents_stack = [root]
for line in lines:
stripped_line = line.lstrip(' ')
indent = len(line) - len(stripped_line) # Number of indent spaces.
new_node = _DumpsysNode(stripped_line) # Create a node without indentation.
parent = parents_stack.pop()
if parent.data == '___root___': # The root is an exception for indentation.
parent_indent = -2
else:
parent_indent = (len(parents_stack) - 1) * 2
if indent == parent_indent: # `new_node` is a sibiling.
parent = parents_stack.pop()
elif indent < parent_indent: # Indentation reduced (i.e. a block finished)
num_levels = (indent // 2) + 1
parents_stack = parents_stack[:num_levels]
parent = parents_stack.pop()
elif indent > parent_indent: # `new_node` is a child.
pass # No need to change the current parent.
parent.children.append(new_node)
parents_stack.append(parent)
parents_stack.append(new_node)
return root
def matches_path(
dumpsys_activity_output: str,
expected_view_hierarchy_path: Sequence[re.Pattern[str]],
max_levels: int = 0,
) -> bool:
"""Returns True if the current dumpsys output matches the expected path.
Args:
dumpsys_activity_output: The output of running `dumpsys activity ...`.
expected_view_hierarchy_path: [regex] A list of regular expressions to be
tested at each level of the tree.
max_levels: How many levels to search from root for View Hierarchy.
Returns:
True if the dumpsys tree contains one path that matches all regexes.
"""
root = build_tree_from_dumpsys_output(dumpsys_activity_output)
# Find the View Hierarchy.
view_hierarchy = root.find_child(
lambda x: x.data.startswith('View Hierarchy'), max_levels)
if view_hierarchy is None:
logging.error(
'view_hierarchy is None. Dumpsys activity output: %s. tree: %r',
str(dumpsys_activity_output), root.print_tree())
logging.error('Tree root: %s', str(root))
return False
current_node = view_hierarchy
for i, regex in enumerate(expected_view_hierarchy_path):
def regex_predicate(node, expr=regex):
matches = expr.match(node.data)
return matches is not None
child = current_node.find_child(regex_predicate)
if child is None:
logging.error('Mismatched regex (%i, %s). current_node: %s', i,
regex.pattern, current_node)
logging.error('Dumpsys activity output: %s', str(dumpsys_activity_output))
logging.error('Tree root: %s', str(root))
return False
else:
current_node = child
return True
class AppScreenChecker:
"""Checks that the current app screen matches an expected screen."""
class Outcome(enum.IntEnum):
"""Possible return vales from checking the current app screen."""
# The current app screen matches the expected app screen.
SUCCESS = 0
# There's no activity to check.
EMPTY_EXPECTED_ACTIVITY = 1
# We were unable to determine the current activity.
FAILED_ACTIVITY_EXTRACTION = 2
# The current activity does not match the expected activity.
UNEXPECTED_ACTIVITY = 3
# The current view hierarchy does not match the expected view hierarchy.
UNEXPECTED_VIEW_HIERARCHY = 4
def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser,
expected_app_screen: task_pb2.AppScreen):
self._adb_call_parser = adb_call_parser
self._expected_app_screen = expected_app_screen
self._expected_activity = expected_app_screen.activity
self._expected_view_hierarchy_path = [
re.compile(regex) for regex in expected_app_screen.view_hierarchy_path
]
# Return type is AppScreenChecker.Outcome, but pytype doesn't understand that.
def matches_current_app_screen(self) -> enum.IntEnum:
"""Determines whether the current app screen matches `expected_app_screen`."""
if not self._expected_activity:
return AppScreenChecker.Outcome.EMPTY_EXPECTED_ACTIVITY
# Check if we are still on the expected Activity.
response = self._adb_call_parser.parse(
adb_pb2.AdbRequest(
get_current_activity=adb_pb2.AdbRequest.GetCurrentActivity()))
if response.status != adb_pb2.AdbResponse.OK:
return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION
current_activity = response.get_current_activity.full_activity
if current_activity != self._expected_activity:
logging.error('current_activity: %s, expected_activity: %s',
current_activity, self._expected_activity)
return AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY
# Extract just the package name from the full activity name.
package_name = self._expected_activity.split('/')[0]
# Check if we are in the expected view hierarchy path.
if self._expected_view_hierarchy_path:
dumpsys_response = self._adb_call_parser.parse(
adb_pb2.AdbRequest(
dumpsys=adb_pb2.AdbRequest.DumpsysRequest(
service='activity', args=[package_name, package_name])))
if dumpsys_response.status != adb_pb2.AdbResponse.OK:
return AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION
if dumpsys_response.dumpsys.output:
if not matches_path(
dumpsys_response.dumpsys.output.decode('utf-8'),
self._expected_view_hierarchy_path,
max_levels=3):
return AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY
return AppScreenChecker.Outcome.SUCCESS
def wait_for_app_screen(self, timeout_sec: float) -> float:
"""Waits for `self._expected_app_screen` to be the current screen.
Args:
timeout_sec: Maximum total time to wait for the screen to pop up.
Returns:
The total amount of time in seconds spent waiting for the screen to pop
up.
Raises:
errors.WaitForAppScreenError if the screen does not pop up within
`timeout_sec`.
"""
logging.info('Waiting for app screen...')
start_time = time.time()
while time.time() - start_time < timeout_sec:
if self.matches_current_app_screen() == AppScreenChecker.Outcome.SUCCESS:
wait_time = time.time() - start_time
logging.info('Successfully waited for app screen in %r seconds: [%r]',
wait_time, self._expected_app_screen)
return wait_time
time.sleep(0.1)
wait_time = time.time() - start_time
logging.error('Failed to wait for app screen in %r seconds: [%r].',
wait_time, self._expected_app_screen)
raise errors.WaitForAppScreenError()
================================================
FILE: android_env/components/app_screen_checker_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.app_screen_checker."""
import re
from unittest import mock
from absl.testing import absltest
from android_env.components import adb_call_parser
from android_env.components import app_screen_checker
from android_env.components import errors
from android_env.proto import adb_pb2
from android_env.proto import task_pb2
def _flatten_tree(
tree: app_screen_checker._DumpsysNode, flat_tree: list[str], indent: int = 2
):
"""Appends a list of strings to `flat_tree` from `tree`."""
flat_tree.append(' ' * indent + tree.data)
for c in tree.children:
_flatten_tree(c, flat_tree, indent + 2)
class AppScreenCheckerTest(absltest.TestCase):
# Ensures that build_tree_from_dumpsys_output produces a node whose flat
# representation matches our expectation from an arbitrary hierarchy.
def test_build_tree_from_dumpsys_output(self):
dumpsys_output = """
Queen Elizabeth II
Charles
William
George
Charlotte
Louis
Harry
Archie
Anne
Peter
Savannah
Isla
Zara
Mia
Lena
Andrew
Beatrice
Eugenie
Edward
Louise
James
"""
tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output)
flat_tree = []
_flatten_tree(tree, flat_tree, indent=2)
self.assertEqual(flat_tree, [
' ___root___',
' Queen Elizabeth II',
' Charles',
' William',
' George',
' Charlotte',
' Louis',
' Harry',
' Archie',
' Anne',
' Peter',
' Savannah',
' Isla',
' Zara',
' Mia',
' Lena',
' Andrew',
' Beatrice',
' Eugenie',
' Edward',
' Louise',
' James',
])
# Ensures that build_tree_from_dumpsys_output produces a node whose flat
# representation matches our expectation from an arbitrary hierarchy.
def test_build_forest_from_dumpsys_output(self):
dumpsys_output = """
Tree1
Branch1
Leaf1
Leaf2
Branch2
Leaf3
Leaf4
Leaf5
Tree2
Branch3
Leaf6
Leaf7
Branch4
Leaf8
Leaf9
Leaf10
Leaf11
"""
tree = app_screen_checker.build_tree_from_dumpsys_output(dumpsys_output)
flat_tree = []
_flatten_tree(tree, flat_tree, indent=2)
self.assertEqual(flat_tree, [
' ___root___',
' Tree1',
' Branch1',
' Leaf1',
' Leaf2',
' Branch2',
' Leaf3',
' Leaf4',
' Leaf5',
' Tree2',
' Branch3',
' Leaf6',
' Leaf7',
' Branch4',
' Leaf8',
' Leaf9',
' Leaf10',
' Leaf11',
])
def test_no_view_hierarchy_matches_path(self):
dumpsys_output = """
TASK
ACTIVITY
Missing View Hierarchy
A
B
C
D
E
F
"""
expected_path = ['^A$', 'B$']
expected_view_hierarchy_path = [
re.compile(regex) for regex in expected_path
]
self.assertFalse(
app_screen_checker.matches_path(dumpsys_output,
expected_view_hierarchy_path))
def test_matches_path(self):
dumpsys_output = """
TASK
ACTIVITY
Some node we don't care
Blah
View Hierarchy
Hirohito
Akihito
Naruhito
Aiko
Fumihito
Mako
Kako
Hisahito
Masahito
"""
expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$']
expected_view_hierarchy_path = [
re.compile(regex) for regex in expected_path
]
self.assertTrue(
app_screen_checker.matches_path(
dumpsys_output, expected_view_hierarchy_path, max_levels=2))
# Also check that the following path does not match anything in the tree.
expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kenji$']
expected_view_hierarchy_path = [
re.compile(regex) for regex in expected_path
]
self.assertFalse(
app_screen_checker.matches_path(dumpsys_output,
expected_view_hierarchy_path))
def test_matches_path_one_level_deep(self):
dumpsys_output = """
TASK
ACTIVITY
Some node we don't care
Blah
Some intermediate node
View Hierarchy
Hirohito
Akihito
Naruhito
Aiko
Fumihito
Mako
Kako
Hisahito
Masahito
"""
expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$']
expected_view_hierarchy_path = [
re.compile(regex) for regex in expected_path
]
self.assertTrue(
app_screen_checker.matches_path(
dumpsys_output, expected_view_hierarchy_path, max_levels=3))
# Also check that the view hierarchy is not found when searching only grand
# children of TASK.
expected_path = ['^Hirohito$', 'Akihito$', 'Fumihito$', 'Kako$']
expected_view_hierarchy_path = [
re.compile(regex) for regex in expected_path
]
self.assertFalse(
app_screen_checker.matches_path(
dumpsys_output, expected_view_hierarchy_path, max_levels=2))
def test_wait_for_app_screen_zero_timeout(self):
"""Ensures that an exception is raised if the timeout is passed."""
app_screen = task_pb2.AppScreen(activity='whatever.MyActivity')
call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
screen_checker = app_screen_checker.AppScreenChecker(
adb_call_parser=call_parser,
expected_app_screen=app_screen)
# With a zero timeout, the method should never be able to wait for the
# screen to pop up and an exception should be raised.
self.assertRaises(
errors.WaitForAppScreenError,
screen_checker.wait_for_app_screen,
timeout_sec=0.0)
def test_wait_for_app_screen_successful(self):
"""Ensures that with the right conditions, the app screen should pop up."""
app_screen = task_pb2.AppScreen(activity='my.favorite.AwesomeActivity')
call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
call_parser.parse.return_value = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK,
get_current_activity=adb_pb2.AdbResponse.GetCurrentActivityResponse(
full_activity='my.favorite.AwesomeActivity'))
screen_checker = app_screen_checker.AppScreenChecker(
call_parser, app_screen)
timeout = 1.0
wait_time = screen_checker.wait_for_app_screen(timeout_sec=timeout)
# The call should not generate an exception and the return value should be
# less than the timeout given.
self.assertLess(wait_time, timeout)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/config_classes.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataclass definitions used for instantiating AndroidEnv components."""
import dataclasses
import enum
@dataclasses.dataclass
class AdbControllerConfig:
"""Settings for instatiating an `AdbController` instance."""
# Filesystem path to the `adb` binary.
# NOTE: This must be a full path and must not contain environment variables
# or user folder shorthands (e.g. `~/some/path/to/adb`) since they will not be
# expanded internally by AndroidEnv.
adb_path: str = '~/Android/Sdk/platform-tools/adb'
# Port for adb server.
adb_server_port: int = 5037
# Default timeout in seconds for internal commands.
default_timeout: float = 120.0
# Name of the device to communicate with.
device_name: str = ''
# Whether to use adb server port set in OS Environment variables.
# When True, adb will use the ANDROID_ADB_SERVER_PORT OS environment variable
# for selecting its server port if available (or the default 5037 if not set).
# When False, the ANDROID_ADB_SERVER_PORT OS environment var will be unset
# and adb_server_port will be used and supplied as an argument to all adb
# commands.
use_adb_server_port_from_os_env: bool = False
@dataclasses.dataclass
class DeviceSettingsConfig:
"""Config class for DeviceSettings."""
# Whether to show circles on the screen indicating touch position.
show_touches: bool = True
# Whether to show blue lines on the screen indicating touch position.
show_pointer_location: bool = True
# Whether or not to show the status (top) bar.
show_status_bar: bool = False
# Whether or not to show the navigation (bottom) bar.
show_navigation_bar: bool = False
@dataclasses.dataclass
class CoordinatorConfig:
"""Config class for Coordinator."""
# Number of virtual "fingers" of the agent.
num_fingers: int = 1
# Whether to enable keyboard key events.
enable_key_events: bool = False
# Time between periodic restarts in minutes. If > 0, will trigger
# a simulator restart at the beginning of the next episode once the time has
# been reached.
periodic_restart_time_min: float = 0.0
# General Android settings.
device_settings: DeviceSettingsConfig = dataclasses.field(
default_factory=DeviceSettingsConfig
)
@dataclasses.dataclass
class SimulatorConfig:
"""Base class for all simulator configs."""
# If true, the log stream of the simulator will be verbose.
verbose_logs: bool = False
# How often to (asynchronously) grab the screenshot from the simulator.
# If <= 0, stepping the environment blocks on fetching the screenshot (the
# environment is synchronous).
interaction_rate_sec: float = 0.0
@enum.unique
class GPUMode(enum.Enum):
"""Emulator GPU Mode."""
HOST = 'host'
SWANGLE_INDIRECT = 'swangle_indirect'
SWIFTSHADER_INDIRECT = 'swiftshader_indirect'
@dataclasses.dataclass
class EmulatorLauncherConfig:
"""Config class for EmulatorLauncher."""
# NOTE: If `adb_port`, `emulator_console_port` and `grpc_port` are defined
# (i.e. not all equal to 0), it is assumed that the emulator they point to
# exists already and EmulatorLauncher will be skipped.
# Filesystem path to the `emulator` binary.
emulator_path: str = '~/Android/Sdk/emulator/emulator'
# Filesystem path to the Android SDK root.
android_sdk_root: str = '~/Android/Sdk'
# Name of the AVD.
avd_name: str = ''
# Local directory for AVDs.
android_avd_home: str = '~/.android/avd'
# Name of the snapshot to load.
snapshot_name: str = ''
# Path to the KVM device.
kvm_device: str = '/dev/kvm'
# Path to directory which will hold temporary files.
tmp_dir: str = '/tmp/android_env/simulator/'
# GPU mode override.
# Please see
# https://developer.android.com/studio/run/emulator-acceleration#accel-graphics.
gpu_mode: str = GPUMode.SWANGLE_INDIRECT.value
# Whether to run in headless mode (i.e. without a graphical window).
run_headless: bool = True
# Whether to restrict network access.
# If True, will disable networking on the device. This option is only
# available for emulator version > 31.3.9 (June 2022).
restrict_network: bool = False
# Whether to set `SHOW_PERF_STATS=1` when launching the emulator to display
# performance and memory statistics.
show_perf_stats: bool = False
# ADB port for the Android device.
adb_port: int = 0
# Port for telnet communication with the emulator.
emulator_console_port: int = 0
# Port for gRPC communication with the emulator.
grpc_port: int = 0
@dataclasses.dataclass
class EmulatorConfig(SimulatorConfig):
"""Config class for EmulatorSimulator."""
# Configuration for launching the Android Emulator.
emulator_launcher: EmulatorLauncherConfig = dataclasses.field(
default_factory=EmulatorLauncherConfig
)
# Configuration for talking to adb.
adb_controller: AdbControllerConfig = dataclasses.field(
default_factory=AdbControllerConfig
)
# Path to file which holds emulator logs. If not provided, it will be
# determined by the EmulatorLauncher.
logfile_path: str = ''
# The number of times to try launching the emulator before rebooting (reboot
# on the n+1-st try).
launch_n_times_without_reboot: int = 1
# The number of times to try launching the emulator before reinstalling
# (reinstall on the n+1-st try).
launch_n_times_without_reinstall: int = 2
@dataclasses.dataclass
class FakeSimulatorConfig(SimulatorConfig):
"""Config class for FakeSimulator."""
# The dimensions in pixels of the device screen (HxW).
screen_dimensions: tuple[int, int] = (0, 0)
@dataclasses.dataclass
class TaskManagerConfig:
"""Config class for TaskManager."""
# If max_bad_states episodes finish in a bad state in a row, restart
# the simulation.
max_bad_states: int = 3
# The frequency to check for the current activity and view hierarchy.
# The unit is raw observation (i.e. each call to AndroidEnv.step()).
dumpsys_check_frequency: int = 150
# The maximum number of tries for extracting the current activity before
# forcing the episode to restart.
max_failed_current_activity: int = 10
# The maximum number of extras elements to store. If this number is exceeded,
# elements are dropped in the order they were received.
extras_max_buffer_size: int = 100
@dataclasses.dataclass
class TaskConfig:
"""Base config class for loading tasks."""
# The directory for temporary task-related resources.
tmp_dir: str = ''
@dataclasses.dataclass
class FilesystemTaskConfig(TaskConfig):
"""Config for protobuf files stored in the local filesystem."""
# Filesystem path to `.binarypb` or `.textproto` protobuf Task.
path: str = ''
@dataclasses.dataclass
class AndroidEnvConfig:
"""Config class for AndroidEnv."""
# Configs for main components.
task: TaskConfig = dataclasses.field(default_factory=TaskConfig)
task_manager: TaskManagerConfig = dataclasses.field(
default_factory=TaskManagerConfig
)
coordinator: CoordinatorConfig = dataclasses.field(
default_factory=CoordinatorConfig
)
simulator: SimulatorConfig = dataclasses.field(default_factory=EmulatorConfig)
================================================
FILE: android_env/components/coordinator.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Coordinator handles interaction between internal components of AndroidEnv."""
import copy
import time
from typing import Any
from absl import logging
from android_env.components import action_fns
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components import device_settings as device_settings_lib
from android_env.components import errors
from android_env.components import specs
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators import base_simulator
from android_env.proto import adb_pb2
import dm_env
import numpy as np
class Coordinator:
"""Handles interaction between internal components of AndroidEnv."""
def __init__(
self,
simulator: base_simulator.BaseSimulator,
task_manager: task_manager_lib.TaskManager,
device_settings: device_settings_lib.DeviceSettings,
config: config_classes.CoordinatorConfig | None = None,
):
"""Handles communication between AndroidEnv and its components.
Args:
simulator: A BaseSimulator instance.
task_manager: The TaskManager, responsible for coordinating RL tasks.
config: Settings to customize this Coordinator.
"""
self._simulator = simulator
self._task_manager = task_manager
self._config = config or config_classes.CoordinatorConfig()
self._device_settings = device_settings
self._adb_call_parser: adb_call_parser.AdbCallParser = None
# Initialize stats.
self._stats = {
'relaunch_count': 0,
'relaunch_count_periodic': 0,
'relaunch_count_setup_steps': 0,
'relaunch_count_reset_steps': 0,
'relaunch_count_simulator_launch': 0,
'relaunch_count_simulator_reset': 0,
'relaunch_count_execute_action': 0,
'relaunch_count_fetch_observation': 0,
'relaunch_count_update_settings': 0,
'failed_task_updates': 0,
}
# Initialize counters.
self._simulator_healthy = False
self._latest_observation_time = 0
self._simulator_start_time = None
logging.info('Starting the simulator...')
self._launch_simulator()
def action_spec(self) -> dict[str, dm_env.specs.Array]:
return specs.base_action_spec(
num_fingers=self._config.num_fingers,
enable_key_events=self._config.enable_key_events,
)
def observation_spec(self) -> dict[str, dm_env.specs.Array]:
return specs.base_observation_spec(
height=self._device_settings.screen_height(),
width=self._device_settings.screen_width(),
)
def _should_periodic_relaunch(self) -> bool:
"""Checks if it is time to restart the simulator.
If a periodic restart time was specified, the Coordinator will re-launch
the simulator at regular time intervals. This helps to make sure that the
simulator is not in a stale state even if the environment has been running
for a significant amount of time.
Returns:
Boolean indicating if it is time to restart the simulator.
"""
if self._config.periodic_restart_time_min and self._simulator_start_time:
sim_alive_time = (time.time() - self._simulator_start_time) / 60.0
logging.info('Simulator has been running for %f mins', sim_alive_time)
if sim_alive_time > self._config.periodic_restart_time_min:
logging.info('Maximum alive time reached. Restarting simulator.')
self._stats['relaunch_count_periodic'] += 1
return True
return False
def _launch_simulator(self, max_retries: int = 3):
"""Launches the simulator.
Sets up the simulator and other task-related settings.
Args:
max_retries: Number of times to attempt a restart before raising an error.
"""
self._simulator_healthy = False
# Attempt to restart the system a given number of times.
num_tries = 1
latest_error = None
while True:
if num_tries > max_retries:
raise errors.TooManyRestartsError(
'Maximum number of restart attempts reached.'
) from latest_error
logging.info('Simulator launch attempt %d of %d', num_tries, max_retries)
self._task_manager.stop()
# Launch the simulator.
self._simulator.launch()
self._simulator_start_time = time.time()
# From here on, the simulator is assumed to be up and running.
self._adb_call_parser = self._create_adb_call_parser()
try:
self._device_settings.update(self._config.device_settings)
except errors.AdbControllerError as e:
logging.exception('device_settings.update() failed.')
self._stats['relaunch_count_update_settings'] += 1
self._latest_error = e
num_tries += 1
continue
# Start the task.
self._task_manager.start(
adb_call_parser_factory=self._create_adb_call_parser,
log_stream=self._simulator.create_log_stream(),
)
try:
self._task_manager.setup_task()
except errors.StepCommandError as error:
logging.exception('Failed to set up the task. Restarting simulator.')
self._stats['relaunch_count_setup_steps'] += 1
latest_error = error
num_tries += 1
continue
# Restart was successful.
self._simulator_healthy = True
self._stats['relaunch_count'] += 1
break
def _create_adb_call_parser(self):
"""Creates a new AdbCallParser instance."""
return adb_call_parser.AdbCallParser(
adb_controller=self._simulator.create_adb_controller()
)
def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
return self._adb_call_parser.parse(call)
def rl_reset(self) -> dm_env.TimeStep:
"""Resets the RL episode."""
# Relaunch the simulator if necessary.
if not self._simulator_healthy or self._should_periodic_relaunch():
self._launch_simulator()
# Reset counters.
self._latest_observation_time = 0
for key in self._stats:
if key.startswith('episode'):
self._stats[key] = 0.0
# Execute a lift action before resetting the task.
if not action_fns.send_action_to_simulator(
action_fns.lift_all_fingers_action(self._config.num_fingers),
self._simulator,
self._device_settings.screen_width(),
self._device_settings.screen_height(),
self._config.num_fingers,
):
self._stats['relaunch_count_execute_action'] += 1
self._simulator_healthy = False
# Reset the task.
self._task_manager.reset_task()
self._device_settings.get_orientation()
# Get data from the simulator.
simulator_signals = self._gather_simulator_signals()
return self._task_manager.rl_reset(simulator_signals)
def rl_step(self, agent_action: dict[str, np.ndarray]) -> dm_env.TimeStep:
"""Executes the selected action and returns a timestep.
Args:
agent_action: Selected action to perform on the simulated Android device.
If `agent_action` is `None` it means that this is an RL reset (to start
a new episode).
Returns:
An RL timestep.
"""
if not action_fns.send_action_to_simulator(
agent_action,
self._simulator,
self._device_settings.screen_width(),
self._device_settings.screen_height(),
self._config.num_fingers,
):
self._stats['relaunch_count_execute_action'] += 1
self._simulator_healthy = False
# Get data from the simulator.
try:
simulator_signals = self._gather_simulator_signals()
except errors.ReadObservationError:
logging.exception('Unable to fetch observation. Restarting simulator.')
self._stats['relaunch_count_fetch_observation'] += 1
self._simulator_healthy = False
if not self._simulator_healthy:
return dm_env.truncation(reward=0.0, observation=None)
return self._task_manager.rl_step(simulator_signals)
def _gather_simulator_signals(self) -> dict[str, np.ndarray]:
"""Gathers data from various sources to assemble the RL observation."""
# Get current timestamp and update the delta.
now = time.time()
timestamp_delta = (
0
if self._latest_observation_time == 0
else (now - self._latest_observation_time) * 1e6
)
self._latest_observation_time = now
return {
'pixels': self._simulator.get_screenshot(),
'orientation': self._device_settings.get_orientation(),
'timedelta': np.array(timestamp_delta, dtype=np.int64),
}
def __del__(self):
self.close()
def stats(self) -> dict[str, Any]:
"""Returns various statistics."""
return copy.deepcopy(self._stats)
def close(self):
"""Cleans up the state of this Coordinator."""
if hasattr(self, '_task_manager'):
try:
self._task_manager.stop()
except: # pylint: disable=bare-except
logging.exception('Failed to stop task manager. Continuing.')
if hasattr(self, '_simulator'):
try:
self._simulator.close()
except: # pylint: disable=bare-except
logging.exception('Failed to close simulator. Continuing.')
================================================
FILE: android_env/components/coordinator_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.coordinator."""
import tempfile
import time
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import action_type
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import device_settings as device_settings_lib
from android_env.components import errors
from android_env.components import task_manager
from android_env.components.simulators import base_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
from android_env.proto import task_pb2
import dm_env
import numpy as np
class CoordinatorTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self.addCleanup(mock.patch.stopall) # Disable previous patches.
self._simulator = mock.create_autospec(base_simulator.BaseSimulator)
self._random_screenshot = np.random.randint(
low=0, high=255, size=(800, 600, 3), dtype=np.uint8)
self._simulator.get_screenshot.return_value = self._random_screenshot
self._task_manager = mock.create_autospec(task_manager.TaskManager)
self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
self.enter_context(
mock.patch.object(
adb_call_parser,
'AdbCallParser',
autospec=True,
return_value=self._adb_call_parser))
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
)
def tearDown(self):
super().tearDown()
self._coordinator.close()
@mock.patch.object(time, 'sleep', autospec=True)
def test_relaunch_simulator(self, unused_mock_sleep):
relaunch_count = self._coordinator.stats()['relaunch_count']
self._coordinator._launch_simulator()
self.assertEqual(self._coordinator.stats()['relaunch_count'],
relaunch_count + 1)
@mock.patch.object(time, 'sleep', autospec=True)
def test_reset(self, unused_mock_sleep):
"""'relaunch_count_execute_action' should be zero if there's no error."""
self._coordinator.rl_reset()
stats = self._coordinator.stats()
self.assertIn('relaunch_count_execute_action', stats)
self.assertEqual(stats['relaunch_count_execute_action'], 0)
@mock.patch.object(time, 'sleep', autospec=True)
def test_reset_error_sending_action(self, unused_mock_sleep):
"""'relaunch_count_execute_action' should be positive if there's an error."""
self._simulator.send_touch.side_effect = errors.SendActionError()
self._coordinator.rl_reset()
stats = self._coordinator.stats()
self.assertIn('relaunch_count_execute_action', stats)
self.assertEqual(stats['relaunch_count_execute_action'], 1)
@mock.patch.object(time, 'sleep', autospec=True)
def test_lift_all_fingers(self, unused_mock_sleep):
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
config=config_classes.CoordinatorConfig(num_fingers=3),
)
self._coordinator.rl_reset()
expected_actions = [
# (x, y, is_down, identifier).
(0, 0, False, 0),
(0, 0, False, 1),
(0, 0, False, 2),
]
actual_actions = self._simulator.send_touch.call_args[0][0]
for actual, expected in zip(actual_actions, expected_actions):
np.testing.assert_array_equal(actual, expected)
@mock.patch.object(time, 'sleep', autospec=True)
def test_process_action(self, unused_mock_sleep):
def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=10.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})
self._task_manager.rl_step.side_effect = fake_rl_step
timestep = self._coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
obs = timestep.observation
self.assertEqual(obs['pixels'].shape, (800, 600, 3))
np.testing.assert_equal(obs['orientation'],
np.array([0, 0, 0, 0], dtype=np.uint8))
self.assertEqual(timestep.reward, 10.0)
self.assertEqual(obs['extras'], {'extra': [0.0]})
self.assertFalse(timestep.last())
@mock.patch.object(time, 'sleep', autospec=True)
def test_process_action_error(self, unused_mock_sleep):
def fake_rl_step(simulator_signals):
self.assertFalse(simulator_signals['simulator_healthy'])
return dm_env.truncation(reward=0.0, observation=None)
self._task_manager.rl_step.side_effect = fake_rl_step
self._simulator.get_screenshot.side_effect = errors.ReadObservationError()
timestep = self._coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.LIFT),
'touch_position': np.array([0.5, 0.5]),
})
self.assertIsNone(timestep.observation)
self.assertEqual(timestep.reward, 0.0)
self.assertTrue(timestep.last())
@mock.patch.object(time, 'sleep', autospec=True)
def test_execute_action_touch(self, unused_mock_sleep):
def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=123.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})
self._task_manager.rl_step.side_effect = fake_rl_step
timestep = self._coordinator.rl_step(
agent_action={
'action_type': np.array(action_type.ActionType.TOUCH),
'touch_position': np.array([0.5, 0.5])
})
self.assertEqual(timestep.reward, 123.0)
np.testing.assert_equal(timestep.observation['pixels'],
self._random_screenshot)
self._simulator.send_touch.assert_called_once_with([(300, 400, True, 0)])
@mock.patch.object(time, 'sleep', autospec=True)
def test_execute_multitouch_action(self, unused_mock_sleep):
self._coordinator = coordinator_lib.Coordinator(
simulator=self._simulator,
task_manager=self._task_manager,
device_settings=device_settings_lib.DeviceSettings(self._simulator),
config=config_classes.CoordinatorConfig(num_fingers=3),
)
def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=456.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})
self._task_manager.rl_step.side_effect = fake_rl_step
action = {
'action_type': np.array([action_type.ActionType.TOUCH]),
'touch_position': np.array([0.25, 0.75]),
'action_type_2': np.array([action_type.ActionType.TOUCH]),
'touch_position_2': np.array([0.75, 0.25]),
'action_type_3': np.array([action_type.ActionType.LIFT]),
'touch_position_3': np.array([0.5, 0.5]),
}
timestep = self._coordinator.rl_step(action)
self._simulator.send_touch.assert_called_once_with([(150, 600, True, 0),
(450, 200, True, 1),
(300, 400, False, 2)])
self.assertEqual(timestep.reward, 456.0)
np.testing.assert_equal(timestep.observation['pixels'],
self._random_screenshot)
@mock.patch.object(time, 'sleep', autospec=True)
def test_execute_action_repeat(self, unused_mock_sleep):
def fake_rl_step(simulator_signals):
return dm_env.transition(
reward=10.0,
observation={
'pixels': simulator_signals['pixels'],
'orientation': simulator_signals['orientation'],
'timedelta': simulator_signals['timedelta'],
'extras': {
'extra': [0.0]
}
})
self._task_manager.rl_step.side_effect = fake_rl_step
timestep = self._coordinator.rl_step(
{'action_type': np.array(action_type.ActionType.REPEAT)})
self._simulator.send_touch.assert_not_called()
np.testing.assert_equal(timestep.observation['pixels'],
self._random_screenshot)
@mock.patch.object(time, 'sleep', autospec=True)
def test_execute_action_error(self, unused_mock_sleep):
def fake_rl_step(simulator_signals):
self.assertFalse(simulator_signals['simulator_healthy'])
return dm_env.truncation(reward=0.0, observation=None)
self._task_manager.rl_step.side_effect = fake_rl_step
self._simulator.send_touch.side_effect = errors.SendActionError
timestep = self._coordinator.rl_step({
'action_type': np.array(action_type.ActionType.TOUCH),
'touch_position': np.array([0.3, 0.8])
})
self.assertIsNone(timestep.observation)
@mock.patch.object(time, 'sleep', autospec=True)
def test_max_restarts_setup_steps(self, unused_mock_sleep):
init_fn_call = self._task_manager.setup_task.call_count
self._task_manager.setup_task.side_effect = errors.StepCommandError
self.assertRaises(errors.TooManyRestartsError,
self._coordinator._launch_simulator)
# The method was called three more times when attempting to relaunch.
self.assertEqual(init_fn_call + 3,
self._task_manager.setup_task.call_count)
@mock.patch.object(time, 'sleep', autospec=True)
def test_execute_adb_call(self, unused_mock_sleep):
call = adb_pb2.AdbRequest(
force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah'))
expected_response = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
self._adb_call_parser.parse.side_effect = [expected_response]
response = self._coordinator.execute_adb_call(call)
self.assertEqual(response, expected_response)
self._adb_call_parser.parse.assert_called_with(call)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/device_settings.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Sets and gets some global settings on an Android device."""
from typing import Final
from unittest import mock
from absl import logging
from android_env.components import adb_call_parser
from android_env.components import config_classes
from android_env.components.simulators import base_simulator
from android_env.proto import adb_pb2
import numpy as np
# The internal `AdbCallParser` instance is lazily instantiated within
# `DeviceSettings`. If we make it optional (i.e. `| None`), pytype will think
# that it could be `None`, requiring either explicit runtime checks or escape
# hatches in every actual call, even if it's never actually `None` if reached
# via the public API.
# The trick here is to create this dummy instance of the right type that's used
# as a sentinel to indicate that it hasn't been initialized yet.
_PLACEHOLDER_ADB_CALL_PARSER: Final[adb_call_parser.AdbCallParser] = (
mock.create_autospec(adb_call_parser.AdbCallParser)
)
class DeviceSettings:
"""An abstraction for general properties and settings of an Android device."""
def __init__(self, simulator: base_simulator.BaseSimulator):
self._simulator = simulator
self._adb_call_parser = _PLACEHOLDER_ADB_CALL_PARSER
# The size of the device screen in pixels.
self._screen_width: int = 0
self._screen_height: int = 0
# The device orientation.
self._orientation = np.zeros(4, dtype=np.uint8)
def update(self, config: config_classes.DeviceSettingsConfig) -> None:
"""Sets the configuration of the device according to `config`."""
if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER:
self._adb_call_parser = adb_call_parser.AdbCallParser(
adb_controller=self._simulator.create_adb_controller()
)
self._update_screen_size()
self._set_show_touches(config.show_touches)
self._set_show_pointer_location(config.show_pointer_location)
self._set_status_navigation_bars(
config.show_navigation_bar, config.show_status_bar
)
def screen_width(self) -> int:
"""The screen width in pixels. Only valid after `update()` is called."""
return self._screen_width
def screen_height(self) -> int:
"""The screen height in pixels. Only valid after `update()` is called."""
return self._screen_height
def get_orientation(self) -> np.ndarray:
"""Returns the device orientation. Please see specs.py for details."""
if self._adb_call_parser is _PLACEHOLDER_ADB_CALL_PARSER:
self._adb_call_parser = adb_call_parser.AdbCallParser(
adb_controller=self._simulator.create_adb_controller()
)
self._update_orientation()
return self._orientation
def _update_screen_size(self) -> None:
"""Sets the screen size from a screenshot ignoring the color channel."""
screenshot = self._simulator.get_screenshot()
self._screen_height = screenshot.shape[0]
self._screen_width = screenshot.shape[1]
def _set_show_touches(self, show: bool) -> None:
"""Whether to display circles indicating the touch position."""
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='show_touches', value='1' if show else '0'
),
)
)
)
def _set_show_pointer_location(self, show: bool) -> None:
"""Whether to display blue lines on the screen indicating touch position."""
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SYSTEM,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='pointer_location', value='1' if show else '0'
),
)
)
)
def _set_status_navigation_bars(
self, show_navigation: bool, show_status: bool
) -> None:
"""Whether to display the status (top) and navigation (bottom) bars."""
if show_navigation and show_status:
policy_control_value = 'null*'
elif show_navigation and not show_status:
policy_control_value = 'immersive.status=*'
elif not show_navigation and show_status:
policy_control_value = 'immersive.navigation=*'
else:
policy_control_value = 'immersive.full=*'
self._adb_call_parser.parse(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='policy_control', value=policy_control_value
),
)
)
)
def _update_orientation(self) -> None:
"""Updates the current device orientation."""
# Skip fetching the orientation if we already have it.
if not np.all(self._orientation == np.zeros(4)):
return
orientation_response = self._adb_call_parser.parse(
adb_pb2.AdbRequest(
get_orientation=adb_pb2.AdbRequest.GetOrientationRequest()
)
)
if orientation_response.status != adb_pb2.AdbResponse.Status.OK:
logging.error('Got bad orientation: %r', orientation_response)
return
orientation = orientation_response.get_orientation.orientation
if orientation not in {0, 1, 2, 3}:
logging.error('Got bad orientation: %r', orientation)
return
# Transform into one-hot format.
orientation_onehot = np.zeros([4], dtype=np.uint8)
orientation_onehot[orientation] = 1
self._orientation = orientation_onehot
================================================
FILE: android_env/components/device_settings_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import config_classes
from android_env.components import device_settings as device_settings_lib
from android_env.components.simulators import base_simulator
import numpy as np
class DeviceSettingsTest(parameterized.TestCase):
def test_screen_size_before_update(self):
"""The screen size should be 0x0 without calling `update()`."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
device_settings = device_settings_lib.DeviceSettings(simulator)
# Act.
height = device_settings.screen_height()
width = device_settings.screen_width()
# Assert.
self.assertEqual(height, 0)
self.assertEqual(width, 0)
def test_screen_size_after_update(self):
"""The screen size should be set after calling `update()`."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
simulator.get_screenshot.return_value = np.random.randint(
low=0, high=255, size=(123, 456, 3), dtype=np.uint8
)
adb_controller = simulator.create_adb_controller.return_value
adb_controller.execute_command.return_value = b''
device_settings = device_settings_lib.DeviceSettings(simulator)
# Act.
device_settings.update(config_classes.DeviceSettingsConfig())
height = device_settings.screen_height()
width = device_settings.screen_width()
# Assert.
self.assertEqual(height, 123)
self.assertEqual(width, 456)
@parameterized.named_parameters(
(
'show_touches',
config_classes.DeviceSettingsConfig(show_touches=True),
mock.call(
['shell', 'settings', 'put', 'system', 'show_touches', '1'],
timeout=None,
),
),
(
'show_touches_false',
config_classes.DeviceSettingsConfig(show_touches=False),
mock.call(
['shell', 'settings', 'put', 'system', 'show_touches', '0'],
timeout=None,
),
),
(
'show_pointer_location',
config_classes.DeviceSettingsConfig(show_pointer_location=True),
mock.call(
['shell', 'settings', 'put', 'system', 'pointer_location', '1'],
timeout=None,
),
),
(
'show_pointer_location_false',
config_classes.DeviceSettingsConfig(show_pointer_location=False),
mock.call(
['shell', 'settings', 'put', 'system', 'pointer_location', '0'],
timeout=None,
),
),
(
'show_navigation_and_status',
config_classes.DeviceSettingsConfig(
show_navigation_bar=True, show_status_bar=True
),
mock.call(
['shell', 'settings', 'put', 'global', 'policy_control', 'null*'],
timeout=None,
),
),
(
'show_navigation_and_no_status',
config_classes.DeviceSettingsConfig(
show_navigation_bar=True, show_status_bar=False
),
mock.call(
[
'shell',
'settings',
'put',
'global',
'policy_control',
'immersive.status=*',
],
timeout=None,
),
),
(
'show_no_navigation_and_status',
config_classes.DeviceSettingsConfig(
show_navigation_bar=False, show_status_bar=True
),
mock.call(
[
'shell',
'settings',
'put',
'global',
'policy_control',
'immersive.navigation=*',
],
timeout=None,
),
),
(
'show_no_navigation_and_no_status',
config_classes.DeviceSettingsConfig(
show_navigation_bar=False, show_status_bar=False
),
mock.call(
[
'shell',
'settings',
'put',
'global',
'policy_control',
'immersive.full=*',
],
timeout=None,
),
),
)
def test_update(
self, settings: config_classes.DeviceSettingsConfig, expected_call
):
"""We expect the right call for each setting."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
adb_controller = simulator.create_adb_controller.return_value
adb_controller.execute_command.return_value = b''
device_settings = device_settings_lib.DeviceSettings(simulator)
# Act.
device_settings.update(settings)
# Assert.
adb_controller.execute_command.assert_has_calls(
[expected_call], any_order=True
)
def test_get_orientation_bad_response(self):
"""The orientation should be unset if the underlying response is bad."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
adb_controller = simulator.create_adb_controller.return_value
adb_controller.execute_command.return_value = b''
device_settings = device_settings_lib.DeviceSettings(simulator)
# Act.
orientation = device_settings.get_orientation()
# Assert.
np.testing.assert_array_equal(orientation, np.zeros(4))
def test_get_orientation_bad_orientation(self):
"""The orientation should be unset if the underlying orientation is bad."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
adb_controller = simulator.create_adb_controller.return_value
adb_controller.execute_command.return_value = b' InputDeviceOrientation: 9'
device_settings = device_settings_lib.DeviceSettings(simulator)
# Act.
orientation = device_settings.get_orientation()
# Assert.
np.testing.assert_array_equal(orientation, np.zeros(4))
def test_get_orientation_success(self):
"""Checks that the orientation comes back as expected."""
# Arrange.
simulator = mock.create_autospec(base_simulator.BaseSimulator)
adb_controller = simulator.create_adb_controller.return_value
adb_controller.execute_command.return_value = b' InputDeviceOrientation: 3'
device_settings = device_settings_lib.DeviceSettings(simulator)
# Act.
orientation = device_settings.get_orientation()
# The output should be idempotent if the underlying system did not change.
orientation_again = device_settings.get_orientation()
# Assert.
np.testing.assert_array_equal(orientation, np.array([0, 0, 0, 1]))
np.testing.assert_array_equal(orientation, orientation_again)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/dumpsys_thread.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A ThreadFunction that runs and parses adb dumpsys."""
import concurrent.futures
from absl import logging
from android_env.components import app_screen_checker as app_screen_checker_lib
_Outcome = app_screen_checker_lib.AppScreenChecker.Outcome
class DumpsysThread:
"""A thread that checks if the user is in the expected app screen."""
def __init__(
self,
app_screen_checker: app_screen_checker_lib.AppScreenChecker,
check_frequency: int = 10,
max_failed_current_activity: int = 10,
):
"""Initializes the dumpsys reader thread.
This loops forever checking if the user is in the expected screen dictated
by `app_screen_checker`. These analyses are too expensive to be in the
critical path of AndroidEnv::step() so we consume them async from this
separate thread.
Args:
app_screen_checker: The class that actually determines if the current
screen matches the expected screen.
check_frequency: Integer. We only call dumpsys 1/check_frequency times in
each iteration of the while loop below.
max_failed_current_activity: Integer. We try to fetch the current activity
but sometimes it fails. If it fails more than
`max_failed_current_activity` consecutive times, we declare that the
user has exited `expected_activity`.
"""
self._app_screen_checker = app_screen_checker
self._main_loop_counter = 0
self._check_frequency = check_frequency
self._max_failed_activity_extraction = max_failed_current_activity
self._num_failed_activity_extraction = 0
self._latest_check: concurrent.futures.Future | None = None
def check_user_exited(self, timeout: float | None = None) -> bool:
"""Returns True if the user is not in the expected screen.
Args:
timeout: An optional time in seconds to block waiting for the result of
the (expensive) checking operation. If None, the function will return
immediately with `False`.
Returns:
Whether the user of the Android device has exited the expected screen
determined by `AppScreenChecker` given at __init__().
"""
# Update and check loop_counter against check_frequency.
self._main_loop_counter += 1
if (self._check_frequency <= 0 or
self._main_loop_counter < self._check_frequency):
return False
self._main_loop_counter = 0
# If the latest check is None, perform a check and return.
if self._latest_check is None:
with concurrent.futures.ThreadPoolExecutor() as executor:
self._latest_check = executor.submit(self._check_impl)
return False
# If there's a check in flight, continue only if it's finished.
if not timeout and not self._latest_check.done():
return False
v = self._latest_check.result(timeout=timeout)
self._latest_check = None # Reset the check.
return v
def _check_impl(self) -> bool:
"""The synchronous implementation of Dumpsys."""
outcome = self._app_screen_checker.matches_current_app_screen()
# We were unable to determine the current activity.
if outcome == _Outcome.FAILED_ACTIVITY_EXTRACTION:
self._num_failed_activity_extraction += 1
logging.info('self._num_failed_activity_extraction: %s',
self._num_failed_activity_extraction)
if (self._num_failed_activity_extraction >=
self._max_failed_activity_extraction):
logging.error('Maximum number of failed activity extraction reached.')
self._num_failed_activity_extraction = 0
return True
else:
self._num_failed_activity_extraction = 0
# The current app screen matches all expectations.
if (outcome == _Outcome.SUCCESS or
outcome == _Outcome.EMPTY_EXPECTED_ACTIVITY):
return False
# Player has exited the app. Terminate the episode.
elif outcome == _Outcome.UNEXPECTED_ACTIVITY:
return True
# Player has exited the main game. Terminate the episode.
elif outcome == _Outcome.UNEXPECTED_VIEW_HIERARCHY:
return True
return False
================================================
FILE: android_env/components/dumpsys_thread_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.dumpsys_thread."""
from unittest import mock
from absl.testing import absltest
from android_env.components import app_screen_checker as screen_checker
from android_env.components import dumpsys_thread
class DumpsysThreadTest(absltest.TestCase):
def setUp(self):
super().setUp()
self._app_screen_checker = mock.create_autospec(
screen_checker.AppScreenChecker)
def test_unexpected_activity(self):
dumpsys = dumpsys_thread.DumpsysThread(
app_screen_checker=self._app_screen_checker, check_frequency=1)
outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_ACTIVITY
self._app_screen_checker.matches_current_app_screen.return_value = outcome
# The first time that `check_user_exited()` is called, it'll only trigger
# the processing, but it should return immediately.
self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
# The second time it should then wait for the result.
self.assertTrue(dumpsys.check_user_exited(timeout=1.0))
def test_unexpected_view_hierarchy(self):
dumpsys = dumpsys_thread.DumpsysThread(
app_screen_checker=self._app_screen_checker, check_frequency=1)
outcome = screen_checker.AppScreenChecker.Outcome.UNEXPECTED_VIEW_HIERARCHY
self._app_screen_checker.matches_current_app_screen.return_value = outcome
self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
self.assertTrue(dumpsys.check_user_exited(timeout=1.0))
def test_success(self):
dumpsys = dumpsys_thread.DumpsysThread(
app_screen_checker=self._app_screen_checker, check_frequency=1)
outcome = screen_checker.AppScreenChecker.Outcome.SUCCESS
self._app_screen_checker.matches_current_app_screen.return_value = outcome
self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
def test_skipped(self):
dumpsys = dumpsys_thread.DumpsysThread(
app_screen_checker=self._app_screen_checker, check_frequency=5)
self._app_screen_checker.matches_current_app_screen.side_effect = [
screen_checker.AppScreenChecker.Outcome.SUCCESS,
screen_checker.AppScreenChecker.Outcome.FAILED_ACTIVITY_EXTRACTION
]
for _ in range(17):
self.assertFalse(dumpsys.check_user_exited(timeout=1.0))
# The first 4 calls will hit the early exit from `check_frequency`.
# The 5th call will trigger the processing (increasing the call count to
# matches_current_app_screen() by 1), but it should return early.
# The 10th call will find a result of the previous processing, and it should
# be SUCCESS.
# The next 4 calls (11, 12, 13, 14) will hit the early exit from
# `check_frequency`.
# The 15th call should trigger the processing again (increasing the call
# count to matches_current_app_screen() by 1), but it should return early.
# The next 2 calls (16, 17) will hit the early exit from `check_frequency`.
# In total there should be only two calls to `matches_current_app_screen()`.
expected_call_count = 2
self.assertEqual(
self._app_screen_checker.matches_current_app_screen.call_count,
expected_call_count)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/errors.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definitions of exceptions used by AndroidEnv."""
class AndroidEnvError(Exception):
"""Base class for all known errors generated by AndroidEnv."""
# An integer that identifies this class of error.
# Subclasses should use a different value.
ERROR_CODE: int = 0
class ReadObservationError(AndroidEnvError):
"""When the environment is unable to obtain an observation from a simulator."""
ERROR_CODE = 1
class CoordinatorError(AndroidEnvError):
"""Error raised by the Coordinator."""
ERROR_CODE = 2
class TooManyRestartsError(CoordinatorError):
"""The number of restarts has exceeded _MAX_RESTART_TRIES."""
ERROR_CODE = 3
class AdbControllerError(AndroidEnvError):
"""Errors that can be raised by ADBController."""
ERROR_CODE = 4
class SimulatorError(AndroidEnvError):
"""Errors that can be raised by a simulator."""
ERROR_CODE = 5
class SendActionError(AndroidEnvError):
"""Raised when action couldn't be sent successfully."""
ERROR_CODE = 6
class StepCommandError(AndroidEnvError):
"""Raised when setup step interpreter cannot process a command."""
ERROR_CODE = 7
class WaitForAppScreenError(StepCommandError):
"""Raised when the wait_for_app_screen success check is not met."""
ERROR_CODE = 8
class CheckInstallError(StepCommandError):
"""Raised when the check_install success check is not met."""
ERROR_CODE = 9
def from_code(code: int, msg: str = '') -> AndroidEnvError | None:
"""Returns an AndroidEnvError instance from the given arguments."""
code_to_error = {
0: AndroidEnvError,
1: ReadObservationError,
2: CoordinatorError,
3: TooManyRestartsError,
4: AdbControllerError,
5: SimulatorError,
6: SendActionError,
7: StepCommandError,
8: WaitForAppScreenError,
9: CheckInstallError,
}
if code in code_to_error:
return code_to_error[code](msg)
================================================
FILE: android_env/components/errors_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for errors.py."""
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import errors
class ErrorsTest(parameterized.TestCase):
@parameterized.parameters(
(errors.ReadObservationError, 1),
(errors.CoordinatorError, 2),
(errors.TooManyRestartsError, 3),
(errors.AdbControllerError, 4),
(errors.SimulatorError, 5),
(errors.SendActionError, 6),
(errors.StepCommandError, 7),
(errors.WaitForAppScreenError, 8),
(errors.CheckInstallError, 9),
)
def test_error_codes(self, error, expected_error_code):
with self.assertRaises(error) as context:
raise error()
self.assertEqual(context.exception.ERROR_CODE, expected_error_code)
def test_error_codes_unique(self):
error_codes = set()
errors_list = (
errors.ReadObservationError,
errors.CoordinatorError,
errors.TooManyRestartsError,
errors.AdbControllerError,
errors.SimulatorError,
errors.SendActionError,
errors.StepCommandError,
errors.WaitForAppScreenError,
errors.CheckInstallError,
)
for error in errors_list:
self.assertNotIn(error.ERROR_CODE, error_codes)
error_codes.add(error.ERROR_CODE)
@parameterized.parameters([
errors.ReadObservationError(),
errors.CoordinatorError(),
errors.TooManyRestartsError(),
errors.AdbControllerError(),
errors.SimulatorError(),
errors.SendActionError(),
errors.StepCommandError(),
errors.WaitForAppScreenError(),
errors.CheckInstallError(),
])
def test_all_errors_are_androidenv_errors(self, error):
self.assertIsInstance(error, errors.AndroidEnvError)
@parameterized.named_parameters([
('less_than_zero', -1),
# The largest `ERROR_CODE` is currently `CheckInstallError == 10`.
('greater_than_all_errors', 10 + 1),
('less_than_zero_float', -3.14159265),
('greater_than_all_errors_float', 123.456),
])
def test_from_code_unsupported_code(self, code: int):
"""Unsupported errors should raise `RuntimeError`."""
self.assertIsNone(errors.from_code(code))
@parameterized.parameters([
(-1, None, 'No such error code.'),
(0, errors.AndroidEnvError, 'hello'),
(0, errors.AndroidEnvError, ''),
(1, errors.ReadObservationError, 'Could not read obs.'),
(2, errors.CoordinatorError, 'Some error'),
(3, errors.TooManyRestartsError, 'Too many already...'),
(4, errors.AdbControllerError, 'Some adb error...'),
(5, errors.SimulatorError, 'Simulator is not coping.'),
(6, errors.SendActionError, 'Could not send action.'),
(7, errors.StepCommandError, 'Some issue setting up the task.'),
(8, errors.WaitForAppScreenError, 'Waited for too long!'),
(9, errors.CheckInstallError, 'App did not install correctly.'),
])
def test_from_code(self, code: int, expected_class: errors.AndroidEnvError,
msg: str):
"""`from_code` should produce consistent outputs for known errors."""
error = errors.from_code(code, msg)
if error is not None:
self.assertIsInstance(error, expected_class)
self.assertEqual(error.ERROR_CODE, code)
self.assertEqual(str(error), msg)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/log_stream.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract class for handling a stream of logs from a simulator."""
import abc
from collections.abc import Generator, Sequence
import threading
from absl import logging
class LogStream(metaclass=abc.ABCMeta):
"""Manages the stream of logs output by a simulator."""
def __init__(self, verbose: bool = False):
self._verbose = verbose
self._filters = []
self._should_stream = threading.Event()
def get_stream_output(self) -> Generator[str, None, None]:
"""Starts log process and returns the stream of logs."""
for line in self._get_stream_output():
if self._verbose:
logging.info('line: %r', line)
if self._should_stream.is_set():
yield line
@abc.abstractmethod
def _get_stream_output(self):
"""Starts log process and returns the stream of logs."""
pass
@abc.abstractmethod
def stop_stream(self) -> None:
"""Terminates the log stream.
NOTE: This should only be called _after_ `get_stream_output()`.
"""
def pause_stream(self) -> None:
"""No lines are yielded while the event is not set."""
logging.info('Pausing LogStream.')
self._should_stream.clear()
def resume_stream(self) -> None:
"""The stream will continue yielding lines if the event is set."""
logging.info('Resuming LogStream.')
self._should_stream.set()
def set_log_filters(self, log_filters: Sequence[str]):
"""Sets the filters for the log stream."""
self._filters = list(log_filters) + ['*:S']
================================================
FILE: android_env/components/log_stream_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for log_stream."""
from absl.testing import absltest
from android_env.components import log_stream
class FakeLogStream(log_stream.LogStream):
def __init__(self, filter_name: str):
super().__init__()
self._filter_name = filter_name
def _get_stream_output(self):
"""Starts a log process and returns the stream of logs."""
lines = [
f'{self._filter_name} fake_line_1',
'fake_line_2',
f'{self._filter_name} fake_line_3',
f'{self._filter_name} fake_line_4',
'fake_line_5',
'fake_line_6',
]
for line in lines:
if f'{self._filter_name}:V' in self._filters:
if self._filter_name in line:
yield line
else:
yield line
def stop_stream(self):
"""Stops the log stream from the simulator."""
pass
class LogStreamTest(absltest.TestCase):
def test_get_stream_output(self):
filter_name = 'AndroidRLTask'
stream = FakeLogStream(filter_name=filter_name)
stream.resume_stream()
stream_output = stream.get_stream_output()
expected_lines = [
f'{filter_name} fake_line_1',
'fake_line_2',
f'{filter_name} fake_line_3',
f'{filter_name} fake_line_4',
'fake_line_5',
'fake_line_6',
]
for line, expected_line in zip(stream_output, expected_lines):
self.assertEqual(line, expected_line)
def test_set_log_filters(self):
filter_name = 'AndroidRLTask'
stream = FakeLogStream(filter_name=filter_name)
stream.set_log_filters([f'{filter_name}:V'])
stream.resume_stream()
stream_output = stream.get_stream_output()
expected_lines = [
f'{filter_name} fake_line_1',
f'{filter_name} fake_line_3',
f'{filter_name} fake_line_4',
]
for line, expected_line in zip(stream_output, expected_lines):
self.assertEqual(line, expected_line)
def test_pause_resume_stream(self):
filter_name = 'AndroidRLTask'
stream = FakeLogStream(filter_name=filter_name)
stream.resume_stream()
stream_output = stream.get_stream_output()
expected_lines = [
f'{filter_name} fake_line_1',
'fake_line_2',
f'{filter_name} fake_line_3',
f'{filter_name} fake_line_4',
'fake_line_5',
'fake_line_6',
]
for line, expected_line in zip(stream_output, expected_lines):
self.assertEqual(line, expected_line)
# If the stream is paused, we expect no lines to be yielded.
stream.pause_stream()
stream_output = list(stream.get_stream_output())
self.assertEmpty(stream_output)
# If the stream is resumed, we expect to see all lines yielded.
stream.resume_stream()
stream_output = stream.get_stream_output()
for line, expected_line in zip(stream_output, expected_lines):
self.assertEqual(line, expected_line)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/logcat_thread.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A class that launches a thread to read Android log outputs."""
from collections.abc import Callable
import dataclasses
import re
import threading
from absl import logging
from android_env.components import log_stream as log_stream_lib
@dataclasses.dataclass
class EventListener:
"""A function that's called when an event is triggered."""
regexp: re.Pattern[str]
handler_fn: Callable[[re.Pattern[str], re.Match[str]], None]
class LogcatThread:
"""Reads log entries in a separate thread."""
def __init__(self, log_stream: log_stream_lib.LogStream):
"""Initializes this LogcatThread with optional filters.
Please see https://developer.android.com/studio/command-line/logcat for more
info on `logcat`.
Args:
log_stream: Stream of logs from simulator.
"""
self._log_stream = log_stream
self._listeners = {}
self._line_ready = threading.Event()
self._line_ready.set()
self._should_stop = threading.Event()
self._thread = threading.Thread(target=self._process_logs)
self._thread.daemon = True
self._thread.start()
def add_event_listener(self, event_listener: EventListener) -> None:
"""Adds `fn` to the list of handlers to call when `event` occurs."""
event_regexp = event_listener.regexp
if event_regexp not in self._listeners:
self._listeners[event_regexp] = []
self._listeners[event_regexp].append(event_listener.handler_fn)
def remove_event_listener(self, event_listener: EventListener) -> None:
"""Removes `fn` from the list of handlers to call when `event` occurs."""
event_regexp = event_listener.regexp
if event_regexp not in self._listeners:
logging.error('Event: %r is not registered.', event_regexp)
return
self._listeners[event_regexp].remove(event_listener.handler_fn)
def line_ready(self) -> threading.Event:
"""Indicates whether all listeners have been notified for a given line."""
return self._line_ready
def pause(self):
self._log_stream.pause_stream()
def resume(self):
"""Resume or restart the thread if it's dead after resetting environment."""
if not self._thread.is_alive():
self._should_stop.clear()
self._thread = threading.Thread(target=self._process_logs)
self._thread.daemon = True
self._thread.start()
self._log_stream.resume_stream()
def kill(self):
self._should_stop.set()
self._log_stream.stop_stream()
self._thread.join(timeout=3.0)
def _process_logs(self) -> None:
"""A loop that runs until `self._should_stop` is set()."""
# pylint: disable=g-line-too-long
# Format is: "TIME_SEC PID TID PRIORITY TAG: MESSAGE"
#
# Example:
# ' 1553110400.424 5583 5658 D NostalgicRacer: com.google.example.games.nostalgicracer.views.renderers.OpenGLRenderDriver@912fb8.onSurfaceChanged 480x320' #
# pylint: enable=g-line-too-long
logline_regexp = r"""
^ # Beginning of the line.
[ ]+(?P[0-9]+\.[0-9]+) # Spaces and a float.
[ ]+(?P[0-9]+) # Spaces and an int.
[ ]+(?P[0-9]+) # Spaces and an int.
[ ]+(?P.) # Spaces and any single character.
[ ]+(?P[^:]*): # Spaces and any char that's not ':'.
[ ](?P.*)$ # The actual log message.
"""
logline_re = re.compile(logline_regexp, re.VERBOSE)
for line in self._log_stream.get_stream_output():
if self._should_stop.is_set():
break
if not line: # Skip empty lines.
continue
matches = logline_re.match(line)
if not matches or len(matches.groups()) != 6:
continue
# Make sure that values are not read until all listeners are notified.
self._line_ready.clear()
# We're currently only consuming `message`, but we may use the other
# fields in the future.
content = matches.group('message')
for ev, listeners in self._listeners.items():
ev_matches = ev.match(content)
if ev_matches:
for listener in listeners: # Notify listeners.
listener(ev, ev_matches)
self._line_ready.set() # Allow consumers to read values.
================================================
FILE: android_env/components/logcat_thread_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.logcat_thread."""
import re
import threading
from absl.testing import absltest
from android_env.components import log_stream
from android_env.components import logcat_thread
from android_env.proto import task_pb2
class FakeStream:
"""This class simulates the logs coming from ADB."""
def __init__(self):
self._values = []
self._kill = False
self._lock = threading.Lock()
def send_value(self, value):
with self._lock:
self._values.append(value)
def has_next_value(self):
return bool(self._values)
def kill(self):
self._kill = True
def reset(self):
self._kill = False
def __iter__(self):
while True:
if self._kill:
return
if not self._values:
continue
else:
with self._lock:
next_value = self._values.pop(0)
yield next_value
def make_stdout(data):
"""Returns a valid log output with given data as message."""
return ' 1553110400.424 5583 5658 D Tag: %s' % data
class FakeLogStream(log_stream.LogStream):
"""FakeLogStream class that wraps a FakeStream."""
def __init__(self):
super().__init__(verbose=False)
self.logs = FakeStream()
self.stream_is_alive = True
def _get_stream_output(self):
return self.logs
def stop_stream(self):
self.stream_is_alive = False
self.logs.kill()
def reset(self):
self.stream_is_alive = True
self.logs.reset()
class LogcatThreadTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.fake_log_stream = FakeLogStream()
def tearDown(self):
self.fake_log_stream.stop_stream()
super().tearDown()
def test_set_filters(self):
log_parsing_config = task_pb2.LogParsingConfig(filters=['AndroidRLTask:V'])
self.fake_log_stream.set_log_filters(log_parsing_config.filters)
_ = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
expected_filters = ['AndroidRLTask:V', '*:S']
self.assertEqual(expected_filters, self.fake_log_stream._filters)
def test_kill(self):
logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
self.assertTrue(self.fake_log_stream.stream_is_alive)
logcat.kill()
self.assertFalse(self.fake_log_stream.stream_is_alive)
def test_listeners(self):
"""Ensures that we can wait for a specific message without polling."""
logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
# Start yielding lines from LogStream.
logcat.resume()
# Set up a listener that modifies an arbitrary state.
some_state = threading.Event()
def my_handler(event: re.Pattern[str], match: re.Match[str]):
del event, match
nonlocal some_state
some_state.set()
# Create a desired event and hook up the listener.
my_event = re.compile('Hello world')
listener = logcat_thread.EventListener(my_event, my_handler)
logcat.add_event_listener(listener)
self.fake_log_stream.logs.send_value('Hi there!') # This should not match.
self.assertFalse(some_state.is_set())
self.fake_log_stream.logs.send_value(make_stdout('Hello world'))
some_state.wait(timeout=1.0)
self.assertTrue(some_state.is_set())
# Waiting for any events should also trigger the listener.
some_state.clear()
self.fake_log_stream.logs.send_value(make_stdout('Hello world'))
some_state.wait(timeout=1.0)
self.assertTrue(some_state.is_set())
# After removing the listener, it should not be called anymore.
some_state.clear()
logcat.remove_event_listener(listener)
self.fake_log_stream.logs.send_value(make_stdout('Hello world'))
some_state.wait(timeout=1.0)
self.assertFalse(some_state.is_set())
def test_resume_does_not_recreate_alive_thread(self):
logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
thread_before = logcat._thread
self.assertTrue(thread_before.is_alive())
logcat.resume()
thread_after = logcat._thread
self.assertTrue(thread_after.is_alive())
self.assertIs(thread_before, thread_after)
def test_resume_recreates_thread(self):
logcat = logcat_thread.LogcatThread(log_stream=self.fake_log_stream)
self.assertTrue(logcat._thread.is_alive())
logcat.kill()
self.assertFalse(logcat._thread.is_alive())
self.assertTrue(logcat._should_stop.is_set())
self.fake_log_stream.reset()
logcat.resume()
self.assertTrue(logcat._thread.is_alive())
self.assertFalse(logcat._should_stop.is_set())
self.assertTrue(logcat._thread.daemon)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/pixel_fns.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for AndroidEnv."""
from collections.abc import Sequence
from dm_env import specs
import numpy as np
def touch_position_to_pixel_position(
touch_position: np.ndarray,
width_height: Sequence[int],
) -> tuple[int, int]:
"""Maps touch position in [0,1] to the corresponding pixel on the screen."""
touch_pixels = (touch_position * width_height).astype(np.int32)
cap_idx = lambda v, idx_len: min(v, idx_len - 1)
return tuple(map(cap_idx, touch_pixels, width_height))
def transpose_pixels(frame: np.ndarray) -> np.ndarray:
"""Converts image from shape (H, W, C) to (W, H, C) and vice-versa."""
return np.transpose(frame, axes=(1, 0, 2))
def orient_pixels(frame: np.ndarray, orientation: int) -> np.ndarray:
"""Rotates screen pixels according to the given orientation."""
match orientation:
case 0: # PORTRAIT_90
return frame
case 1: # LANDSCAPE_90
return np.rot90(frame, k=3, axes=(0, 1))
case 2: # PORTRAIT_180
return np.rot90(frame, k=2, axes=(0, 1))
case 3: # LANDSCAPE_270
return np.rot90(frame, k=1, axes=(0, 1))
case _:
raise ValueError(
'Orientation must be an integer in [0, 3] but is %r' % orientation
)
def convert_int_to_float(data: np.ndarray, data_spec: specs.Array):
"""Converts an array of int values to floats between 0 and 1."""
if not np.issubdtype(data.dtype, np.integer):
raise TypeError(f'{data.dtype} is not an integer type')
if isinstance(data_spec, specs.BoundedArray):
value_min = data_spec.minimum
value_max = data_spec.maximum
else:
# We use the int type to figure out the boundaries.
iinfo = np.iinfo(data_spec.dtype)
value_min = iinfo.min
value_max = iinfo.max
return np.float32(1.0 * (data - value_min) / (value_max - value_min))
================================================
FILE: android_env/components/pixel_fns_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for pixel_fns."""
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import pixel_fns
from dm_env import specs
import numpy as np
class UtilsTest(parameterized.TestCase):
@parameterized.parameters(
([0.5, 0.5], [320, 480], (160, 240)),
([0.25, 0.75], [320, 480], (80, 360)),
([0.0, 0.0], [320, 480], (0, 0)),
([1.0, 1.0], [320, 480], (319, 479)),
)
def test_touch_position_to_pixel_position(
self, touch_pos, width_height, pixel_pos):
self.assertEqual(
pixel_fns.touch_position_to_pixel_position(
np.array(touch_pos), width_height
),
pixel_pos,
)
def test_transpose_pixels(self):
image = np.reshape(np.array(range(12)), (3, 2, 2))
expected = [[[0, 1], [4, 5], [8, 9]], [[2, 3], [6, 7], [10, 11]]]
self.assertEqual(pixel_fns.transpose_pixels(image).shape, (2, 3, 2))
self.assertTrue((pixel_fns.transpose_pixels(image) == expected).all())
def test_orient_pixels(self):
image = np.reshape(np.array(range(12)), (3, 2, 2))
expected_90 = [[[8, 9], [4, 5], [0, 1]], [[10, 11], [6, 7], [2, 3]]]
rot_90 = 1 # LANDSCAPE_90
rotated = pixel_fns.orient_pixels(image, rot_90)
self.assertEqual(rotated.shape, (2, 3, 2))
self.assertTrue((rotated == expected_90).all())
expected_180 = [[[10, 11], [8, 9]], [[6, 7], [4, 5]], [[2, 3], [0, 1]]]
rot_180 = 2 # PORTRAIT_180
rotated = pixel_fns.orient_pixels(image, rot_180)
self.assertEqual(rotated.shape, (3, 2, 2))
self.assertTrue((rotated == expected_180).all())
expected_270 = [[[2, 3], [6, 7], [10, 11]], [[0, 1], [4, 5], [8, 9]]]
rot_270 = 3 # LANDSCAPE_270
rotated = pixel_fns.orient_pixels(image, rot_270)
self.assertEqual(rotated.shape, (2, 3, 2))
self.assertTrue((rotated == expected_270).all())
rot_0 = 0 # PORTRAIT_0
rotated = pixel_fns.orient_pixels(image, rot_0)
self.assertEqual(rotated.shape, (3, 2, 2))
self.assertTrue((rotated == image).all())
def test_convert_int_to_float_bounded_array(self):
spec = specs.BoundedArray(
shape=(4,),
dtype=np.int32,
minimum=[0, 1, 10, -2],
maximum=[5, 5, 20, 2],
name='bounded_array')
data = np.array([2, 2, 10, 0], dtype=np.int32)
float_data = pixel_fns.convert_int_to_float(data, spec)
np.testing.assert_equal(
np.array([2.0 / 5.0, 1.0 / 4.0, 0.0, 0.5], dtype=np.float32), float_data
)
def test_convert_int_to_float_bounded_array_broadcast(self):
spec = specs.BoundedArray(
shape=(3,), dtype=np.int16, minimum=2, maximum=4, name='bounded_array')
data = np.array([2, 3, 4], dtype=np.int16)
float_data = pixel_fns.convert_int_to_float(data, spec)
np.testing.assert_equal(
np.array([0.0, 0.5, 1.0], dtype=np.float32), float_data)
def test_convert_int_to_float_no_bounds(self):
spec = specs.Array(
shape=(3,),
dtype=np.int8, # int8 implies min=-128, max=127
name='bounded_array')
data = np.array([-128, 0, 127], dtype=np.int16)
float_data = pixel_fns.convert_int_to_float(data, spec)
np.testing.assert_equal(
np.array([0.0, 128. / 255., 1.0], dtype=np.float32), float_data)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/setup_step_interpreter.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A component that parses and processes SetupSteps."""
from collections.abc import Sequence
import copy
import time
from typing import Any
from absl import logging
from android_env.components import adb_call_parser as adb_call_parser_lib
from android_env.components import app_screen_checker
from android_env.components import errors
from android_env.proto import adb_pb2
from android_env.proto import task_pb2
class SetupStepInterpreter:
"""An interpreter for SetupSteps."""
def __init__(self, adb_call_parser: adb_call_parser_lib.AdbCallParser):
"""Initializes this interpreter.
Args:
adb_call_parser: An object to communicate with Android via ADB.
"""
self._adb_call_parser = adb_call_parser
self._stats = {
'error_count_adb_request': 0,
'error_count_wait_for_app_screen': 0,
'error_count_check_install': 0,
'error_count_wait_for_message': 0,
'total_time_waiting_for_app_screen': 0
}
def stats(self) -> dict[str, Any]:
return copy.deepcopy(self._stats)
def interpret(self, setup_steps: Sequence[task_pb2.SetupStep]) -> None:
"""Returns True if parsing and processing `setup_steps` is successful."""
if setup_steps:
logging.info('Executing setup steps: %s', setup_steps)
for step in setup_steps:
self._process_step_command(step)
logging.info('Done executing setup steps.')
def _process_step_command(self, step_cmd: task_pb2.SetupStep) -> None:
"""Processes a single step command from a reset or extra setup."""
if not step_cmd:
logging.info('Empty step_cmd')
return
logging.info('Executing step_cmd: %r', step_cmd)
step_type = step_cmd.WhichOneof('step')
success_condition = step_cmd.success_condition
success_check = success_condition.WhichOneof('check')
assert step_type or success_check, (
'At least one of step and success_condition must be defined.')
num_tries = 0
max_retries = max(success_condition.num_retries, 3)
latest_error = None
while num_tries < max_retries:
num_tries += 1
try:
unused_adb_response = self._execute_step_cmd(step_cmd, step_type)
time.sleep(0.5)
self._check_success(success_check, success_condition)
return
except NotImplementedError:
logging.exception('Not implemented error! Skipping this step command.')
return
except errors.AdbControllerError as error:
latest_error = error
self._stats['error_count_adb_request'] += 1
logging.exception('ADB call [%r] has failed. Try %d of %d.',
step_cmd.adb_request, num_tries, max_retries)
except errors.WaitForAppScreenError as error:
latest_error = error
self._stats['error_count_wait_for_app_screen'] += 1
logging.exception('Failed to wait for app screen. Try %d of %d.',
num_tries, max_retries)
except errors.CheckInstallError as error:
latest_error = error
self._stats['error_count_check_install'] += 1
logging.exception('Package [%r] not installed. Try %d of %d.',
success_condition.check_install.package_name,
num_tries, max_retries)
raise errors.StepCommandError(
f'Step failed: [{step_cmd}]') from latest_error
def _execute_step_cmd(
self, step_cmd: task_pb2.SetupStep, step_type: str | None
) -> adb_pb2.AdbResponse | None:
"""Executes a step command of given type."""
match step_type:
case None:
return None
case 'sleep':
time.sleep(step_cmd.sleep.time_sec)
return None
case 'adb_request':
response = self._adb_call_parser.parse(step_cmd.adb_request)
if response.status != adb_pb2.AdbResponse.Status.OK:
raise errors.AdbControllerError(
f'Failed to execute AdbRequest [{step_cmd.adb_request}].\n'
f'Status: {response.status}\n'
f'Error: {response.error_message}'
)
return response
case _:
raise NotImplementedError(f'No step command of type [{step_type}].')
def _check_success(
self,
success_check: str | None,
success_condition: task_pb2.SuccessCondition,
) -> None:
"""Checks whether the given success condition was met."""
match success_check:
case None:
return None
case 'wait_for_app_screen':
wait_for_app_screen = success_condition.wait_for_app_screen
screen_checker = app_screen_checker.AppScreenChecker(
adb_call_parser=self._adb_call_parser,
expected_app_screen=wait_for_app_screen.app_screen,
)
wait_time = screen_checker.wait_for_app_screen(
timeout_sec=wait_for_app_screen.timeout_sec
)
self._stats['total_time_waiting_for_app_screen'] += wait_time
case 'check_install':
self._check_install(success_condition.check_install)
case _:
raise NotImplementedError(f'No success check called [{success_check}].')
def _check_install(self, check_install: task_pb2.CheckInstall) -> None:
"""Checks that the given package is installed."""
package = check_install.package_name
logging.info('Checking if package is installed: [%r]', package)
request = adb_pb2.AdbRequest(
package_manager=adb_pb2.AdbRequest.PackageManagerRequest(
list=adb_pb2.AdbRequest.PackageManagerRequest.List(
packages=adb_pb2.AdbRequest.PackageManagerRequest.List.Packages(
))))
start_time = time.time()
while time.time() - start_time < check_install.timeout_sec:
response = self._adb_call_parser.parse(request)
if package in response.package_manager.list.items:
logging.info('Done confirming that package is installed.')
return
time.sleep(0.1)
logging.error('Package not found.')
raise errors.CheckInstallError()
================================================
FILE: android_env/components/setup_step_interpreter_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.setup_step_interpreter."""
from unittest import mock
from absl.testing import absltest
from android_env.components import adb_call_parser
from android_env.components import errors
from android_env.components import setup_step_interpreter
from android_env.proto import adb_pb2
from android_env.proto import task_pb2
from google.protobuf import text_format
def _to_proto(proto_class, text):
proto = proto_class()
text_format.Parse(text, proto)
return proto
class SetupStepInterpreterTest(absltest.TestCase):
def setUp(self):
super().setUp()
self._parser = mock.create_autospec(
adb_call_parser.AdbCallParser, instance=True)
def test_empty_setup_steps(self):
"""Simple test where nothing should break, and nothing should be done.
The test simply expects this test to not crash.
"""
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([])
def test_none_setup_steps(self):
"""Simple test where nothing should break, and nothing should be done.
The test simply expects this test to not crash.
"""
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
# Empty setup steps should be ignored.
interpreter.interpret([])
def test_invalid_setup_step(self):
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
# Empty setup steps should be ignored.
self.assertRaises(AssertionError, interpreter.interpret,
[task_pb2.SetupStep()])
def test_adb_install_apk_filesystem(self):
self._parser.parse.return_value = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
adb_request: {
install_apk: {
filesystem: {
path: "/my/favorite/dir/my_apk.apk"
}
}
}""")
])
self._parser.parse.assert_called_once_with(
adb_pb2.AdbRequest(
install_apk=adb_pb2.AdbRequest.InstallApk(
filesystem=adb_pb2.AdbRequest.InstallApk.Filesystem(
path='/my/favorite/dir/my_apk.apk'))))
def test_adb_force_stop(self):
self._parser.parse.return_value = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
adb_request: { force_stop: { package_name: "my.app.Activity" } }""")
])
self._parser.parse.assert_called_once_with(
adb_pb2.AdbRequest(
force_stop=adb_pb2.AdbRequest.ForceStop(
package_name='my.app.Activity')))
def test_adb_start_activity(self):
self._parser.parse.return_value = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
adb_request: {
start_activity: {
full_activity: "my.app.Activity"
extra_args: "arg1"
extra_args: "arg2"
}
}""")
])
self._parser.parse.assert_called_once_with(
adb_pb2.AdbRequest(
start_activity=adb_pb2.AdbRequest.StartActivity(
full_activity='my.app.Activity', extra_args=['arg1', 'arg2'])))
def test_adb_single_tap(self):
self._parser.parse.return_value = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(task_pb2.SetupStep, """
adb_request: {
tap: {
x: 321
y: 654
}
}""")
])
self._parser.parse.assert_called_once_with(
adb_pb2.AdbRequest(tap=adb_pb2.AdbRequest.Tap(x=321, y=654)))
def test_adb_press_button(self):
self._parser.parse.return_value = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(task_pb2.SetupStep,
""" adb_request: { press_button: { button: HOME } }""")
])
self._parser.parse.assert_called_once_with(
adb_pb2.AdbRequest(
press_button=adb_pb2.AdbRequest.PressButton(
button=adb_pb2.AdbRequest.PressButton.Button.HOME)))
self._parser.reset_mock()
interpreter.interpret([
_to_proto(task_pb2.SetupStep,
""" adb_request: { press_button: { button: BACK } }""")
])
self._parser.parse.assert_called_once_with(
adb_pb2.AdbRequest(
press_button=adb_pb2.AdbRequest.PressButton(
button=adb_pb2.AdbRequest.PressButton.Button.BACK)))
def test_adb_start_screen_pinning(self):
self._parser.parse.return_value = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
adb_request: {
start_screen_pinning: {
full_activity: "my.app.HighlanderApp" # "There can be only one".
}
}""")
])
self._parser.parse.assert_called_once_with(
adb_pb2.AdbRequest(
start_screen_pinning=adb_pb2.AdbRequest.StartScreenPinning(
full_activity='my.app.HighlanderApp')))
@mock.patch('time.sleep')
def test_time_sleep(self, mock_sleep):
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret(
[_to_proto(task_pb2.SetupStep, """sleep: { time_sec: 0.875 }""")])
assert mock_sleep.call_count == 2
mock_sleep.assert_has_calls([mock.call(0.875), mock.call(0.5)])
@mock.patch('time.sleep')
def test_wait_for_app_screen_empty_activity(self, unused_mock_sleep):
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
with self.assertRaises(errors.StepCommandError):
interpreter.interpret([
_to_proto(task_pb2.SetupStep,
"""success_condition: {wait_for_app_screen: { }}""")
])
@mock.patch('time.sleep')
def test_check_install_not_installed(self, unused_mock_sleep):
self._parser.parse.return_value = adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
'com.some.package',
'not.what.you.are.looking.for',
])))
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
with self.assertRaises(errors.StepCommandError):
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
success_condition: {
check_install: {
package_name: "faz"
timeout_sec: 0.0001
}
}
""")
])
def test_check_install_installed(self):
self._parser.parse.return_value = adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
'com.some.package',
'baz',
])))
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
# The test checks that this command raises no AssertionError.
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
success_condition: {
check_install: {
package_name: "baz"
timeout_sec: 0.0001
}
}""")
])
def test_num_retries_failure(self):
self._parser.parse.side_effect = [
adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(
items=[]))),
] * 3
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
with self.assertRaises(errors.StepCommandError):
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
success_condition: {
check_install: {
package_name: "faz"
timeout_sec: 0.0001
}
num_retries: 3
}""")
])
# We retried 3 times after the first call, so we expect 3+1 calls.
self.assertEqual(self._parser.parse.call_count, 3)
@mock.patch('time.sleep')
def test_num_retries_success(self, unused_mock_sleep):
self._parser.parse.side_effect = [
adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(
items=[]))),
adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(
items=[]))),
adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
'com.some.package',
'bar',
]))),
adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[])))
]
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
success_condition: {
check_install: {
package_name: "bar"
timeout_sec: 0.0001
}
num_retries: 5
}""")
])
# The check should succeed on the third try.
self.assertEqual(self._parser.parse.call_count, 3)
def test_retry_step(self):
self._parser.parse.side_effect = [
adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(
items=[]))),
adb_pb2.AdbResponse(
package_manager=adb_pb2.AdbResponse.PackageManagerResponse(
list=adb_pb2.AdbResponse.PackageManagerResponse.List(items=[
'com.some.package',
'bar',
]))),
]
interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=self._parser)
interpreter.interpret([
_to_proto(
task_pb2.SetupStep, """
success_condition: {
check_install: {
package_name: "bar"
timeout_sec: 0.0001
}
num_retries: 2
}""")
])
# We expect the check to fail once and succeed on the second pass.
self.assertEqual(self._parser.parse.call_count, 2)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/simulators/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/components/simulators/base_simulator.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A base class for talking to different types of Android simulators."""
import abc
from collections.abc import Callable
import threading
import time
from absl import logging
from android_env.components import adb_controller
from android_env.components import config_classes
from android_env.components import errors
from android_env.components import log_stream
from android_env.proto import state_pb2
import numpy as np
class BaseSimulator(metaclass=abc.ABCMeta):
"""An interface for communicating with an Android simulator."""
def __init__(self, config: config_classes.SimulatorConfig):
"""Instantiates a BaseSimulator object.
The simulator may be an emulator, virtual machine or even a physical device.
Each simulator has its own AdbController that is used for internal
bookkeeping.
Args:
config: Settings for this simulator.
"""
self._config = config
self._interaction_thread: InteractionThread | None = None
# An increasing number that tracks the attempt at launching the simulator.
self._num_launch_attempts: int = 0
def get_logs(self) -> str:
"""Returns logs recorded by the simulator (if provided)."""
return 'No simulator logs provided.'
@abc.abstractmethod
def adb_device_name(self) -> str:
"""Returns the device name that the adb client will connect to."""
@abc.abstractmethod
def create_adb_controller(self) -> adb_controller.AdbController:
"""Returns an ADB controller which can communicate with this simulator."""
@abc.abstractmethod
def create_log_stream(self) -> log_stream.LogStream:
"""Creates a stream of logs from the simulator."""
def launch(self) -> None:
"""Starts the simulator."""
# Stop screenshot thread if it's enabled.
if self._interaction_thread is not None:
self._interaction_thread.stop()
self._interaction_thread.join()
self._num_launch_attempts += 1
try:
self._launch_impl()
except Exception as error:
for line in self.get_logs().splitlines():
logging.error(line)
raise errors.SimulatorError(
'Exception caught in simulator. Please see the simulator logs '
'above for more details.'
) from error
# Start interaction thread.
if self._config.interaction_rate_sec > 0:
self._interaction_thread = InteractionThread(
self._get_screenshot_impl, self._config.interaction_rate_sec
)
self._interaction_thread.start()
@abc.abstractmethod
def _launch_impl(self) -> None:
"""Platform specific launch implementation."""
@abc.abstractmethod
def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None:
"""Sends a touch event to be executed on the simulator.
Args:
touches: A list of touch events. Each element in the list corresponds to a
single touch event. Each touch event tuple should have:
0 x: The horizontal coordinate of this event.
1 y: The vertical coordinate of this event.
2 is_down: Whether the finger is touching or not the screen.
3 identifier: Identifies a particular finger in a multitouch event.
"""
@abc.abstractmethod
def send_key(self, keycode: np.int32, event_type: str) -> None:
"""Sends a keyboard event.
Args:
keycode: Represents a specific keyboard key. This is platform and
simulator-specific.
event_type: Type of key event to be sent.
"""
def load_state(
self, request: state_pb2.LoadStateRequest
) -> state_pb2.LoadStateResponse:
"""Loads a state.
Args:
request: A `LoadStateRequest` containing any parameters necessary to
specify how/what state to load.
Returns:
A `LoadStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
raise NotImplementedError('This simulator does not support load_state()')
def save_state(
self, request: state_pb2.SaveStateRequest
) -> state_pb2.SaveStateResponse:
"""Saves a state.
Args:
request: A `SaveStateRequest` containing any parameters necessary to
specify how/what state to save.
Returns:
A `SaveStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
raise NotImplementedError('This simulator does not support save_state()')
def get_screenshot(self) -> np.ndarray:
"""Returns pixels representing the current screenshot of the simulator."""
if self._config.interaction_rate_sec > 0:
assert self._interaction_thread is not None
return self._interaction_thread.screenshot() # Async mode.
else:
return self._get_screenshot_impl() # Sync mode.
@abc.abstractmethod
def _get_screenshot_impl(self) -> np.ndarray:
"""Actual implementation of `get_screenshot()`.
The output numpy array should have shape [height, width, num_channels] and
can be loaded into PIL using Image.fromarray(img, mode='RGB') and be saved
as a PNG file using my_pil.save('/tmp/my_screenshot.png', 'PNG').
"""
def close(self):
"""Frees up resources allocated by this object."""
if self._interaction_thread is not None:
self._interaction_thread.stop()
self._interaction_thread.join()
class InteractionThread(threading.Thread):
"""A thread that gets screenshot in the background."""
def __init__(
self,
get_screenshot_fn: Callable[[], np.ndarray],
interaction_rate_sec: float,
):
super().__init__()
self._get_screenshot_fn = get_screenshot_fn
self._interaction_rate_sec = interaction_rate_sec
self._should_stop = threading.Event()
self._screenshot = self._get_screenshot_fn()
def run(self):
last_read = time.time()
while not self._should_stop.is_set():
self._screenshot = self._get_screenshot_fn()
now = time.time()
elapsed = now - last_read
last_read = now
sleep_time = self._interaction_rate_sec - elapsed
if sleep_time > 0.0:
time.sleep(sleep_time)
logging.info('InteractionThread.run() finished.')
def stop(self):
logging.info('Stopping InteractionThread.')
self._should_stop.set()
def screenshot(self) -> np.ndarray:
return self._screenshot
================================================
FILE: android_env/components/simulators/base_simulator_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import itertools
import time
from unittest import mock
from absl.testing import absltest
from android_env.components import config_classes
from android_env.components import errors
# fake_simulator.FakeSimulator inherits from BaseSimulator, so there's no need
# to import it here explicitly.
from android_env.components.simulators import base_simulator
from android_env.components.simulators.fake import fake_simulator
import numpy as np
class BaseSimulatorTest(absltest.TestCase):
def test_launch(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480))
)
# The simulator should launch and not crash.
simulator.launch()
def test_launch_close(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig()
)
# The simulator should launch and not crash.
simulator.launch()
# Closing the simulator should also not crash.
simulator.close()
def test_get_screenshot(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(640, 480))
)
# The simulator should launch and not crash.
simulator.launch()
screenshot = simulator.get_screenshot()
np.testing.assert_equal(screenshot.shape, [640, 480, 3])
def test_print_logs_on_exception(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig()
)
with mock.patch.object(
simulator, 'get_logs'
) as mock_get_logs, mock.patch.object(
simulator, '_launch_impl', autospec=True
) as mock_launch:
mock_launch.side_effect = ValueError('Oh no!')
self.assertRaises(errors.SimulatorError, simulator.launch)
mock_get_logs.assert_called_once()
def test_get_screenshot_error_async(self):
"""An exception in the underlying interaction thread should bubble up."""
# Arrange.
mock_interaction_thread = mock.create_autospec(
base_simulator.InteractionThread
)
mock_interaction_thread.screenshot.side_effect = (
errors.ReadObservationError()
)
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(interaction_rate_sec=0.5)
)
with mock.patch.object(
base_simulator,
'InteractionThread',
autospec=True,
return_value=mock_interaction_thread,
):
simulator.launch()
# Act & Assert.
self.assertRaises(errors.ReadObservationError, simulator.get_screenshot)
# Cleanup.
simulator.close()
def test_get_screenshot_faster_than_screenshot_impl(self):
"""Return same screenshot when step is faster than the interaction rate."""
# Arrange.
slow_rate = 0.5
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(interaction_rate_sec=slow_rate)
)
# Act.
with mock.patch.object(
simulator, '_get_screenshot_impl', autospec=True
) as mock_get_screenshot_impl:
mock_get_screenshot_impl.side_effect = (
np.array(i, ndmin=3) for i in itertools.count(0, 1)
)
simulator.launch()
# Get two screenshots one after the other without pausing.
screenshot1 = simulator.get_screenshot()
screenshot2 = simulator.get_screenshot()
# Assert.
self.assertAlmostEqual(screenshot1[0][0][0], screenshot2[0][0][0])
# Cleanup.
simulator.close()
def test_get_screenshot_slower_than_screenshot_impl(self):
"""Return different screenshots when step slower than the interaction rate."""
# Arrange.
fast_rate = 0.01
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(interaction_rate_sec=fast_rate)
)
# Act.
with mock.patch.object(
simulator, '_get_screenshot_impl', autospec=True
) as mock_get_screenshot_impl:
mock_get_screenshot_impl.side_effect = (
np.array(i, ndmin=3) for i in itertools.count(0, 1)
)
simulator.launch()
# Sleep for 500ms between two screenshots.
screenshot1 = simulator.get_screenshot()
time.sleep(0.5)
screenshot2 = simulator.get_screenshot()
# Assert.
self.assertNotEqual(screenshot1[0][0][0], screenshot2[0][0][0])
# Cleanup.
simulator.close()
def test_interaction_thread_closes_upon_relaunch(self):
"""Async interaction should kill the InteractionThread when relaunching."""
# Arrange.
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(interaction_rate_sec=0.01)
)
mock_interaction_thread = mock.create_autospec(
base_simulator.InteractionThread
)
# Act & Assert.
with mock.patch.object(
base_simulator,
'InteractionThread',
autospec=True,
return_value=mock_interaction_thread,
):
simulator.launch()
mock_interaction_thread.stop.assert_not_called()
mock_interaction_thread.join.assert_not_called()
simulator.launch()
mock_interaction_thread.stop.assert_called_once()
mock_interaction_thread.join.assert_called_once()
simulator.close()
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/simulators/emulator/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/components/simulators/emulator/emulator_launcher.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Prepares and launches an emulator process."""
import glob
import os
import subprocess
import tempfile
from absl import logging
from android_env.components import config_classes
class EmulatorLauncher:
"""Handles launching an emulator."""
def __init__(
self,
config: config_classes.EmulatorLauncherConfig,
adb_controller_config: config_classes.AdbControllerConfig,
):
"""Launches an emulator."""
self._config = config
self._adb_controller_config = adb_controller_config
self._emulator = None
self._emulator_output = None
self._is_closed = False
# Create directory for tmp files.
# Note: this will be deleted once EmulatorLauncher instance is cleaned up.
os.makedirs(config.tmp_dir, exist_ok=True)
self._local_tmp_dir_handle = tempfile.TemporaryDirectory(
dir=config.tmp_dir, prefix='simulator_instance_'
)
self._local_tmp_dir = self._local_tmp_dir_handle.name
self._logfile_path = os.path.join(self._local_tmp_dir, 'emulator_output')
logging.info('Simulator local_tmp_dir: %s', self._local_tmp_dir)
def logfile_path(self) -> str:
return self._logfile_path
def launch_emulator_process(self) -> None:
"""Launches the emulator."""
logging.info('Booting new emulator: %s', self._config.emulator_path)
# Set necessary environment variables.
base_lib_dir = self._config.emulator_path[:-8] + 'lib64/'
ld_library_path = ':'.join([
base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/',
base_lib_dir + 'gles_swiftshader/', base_lib_dir
])
extra_env_vars = {
'ANDROID_HOME': '',
'ANDROID_SDK_ROOT': self._config.android_sdk_root,
'ANDROID_AVD_HOME': self._config.android_avd_home,
'ANDROID_EMULATOR_KVM_DEVICE': self._config.kvm_device,
'ANDROID_ADB_SERVER_PORT': str(
self._adb_controller_config.adb_server_port
),
'LD_LIBRARY_PATH': ld_library_path,
'QT_XKB_CONFIG_ROOT': str(
self._config.emulator_path[:-8] + 'qt_config/'
),
'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1',
'SHOW_PERF_STATS': str(1 if self._config.show_perf_stats else 0),
}
logging.info('extra_env_vars: %s',
' '.join(f'{k}={v}' for k, v in extra_env_vars.items()))
env_vars = dict(os.environ).copy()
env_vars.update(extra_env_vars)
# Compile command.
grpc_port = (
['-grpc', str(self._config.grpc_port)]
if self._config.grpc_port >= 0
else []
)
run_headless = (
['-no-skin', '-no-window'] if self._config.run_headless else []
)
ports = [
'-ports',
'%s,%s' % (self._config.emulator_console_port, self._config.adb_port),
]
snapshot = [
'-snapshot',
self._config.snapshot_name,
'-feature',
'AllowSnapshotMigration,MigratableSnapshotSave',
]
snapshot = snapshot if self._config.snapshot_name else ['-no-snapshot']
restrict_network_args = [
'-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options',
'restrict=y'
]
network_args = (
restrict_network_args if self._config.restrict_network else []
)
command = (
[
self._config.emulator_path,
'-adb-path',
self._adb_controller_config.adb_path,
'-gpu',
self._config.gpu_mode,
'-no-audio',
'-show-kernel',
'-verbose',
'-avd',
self._config.avd_name,
]
+ grpc_port
+ run_headless
+ ports
+ snapshot
+ network_args
)
logging.info('Emulator launch command: %s', ' '.join(command))
# Prepare logfile.
self._emulator_output = open(self._logfile_path, 'wb')
# Spawn the emulator process.
self._emulator = subprocess.Popen(
command,
env=env_vars,
stdout=self._emulator_output,
stderr=self._emulator_output)
def confirm_shutdown(self) -> None:
"""Shuts down the emulator process."""
if self._emulator is not None:
logging.info('Checking if emulator process has finished...')
try:
self._emulator.wait(timeout=30.0)
except subprocess.TimeoutExpired:
logging.exception(
'The emulator process did not finish after 30s. '
'returncode: %s. Will now try to kill() it.',
self._emulator.returncode)
self._emulator.kill()
self._emulator = None
self._emulator_output.close()
logging.info('The emulator process has finished.')
def close(self):
"""Clean up launcher files and processes."""
if not self._is_closed:
self._local_tmp_dir_handle.cleanup()
self.confirm_shutdown()
self._is_closed = True
def __del__(self):
self.close()
================================================
FILE: android_env/components/simulators/emulator/emulator_launcher_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.emulator_launcher."""
import builtins
import os
import subprocess
import tempfile
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import config_classes
from android_env.components.simulators.emulator import emulator_launcher
class EmulatorLauncherTest(parameterized.TestCase):
def setUp(self):
super().setUp()
self._emulator_path = 'fake/path/emulator'
self._adb_path = 'fake/path/adb'
self._adb_port = 5554
self._adb_server_port = 1234
self._emulator_console_port = 5555
self._avd_name = 'my_avd_name'
self._expected_command = [
self._emulator_path,
'-adb-path',
'fake/path/adb',
'-gpu',
'swangle_indirect',
'-no-audio',
'-show-kernel',
'-verbose',
'-avd',
self._avd_name,
]
self._headless = ['-no-skin', '-no-window']
self._ports = ['-ports', f'{self._emulator_console_port},{self._adb_port}']
self._snapshot = ['-no-snapshot']
base_lib_dir = self._emulator_path[:-8] + 'lib64/'
ld_library_path = ':'.join([
base_lib_dir + 'x11/', base_lib_dir + 'qt/lib/',
base_lib_dir + 'gles_swiftshader/', base_lib_dir
])
# Instantiate the config to extract default values.
config = config_classes.EmulatorLauncherConfig()
self._expected_env_vars = {
'ANDROID_HOME': '',
'ANDROID_SDK_ROOT': config.android_sdk_root,
'ANDROID_AVD_HOME': config.android_avd_home,
'ANDROID_EMULATOR_KVM_DEVICE': '/dev/kvm',
'ANDROID_ADB_SERVER_PORT': '1234',
'LD_LIBRARY_PATH': ld_library_path,
'QT_XKB_CONFIG_ROOT': str(self._emulator_path[:-8] + 'qt_config/'),
'ANDROID_EMU_ENABLE_CRASH_REPORTING': '1',
}
@parameterized.named_parameters([
('hide_perf_stats', False),
('show_perf_stats', True),
])
@mock.patch.object(os, 'makedirs')
@mock.patch.object(os, 'environ', autospec=True, return_value=dict())
@mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
def test_launch(
self,
show_perf_stats: bool,
mock_tmp_dir,
unused_os_environ,
unused_os_makedirs,
):
mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
config = config_classes.EmulatorLauncherConfig(
adb_port=self._adb_port,
emulator_console_port=self._emulator_console_port,
emulator_path=self._emulator_path,
avd_name=self._avd_name,
grpc_port=-1,
show_perf_stats=show_perf_stats,
)
adb_controller_config = config_classes.AdbControllerConfig(
adb_path=self._adb_path,
adb_server_port=self._adb_server_port,
)
launcher = emulator_launcher.EmulatorLauncher(
config=config, adb_controller_config=adb_controller_config
)
expected_env_vars = self._expected_env_vars
expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
with mock.patch.object(
subprocess, 'Popen', autospec=True
) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f:
f.return_value.__enter__ = f()
launcher.launch_emulator_process()
emulator_init.assert_called_once_with(
args=self._expected_command
+ self._headless
+ self._ports
+ self._snapshot,
env=expected_env_vars,
stdout=f(),
stderr=f(),
)
@parameterized.named_parameters([
('hide_perf_stats', False),
('show_perf_stats', True),
])
@mock.patch.object(os, 'makedirs')
@mock.patch.object(os, 'environ', autospec=True, return_value=dict())
@mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
def test_grpc_port(
self,
show_perf_stats: bool,
mock_tmp_dir,
unused_os_environ,
unused_os_makedirs,
):
mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
config = config_classes.EmulatorLauncherConfig(
adb_port=self._adb_port,
emulator_console_port=self._emulator_console_port,
emulator_path=self._emulator_path,
avd_name=self._avd_name,
grpc_port=8554,
show_perf_stats=show_perf_stats,
)
adb_controller_config = config_classes.AdbControllerConfig(
adb_path=self._adb_path,
adb_server_port=self._adb_server_port,
)
launcher = emulator_launcher.EmulatorLauncher(
config=config, adb_controller_config=adb_controller_config
)
expected_env_vars = self._expected_env_vars
expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
with mock.patch.object(
subprocess, 'Popen', autospec=True
) as emulator_init, mock.patch.object(builtins, 'open', autospec=True) as f:
f.return_value.__enter__ = f()
launcher.launch_emulator_process()
emulator_init.assert_called_once_with(
args=self._expected_command
+ ['-grpc', '8554']
+ self._headless
+ self._ports
+ self._snapshot,
env=expected_env_vars,
stdout=f(),
stderr=f(),
)
@parameterized.named_parameters([
('hide_perf_stats', False),
('show_perf_stats', True),
])
@mock.patch.object(os, 'makedirs')
@mock.patch.object(os, 'environ', autospec=True, return_value=dict())
@mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
def test_snapshot(
self,
show_perf_stats: bool,
mock_tmp_dir,
unused_os_environ,
unused_os_makedirs,
):
mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
config = config_classes.EmulatorLauncherConfig(
adb_port=self._adb_port,
emulator_console_port=self._emulator_console_port,
emulator_path=self._emulator_path,
avd_name=self._avd_name,
grpc_port=-1,
snapshot_name='my_snapshot',
show_perf_stats=show_perf_stats,
)
adb_controller_config = config_classes.AdbControllerConfig(
adb_path=self._adb_path,
adb_server_port=self._adb_server_port,
)
launcher = emulator_launcher.EmulatorLauncher(
config=config, adb_controller_config=adb_controller_config
)
expected_snapshot = [
'-snapshot', 'my_snapshot', '-feature',
'AllowSnapshotMigration,MigratableSnapshotSave'
]
expected_env_vars = self._expected_env_vars
expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
with mock.patch.object(
subprocess, 'Popen', autospec=True) as emulator_init, \
mock.patch.object(builtins, 'open', autospec=True) as f:
f.return_value.__enter__ = f()
launcher.launch_emulator_process()
emulator_init.assert_called_once_with(
args=self._expected_command
+ self._headless
+ self._ports
+ expected_snapshot,
env=expected_env_vars,
stdout=f(),
stderr=f(),
)
@parameterized.named_parameters([
('hide_perf_stats', False),
('show_perf_stats', True),
])
@mock.patch.object(os, 'makedirs')
@mock.patch.object(os, 'environ', autospec=True, return_value=dict())
@mock.patch.object(tempfile, 'TemporaryDirectory', instance=True)
def test_network_restrict(
self,
show_perf_stats: bool,
mock_tmp_dir,
unused_os_environ,
unused_os_makedirs,
):
mock_tmp_dir.return_value.name.return_value = 'local_tmp_dir'
config = config_classes.EmulatorLauncherConfig(
adb_port=self._adb_port,
emulator_console_port=self._emulator_console_port,
emulator_path=self._emulator_path,
avd_name=self._avd_name,
grpc_port=-1,
restrict_network=True,
show_perf_stats=show_perf_stats,
)
adb_controller_config = config_classes.AdbControllerConfig(
adb_path=self._adb_path,
adb_server_port=self._adb_server_port,
)
launcher = emulator_launcher.EmulatorLauncher(
config=config, adb_controller_config=adb_controller_config
)
expected_snapshot = ['-no-snapshot']
expected_network_restrict = [
'-network-user-mode-options', 'restrict=y', '-wifi-user-mode-options',
'restrict=y'
]
expected_env_vars = self._expected_env_vars
expected_env_vars['SHOW_PERF_STATS'] = '1' if show_perf_stats else '0'
with mock.patch.object(
subprocess, 'Popen', autospec=True) as emulator_init, \
mock.patch.object(builtins, 'open', autospec=True) as f:
f.return_value.__enter__ = f()
launcher.launch_emulator_process()
emulator_init.assert_called_once_with(
self._expected_command
+ self._headless
+ self._ports
+ expected_snapshot
+ expected_network_restrict,
env=expected_env_vars,
stdout=f(),
stderr=f(),
)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/simulators/emulator/emulator_simulator.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A class that manages an Android Emulator."""
import os
import time
from typing import Any
from absl import logging
from android_env.components import adb_controller
from android_env.components import adb_log_stream
from android_env.components import config_classes
from android_env.components import errors
from android_env.components import log_stream
from android_env.components.simulators import base_simulator
from android_env.components.simulators.emulator import emulator_launcher
from android_env.proto import state_pb2
import grpc
import numpy as np
import portpicker
from android_env.proto import emulator_controller_pb2
from android_env.proto import emulator_controller_pb2_grpc
from android_env.proto import snapshot_service_pb2
from android_env.proto import snapshot_service_pb2_grpc
from google.protobuf import empty_pb2
_DEFAULT_SNAPSHOT_NAME: str = 'default_snapshot'
def _is_existing_emulator_provided(
launcher_config: config_classes.EmulatorLauncherConfig,
) -> bool:
"""Returns true if all necessary args were provided."""
return bool(
launcher_config.adb_port
and launcher_config.emulator_console_port
and launcher_config.grpc_port
)
def _pick_adb_port() -> int:
"""Tries to pick a port in the recommended range 5555-5585.
If no such port can be found, will return a random unused port. More info:
https://developer.android.com/studio/command-line/adb#howadbworks.
Returns:
port: an available port for adb.
"""
for p in range(5555, 5587, 2):
if portpicker.is_port_free(p):
return p
return portpicker.pick_unused_port()
def _pick_emulator_grpc_port() -> int:
"""Tries to pick the recommended port for grpc.
If no such port can be found, will return a random unused port. More info:
https://android.googlesource.com/platform/external/qemu/+/emu-master-dev/android/android-grpc/docs/.
Returns:
port: an available port for emulator grpc.
"""
if portpicker.is_port_free(8554):
return 8554
else:
return portpicker.pick_unused_port()
class EmulatorBootError(errors.SimulatorError):
"""Raised when an emulator failed to boot."""
class EmulatorCrashError(errors.SimulatorError):
"""Raised when a simulator crashed."""
class EmulatorSimulator(base_simulator.BaseSimulator):
"""Controls an Android Emulator."""
def __init__(self, config: config_classes.EmulatorConfig):
"""Instantiates an EmulatorSimulator."""
super().__init__(config)
self._config = config
# If adb_port, console_port and grpc_port are all already provided,
# we assume the emulator already exists and there's no need to launch.
if _is_existing_emulator_provided(self._config.emulator_launcher):
self._existing_emulator_provided = True
logging.info('Connecting to existing emulator "%r"',
self.adb_device_name())
else:
self._existing_emulator_provided = False
self._config.emulator_launcher.adb_port = _pick_adb_port()
self._config.emulator_launcher.emulator_console_port = (
portpicker.pick_unused_port()
)
self._config.emulator_launcher.grpc_port = _pick_emulator_grpc_port()
self._channel = None
self._emulator_stub: emulator_controller_pb2_grpc.EmulatorControllerStub | None = (
None
)
self._snapshot_stub = None
# Set the image format to RGBA. The width and height of the returned
# screenshots will use the device's width and height.
self._image_format = emulator_controller_pb2.ImageFormat(
format=emulator_controller_pb2.ImageFormat.ImgFormat.RGBA8888)
if (
self._config.launch_n_times_without_reboot
> self._config.launch_n_times_without_reinstall
):
raise ValueError(
'Number of launch attempts before reboot'
f' ({self._config.launch_n_times_without_reboot}) should not be'
' greater than number of launch attempts before reinstall'
f' ({self._config.launch_n_times_without_reinstall})'
)
# Initialize own ADB controller.
self._config.adb_controller.device_name = self.adb_device_name()
self._adb_controller = self.create_adb_controller()
self._adb_controller.init_server()
logging.info(
'Initialized simulator with ADB server port %r.',
self._config.adb_controller.adb_server_port,
)
# If necessary, create EmulatorLauncher.
if self._existing_emulator_provided:
self._logfile_path = self._config.logfile_path or None
self._launcher = None
else:
logging.info(
'emulator_launcher config: %r', self._config.emulator_launcher
)
self._launcher = emulator_launcher.EmulatorLauncher(
config=self._config.emulator_launcher,
adb_controller_config=self._config.adb_controller,
)
self._logfile_path = (
self._config.logfile_path or self._launcher.logfile_path()
)
def _reconnect_on_grpc_error(func):
"""Decorator function for reconnecting to emulator upon grpc errors."""
def wrapper(self, *args, **kwargs):
try:
return func(self, *args, **kwargs)
except grpc.RpcError:
logging.exception('RpcError caught. Reconnecting to emulator...')
self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
self._config.emulator_launcher.grpc_port
)
return func(self, *args, **kwargs)
return wrapper
def get_logs(self) -> str:
"""Returns logs recorded by the emulator."""
if self._logfile_path and os.path.exists(self._logfile_path):
with open(self._logfile_path, 'rb') as f:
return f.read().decode('utf-8')
else:
return f'Logfile does not exist: {self._logfile_path}.'
def adb_device_name(self) -> str:
return 'emulator-%s' % (self._config.emulator_launcher.adb_port - 1)
def create_adb_controller(self):
"""Returns an ADB controller which can communicate with this simulator."""
return adb_controller.AdbController(self._config.adb_controller)
def create_log_stream(self) -> log_stream.LogStream:
return adb_log_stream.AdbLogStream(
adb_command_prefix=self._adb_controller.command_prefix(),
verbose=self._config.verbose_logs,
)
def _launch_impl(self) -> None:
"""Prepares an Android Emulator for RL interaction.
The behavior depends on `self._num_launch_attempts`'s value:
* <= self._config.launch_n_times_without_reboot -> Normal boot behavior.
* > self._config.launch_n_times_without_reboot but <=
self._config.launch_n_times_without_reinstall -> reboot (i.e. process
is killed and started again).
* > self._config.launch_n_times_without_reinstall -> reinstall (i.e.
process is killed, emulator files are deleted and the process started
again).
"""
logging.info('Attempt %r at launching the Android Emulator (%r)',
self._num_launch_attempts, self.adb_device_name())
if self._launcher is not None:
# If not the first time, then shutdown the emulator first.
if (
self._emulator_stub is not None
and self._num_launch_attempts
> self._config.launch_n_times_without_reboot
):
self._shutdown_emulator()
# Subsequent attempts cause the emulator files to be reinstalled.
if (
self._num_launch_attempts
> self._config.launch_n_times_without_reinstall
):
logging.info('Closing emulator (%r)', self.adb_device_name())
self._launcher.close()
self._launcher = emulator_launcher.EmulatorLauncher(
config=self._config.emulator_launcher,
adb_controller_config=self._config.adb_controller,
)
self._launcher.launch_emulator_process()
# Establish grpc connection to emulator process.
self._emulator_stub, self._snapshot_stub = self._connect_to_emulator(
self._config.emulator_launcher.grpc_port
)
# Confirm booted status.
try:
self._confirm_booted()
except EmulatorCrashError:
logging.exception('Failed to confirm booted status of emulator.')
logging.info('Done booting the Android Emulator.')
def load_state(
self, request: state_pb2.LoadStateRequest
) -> state_pb2.LoadStateResponse:
"""Loads a state using the emulator's snapshotting mechanism.
Args:
request: The `LoadStateRequest`. In this case, `args` should be a dict
containing the key 'snapshot_name', representing the name of the
snapshot to load. If `request.args.snapshot_name` is `None`, a default
snapshot name is used.
Returns:
A response indicating whether the snapshot was successfully loaded.
* If the snapshot was loaded successfully, the status will be `OK`.
* If no snapshot of the given name was found, the status will be
`NOT_FOUND`.
* If an error occurred during the snapshot loading process, the status
will be `ERROR` and the `error_message` field will be filled.
"""
assert self._snapshot_stub is not None
snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME)
snapshot_list = self._snapshot_stub.ListSnapshots(
snapshot_service_pb2.SnapshotFilter(
statusFilter=snapshot_service_pb2.SnapshotFilter.LoadStatus.All
)
)
if any(
snapshot.snapshot_id == snapshot_name
for snapshot in snapshot_list.snapshots
):
snapshot_result = self._snapshot_stub.LoadSnapshot(
snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name)
)
if snapshot_result.success:
return state_pb2.LoadStateResponse(
status=state_pb2.LoadStateResponse.Status.OK
)
else:
return state_pb2.LoadStateResponse(
status=state_pb2.LoadStateResponse.Status.ERROR,
error_message=snapshot_result.err.decode('utf-8'),
)
else:
return state_pb2.LoadStateResponse(
status=state_pb2.LoadStateResponse.Status.NOT_FOUND
)
def save_state(
self, request: state_pb2.SaveStateRequest
) -> state_pb2.SaveStateResponse:
"""Saves a state using the emulator's snapshotting mechanism.
Args:
request: The `SaveStateRequest`. In this case, `args` should be a dict
containing the key 'snapshot_name', representing the name of the
snapshot to save. If `request.args.snapshot_name` is `None`, a default
snapshot name is used.
Returns:
A response indicating whether the snapshot was successfully saved.
* If the snapshot was saved successfully, the status will be `OK`.
* If an error occurred during the snapshot saving process, the status
will be `ERROR` and the `error_message` field will be filled.
"""
assert self._snapshot_stub is not None
snapshot_name = request.args.get('snapshot_name', _DEFAULT_SNAPSHOT_NAME)
snapshot_result = self._snapshot_stub.SaveSnapshot(
snapshot_service_pb2.SnapshotPackage(snapshot_id=snapshot_name)
)
if snapshot_result.success:
return state_pb2.SaveStateResponse(
status=state_pb2.SaveStateResponse.Status.OK
)
else:
return state_pb2.SaveStateResponse(
status=state_pb2.SaveStateResponse.Status.ERROR,
error_message=snapshot_result.err.decode('utf-8'),
)
def _connect_to_emulator(
self,
grpc_port: int,
timeout_sec: int = 100,
) -> tuple[
emulator_controller_pb2_grpc.EmulatorControllerStub,
snapshot_service_pb2_grpc.SnapshotServiceStub,
]:
"""Connects to an emulator and returns a corresponsing stub."""
logging.info('Creating gRPC channel to the emulator on port %r', grpc_port)
port = f'localhost:{grpc_port}'
options = [('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1)]
creds = grpc.local_channel_credentials()
try:
self._channel = grpc.secure_channel(port, creds, options=options)
grpc.channel_ready_future(self._channel).result(timeout=timeout_sec)
except (grpc.RpcError, grpc.FutureTimeoutError) as grpc_error:
logging.exception('Failed to connect to the emulator.')
raise EmulatorBootError(
'Failed to connect to the emulator.') from grpc_error
logging.info('Added gRPC channel for the Emulator on port %s', port)
emulator_controller_stub = (
emulator_controller_pb2_grpc.EmulatorControllerStub(self._channel)
)
snapshot_stub = snapshot_service_pb2_grpc.SnapshotServiceStub(self._channel)
return emulator_controller_stub, snapshot_stub
@_reconnect_on_grpc_error
def _confirm_booted(self, startup_wait_time_sec: int = 300):
"""Waits until the emulator is fully booted."""
assert (
self._emulator_stub is not None
), 'Emulator stub has not been initialized yet.'
start_time = time.time()
deadline = start_time + startup_wait_time_sec
success = False
while time.time() < deadline:
emu_status = self._emulator_stub.getStatus(empty_pb2.Empty())
logging.info('Waiting for emulator (%r) to start... (%rms)',
self.adb_device_name(), emu_status.uptime)
if emu_status.booted:
success = True
break
time.sleep(5.0)
elapsed_time = time.time() - start_time
if not success:
raise EmulatorCrashError(
f'The emulator failed to boot after {startup_wait_time_sec} seconds')
logging.info('Done booting the emulator (in %f seconds).', elapsed_time)
logging.info('********** Emulator logs **********')
for line in self.get_logs().splitlines():
logging.info(line)
logging.info('******* End of emulator logs *******')
logging.info('See the full emulator logs at %r', self._logfile_path)
@_reconnect_on_grpc_error
def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None:
"""Sends a touch event to be executed on the simulator.
Args:
touches: A list of touch events. Each element in the list corresponds to a
single touch event. Each touch event tuple should have:
0 x: The horizontal coordinate of this event.
1 y: The vertical coordinate of this event.
2 is_down: Whether the finger is touching or not the screen.
3 identifier: Identifies a particular finger in a multitouch event.
"""
assert (
self._emulator_stub is not None
), 'Emulator stub has not been initialized yet.'
touch_events = [
emulator_controller_pb2.Touch(
x=t[0], y=t[1], pressure=int(t[2]), identifier=t[3])
for t in touches
]
self._emulator_stub.sendTouch(
emulator_controller_pb2.TouchEvent(touches=touch_events))
@_reconnect_on_grpc_error
def send_key(self, keycode: np.int32, event_type: str) -> None:
"""Sends a key event to the emulator.
Args:
keycode: Code representing the desired key press in XKB format.
See the emulator_controller_pb2 for details.
event_type: Type of key event to be sent.
"""
event_types = emulator_controller_pb2.KeyboardEvent.KeyEventType.keys()
if event_type not in event_types:
raise ValueError(
f'Event type must be one of {event_types} but is {event_type}.')
assert (
self._emulator_stub is not None
), 'Emulator stub has not been initialized yet.'
self._emulator_stub.sendKey(
emulator_controller_pb2.KeyboardEvent(
codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType.Value(
event_type
),
keyCode=int(keycode),
)
)
@_reconnect_on_grpc_error
def _get_screenshot_impl(self) -> np.ndarray:
"""Fetches the latest screenshot from the emulator."""
assert (
self._emulator_stub is not None
), 'Emulator stub has not been initialized yet.'
assert self._image_format, 'ImageFormat has not been initialized yet.'
image_proto = self._emulator_stub.getScreenshot(self._image_format)
h, w = image_proto.format.height, image_proto.format.width
image = np.frombuffer(image_proto.image, dtype='uint8', count=h * w * 4)
image.shape = (h, w, 4)
return image[:, :, :3]
@_reconnect_on_grpc_error
def _shutdown_emulator(self):
"""Sends a signal to trigger emulator shutdown."""
if self._emulator_stub is None:
logging.info('Emulator (%r) is not up.', self.adb_device_name())
return
assert self._launcher is not None, 'Launcher is already down.'
logging.info('Shutting down the emulator (%r)...', self.adb_device_name())
self._emulator_stub.setVmState(
emulator_controller_pb2.VmRunState(
state=emulator_controller_pb2.VmRunState.RunState.SHUTDOWN))
self._launcher.confirm_shutdown()
def close(self):
super().close()
if self._launcher is not None:
self._shutdown_emulator()
logging.info('Closing emulator (%r)', self.adb_device_name())
self._launcher.close()
self._emulator_stub = None
self._snapshot_stub = None
if self._channel is not None:
self._channel.close()
super().close()
================================================
FILE: android_env/components/simulators/emulator/emulator_simulator_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.emulator_simulator."""
import builtins
import os
import time
from unittest import mock
from absl.testing import absltest
from android_env.components import adb_call_parser
from android_env.components import adb_controller
from android_env.components import config_classes
from android_env.components.simulators.emulator import emulator_launcher
from android_env.components.simulators.emulator import emulator_simulator
from android_env.proto import state_pb2
import grpc
from PIL import Image
import portpicker
from android_env.proto import emulator_controller_pb2
from android_env.proto import emulator_controller_pb2_grpc
from android_env.proto import snapshot_service_pb2
class EmulatorSimulatorTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.addCleanup(mock.patch.stopall) # Disable previous patches.
self._adb_controller = mock.create_autospec(adb_controller.AdbController)
self._adb_call_parser = mock.create_autospec(adb_call_parser.AdbCallParser)
self._launcher = mock.create_autospec(emulator_launcher.EmulatorLauncher)
self._launcher.logfile_path.return_value = 'logfile_path'
self._emulator_stub = mock.create_autospec(
emulator_controller_pb2_grpc.EmulatorControllerStub)
self._grpc_channel = mock.create_autospec(grpc.Channel)
mock.patch.object(
grpc.aio, 'secure_channel', return_value=self._grpc_channel).start()
mock.patch.object(
grpc, 'secure_channel', return_value=self._grpc_channel).start()
mock.patch.object(
grpc, 'local_channel_credentials',
return_value=self._grpc_channel).start()
self._mock_future = mock.create_autospec(grpc.Future)
mock.patch.object(
grpc, 'channel_ready_future', return_value=self._mock_future).start()
mock.patch.object(time, 'time', return_value=12345).start()
mock.patch.object(
adb_controller, 'AdbController',
return_value=self._adb_controller).start()
mock.patch.object(
adb_call_parser,
'AdbCallParser',
autospec=True,
return_value=self._adb_call_parser).start()
mock.patch.object(
emulator_launcher, 'EmulatorLauncher',
return_value=self._launcher).start()
def test_adb_device_name_not_empty(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
self.assertNotEmpty(simulator.adb_device_name())
def test_logfile_path(self):
"""The log file's path should correspond to the one from the config."""
config = config_classes.EmulatorConfig(
logfile_path='fake/logfile/path',
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
with mock.patch.object(
os.path, 'exists', autospec=True, return_value=True
), mock.patch.object(builtins, 'open', autospec=True) as mock_open:
mock_file = mock_open.return_value.__enter__.return_value
mock_file.read.return_value = b'fake_logs'
logs = simulator.get_logs()
mock_open.assert_called_once_with('fake/logfile/path', 'rb')
self.assertEqual(logs, 'fake_logs')
@mock.patch.object(portpicker, 'is_port_free', return_value=True)
def test_grpc_port(self, unused_mock_portpicker):
launcher_config = config_classes.EmulatorLauncherConfig(
tmp_dir=self.create_tempdir().full_path
)
config = config_classes.EmulatorConfig(
emulator_launcher=launcher_config,
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
self.assertEqual(launcher_config.grpc_port, 8554)
@mock.patch.object(portpicker, 'is_port_free', return_value=False)
def test_grpc_port_unavailable(self, unused_mock_portpicker):
launcher_config = config_classes.EmulatorLauncherConfig(
tmp_dir=self.create_tempdir().full_path
)
config = config_classes.EmulatorConfig(
emulator_launcher=launcher_config,
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
self.assertNotEqual(launcher_config.grpc_port, 8554)
def test_launch_operation_order(self):
"""Makes sure that adb_controller is started before Emulator is launched."""
# Arrange.
call_order = []
self._adb_controller.init_server.side_effect = lambda: call_order.append(
'init_server'
)
self._launcher.launch_emulator_process.side_effect = (
lambda: call_order.append('launch_emulator_process')
)
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
# Act.
simulator.launch() # The simulator should launch and not crash.
# Assert.
# The adb server should be initialized before launching the emulator.
self.assertEqual(call_order, ['init_server', 'launch_emulator_process'])
def test_close(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
# For whatever reason clients may want to close the EmulatorSimulator.
# We just want to check that the simulator does not crash and/or leak
# resources.
simulator.close()
def test_value_error_if_launch_attempt_params_incorrect(self):
self.assertRaises(
ValueError,
emulator_simulator.EmulatorSimulator,
config=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
launch_n_times_without_reboot=2,
launch_n_times_without_reinstall=1,
),
)
def test_launch_attempt_reboot(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
launch_n_times_without_reboot=1,
launch_n_times_without_reinstall=2,
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
self._launcher.launch_emulator_process.assert_called_once()
self._launcher.reset_mock()
# Launch attempt 2.
simulator.launch()
self._launcher.confirm_shutdown.assert_called_once()
self._launcher.close.assert_not_called()
self._launcher.launch_emulator_process.assert_called_once()
def test_launch_attempt_reinstall_after_zero_attempts(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
launch_n_times_without_reboot=0,
launch_n_times_without_reinstall=0,
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should not reboot or reinstall on its very first launch.
simulator.launch()
self._launcher.launch_emulator_process.assert_called_once()
self._launcher.confirm_shutdown.assert_not_called()
self._launcher.close.assert_not_called()
# Every subsequent attempt should reboot and reinstall.
self._launcher.reset_mock()
simulator.launch()
self._launcher.confirm_shutdown.assert_called_once()
self._launcher.close.assert_called_once() # Now this should `close()`.
self._launcher.launch_emulator_process.assert_called_once()
def test_launch_attempt_reinstall(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
launch_n_times_without_reboot=1,
launch_n_times_without_reinstall=2,
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
self._launcher.launch_emulator_process.assert_called_once()
# Launch attempt 2.
self._launcher.reset_mock()
simulator.launch()
self._launcher.confirm_shutdown.assert_called_once()
self._launcher.close.assert_not_called() # Reboots don't `close()`.
self._launcher.launch_emulator_process.assert_called_once()
# Launch attempt 3.
self._launcher.reset_mock()
simulator.launch()
self._launcher.confirm_shutdown.assert_called_once()
self._launcher.close.assert_called_once() # Now this should `close()`.
self._launcher.launch_emulator_process.assert_called_once()
def test_get_screenshot(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
simulator._emulator_stub.getScreenshot = mock.MagicMock(
return_value=emulator_controller_pb2.Image(
format=emulator_controller_pb2.ImageFormat(width=5678, height=1234),
image=Image.new('RGBA', (1234, 5678)).tobytes(),
timestampUs=123))
screenshot = simulator.get_screenshot()
# The screenshot should have the same screen dimensions as reported by ADB
# and it should have 3 channels (RGB).
self.assertEqual(screenshot.shape, (1234, 5678, 3))
def test_load_state(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
with mock.patch.object(
simulator, '_snapshot_stub', create_autospec=True
) as mock_snapshot_stub:
snapshot_list = snapshot_service_pb2.SnapshotList()
snapshot_list.snapshots.add(snapshot_id='snapshot_name_foo')
snapshot_list.snapshots.add(snapshot_id='snapshot_name_bar')
mock_snapshot_stub.ListSnapshots.return_value = snapshot_list
mock_snapshot_stub.LoadSnapshot.return_value = (
snapshot_service_pb2.SnapshotPackage(success=True)
)
load_response = simulator.load_state(
request=state_pb2.LoadStateRequest(
args={'snapshot_name': 'snapshot_name_foo'}
)
)
self.assertEqual(
load_response.status, state_pb2.LoadStateResponse.Status.OK
)
load_response = simulator.load_state(
request=state_pb2.LoadStateRequest(
args={'snapshot_name': 'snapshot_name_baz'}
)
)
self.assertEqual(
load_response.status, state_pb2.LoadStateResponse.Status.NOT_FOUND
)
mock_snapshot_stub.LoadSnapshot.return_value = (
snapshot_service_pb2.SnapshotPackage(success=False, err=b'error')
)
load_response = simulator.load_state(
request=state_pb2.LoadStateRequest(
args={'snapshot_name': 'snapshot_name_bar'}
)
)
self.assertEqual(
load_response.status, state_pb2.LoadStateResponse.Status.ERROR
)
self.assertEqual(load_response.error_message, 'error')
def test_save_state(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
with mock.patch.object(
simulator, '_snapshot_stub', create_autospec=True
) as mock_snapshot_stub:
mock_snapshot_stub.SaveSnapshot.return_value = (
snapshot_service_pb2.SnapshotPackage(success=True)
)
save_response = simulator.save_state(
request=state_pb2.SaveStateRequest(
args={'snapshot_name': 'snapshot_name_foo'}
)
)
self.assertEqual(
save_response.status, state_pb2.SaveStateResponse.Status.OK
)
mock_snapshot_stub.SaveSnapshot.return_value = (
snapshot_service_pb2.SnapshotPackage(success=False, err=b'error')
)
save_response = simulator.save_state(
request=state_pb2.SaveStateRequest(
args={'snapshot_name': 'snapshot_name_bar'}
)
)
self.assertEqual(
save_response.status, state_pb2.SaveStateResponse.Status.ERROR
)
self.assertEqual(save_response.error_message, 'error')
def test_send_touch(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)
simulator.send_touch([(123, 456, True, 0), (135, 246, True, 1)])
simulator.send_touch([(1, 2, True, 0), (3, 4, True, 1)])
simulator.send_touch([(321, 654, False, 0), (531, 642, False, 1)])
simulator._emulator_stub.sendTouch.assert_has_calls([
mock.call(
emulator_controller_pb2.TouchEvent(touches=[{
'x': 123,
'y': 456,
'pressure': 1
}, {
'x': 135,
'y': 246,
'pressure': 1,
'identifier': 1
}])),
mock.call(
emulator_controller_pb2.TouchEvent(touches=[{
'x': 1,
'y': 2,
'pressure': 1
}, {
'x': 3,
'y': 4,
'pressure': 1,
'identifier': 1
}])),
mock.call(
emulator_controller_pb2.TouchEvent(touches=[{
'x': 321,
'y': 654,
'pressure': 0
}, {
'x': 531,
'y': 642,
'pressure': 0,
'identifier': 1
}])),
])
def test_send_key(self):
config = config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
grpc_port=1234, tmp_dir=self.create_tempdir().full_path
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='/my/adb',
adb_server_port=5037,
),
)
simulator = emulator_simulator.EmulatorSimulator(config)
# The simulator should launch and not crash.
simulator.launch()
simulator._emulator_stub.sendTouch = mock.MagicMock(return_value=None)
simulator.send_key(123, 'keydown')
simulator.send_key(321, 'keydown')
simulator.send_key(321, 'keyup')
simulator.send_key(123, 'keyup')
simulator.send_key(321, 'keypress')
simulator.send_key(123, 'keypress')
simulator._emulator_stub.sendKey.assert_has_calls([
mock.call(
emulator_controller_pb2.KeyboardEvent(
codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
.keydown,
keyCode=123,
)),
mock.call(
emulator_controller_pb2.KeyboardEvent(
codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
.keydown,
keyCode=321,
)),
mock.call(
emulator_controller_pb2.KeyboardEvent(
codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
.keyup,
keyCode=321,
)),
mock.call(
emulator_controller_pb2.KeyboardEvent(
codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
.keyup,
keyCode=123,
)),
mock.call(
emulator_controller_pb2.KeyboardEvent(
codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
.keypress,
keyCode=321,
)),
mock.call(
emulator_controller_pb2.KeyboardEvent(
codeType=emulator_controller_pb2.KeyboardEvent.KeyCodeType.XKB,
eventType=emulator_controller_pb2.KeyboardEvent.KeyEventType
.keypress,
keyCode=123,
))
])
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/simulators/fake/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/components/simulators/fake/fake_simulator.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fake Simulator for testing AndroidEnv infrastructure."""
import random
import threading
import time
from absl import logging
from android_env.components import adb_controller
from android_env.components import config_classes
from android_env.components import log_stream
from android_env.components.simulators import base_simulator
import numpy as np
class FakeStream:
"""This class simulates the logs coming from ADB."""
def __init__(self):
self._values = [
'',
self._make_stdout('reward: 0.5'),
self._make_stdout('reward: 1.0'),
self._make_stdout('extra: my_extra [1.0]'),
self._make_stdout('episode end'),
]
self._kill = False
self._lock = threading.Lock()
def _make_stdout(self, data):
"""Returns a valid log output with given data as message."""
return f' 1553110400.424 5583 5658 D Tag: {data}'
def kill(self):
self._kill = True
def __iter__(self):
while True:
if self._kill:
return
else:
with self._lock:
next_value = random.choices(
self._values, weights=[0.49, 0.15, 0.15, 0.15, 0.01], k=1)[0]
time.sleep(0.1)
yield next_value
class FakeLogStream(log_stream.LogStream):
"""FakeLogStream class that wraps a FakeStream."""
def __init__(self):
super().__init__(verbose=False)
self.stream = FakeStream()
def _get_stream_output(self):
return self.stream
def stop_stream(self):
self.stream.kill()
class FakeAdbController(adb_controller.AdbController):
"""Fake adb controller for FakeSimulator."""
def execute_command(
self,
args: list[str],
timeout: float | None = None,
device_specific: bool = True,
) -> bytes:
"""Returns fake output for adb commands."""
del timeout, device_specific
# Fake "service is ready" output.
if args[:3] == ['shell', 'service', 'check']:
return f'Service {args[-1]}: found'.encode('utf-8')
# Fake dumpsys output for getting orientation.
if args == ['shell', 'dumpsys', 'input']:
return b' SurfaceOrientation: 0'
# app_screen_checker: fake_task expects 'fake_activity'.
if args[:4] == ['shell', 'am', 'stack', 'list']:
return (b'taskId=0 fake_activity visible=true '
b'topActivity=ComponentInfo{fake_activity}')
return b'fake output'
class FakeSimulator(base_simulator.BaseSimulator):
"""FakeSimulator class."""
def __init__(self, config: config_classes.FakeSimulatorConfig):
"""FakeSimulator class that can replace EmulatorSimulator in AndroidEnv."""
super().__init__(config)
self._screen_dimensions = np.array(config.screen_dimensions)
logging.info('Created FakeSimulator.')
def get_logs(self) -> str:
return 'FakeSimulator: fake logs'
def adb_device_name(self) -> str:
return 'fake_simulator'
def create_adb_controller(self):
return FakeAdbController(config_classes.AdbControllerConfig())
def create_log_stream(self) -> log_stream.LogStream:
return FakeLogStream()
def _launch_impl(self) -> None:
pass
def send_touch(self, touches: list[tuple[int, int, bool, int]]) -> None:
del touches
def send_key(self, keycode: np.int32, event_type: str) -> None:
del keycode, event_type
def _get_screenshot_impl(self) -> np.ndarray:
return np.random.randint(
low=0, high=255, size=(*self._screen_dimensions, 3), dtype=np.uint8)
================================================
FILE: android_env/components/simulators/fake/fake_simulator_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for fake_simulator."""
import re
from absl.testing import absltest
from android_env.components import config_classes
from android_env.components.simulators.fake import fake_simulator
import numpy as np
class FakeSimulatorTest(absltest.TestCase):
def test_device_name(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
)
self.assertEqual(simulator.adb_device_name(), 'fake_simulator')
def test_launch_close(self):
# The simulator should launch and not crash.
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
)
simulator.launch()
# Closing the simulator should also not crash.
simulator.close()
def test_get_screenshot(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
)
simulator.launch()
screenshot = simulator.get_screenshot()
np.testing.assert_equal(screenshot.shape, [320, 480, 3])
np.testing.assert_equal(screenshot.dtype, np.uint8)
def test_log_stream(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
)
simulator.launch()
log_stream = simulator.create_log_stream()
# Start yielding lines from LogStream.
log_stream.resume_stream()
lines = [
'',
' 1553110400.424 5583 5658 D Tag: reward: 0.5',
' 1553110400.424 5583 5658 D Tag: reward: 1.0',
' 1553110400.424 5583 5658 D Tag: extra: my_extra [1.0]',
' 1553110400.424 5583 5658 D Tag: episode end',
]
for i, line in enumerate(log_stream.get_stream_output()):
self.assertIn(line, lines)
if i > 10:
break
def test_adb_output(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
)
simulator.launch()
adb_controller = simulator.create_adb_controller()
line = adb_controller.execute_command(['shell', 'dumpsys', 'input'])
line = line.decode('utf-8')
matches = re.match(r'\s+SurfaceOrientation:\s+(\d)', line)
self.assertIsNotNone(matches)
orientation = matches.group(1)
self.assertEqual(orientation, '0')
line = adb_controller.execute_command(['shell', 'service', 'check', 'foo'])
line = line.decode('utf-8')
self.assertEqual(line, 'Service foo: found')
line = adb_controller.execute_command(['shell', 'am', 'stack', 'list'])
line = line.decode('utf-8')
self.assertEqual(line, 'taskId=0 fake_activity visible=true '
'topActivity=ComponentInfo{fake_activity}')
def test_send_touch(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
)
simulator.launch()
simulator.send_touch([(0, 1, True, 0)])
simulator.send_touch([(0, 1, False, 0)])
# No assertions, we just want to ensure that `send_touch()` can be called
# without crashing anything.
def test_send_key(self):
simulator = fake_simulator.FakeSimulator(
config_classes.FakeSimulatorConfig(screen_dimensions=(320, 480))
)
simulator.launch()
simulator.send_key(np.int32(123), 'keydown')
simulator.send_key(np.int32(123), 'keyup')
simulator.send_key(np.int32(123), 'keypress')
# No assertions, we just want to ensure that `send_key()` can be called
# without crashing anything.
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/specs.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base specs for AndroidEnv."""
from android_env.components import action_type
from android_env.proto import task_pb2
from dm_env import specs
import numpy as np
_PROTO_DTYPE_TO_NUMPY_DTYPE = {
task_pb2.ArraySpec.DataType.FLOAT: np.float32,
task_pb2.ArraySpec.DataType.DOUBLE: np.float64,
task_pb2.ArraySpec.DataType.INT8: np.int8,
task_pb2.ArraySpec.DataType.INT16: np.int16,
task_pb2.ArraySpec.DataType.INT32: np.int32,
task_pb2.ArraySpec.DataType.INT64: np.int64,
task_pb2.ArraySpec.DataType.UINT8: np.uint8,
task_pb2.ArraySpec.DataType.UINT16: np.uint16,
task_pb2.ArraySpec.DataType.UINT32: np.uint32,
task_pb2.ArraySpec.DataType.UINT64: np.uint64,
task_pb2.ArraySpec.DataType.BOOL: np.bool_,
task_pb2.ArraySpec.DataType.STRING_U1: np.dtype(('U1')),
task_pb2.ArraySpec.DataType.STRING_U16: np.dtype((' dict[str, specs.Array]:
"""Default action spec for AndroidEnv.
Args:
num_fingers: Number of virtual fingers of the agent.
enable_key_events: Whether keyboard key events are enabled.
Returns:
A dict of action specs, each item corresponding to a virtual finger.
action_type: An integer of type ActionType: TOUCH=0, LIFT=1, REPEAT=2
touch_position: Position [x, y] of the touch action, where x, y are float
values between 0.0 and 1.0 corresponding to the relative position on the
screen. IGNORED when (action_type != ActionType.TOUCH).
keycode: code representing the desired key press in XKB format. See the
emulator_controller_pb2 for details.
action_type_i: Action type for additional fingers (i>1).
touch_position_i: Touch position for additional fingers (i>1).
"""
num_actions = len(action_type.ActionType) if enable_key_events else 3
action_spec = {
'action_type':
specs.DiscreteArray(num_values=num_actions, name='action_type'),
'touch_position':
specs.BoundedArray(
shape=(2,),
dtype=np.float32,
minimum=[0.0, 0.0],
maximum=[1.0, 1.0],
name='touch_position'),
}
for i in range(2, num_fingers + 1):
action_spec.update({
f'action_type_{i}':
specs.DiscreteArray(
num_values=len(action_type.ActionType),
name=f'action_type_{i}'),
f'touch_position_{i}':
specs.BoundedArray(
shape=(2,),
dtype=np.float32,
minimum=[0.0, 0.0],
maximum=[1.0, 1.0],
name=f'touch_position_{i}'),
})
if enable_key_events:
action_spec['keycode'] = specs.DiscreteArray(
num_values=(1 << 16) - 1, name='keycode')
return action_spec
def base_observation_spec(height: int, width: int) -> dict[str, specs.Array]:
"""Default observation spec for AndroidEnv.
Args:
height: Height of the device screen in pixels.
width: Width of the device screen in pixels.
Returns:
pixels: Spec for the RGB screenshot of the device. Has shape (H, W, 3)
timedelta: Spec for time delta since the last observation (in microseconds).
The first timestep immediately after reset() will have this value set to
0.
orientation: Spec for the latest orientation in a one-hot representation:
[1, 0, 0, 0]: PORTRAIT (0 degrees)
[0, 1, 0, 0]: LANDSCAPE (90 degrees clockwise)
[0, 0, 1, 0]: PORTRAIT (180 degrees) ("upside down")
[0, 0, 0, 1]: LANDSCAPE (270 degrees clockwise)
"""
return {
'pixels':
specs.BoundedArray(
shape=(height, width, 3),
dtype=np.uint8,
name='pixels',
minimum=0,
maximum=255),
'timedelta':
specs.Array(shape=(), dtype=np.int64, name='timedelta'),
'orientation':
specs.BoundedArray(
shape=np.array([4]),
dtype=np.uint8,
name='orientation',
minimum=0,
maximum=1),
}
================================================
FILE: android_env/components/specs_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for specs.py."""
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import specs
from android_env.proto import task_pb2
from dm_env import specs as dm_env_specs
import numpy as np
class SpecsTest(parameterized.TestCase):
def test_base_action_spec(self):
action_spec = specs.base_action_spec(num_fingers=1)
for spec in action_spec.values():
self.assertIsInstance(spec, dm_env_specs.Array)
self.assertEqual(action_spec['action_type'].shape, ())
self.assertEqual(action_spec['action_type'].dtype, np.int32)
self.assertEqual(action_spec['touch_position'].shape, (2,))
self.assertEqual(action_spec['touch_position'].dtype, np.float32)
def test_base_action_spec_with_key_events(self):
action_spec = specs.base_action_spec(num_fingers=1, enable_key_events=True)
for spec in action_spec.values():
self.assertIsInstance(spec, dm_env_specs.Array)
self.assertEqual(action_spec['action_type'].shape, ())
self.assertEqual(action_spec['action_type'].dtype, np.int32)
self.assertEqual(action_spec['touch_position'].shape, (2,))
self.assertEqual(action_spec['touch_position'].dtype, np.float32)
self.assertEqual(action_spec['keycode'].shape, ())
self.assertEqual(action_spec['keycode'].dtype, np.int32)
def test_base_action_spec_multitouch(self):
action_spec = specs.base_action_spec(num_fingers=3)
self.assertLen(action_spec.keys(), 6)
for spec in action_spec.values():
self.assertIsInstance(spec, dm_env_specs.Array)
self.assertEqual(action_spec['action_type'].shape, ())
self.assertEqual(action_spec['action_type'].dtype, np.int32)
self.assertEqual(action_spec['touch_position'].shape, (2,))
self.assertEqual(action_spec['touch_position'].dtype, np.float32)
self.assertEqual(action_spec['action_type_2'].shape, ())
self.assertEqual(action_spec['action_type_2'].dtype, np.int32)
self.assertEqual(action_spec['touch_position_2'].shape, (2,))
self.assertEqual(action_spec['touch_position_2'].dtype, np.float32)
self.assertEqual(action_spec['action_type_3'].shape, ())
self.assertEqual(action_spec['action_type_3'].dtype, np.int32)
self.assertEqual(action_spec['touch_position_3'].shape, (2,))
self.assertEqual(action_spec['touch_position_3'].dtype, np.float32)
@parameterized.parameters(
(480, 320),
(100, 100),
(1440, 1960),
)
def test_base_observation_spec(self, height, width):
observation_spec = specs.base_observation_spec(height, width)
for spec in observation_spec.values():
self.assertIsInstance(spec, dm_env_specs.Array)
self.assertEqual(observation_spec['pixels'].shape, (height, width, 3))
self.assertEqual(observation_spec['pixels'].dtype, np.uint8)
self.assertEqual(observation_spec['timedelta'].shape, ())
self.assertEqual(observation_spec['timedelta'].dtype, np.int64)
self.assertEqual(observation_spec['orientation'].shape, (4,))
self.assertEqual(observation_spec['orientation'].dtype, np.uint8)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/components/task_manager.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TaskManager handles all events and information related to the task."""
import ast
from collections.abc import Callable, Iterable, Sequence
import copy
import datetime
import itertools
import json
import re
import threading
import time
from typing import Any
from absl import logging
from android_env.components import adb_call_parser as adb_call_parser_lib
from android_env.components import app_screen_checker
from android_env.components import config_classes
from android_env.components import dumpsys_thread
from android_env.components import log_stream as log_stream_lib
from android_env.components import logcat_thread
from android_env.components import setup_step_interpreter
from android_env.proto import task_pb2
import dm_env
import numpy as np
class TaskManager:
"""Handles all events and information related to the task."""
def __init__(
self,
task: task_pb2.Task,
config: config_classes.TaskManagerConfig | None = None,
):
"""Controls task-relevant events and information.
Args:
task: A task proto defining the RL task.
config: Configuration for this instance.
"""
self._task = task
self._config = config or config_classes.TaskManagerConfig()
self._lock = threading.Lock()
self._logcat_thread = None
self._dumpsys_thread = None
self._setup_step_interpreter = None
# Initialize stats.
self._stats = {
'episode_steps': 0,
'reset_count_step_timeout': 0,
'reset_count_user_exited': 0,
'reset_count_episode_end': 0,
'reset_count_max_duration_reached': 0,
'restart_count_max_bad_states': 0,
'task_updates': 0,
}
# Initialize internal state
self._task_start_time = None
self._bad_state_counter = 0
self._is_bad_episode = False
self._latest_values = {
'reward': 0.0,
'score': 0.0,
'extra': {},
'episode_end': False,
}
logging.info('Task config: %s', self._task)
def stats(self) -> dict[str, Any]:
"""Returns a dictionary of stats.
This method is expected to be called after setup_task() has been called.
"""
output = copy.deepcopy(self._stats)
if self._setup_step_interpreter is not None:
output.update(self._setup_step_interpreter.stats())
return output
def setup_task(self) -> None:
"""Performs one-off task setup.."""
self._setup_step_interpreter.interpret(self._task.setup_steps)
def stop(self) -> None:
"""Suspends task processing."""
n_tries = 3
for i in range(n_tries):
try:
self._stop_logcat_thread()
break
except: # pylint: disable=bare-except
logging.exception(
'Failed to stop logcat thread [%d/%d]. Continuing.', i + 1, n_tries
)
time.sleep(1)
def start(
self,
adb_call_parser_factory: Callable[[], adb_call_parser_lib.AdbCallParser],
log_stream: log_stream_lib.LogStream,
) -> None:
"""Starts task processing."""
self._start_logcat_thread(log_stream=log_stream)
self._logcat_thread.resume()
self._start_dumpsys_thread(adb_call_parser_factory())
self._start_setup_step_interpreter(adb_call_parser_factory())
def reset_task(self) -> None:
"""Resets a task for a new run."""
self._logcat_thread.pause()
self._setup_step_interpreter.interpret(self._task.reset_steps)
self._logcat_thread.resume()
# Reset some other variables.
if not self._is_bad_episode:
self._bad_state_counter = 0
self._is_bad_episode = False
self._task_start_time = datetime.datetime.now()
with self._lock:
self._latest_values = {
'reward': 0.0,
'score': 0.0,
'extra': {},
'episode_end': False,
}
def rl_reset(self, observation: dict[str, Any]) -> dm_env.TimeStep:
"""Performs one RL step."""
self._stats['episode_steps'] = 0
self._logcat_thread.line_ready().wait()
with self._lock:
extras = self._get_current_extras()
observation['extras'] = extras
return dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=0.0,
discount=0.0,
observation=observation,
)
def rl_step(self, observation: dict[str, Any]) -> dm_env.TimeStep:
"""Performs one RL step."""
self._stats['episode_steps'] += 1
self._logcat_thread.line_ready().wait()
with self._lock:
reward = self._get_current_reward()
extras = self._get_current_extras()
transition_fn = self._determine_transition_fn()
observation['extras'] = extras
return transition_fn(reward=reward, observation=observation)
def _get_current_reward(self) -> float:
"""Returns total reward accumulated since the last step."""
reward = self._latest_values['reward']
self._latest_values['reward'] = 0.0
return reward
def _get_current_extras(self) -> dict[str, Any]:
"""Returns task extras accumulated since the last step."""
extras = {}
for name, values in self._latest_values['extra'].items():
extras[name] = np.stack(values)
self._latest_values['extra'] = {}
return extras
def _determine_transition_fn(self) -> Callable[..., dm_env.TimeStep]:
"""Determines the type of RL transition will be used."""
# Check if user existed the task
if self._dumpsys_thread.check_user_exited():
self._increment_bad_state()
self._stats['reset_count_user_exited'] += 1
logging.warning('User exited the task. Truncating the episode.')
logging.info('************* END OF EPISODE *************')
return dm_env.truncation
# Check if episode has ended
if self._latest_values['episode_end']:
self._stats['reset_count_episode_end'] += 1
logging.info('End of episode from logcat! Ending episode.')
return dm_env.termination
# Check if step limit or time limit has been reached
if self._task.max_episode_steps > 0:
if self._stats['episode_steps'] > self._task.max_episode_steps:
self._stats['reset_count_max_duration_reached'] += 1
logging.info(
'Maximum task duration (%r steps) reached. Truncating the episode.',
self._task.max_episode_steps,
)
return dm_env.truncation
if self._task.max_episode_sec > 0.0:
task_duration = datetime.datetime.now() - self._task_start_time
max_episode_sec = self._task.max_episode_sec
if task_duration > datetime.timedelta(seconds=int(max_episode_sec)):
self._stats['reset_count_max_duration_reached'] += 1
logging.info(
'Maximum task duration (%r sec) reached. Truncating the episode.',
max_episode_sec,
)
return dm_env.truncation
return dm_env.transition
def _start_setup_step_interpreter(
self, adb_call_parser: adb_call_parser_lib.AdbCallParser
):
self._setup_step_interpreter = setup_step_interpreter.SetupStepInterpreter(
adb_call_parser=adb_call_parser
)
def _start_logcat_thread(self, log_stream: log_stream_lib.LogStream):
log_stream.set_log_filters(list(self._task.log_parsing_config.filters))
self._logcat_thread = logcat_thread.LogcatThread(log_stream=log_stream)
for event_listener in self._logcat_listeners():
self._logcat_thread.add_event_listener(event_listener)
def _start_dumpsys_thread(
self, adb_call_parser: adb_call_parser_lib.AdbCallParser
):
self._dumpsys_thread = dumpsys_thread.DumpsysThread(
app_screen_checker=app_screen_checker.AppScreenChecker(
adb_call_parser=adb_call_parser,
expected_app_screen=self._task.expected_app_screen,
),
check_frequency=self._config.dumpsys_check_frequency,
max_failed_current_activity=self._config.max_failed_current_activity,
)
def _stop_logcat_thread(self):
if self._logcat_thread is not None:
self._logcat_thread.kill()
self._logcat_thread = None
def _increment_bad_state(self) -> None:
"""Increments the bad state counter.
Bad states are errors that shouldn't happen and that trigger an
episode reset. If enough bad states have been seen consecutively,
we restart the simulation in the hope of returning the simulation
to a good state.
"""
logging.warning('Bad state detected.')
if self._config.max_bad_states:
self._is_bad_episode = True
self._bad_state_counter += 1
logging.warning('Bad state counter: %d.', self._bad_state_counter)
if self._bad_state_counter >= self._config.max_bad_states:
logging.error('Too many consecutive bad states. Restarting simulator.')
self._stats['restart_count_max_bad_states'] += 1
self._should_restart = True
else:
logging.warning('Max bad states not set, bad states will be ignored.')
def _logcat_listeners(self) -> Iterable[logcat_thread.EventListener]:
"""Creates list of EventListeners for logcat thread."""
# Defaults to 'a^' since that regex matches no string by definition.
regexps = self._task.log_parsing_config.log_regexps
return itertools.chain(
self._reward_listeners(regexps),
self._reward_event_listeners(regexps),
self._score_listeners(regexps),
self._episode_end_listeners(regexps),
self._extras_listeners(regexps),
self._json_extras_listeners(regexps),
)
def _reward_listeners(
self, regexps: task_pb2.LogParsingConfig.LogRegexps
) -> Iterable[logcat_thread.EventListener]:
"""Creates an iterable of reward listeners."""
def _reward_handler(event: re.Pattern[str], match: re.Match[str]):
del event
reward = float(match.group(1))
with self._lock:
self._latest_values['reward'] += reward
for regexp in regexps.reward:
yield logcat_thread.EventListener(
regexp=re.compile(regexp or 'a^'), handler_fn=_reward_handler
)
def _reward_event_listeners(
self, regexps: task_pb2.LogParsingConfig.LogRegexps
) -> Iterable[logcat_thread.EventListener]:
"""Creates an iterable of reward event listeners."""
for reward_event in regexps.reward_event:
def get_reward_event_handler(reward):
def _reward_event_handler(event: re.Pattern[str], match: re.Match[str]):
del event, match
with self._lock:
self._latest_values['reward'] += reward
return _reward_event_handler
yield logcat_thread.EventListener(
regexp=re.compile(reward_event.event or 'a^'),
handler_fn=get_reward_event_handler(reward_event.reward),
)
def _score_listeners(
self, regexps: task_pb2.LogParsingConfig.LogRegexps
) -> Iterable[logcat_thread.EventListener]:
"""Creates an iterable of score listeners."""
def _score_handler(event: re.Pattern[str], match: re.Match[str]):
del event
current_score = float(match.group(1))
with self._lock:
current_reward = current_score - self._latest_values['score']
self._latest_values['score'] = current_score
self._latest_values['reward'] += current_reward
yield logcat_thread.EventListener(
regexp=re.compile(regexps.score or 'a^'), handler_fn=_score_handler
)
def _episode_end_listeners(
self, regexps: task_pb2.LogParsingConfig.LogRegexps
) -> Iterable[logcat_thread.EventListener]:
"""Creates an iterable of episode end listeners."""
def _episode_end_handler(event: re.Pattern[str], match: re.Match[str]):
del event, match
with self._lock:
self._latest_values['episode_end'] = True
for regexp in regexps.episode_end:
yield logcat_thread.EventListener(
regexp=re.compile(regexp or 'a^'), handler_fn=_episode_end_handler
)
def _process_extra(self, extra_name: str, extra: Sequence[int | float]):
extra = np.array(extra)
with self._lock:
latest_extras = self._latest_values['extra']
if extra_name in latest_extras:
# If latest extra is not flushed, append.
if (
len(latest_extras[extra_name])
>= self._config.extras_max_buffer_size
):
latest_extras[extra_name].pop(0)
latest_extras[extra_name].append(extra)
else:
latest_extras[extra_name] = [extra]
self._latest_values['extra'] = latest_extras
def _extras_listeners(
self, regexps: task_pb2.LogParsingConfig.LogRegexps
) -> Iterable[logcat_thread.EventListener]:
"""Creates an iterable of extras listeners."""
def _extras_handler(event: re.Pattern[str], match: re.Match[str]):
del event
extra_name = match.group('name')
extra = match.group('extra')
if extra:
try:
extra = ast.literal_eval(extra)
except (
ValueError,
TypeError,
SyntaxError,
MemoryError,
RecursionError,
):
logging.exception('Could not parse extra: %s', extra)
# Don't try to process the extra as text; that would probably crash.
return
else:
# No extra value provided for boolean extra. Setting value to True.
extra = 1
self._process_extra(extra_name, extra)
for regexp in regexps.extra:
yield logcat_thread.EventListener(
regexp=re.compile(regexp or 'a^'), handler_fn=_extras_handler
)
def _json_extras_listeners(
self, regexps: task_pb2.LogParsingConfig.LogRegexps
) -> Iterable[logcat_thread.EventListener]:
"""Creates an iterable of JSON extras listeners."""
def _json_extras_handler(event: re.Pattern[str], match: re.Match[str]):
del event
extra_data = match.group('json_extra')
try:
extra = dict(json.loads(extra_data))
except ValueError:
logging.error('JSON string could not be parsed: %s', extra_data)
return
for extra_name, extra_value in extra.items():
self._process_extra(extra_name, extra_value)
for regexp in regexps.json_extra:
yield logcat_thread.EventListener(
regexp=re.compile(regexp or 'a^'), handler_fn=_json_extras_handler
)
================================================
FILE: android_env/components/task_manager_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.components.task_manager.py."""
import json
from unittest import mock
from absl.testing import absltest
from android_env.components import adb_call_parser as adb_call_parser_lib
from android_env.components import dumpsys_thread
from android_env.components import log_stream
from android_env.components import logcat_thread
from android_env.components import setup_step_interpreter
from android_env.components import task_manager
from android_env.proto import task_pb2
import numpy as np
class TaskManagerTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.addCleanup(mock.patch.stopall) # Disable previous patches.
self._setup_step_interpreter = mock.create_autospec(
setup_step_interpreter.SetupStepInterpreter)
self._dumpsys_thread = mock.create_autospec(dumpsys_thread.DumpsysThread)
self._logcat_thread = mock.create_autospec(logcat_thread.LogcatThread)
self._log_stream = mock.create_autospec(log_stream.LogStream)
mock.patch.object(
setup_step_interpreter,
'SetupStepInterpreter',
return_value=self._setup_step_interpreter).start()
mock.patch.object(
dumpsys_thread, 'DumpsysThread',
return_value=self._dumpsys_thread).start()
mock.patch.object(
logcat_thread, 'LogcatThread',
return_value=self._logcat_thread).start()
mock.patch.object(
log_stream, 'LogStream',
return_value=self._log_stream).start()
def test_start(self):
task_mgr = task_manager.TaskManager(task=task_pb2.Task())
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
self.assertIsNotNone(task_mgr._logcat_thread)
self.assertIsNotNone(task_mgr._dumpsys_thread)
self.assertIsNotNone(task_mgr._setup_step_interpreter)
def test_setup_task(self):
task_mgr = task_manager.TaskManager(task=task_pb2.Task())
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
self._setup_step_interpreter.interpret.assert_called_once()
def test_step_count(self):
task_mgr = task_manager.TaskManager(task=task_pb2.Task())
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
task_mgr.rl_reset(observation={})
self.assertEqual(task_mgr.stats()['episode_steps'], 0)
task_mgr.rl_step(observation={})
self.assertEqual(task_mgr.stats()['episode_steps'], 1)
task_mgr.rl_step(observation={})
self.assertEqual(task_mgr.stats()['episode_steps'], 2)
task_mgr.rl_reset(observation={})
self.assertEqual(task_mgr.stats()['episode_steps'], 0)
def test_get_current_reward(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
match = event_listener.regexp.match('Reward: 123.0')
if match is None: # Ignore events that are not rewards.
return
event_listener.handler_fn(event_listener.regexp, match)
task = task_pb2.Task()
task.log_parsing_config.log_regexps.reward.extend([
'^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$'
])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
self.assertEqual(timestep.reward, 123.0)
np.testing.assert_equal(timestep.observation['pixels'], np.array([1, 2, 3]))
def test_reward_event(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
match_1 = event_listener.regexp.match('foo_1')
match_2 = event_listener.regexp.match('foo_2')
match_3 = event_listener.regexp.match('Reward: 2.0')
if match_1:
event_listener.handler_fn(event_listener.regexp, match_1)
if match_2:
event_listener.handler_fn(event_listener.regexp, match_2)
if match_3:
event_listener.handler_fn(event_listener.regexp, match_3)
task = task_pb2.Task()
reward_event_1 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent(
event='foo_1', reward=5.0)
reward_event_2 = task_pb2.LogParsingConfig.LogRegexps.RewardEvent(
event='foo_2', reward=-1.0)
task.log_parsing_config.log_regexps.reward_event.extend(
[reward_event_1, reward_event_2])
task.log_parsing_config.log_regexps.reward.extend(
['^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$'])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
self.assertEqual(timestep.reward, 6.0)
def test_get_current_reward_via_score(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
event = event_listener.regexp
match = event.match('score: 200.0')
if match is None: # Ignore events that are not scores.
return
event_listener.handler_fn(event, match)
# Scores are accumulated by their differences, so a subsequent lower score
# means that the final reward decreases.
match = event.match('score: 185')
event_listener.handler_fn(event, match)
task = task_pb2.Task()
task.log_parsing_config.log_regexps.score = (
'^score: ([-+]?[0-9]*\\.?[0-9]*)$')
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
self.assertEqual(timestep.reward, 185.0)
def test_get_current_extras(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
event = event_listener.regexp
match = event.match('extra: some_extra [1, 2]')
if match is None: # Ignore events that are not extras.
return
# Emit events.
fn = event_listener.handler_fn
fn(event, event.match('extra: an_extra [1, 2, 3]'))
fn(event, event.match('extra: an_extra [4, 5, 6]'))
fn(event, event.match('extra: another_extra 0.5'))
fn(event, event.match('extra: multi_dimension_extra [[9,8,7],[6,5,4]]'))
fn(event, event.match('extra: boolean_extra'))
# Setup the task and trigger the listener.
task = task_pb2.Task()
task.log_parsing_config.log_regexps.extra.extend([
'^extra: (?P[^ ]*)[ ]?(?P.*)$'
])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
# Check expectations.
self.assertIn('extras', timestep.observation)
extras = timestep.observation['extras']
np.testing.assert_almost_equal([[1, 2, 3], [4, 5, 6]],
extras.get('an_extra'))
np.testing.assert_almost_equal([0.5], extras.get('another_extra'))
np.testing.assert_almost_equal([[[9, 8, 7], [6, 5, 4]]],
extras.get('multi_dimension_extra'))
np.testing.assert_equal([1], extras.get('boolean_extra'))
def test_get_current_extras_json_format(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
event = event_listener.regexp
match = event.match('json_extra: {}')
if match is None: # Ignore events that are not extras.
return
# Emit events.
extra = {
'extra_scalar': 0,
'extra_list': [1, 2, 3, 4],
'extra_dict': {
'foo': 'bar'
},
'extra_string': 'a_string'
}
extra_update = {'extra_string': 'a_new_string', 'extra_float': 0.6}
fn = event_listener.handler_fn
fn(event, event.match(f'json_extra: {json.dumps(extra)}'))
fn(event, event.match(f'json_extra: {json.dumps(extra_update)}'))
# Setup the task and trigger the listener.
task = task_pb2.Task()
task.log_parsing_config.log_regexps.json_extra.extend([
'^json_extra: (?P.*)$'
])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
# Check expectations.
self.assertIn('extras', timestep.observation)
extras = timestep.observation['extras']
expected_extra = {
'extra_scalar': [0],
'extra_list': [[1, 2, 3, 4]],
'extra_dict': [{
'foo': 'bar'
}],
'extra_string': ['a_string', 'a_new_string'],
'extra_float': [0.6]
}
np.testing.assert_almost_equal(
expected_extra.get('extra_scalar'), extras.get('extra_scalar'))
np.testing.assert_almost_equal(
expected_extra.get('extra_list'), extras.get('extra_list'))
np.testing.assert_equal(
expected_extra.get('extra_string'), extras.get('extra_string'))
np.testing.assert_almost_equal(
expected_extra.get('extra_float'), extras.get('extra_float'))
np.testing.assert_equal(
expected_extra.get('extra_dict'), extras.get('extra_dict'))
def test_get_current_extras_failed_to_parse(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
event = event_listener.regexp
match = event.match('extra: some_extra [1, 2]')
if match is None: # Ignore events that are not extras.
return
# Emit events.
fn = event_listener.handler_fn
fn(event, event.match('extra: extra_with_malformed_1 [1]'))
fn(event, event.match('extra: extra_with_malformed_1 [\'this is \\ bad]'))
fn(event, event.match('extra: extra_with_malformed_1 [2]'))
fn(event, event.match('extra: extra_with_malformed_2 [\'this is bad]'))
fn(event, event.match('extra: extra_with_malformed_2 [1]'))
fn(event, event.match('extra: extra_malformed_only [_very_bad_news]'))
# Setup the task and trigger the listener.
task = task_pb2.Task()
task.log_parsing_config.log_regexps.extra.extend([
'^extra: (?P[^ ]*)[ ]?(?P.*)$'
])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
# Check expectations.
self.assertIn('extras', timestep.observation)
extras = timestep.observation['extras']
np.testing.assert_almost_equal(extras.get('extra_with_malformed_1'),
[[1], [2]])
np.testing.assert_almost_equal(extras.get('extra_with_malformed_2'), [[1]])
self.assertNotIn('extra_malformed_only', extras)
def test_multi_log_regexp(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
match = event_listener.regexp.match('Reward_2: 123.0')
if match is None: # Ignore events that are not rewards.
return
event_listener.handler_fn(event_listener.regexp, match)
task = task_pb2.Task()
task.log_parsing_config.log_regexps.reward.extend([
'^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$',
'^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$'
])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
self.assertEqual(timestep.reward, 123.0)
def test_multi_reward_regexp(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.'
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
match_1 = event_listener.regexp.match('Reward_1: 5.0')
match_2 = event_listener.regexp.match('Reward_2: 10.0')
if match_1:
event_listener.handler_fn(event_listener.regexp, match_1)
if match_2:
event_listener.handler_fn(event_listener.regexp, match_2)
task = task_pb2.Task()
task.log_parsing_config.log_regexps.reward.extend([
'^[Rr]eward_1: ([-+]?[0-9]*\\.?[0-9]*)$',
'^[Rr]eward_2: ([-+]?[0-9]*\\.?[0-9]*)$',
])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
self.assertEqual(timestep.reward, 15.0)
def test_determine_transition_fn(self):
# Replace `LogcatThread.add_event_listener` with one that simply calls `fn`
# right away.
def my_add_ev_listener(event_listener: logcat_thread.EventListener):
# Check that the event matches what's expected.
event = event_listener.regexp
match = event.match('I am done!')
if match is None: # Ignore events that are not episode end.
return
event_listener.handler_fn(event, match)
task = task_pb2.Task()
task.log_parsing_config.log_regexps.episode_end.extend(['I am done!'])
task_mgr = task_manager.TaskManager(task=task)
self._logcat_thread.add_event_listener.side_effect = my_add_ev_listener
adb_call_parser = mock.create_autospec(adb_call_parser_lib.AdbCallParser)
task_mgr.start(lambda: adb_call_parser, log_stream=self._log_stream)
task_mgr.setup_task()
timestep = task_mgr.rl_step(
observation={
'pixels': np.array([1, 2, 3]),
})
self.assertTrue(timestep.last())
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/env_interface.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract AndroidEnv interface.
AndroidEnv is a standard dm_env.Environment instance, but it also offers a few
extra methods that clients may use for extended functionality.
"""
import abc
from typing import Any
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
import dm_env
from dm_env import specs
import numpy as np
class AndroidEnvInterface(dm_env.Environment, metaclass=abc.ABCMeta):
"""Pure virtual interface for AndroidEnv implementations."""
# Methods required by dm_env.Environment.
@abc.abstractmethod
def action_spec(self) -> dict[str, specs.Array]:
"""Returns the action specification."""
@abc.abstractmethod
def observation_spec(self) -> dict[str, specs.Array]:
"""Returns the observation specification."""
@abc.abstractmethod
def reset(self) -> dm_env.TimeStep:
"""Resets the current episode."""
@abc.abstractmethod
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
"""Executes `action` and returns a `TimeStep`."""
@abc.abstractmethod
def close(self) -> None:
"""Frees up resources."""
# Extensions provided by AndroidEnv.
def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
"""Returns extra info provided by tasks."""
return {}
@property
def raw_action(self) -> Any:
"""Returns the latest action."""
@property
def raw_observation(self) -> Any:
"""Returns the latest observation."""
def stats(self) -> dict[str, Any]:
"""Returns information generated inside the implementation."""
return {}
def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
"""Executes `call` and returns its response."""
return adb_pb2.AdbResponse()
def load_state(
self, request: state_pb2.LoadStateRequest
) -> state_pb2.LoadStateResponse:
"""Loads a state.
Args:
request: A `LoadStateRequest` containing any parameters necessary to
specify how/what state to load.
Returns:
A `LoadStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
raise NotImplementedError('This environment does not support loading state')
def save_state(
self, request: state_pb2.SaveStateRequest
) -> state_pb2.SaveStateResponse:
"""Saves a state.
Args:
request: A `SaveStateRequest` containing any parameters necessary to
specify how/what state to save.
Returns:
A `SaveStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
raise NotImplementedError('This environment does not support saving state')
================================================
FILE: android_env/environment.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Android environment implementation."""
from typing import Any
from absl import logging
from android_env import env_interface
from android_env.components import adb_call_parser
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators import base_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
import dm_env
import numpy as np
class AndroidEnv(env_interface.AndroidEnvInterface):
"""An RL environment that interacts with Android apps."""
def __init__(
self,
simulator: base_simulator.BaseSimulator,
coordinator: coordinator_lib.Coordinator,
task_manager: task_manager_lib.TaskManager,
):
"""Initializes the state of this AndroidEnv object."""
self._simulator = simulator
self._coordinator = coordinator
self._task_manager = task_manager
self._latest_action = {}
self._latest_observation = {}
self._latest_extras = {}
self._reset_next_step = True
self._is_closed = False
logging.info('Action spec: %s', self.action_spec())
logging.info('Observation spec: %s', self.observation_spec())
def __del__(self) -> None:
self.close()
# Methods required by dm_env.Environment.
def action_spec(self) -> dict[str, dm_env.specs.Array]:
return self._coordinator.action_spec()
def observation_spec(self) -> dict[str, dm_env.specs.Array]:
return self._coordinator.observation_spec()
def reset(self) -> dm_env.TimeStep:
"""Resets the environment for a new RL episode."""
logging.info('Resetting AndroidEnv...')
# Execute a reset. Timestep will be of type FIRST.
timestep = self._coordinator.rl_reset()
# Process relevant information.
if timestep.observation is not None:
self._latest_extras = timestep.observation.pop('extras')
self._latest_observation = timestep.observation.copy()
else:
# If the observation is None, we return the latest observation again.
timestep = timestep._replace(observation=self._latest_observation.copy())
self._latest_action = {}
self._reset_next_step = False
logging.info('Done resetting AndroidEnv.')
logging.info('************* NEW EPISODE *************')
return timestep
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
"""Takes a step in the environment."""
# Check if it's time to reset the episode.
if self._reset_next_step:
return self.reset()
# Execute selected action.
timestep = self._coordinator.rl_step(action)
# Process relevant information.
if timestep.observation is not None:
self._latest_extras = timestep.observation.pop('extras')
self._latest_observation = timestep.observation.copy()
else:
# If the observation is None, we return the latest observation again.
timestep = timestep._replace(observation=self._latest_observation.copy())
self._latest_action = action.copy()
if timestep.last():
self._reset_next_step = True
logging.info('************* END OF EPISODE *************')
return timestep
def close(self) -> None:
"""Cleans up running processes, threads and local files."""
if not self._is_closed:
logging.info('Cleaning up AndroidEnv...')
if hasattr(self, '_coordinator'):
self._coordinator.close()
logging.info('Done cleaning up AndroidEnv.')
self._is_closed = True
# Extensions provided by AndroidEnv.
def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
"""Returns latest task extras."""
task_extras = {} # Build a copy to avoid reusing objects.
for k, spec in self._latest_extras.items():
extra_values = spec.astype(spec.dtype)
task_extras[k] = extra_values[-1] if latest_only else extra_values
return task_extras
@property
def raw_action(self):
return self._latest_action.copy()
@property
def raw_observation(self):
return self._latest_observation.copy()
def stats(self) -> dict[str, Any]:
coordinator_stats = self._coordinator.stats()
task_manager_stats = self._task_manager.stats()
return coordinator_stats | task_manager_stats
def execute_adb_call(self, call: adb_pb2.AdbRequest) -> adb_pb2.AdbResponse:
return self._coordinator.execute_adb_call(call)
def load_state(
self, request: state_pb2.LoadStateRequest
) -> state_pb2.LoadStateResponse:
"""Loads a state.
Args:
request: A `LoadStateRequest` containing any parameters necessary to
specify how/what state to load.
Returns:
A `LoadStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
self._task_manager.stop()
response = self._simulator.load_state(request)
self._task_manager.start(
adb_call_parser_factory=lambda: adb_call_parser.AdbCallParser(
self._simulator.create_adb_controller()
),
log_stream=self._simulator.create_log_stream(),
)
return response
def save_state(
self, request: state_pb2.SaveStateRequest
) -> state_pb2.SaveStateResponse:
"""Saves a state.
Args:
request: A `SaveStateRequest` containing any parameters necessary to
specify how/what state to save.
Returns:
A `SaveStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
return self._simulator.save_state(request)
================================================
FILE: android_env/environment_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for AndroidEnv."""
from unittest import mock
from absl.testing import absltest
from android_env import environment
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators import base_simulator
from android_env.components.simulators.fake import fake_simulator
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
import dm_env
import numpy as np
def _create_mock_coordinator() -> coordinator_lib.Coordinator:
coordinator = mock.create_autospec(coordinator_lib.Coordinator)
coordinator.action_spec.return_value = {
'action_type':
dm_env.specs.DiscreteArray(num_values=3),
'touch_position':
dm_env.specs.BoundedArray(
shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0),
}
coordinator.observation_spec.return_value = {
'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8),
'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
}
return coordinator
def _create_fake_simulator() -> fake_simulator.FakeSimulator:
return fake_simulator.FakeSimulator(
config=config_classes.FakeSimulatorConfig(screen_dimensions=(123, 456))
)
class AndroidEnvTest(absltest.TestCase):
def test_specs(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
# Check action spec.
self.assertNotEmpty(env.action_spec())
self.assertIn('action_type', env.action_spec())
self.assertIsInstance(env.action_spec()['action_type'],
dm_env.specs.DiscreteArray)
self.assertIn('touch_position', env.action_spec())
self.assertIsInstance(env.action_spec()['touch_position'],
dm_env.specs.BoundedArray)
# Check observation spec.
self.assertNotEmpty(env.observation_spec())
self.assertIn('pixels', env.observation_spec())
self.assertIsInstance(env.observation_spec()['pixels'], dm_env.specs.Array)
# The `pixels` entry in the observation spec should match the screen size of
# the simulator with three color channels (RGB).
self.assertEqual(env.observation_spec()['pixels'].shape, (123, 456, 3))
self.assertIn('timedelta', env.observation_spec())
self.assertIsInstance(env.observation_spec()['timedelta'],
dm_env.specs.Array)
# The `timedelta` should be a scalar.
self.assertEqual(env.observation_spec()['timedelta'].shape, ())
self.assertIn('orientation', env.observation_spec())
# The `orientation` should be a one-hot vector with four dimensions.
self.assertIsInstance(env.observation_spec()['orientation'],
dm_env.specs.Array)
self.assertEqual(env.observation_spec()['orientation'].shape, (4,))
def test_reset_and_step(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
coordinator.action_spec.return_value = {
'action_type':
dm_env.specs.DiscreteArray(num_values=3),
'touch_position':
dm_env.specs.BoundedArray(
shape=(2,), dtype=np.float32, minimum=0.0, maximum=1.0),
}
coordinator.observation_spec.return_value = {
'pixels': dm_env.specs.Array(shape=(123, 456, 3), dtype=np.uint8),
'timedelta': dm_env.specs.Array(shape=(), dtype=np.int64),
'orientation': dm_env.specs.Array(shape=(4,), dtype=np.uint8),
}
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
coordinator.rl_reset.return_value = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
reward=0.0,
discount=0.0,
observation={
'pixels': np.random.rand(987, 654, 3),
'timedelta': 123456,
'orientation': np.array((1, 0, 0, 0)),
'extras': {
'click': np.array([[246]], dtype=np.int64)
}
},
)
ts = env.reset()
self.assertIsInstance(ts, dm_env.TimeStep)
# After a `reset()` the TimeStep should follow some expectations.
self.assertTrue(ts.first())
self.assertEqual(ts.reward, 0.0)
self.assertEqual(ts.discount, 0.0)
obs = ts.observation
self.assertIn('pixels', obs)
self.assertEqual(obs['pixels'].shape, (987, 654, 3))
self.assertIn('timedelta', obs)
self.assertEqual(obs['timedelta'], 123456)
self.assertIn('orientation', obs)
self.assertEqual(obs['orientation'].shape, (4,))
np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))
# Extras should also be provided.
extras = env.task_extras()
self.assertIn('click', extras)
self.assertEqual(extras['click'], np.array([246], dtype=np.int64))
coordinator.stats.return_value = {'my_measurement': 135}
task_manager.stats.return_value = {'another_measurement': 79}
# Step again in the environment and check expectations again.
pixels = np.random.rand(987, 654, 3)
latest_observation = {
'pixels': pixels,
'timedelta': 123456,
'orientation': np.array((1, 0, 0, 0)),
'extras': {
'click': np.array([[246]], dtype=np.int64)
}
}
coordinator.rl_step.return_value = dm_env.transition(
reward=0.0,
discount=0.0,
observation=latest_observation,
)
ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
self.assertIsInstance(ts, dm_env.TimeStep)
# The StepType now should NOT be FIRST.
self.assertFalse(ts.first())
self.assertEqual(ts.reward, 0.0)
self.assertEqual(ts.discount, 0.0)
obs = ts.observation
self.assertIn('pixels', obs)
self.assertEqual(obs['pixels'].shape, (987, 654, 3))
self.assertIn('timedelta', obs)
self.assertEqual(obs['timedelta'], 123456)
self.assertIn('orientation', obs)
self.assertEqual(obs['orientation'].shape, (4,))
np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))
# Extras should still be provided.
extras = env.task_extras()
self.assertIn('click', extras)
self.assertEqual(extras['click'], np.array([246], dtype=np.int64))
# At this point these methods and properties should return something.
self.assertNotEmpty(env.stats())
self.assertNotEmpty(env.raw_observation)
self.assertNotIn('extras', env.raw_observation)
self.assertNotEmpty(env.raw_action)
# If the observation is None, we want to return the latest observation.
coordinator.rl_step.return_value = dm_env.truncation(
reward=0.0,
observation=None,
)
ts = env.step({'action_type': 1, 'touch_position': (10, 20)})
self.assertIsInstance(ts, dm_env.TimeStep)
# Assert the observation matches the latest observation.
obs = ts.observation
self.assertIn('pixels', obs)
self.assertEqual(obs['pixels'].shape, (987, 654, 3))
np.testing.assert_equal(obs['pixels'], pixels)
self.assertIn('timedelta', obs)
self.assertEqual(obs['timedelta'], 123456)
self.assertIn('orientation', obs)
self.assertEqual(obs['orientation'].shape, (4,))
np.testing.assert_equal(obs['orientation'], (1, 0, 0, 0))
def test_adb_call(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
call = adb_pb2.AdbRequest(
force_stop=adb_pb2.AdbRequest.ForceStop(package_name='blah'))
expected_response = adb_pb2.AdbResponse(
status=adb_pb2.AdbResponse.Status.OK)
coordinator.execute_adb_call.return_value = expected_response
response = env.execute_adb_call(call)
self.assertEqual(response, expected_response)
coordinator.execute_adb_call.assert_called_once_with(call)
def test_load_state(self):
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
expected_response = state_pb2.LoadStateResponse(
status=state_pb2.LoadStateResponse.Status.OK
)
request = state_pb2.LoadStateRequest(args={'foo': 'bar'})
simulator.load_state.return_value = expected_response
response = env.load_state(request)
self.assertEqual(response, expected_response)
simulator.load_state.assert_called_once_with(request)
task_manager.stop.assert_called_once()
task_manager.start.assert_called_once()
def test_save_state(self):
simulator = mock.create_autospec(base_simulator.BaseSimulator)
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
expected_response = state_pb2.SaveStateResponse(
status=state_pb2.SaveStateResponse.Status.OK
)
request = state_pb2.SaveStateRequest(args={'foo': 'bar'})
simulator.save_state.return_value = expected_response
response = env.save_state(request)
self.assertEqual(response, expected_response)
simulator.save_state.assert_called_once_with(request)
def test_double_close(self):
simulator = _create_fake_simulator()
coordinator = _create_mock_coordinator()
task_manager = mock.create_autospec(task_manager_lib.TaskManager)
env = environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
env.close()
env.close()
coordinator.close.assert_called_once()
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/loader.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Function for loading AndroidEnv."""
import os
from absl import logging
from android_env import environment
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import device_settings as device_settings_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators.emulator import emulator_simulator
from android_env.components.simulators.fake import fake_simulator
from android_env.proto import task_pb2
from google.protobuf import text_format
def _load_task(task_config: config_classes.TaskConfig) -> task_pb2.Task:
"""Returns the task according to `task_config`."""
task = task_pb2.Task()
match task_config:
case config_classes.FilesystemTaskConfig():
with open(task_config.path, 'r') as proto_file:
text_format.Parse(proto_file.read(), task)
case _:
logging.error('Unsupported TaskConfig: %r', task_config)
return task
def load(config: config_classes.AndroidEnvConfig) -> environment.AndroidEnv:
"""Loads an AndroidEnv instance."""
task = _load_task(config.task)
task_manager = task_manager_lib.TaskManager(task)
match config.simulator:
case config_classes.EmulatorConfig():
_process_emulator_launcher_config(config.simulator)
simulator = emulator_simulator.EmulatorSimulator(config=config.simulator)
case config_classes.FakeSimulatorConfig():
simulator = fake_simulator.FakeSimulator(config=config.simulator)
case _:
raise ValueError('Unsupported simulator config: {config.simulator}')
device_settings = device_settings_lib.DeviceSettings(simulator)
coordinator = coordinator_lib.Coordinator(
simulator, task_manager, device_settings
)
return environment.AndroidEnv(
simulator=simulator, coordinator=coordinator, task_manager=task_manager
)
def _process_emulator_launcher_config(
emulator_config: config_classes.EmulatorConfig,
) -> None:
"""Adjusts the configuration of the emulator depending on some conditions."""
# Expand the user directory if specified.
launcher_config = emulator_config.emulator_launcher
launcher_config.android_avd_home = os.path.expanduser(
launcher_config.android_avd_home
)
launcher_config.android_sdk_root = os.path.expanduser(
launcher_config.android_sdk_root
)
launcher_config.emulator_path = os.path.expanduser(
launcher_config.emulator_path
)
emulator_config.adb_controller.adb_path = os.path.expanduser(
emulator_config.adb_controller.adb_path
)
================================================
FILE: android_env/loader_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for loader."""
import builtins
import os
from unittest import mock
from absl.testing import absltest
from android_env import env_interface
from android_env import loader
from android_env.components import config_classes
from android_env.components import coordinator as coordinator_lib
from android_env.components import device_settings as device_settings_lib
from android_env.components import task_manager as task_manager_lib
from android_env.components.simulators.emulator import emulator_simulator
from android_env.components.simulators.fake import fake_simulator
from android_env.proto import task_pb2
class LoaderTest(absltest.TestCase):
@mock.patch.object(task_manager_lib, 'TaskManager', autospec=True)
@mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True)
@mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True)
@mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
@mock.patch.object(builtins, 'open', autospec=True)
def test_load_emulator(
self,
mock_open,
mock_coordinator,
mock_device_settings,
mock_simulator_class,
mock_task_manager,
):
# Arrange.
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.read.return_value = ''
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path='some/path/'),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
avd_name='my_avd',
android_avd_home='~/.android/avd',
android_sdk_root='~/Android/Sdk',
emulator_path='~/Android/Sdk/emulator/emulator',
run_headless=False,
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='~/Android/Sdk/platform-tools/adb',
),
),
)
# Act.
env = loader.load(config)
# Assert.
self.assertIsInstance(env, env_interface.AndroidEnvInterface)
mock_simulator_class.assert_called_with(
config=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
avd_name='my_avd',
android_avd_home=os.path.expanduser('~/.android/avd'),
android_sdk_root=os.path.expanduser('~/Android/Sdk'),
emulator_path=os.path.expanduser(
'~/Android/Sdk/emulator/emulator'
),
run_headless=False,
gpu_mode='swangle_indirect',
),
adb_controller=config_classes.AdbControllerConfig(
adb_path=os.path.expanduser('~/Android/Sdk/platform-tools/adb'),
adb_server_port=5037,
),
)
)
mock_coordinator.assert_called_with(
mock_simulator_class.return_value,
mock_task_manager.return_value,
mock_device_settings.return_value,
)
@mock.patch.object(task_manager_lib, 'TaskManager', autospec=True)
@mock.patch.object(fake_simulator, 'FakeSimulator', autospec=True)
@mock.patch.object(device_settings_lib, 'DeviceSettings', autospec=True)
@mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
@mock.patch.object(builtins, 'open', autospec=True)
def test_load_fake_simulator(
self,
mock_open,
mock_coordinator,
mock_device_settings,
mock_simulator_class,
mock_task_manager,
):
# Arrange.
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.read.return_value = ''
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path='some/path/'),
simulator=config_classes.FakeSimulatorConfig(
screen_dimensions=(1234, 5678)
),
)
# Act.
env = loader.load(config)
# Assert.
self.assertIsInstance(env, env_interface.AndroidEnvInterface)
mock_simulator_class.assert_called_with(
config=config_classes.FakeSimulatorConfig(
screen_dimensions=(1234, 5678)
)
)
mock_coordinator.assert_called_with(
mock_simulator_class.return_value,
mock_task_manager.return_value,
mock_device_settings.return_value,
)
@mock.patch.object(task_manager_lib, 'TaskManager', autospec=True)
@mock.patch.object(emulator_simulator, 'EmulatorSimulator', autospec=True)
@mock.patch.object(coordinator_lib, 'Coordinator', autospec=True)
@mock.patch.object(builtins, 'open', autospec=True)
def test_task(
self, mock_open, mock_coordinator, mock_simulator, mock_task_manager
):
# Arrange.
del mock_coordinator, mock_simulator
mock_open.return_value.__enter__ = mock_open
mock_open.return_value.read.return_value = r'''
id: "fake_task"
name: "Fake Task"
description: "Task for testing loader."
max_episode_sec: 0
'''
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path='some/path/'),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
avd_name='my_avd'
),
adb_controller=config_classes.AdbControllerConfig(
adb_path='~/Android/Sdk/platform-tools/adb',
),
),
)
# Act.
env = loader.load(config)
# Assert.
expected_task = task_pb2.Task()
expected_task.id = 'fake_task'
expected_task.name = 'Fake Task'
expected_task.description = 'Task for testing loader.'
expected_task.max_episode_sec = 0
mock_task_manager.assert_called_with(expected_task)
self.assertIsInstance(env, env_interface.AndroidEnvInterface)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/proto/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/proto/a11y/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/proto/a11y/a11y.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
import "android_env/proto/a11y/android_accessibility_forest.proto";
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// A service to send Accessibility information to a remote server.
//
// The client is assumed to be running inside an Android device (e.g. emulator
// or real device) while the server is assumed to be running outside (e.g. in a
// Python process).
service A11yService {
// Sends a forest of Accessibility trees to a server.
rpc SendForest(AndroidAccessibilityForest) returns (ForestResponse) {}
// Sends an a11y event to a server.
rpc SendEvent(EventRequest) returns (EventResponse) {}
// Long-lived bidirection communication between the client and the server.
rpc Bidi(stream ClientToServer) returns (stream ServerToClient) {}
}
// TODO(b/334952387): Remove `ForestResponse`, `EventRequest` and
// `EventResponse` once bidi communication is in-place.
message ForestResponse {
// The error if anything.
string error = 1;
}
// An Accessibility event.
message EventRequest {
// A single event as a dictionary.
map event = 1;
}
message EventResponse {
// The error if anything.
string error = 1;
}
// The message sent from the Android device to the server running outside of the
// device.
message ClientToServer {
oneof payload {
EventRequest event = 1;
AndroidAccessibilityForest forest = 2;
}
}
// The message sent from the server running outside of the device to the Android
// device.
message ServerToClient {
// A request to obtain the Accessibility forest.
message GetA11yForest {}
oneof payload {
GetA11yForest get_forest = 1;
}
}
================================================
FILE: android_env/proto/a11y/android_accessibility_action.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// An Android Accessibility Action.
// Next index: 3
message AndroidAccessibilityAction {
// Required ID that uniquely identifies the action for this node.
// Can be one of the standard action IDs listed in the documentation.
// https://developer.android.com/reference/android/view/accessibility/AccessibilityNodeInfo.AccessibilityAction
int32 id = 1;
// Optional label describing what the action is.
string label = 2;
}
================================================
FILE: android_env/proto/a11y/android_accessibility_forest.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
import "android_env/proto/a11y/android_accessibility_window_info.proto";
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// A forest of Android accessibility trees. Each tree belongs to a single
// window. Next index: 2
message AndroidAccessibilityForest {
// All of the windows present on screen.
repeated AndroidAccessibilityWindowInfo windows = 1;
}
================================================
FILE: android_env/proto/a11y/android_accessibility_node_info.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
import "android_env/proto/a11y/android_accessibility_action.proto";
import "android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto";
import "android_env/proto/a11y/rect.proto";
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// An Android AccessibilityNodeInfo.
// Next index: 32
message AndroidAccessibilityNodeInfo {
// Unique monotonically-increasing ID.
int32 unique_id = 1;
// The bounds of this node within the device's screen.
ProtoRect bounds_in_screen = 2;
// The name of the View class that created this node.
string class_name = 3;
// The content description of the node.
string content_description = 4;
// The hint text of the node.
string hint_text = 5;
// The name of the package this node comes from.
string package_name = 6;
// The text of this node.
string text = 7;
// The start index of the text selection.
int64 text_selection_start = 8;
// The end index of the text selection.
int64 text_selection_end = 9;
// The view ID resource name of the node.
string view_id_resource_name = 10;
// The ID of the window this node belongs to.
int32 window_id = 11;
// If true, this node can be checked.
bool is_checkable = 12;
// If true, this node is currently checked.
bool is_checked = 13;
// If true, this node (probably) responds to being clicked.
bool is_clickable = 14;
// If true, this node's text can be edited by the user.
bool is_editable = 15;
// If true, this node is enabled (e.g., if it is a button).
bool is_enabled = 16;
// If true, this node can be focused (e.g., a text input).
bool is_focusable = 17;
// If true, this node is currently focused.
bool is_focused = 18;
// If true, this node (probably) responds to being long pressed.
bool is_long_clickable = 19;
// If true, this node is a password input.
bool is_password = 20;
// If true, this node can be scrolled.
bool is_scrollable = 21;
// If true, this node is currently selected.
bool is_selected = 22;
// If true, this node is (probably) visible to the user.
bool is_visible_to_user = 23;
// List of actions that can be performed on this node.
repeated AndroidAccessibilityAction actions = 24;
// Ordered list of child IDs (i.e., unique_id).
repeated int32 child_ids = 25 [packed = true];
// List of clickable spans present in the node's text or content description.
repeated AndroidAccessibilityNodeInfoClickableSpan clickable_spans = 26;
// The depth of this node in the accessibility tree.
int32 depth = 27;
// Unique ID of the node that this node is declaring itself to be labeled by.
int32 labeled_by_id = 28;
// Unique ID of the node that this is node is declaring itself to be a label
// for.
int32 label_for_id = 29;
// The drawing order for the node.
int32 drawing_order = 30;
// The tooltip text of the node.
string tooltip_text = 31;
}
================================================
FILE: android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// A single clickable span found in the accessibility node's text.
// Next index: 6
message AndroidAccessibilityNodeInfoClickableSpan {
// The source of the span (so the client can find the correct spannable string
// in the node).
// Next index: 3
enum SpanSource {
UNKNOWN_TYPE = 0; // Catch all type for forward compatibility.
TEXT = 1; // The span is from node#getText
CONTENT_DESCRIPTION = 2; // The span is from node#getContentDescription.
}
// The text of the span (a substring of the spannable string).
string text = 1;
// The URL attached to the span if specified.
string url = 2;
// The source of the span.
SpanSource source = 3;
// The index of the first character of the span in the spannable string.
// The end of the span would be a sum of span_start and text.length().
int32 start = 4;
// The unique_id from the corresponding AndroidAccessibilityNodeInfo.
int32 node_id = 5;
}
================================================
FILE: android_env/proto/a11y/android_accessibility_tree.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
import "android_env/proto/a11y/android_accessibility_node_info.proto";
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// A tree (actually a graph) of Android accessibility nodes.
// Next index: 3
message AndroidAccessibilityTree {
// All of the nodes in the graph. The root node is the node whose ID is 0.
repeated AndroidAccessibilityNodeInfo nodes = 1;
}
================================================
FILE: android_env/proto/a11y/android_accessibility_window_info.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
import "android_env/proto/a11y/android_accessibility_tree.proto";
import "android_env/proto/a11y/rect.proto";
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// An Android AccessibilityWindowInfo.
// Next index: 12
message AndroidAccessibilityWindowInfo {
// Type of the window.
// Next index: 8
enum WindowType {
// The window type is an unknown value.
UNKNOWN_TYPE = 0;
// A standard application window.
TYPE_APPLICATION = 1;
// An IME window (e.g. GBoard).
TYPE_INPUT_METHOD = 2;
// A system window (e.g., a notification).
TYPE_SYSTEM = 3;
// An accessibility overlay.
TYPE_ACCESSIBILITY_OVERLAY = 4;
// A system window used to divide the screen in split-screen mode. This type
// of window is present only in split-screen mode.
TYPE_SPLIT_SCREEN_DIVIDER = 5;
// Used to show the UI for window-based magnification.
TYPE_MAGNIFICATION_OVERLAY = 6;
// System window that has the function to control an associated window.
TYPE_WINDOW_CONTROL = 7;
}
// Bounds of this window in the device's screen.
ProtoRect bounds_in_screen = 1;
// A unique ID identifying the display in which this window is shown.
int32 display_id = 2;
// Unique ID as defined by the Android platform.
int32 id = 3;
// Z-index of the window. Windows with a greater z-index appear in front of
// those with a lesser z-index.
int32 layer = 4;
// The title of the window, if set.
string title = 5;
// The type of the window.
WindowType window_type = 6;
// If true, the window is currently accessibility-focused.
bool is_accessibility_focused = 7;
// If true, the window is currently active.
bool is_active = 8;
// If true, the window is currently focused.
bool is_focused = 9;
// If true, the window is in Picture in Picture mode.
bool is_in_picture_in_picture_mode = 10;
// The associated accessibility tree for this window.
AndroidAccessibilityTree tree = 11;
}
================================================
FILE: android_env/proto/a11y/rect.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
option java_multiple_files = true;
option java_package = "com.google.androidenv.accessibilityforwarder";
// Proto representation of Android Rect.
// https://developer.android.com/reference/android/graphics/Rect
// Next index: 5
message ProtoRect {
int32 left = 1;
int32 top = 2;
int32 right = 3;
int32 bottom = 4;
}
================================================
FILE: android_env/proto/adb.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
message AdbRequest {
// Installs an APK into the simulator.
message InstallApk {
// A location in the filesystem.
message Filesystem {
string path = 1;
}
// A byte sequence of a single APK file.
message Blob {
// The serialized file as bytes.
bytes contents = 1;
}
oneof location {
Filesystem filesystem = 2;
Blob blob = 6;
}
}
message StartActivity {
string full_activity = 1;
repeated string extra_args = 2;
// Whether to stop the current app before starting the activity.
// Notice that if this option is `true`, the activity probably needs the
// `android:launchMode="singleTop"` attribute in its `AndroidManifest.xml`,
// otherwise intents may not be received by `onNewIntent()`. Please see more
// info on `android:launchMode` at
// https://developer.android.com/guide/topics/manifest/activity-element.
bool force_stop = 3;
}
message SendBroadcast {
// Action to send during the broadcast event.
string action = 1;
// Specify the component name with package name prefix to create an explicit
// intent, such as com.example.app/.ExampleActivity (see -n specification at
// https://developer.android.com/tools/adb#IntentSpec).
string component = 2;
}
message UninstallPackage {
string package_name = 1;
}
message ForceStop {
string package_name = 1;
}
message Tap {
// NOTE: These are absolute coordinates in the range of the screen
// resolution. They are NOT floats in [0,1].
// Precondition: `x` and `y` must be non-negative.
int32 x = 1;
int32 y = 2;
}
message PressButton {
enum Button {
HOME = 0;
BACK = 1;
ENTER = 2;
}
Button button = 1;
}
// Pins the given activity to the screen.
// This essentially locks the user into a single app mode (aka "Kiosk mode").
message StartScreenPinning {
string full_activity = 1;
}
// Returns the full activity name that is currently opened to the user.
// If successful, a GetCurrentActivityResponse is returned.
message GetCurrentActivity {}
// Returns the orientation of the device.
message GetOrientationRequest {}
// Performs `adb push`.
// Please see https://developer.android.com/studio/command-line/adb#copyfiles.
//
// Notice that a source destination path for the file is not sent, but raw
// bytes in `content` instead. Obviously, the `content` can be set from a real
// file, but this is done to ensure Task definitions are as hermetic as
// possible, without depending on the environment from where they're run.
message Push {
// The contents of the file.
bytes content = 1;
// Destination path _inside_ Android. E.g. /sdcard/my_file.txt.
string path = 2;
}
// Performs `adb pull`.
// Please see https://developer.android.com/studio/command-line/adb#copyfiles.
//
// Notice that a local destination for the copied file is not sent, as raw
// bytes are returned instead (please see PullResponse). Obviously, these
// bytes can be written to disk by the caller of this command.
message Pull {
// Path _inside_ Android. E.g. /sdcard/my_file.txt.
string path = 1;
}
// Inserts text into the current text field (if any).
// Essentially `adb shell input text `.
message InputText {
string text = 1;
}
// Issues an `adb shell settings` command.
message SettingsRequest {
// Each request has an associated namespace.
enum Namespace {
UNKNOWN = 0;
SYSTEM = 1;
SECURE = 2;
GLOBAL = 3;
}
// Retrieves the current value for `key`.
message Get {
string key = 1;
}
// Changes the contents `key` to `value`.
message Put {
string key = 1;
string value = 2;
}
// Deletes the entry for `key`.
message Delete {
string key = 1;
}
// Resets the global/secure table for a package with the given mode.
message Reset {
enum Mode {
UNKNOWN = 0;
UNTRUSTED_DEFAULTS = 1;
UNTRUSTED_CLEAR = 2;
TRUSTED_DEFAULTS = 3;
}
string package_name = 1;
Mode mode = 2;
}
// Prints all defined keys in the given namespace.
message List {}
// The part of the system where this command will take place.
// NOTE: We avoid the identifier `namespace` because it's a keyword in C++.
Namespace name_space = 1;
// The subcommand to issue to `adb settings`.
// NOTE: We avoid the identifiers `delete` and `del` because they're
// keywords in C++ and Python respectively.
oneof verb {
Get get = 2;
Put put = 3;
Delete delete_key = 4;
Reset reset = 5;
List list = 6;
}
}
// Generic ADB command. Use this for commands that are not
// explicitly implemented.
// Calls `adb [args...]`.
message GenericRequest {
repeated string args = 1;
}
message PackageManagerRequest {
message List {
// Lists all features of the system.
message Features {}
// Lists all system libraries.
message Libraries {}
// Lists all packages; optionally only those whose name contains the text
// in `filter`.
message Packages {
string filter = 1;
// Extra options that control the output. Please see `pm help` for
// details.
repeated string options = 2;
}
oneof what {
Features features = 1;
Libraries libraries = 2;
Packages packages = 3;
}
}
// Deletes all data associated with a package.
message Clear {
// The package name to clear its cache.
string package_name = 1;
// Optional USER_ID.
string user_id = 2;
}
message Grant {
string package_name = 1;
// Possible values listed at
// https://developer.android.com/reference/android/Manifest.permission
// To query an app's required permissions, use the following adb command:
// > adb shell dumpsys package
// The output will contain things like
// android.permission.WRITE_SECURE_SETTINGS
repeated string permissions = 2;
}
// The subcommand to issue to `pm`.
oneof verb {
List list = 1;
Clear clear = 2;
Grant grant = 3;
}
}
// For executing `dumpsys` commands.
message DumpsysRequest {
enum PriorityLevel {
UNSET = 0;
NORMAL = 1;
HIGH = 2;
CRITICAL = 3;
}
// The service to dump. If empty, all services will be dumped.
string service = 1;
// Optional arguments to pass to the specific service dump.
repeated string args = 2;
// Lists services, does not dump them.
// This effectively disables dumping information about any particular
// service.
bool list_only = 3;
// Timeouts natively supported by `dumpsys`.
int32 timeout_sec = 4;
int32 timeout_ms = 5;
// Whether to dump the process ID instead of the usual dump.
bool pid = 6;
// Whether dumps will be in proto format. Only works for services that
// support dumping data in proto format.
bool proto = 7;
// Filters services based on specified priority.
PriorityLevel priority = 8;
// Excludes services from the dump.
repeated string skip_services = 9;
}
oneof command {
InstallApk install_apk = 1;
StartActivity start_activity = 2;
ForceStop force_stop = 3;
Tap tap = 6;
PressButton press_button = 7;
StartScreenPinning start_screen_pinning = 10;
UninstallPackage uninstall_package = 16;
GetCurrentActivity get_current_activity = 17;
GetOrientationRequest get_orientation = 24;
Push push = 18;
Pull pull = 19;
InputText input_text = 20;
SettingsRequest settings = 21;
GenericRequest generic = 22;
PackageManagerRequest package_manager = 23;
DumpsysRequest dumpsys = 26;
SendBroadcast send_broadcast = 25;
}
// Optional (soft) deadline in seconds for completing this command.
// Expected to be >0. If ==0 (the default), it's ignored.
// Notice that not all commands accept timeouts, but because it's such a
// common parameter, we include it here instead of in each separate command.
float timeout_sec = 100;
}
message AdbResponse {
enum Status {
// Reserved value for unset statuses.
UNDEFINED = 0;
// Returned when everything goes well.
OK = 1;
// Returned when handling unknown AdbRequest commands.
UNKNOWN_COMMAND = 2;
// Returned when an argument does not respect a precondition.
FAILED_PRECONDITION = 3;
// Returned when something internal did not work as expected.
INTERNAL_ERROR = 4;
// Returned when the adb command failed.
ADB_ERROR = 5;
// Returned when the adb command timed out.
TIMEOUT = 6;
}
Status status = 1;
// `error_message` is only populated in case of errors.
string error_message = 2;
// General stats that components may optionally report.
map stats = 3;
// Response for GetCurrentActivity requests.
message GetCurrentActivityResponse {
// The format of the output is `package/package.ActivityName', for example:
// "com.example.vokram/com.example.vokram.MainActivity"
string full_activity = 1;
}
// Response for GetOrientationRequests.
message GetOrientationResponse {
// Possible values are {0, 1, 2, 3} corresponding to {0, 90, 180, 270}
// degrees respectively.
// Please see https://developer.android.com/reference/android/view/Surface.
int32 orientation = 1;
}
// Response for StartActivity requests.
message StartActivityResponse {
// The activity that was actually started. On a failed request, this will be
// empty.
string full_activity = 1;
bytes output = 2;
}
// Response for PressButton requests.
message PressButtonResponse {
// The output, if any, by `adb` after sending a key press.
// This is intentionally left as `bytes` instead of `string` so that content
// other than `UTF-8` can be transmitted.
bytes output = 1;
}
// Response for Push requests.
message PushResponse {}
// Response for Pull requests.
message PullResponse {
// The contents of the file.
// This is intentionally left as `bytes` instead of `string` so that content
// other than `UTF-8` can be transmitted.
bytes content = 1;
}
// Response for InputText requests.
message InputTextResponse {}
// Response for SettingsRequests.
message SettingsResponse {
// The output, if any, of the `adb shell settings` command.
bytes output = 1;
}
// Response for GenericRequests.
message GenericResponse {
// The output, if any, of the generic adb command.
bytes output = 1;
}
// Response for PackageManagerRequests.
message PackageManagerResponse {
// The output, if any, of the `adb shell pm` command.
bytes output = 1;
message List {
// A list of items. The actual content depends on the request, but it
// could be things like features, libraries or package names.
repeated string items = 1;
}
oneof verb {
List list = 2;
}
}
// Response for DumpsysRequests.
message DumpsysResponse {
// The output, if any, of the `dumpsys` command.
bytes output = 1;
}
oneof payload {
GetCurrentActivityResponse get_current_activity = 10;
StartActivityResponse start_activity = 11;
PressButtonResponse press_button = 12;
PushResponse push = 13;
PullResponse pull = 14;
InputTextResponse input_text = 15;
SettingsResponse settings = 16;
GenericResponse generic = 17;
PackageManagerResponse package_manager = 18;
GetOrientationResponse get_orientation = 19;
DumpsysResponse dumpsys = 21;
}
}
================================================
FILE: android_env/proto/emulator_controller.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (C) 2018 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Note that if you add/remove methods in this file you must update
// the metrics sql as well ./android/scripts/gen-grpc-sql.py
//
// Please group deleted methods in a block including the date (MM/DD/YY)
// it was removed. This enables us to easily keep metrics around after removal
//
// list of deleted methods
// rpc iWasDeleted (03/12/12)
// ...
// LINT: LEGACY_NAMES
syntax = "proto3";
package android.emulation.control;
import "google/protobuf/empty.proto";
option java_multiple_files = true;
option java_package = "com.android.emulator.control";
option objc_class_prefix = "AEC";
// An EmulatorController service lets you control the emulator.
// Note that this is currently an experimental feature, and that the
// service definition might change without notice. Use at your own risk!
//
// We use the following rough conventions:
//
// streamXXX --> streams values XXX (usually for emulator lifetime). Values
// are updated as soon as they become available.
// getXXX --> gets a single value XXX
// setXXX --> sets a single value XXX, does not returning state, these
// usually have an observable lasting side effect.
// sendXXX --> send a single event XXX, possibly returning state information.
// android usually responds to these events.
service EmulatorController {
// Set the sensor data
rpc streamSensor(SensorValue) returns (stream SensorValue) {}
// Get the sensor data
rpc getSensor(SensorValue) returns (SensorValue) {}
// Stream the sensor data
rpc setSensor(SensorValue) returns (google.protobuf.Empty) {}
// Set the physical model, this is likely the one you are
// looking for when you wish to modify the device state.
rpc setPhysicalModel(PhysicalModelValue) returns (google.protobuf.Empty) {}
// Get the physical model
rpc getPhysicalModel(PhysicalModelValue) returns (PhysicalModelValue) {}
// Stream the physical model
rpc streamPhysicalModel(PhysicalModelValue)
returns (stream PhysicalModelValue) {}
// Atomically set the current primary clipboard data.
rpc setClipboard(ClipData) returns (google.protobuf.Empty) {}
// Atomically get the current primary clipboard data.
rpc getClipboard(google.protobuf.Empty) returns (ClipData) {}
// Streams the current data on the clipboard. This will immediately produce
// a result with the current state of the clipboard after which the stream
// will block and wait until a new clip event is available from the guest.
// Calling the setClipboard method above will not result in generating a clip
// event. It is possible to lose clipboard events if the clipboard updates
// very rapidly.
rpc streamClipboard(google.protobuf.Empty) returns (stream ClipData) {}
// Set the battery to the given state.
rpc setBattery(BatteryState) returns (google.protobuf.Empty) {}
// Get the battery to the given state.
rpc getBattery(google.protobuf.Empty) returns (BatteryState) {}
// Set the state of the gps, gps support will only work
// properly if:
//
// - no location ui is active. That is the emulator
// is launched in headless mode (-no-window) or the location
// ui is disabled (-no-location-ui).
// - the passiveUpdate is set to false. Setting this to false
// will disable/break the LocationUI.
//
// Keep in mind that android usually only samples the gps at 1 hz.
rpc setGps(GpsState) returns (google.protobuf.Empty) {}
// Gets the latest gps state as delivered by the setGps call, or location ui
// if active.
//
// Note: this is not necessarily the actual gps coordinate visible at the
// time, due to gps sample frequency (usually 1hz).
rpc getGps(google.protobuf.Empty) returns (GpsState) {}
// Simulate a touch event on the finger print sensor.
rpc sendFingerprint(Fingerprint) returns (google.protobuf.Empty) {}
// Send a keyboard event. Translating the event.
rpc sendKey(KeyboardEvent) returns (google.protobuf.Empty) {}
// Send touch events. Note that mouse events can be simulated by touch events.
rpc sendTouch(TouchEvent) returns (google.protobuf.Empty) {}
// Send mouse events.
rpc sendMouse(MouseEvent) returns (google.protobuf.Empty) {}
// Make a phone call.
rpc sendPhone(PhoneCall) returns (PhoneResponse) {}
// Sends an sms message to the emulator.
rpc sendSms(SmsMessage) returns (PhoneResponse) {}
// Retrieve the status of the emulator. This will contain general
// hardware information, and whether the device has booted or not.
rpc getStatus(google.protobuf.Empty) returns (EmulatorStatus) {}
// Gets an individual screenshot in the desired format.
//
// The image will be scaled to the desired ImageFormat, while maintaining
// the aspect ratio. The returned image will never exceed the provided width
// and height. Not setting the width or height (i.e. they are 0) will result
// in using the device width and height.
//
// The resulting image will be properly oriented and can be displayed
// directly without post processing. For example, if the device has a
// 1080x1920 screen and is in landscape mode and called with no width or
// height parameter, it will return an 1920x1080 image.
//
// This method will return an empty image if the display is not visible.
rpc getScreenshot(ImageFormat) returns (Image) {}
// Streams a series of screenshots in the desired format.
// A new frame will be delivered whenever the device produces a new frame.
// (Beware that this can produce a significant amount of data, and that
// certain translations are (png transform) can be costly).
//
// If the requested display is not visible it will send a single empty image
// and wait start producing images once the display becomes active, again
// producing a single empty image when the display becomes inactive.
rpc streamScreenshot(ImageFormat) returns (stream Image) {}
// Streams a series of audio packets in the desired format.
// A new frame will be delivered whenever the emulated device
// produces a new audio frame. You can expect packets to be
// delivered in intervals of 20-30ms.
//
// Be aware that this can block when the emulator does not
// produce any audio whatsoever!
rpc streamAudio(AudioFormat) returns (stream AudioPacket) {}
// Injects a series of audio packets to the android microphone.
// A new frame will be delivered whenever the emulated device
// requests a new audio frame. Audio is usually delivered at a rate
// that the emulator is requesting frames. Audio will be stored in a
// temporary buffer that can hold 500ms of audio.
//
// Note: Currently the emulator will downsample to 16khz.
//
// - INVALID_ARGUMENT (code 3) The sampling rate was too high
// - INVALID_ARGUMENT (code 3) The audio packet was too large to handle.
// - FAILED_PRECONDITION (code 9) If there was a microphone registered
// already.
rpc injectAudio(stream AudioPacket) returns (google.protobuf.Empty) {}
// Returns the last 128Kb of logcat output from the emulator
// Note that parsed logcat messages are only available after L (Api >23).
// it is possible that the logcat buffer gets overwritten, or falls behind.
rpc getLogcat(LogMessage) returns (LogMessage) {}
// Streams the logcat output from the emulator. The first call
// can retrieve up to 128Kb. This call will not return.
// Note that parsed logcat messages are only available after L (Api >23)
// it is possible that the logcat buffer gets overwritten, or falls behind.
rpc streamLogcat(LogMessage) returns (stream LogMessage) {}
// Transition the virtual machine to the desired state. Note that
// some states are only observable. For example you cannot transition
// to the error state.
rpc setVmState(VmRunState) returns (google.protobuf.Empty) {}
// Gets the state of the virtual machine.
rpc getVmState(google.protobuf.Empty) returns (VmRunState) {}
// Atomically changes the current multi-display configuration.
// After this call the given display configurations will be activated. You
// can only update secondary displays. Displays with id 0 will be ignored.
//
// This call can result in the removal or addition of secondary displays, the
// final display state can be observed by the returned configuration.
//
// The following gRPC error codes can be returned:
// - FAILED_PRECONDITION (code 9) if the AVD does not support a configurable
// secondary display.
// - INVALID_ARGUMENT (code 3) if:
// - The same display id is defined multiple times.
// - The display configurations are outside valid ranges
// (see DisplayConfiguration)
// - INTERNAL (code 13) if there was an internal emulator failure.
rpc setDisplayConfigurations(DisplayConfigurations)
returns (DisplayConfigurations) {}
// Returns all currently valid logical displays.
// The gRPC error code FAILED_PRECONDITION (code 9) is returned if the AVD
// does not support a configurable secondary display.
rpc getDisplayConfigurations(google.protobuf.Empty)
returns (DisplayConfigurations) {}
// Notifies client of the following changes:
//
// - Virtual scene camera status change.
// - Display configuration changes from extended ui. This will only be fired
// if the user makes modifications the extended displays through the
// extended control tab.
//
// Note that this method will send the initial virtual scene state
// immediately.
rpc streamNotification(google.protobuf.Empty) returns (stream Notification) {}
// RotationRadian is relative to the camera's current orientation.
rpc rotateVirtualSceneCamera(RotationRadian) returns (google.protobuf.Empty) {
}
// Velocity is absolute
rpc setVirtualSceneCameraVelocity(Velocity) returns (google.protobuf.Empty) {}
// Set foldable posture
rpc setPosture(Posture) returns (google.protobuf.Empty) {}
}
// A Run State that describes the state of the Virtual Machine.
message VmRunState {
enum RunState {
// The emulator is in an unknown state. You cannot transition to this state.
UNKNOWN = 0;
// Guest is actively running. You can transition to this state from the
// paused state.
RUNNING = 1;
// Guest is paused to load a snapshot. You cannot transition to this state.
RESTORE_VM = 2;
// Guest has been paused. Transitioning to this state will pause the
// emulator the guest will not be consuming any cpu cycles.
PAUSED = 3;
// Guest is paused to take or export a snapshot. You cannot
// transition to this state.
SAVE_VM = 4;
// System shutdown, note that it is similar to power off. It tries to set
// the system status and notify guest. The system is likely going to
// disappear soon and do proper cleanup of resources, possibly taking
// a snapshot. This is the same behavior as closing the emulator by clicking
// the X (close) in the user interface.
SHUTDOWN = 5;
// Immediately terminate the emulator. No resource cleanup will take place.
// There is a good change to corrupt the system.
TERMINATE = 7;
// Will cause the emulator to reset. This is not a state you can observe.
RESET = 9;
// Guest experienced some error state, you cannot transition to this state.
INTERNAL_ERROR = 10;
}
RunState state = 1;
}
message ParameterValue {
repeated float data = 1 [packed = true];
}
message PhysicalModelValue {
enum State {
OK = 0;
NO_SERVICE = -3; // qemud service is not available/initiated.
DISABLED = -2; // Sensor is disabled.
UNKNOWN = -1; // Unknown sensor (should not happen)
}
// Details on the sensors documentation can be found here:
// https://developer.android.com/reference/android/hardware/Sensor.html#TYPE_
// The types must follow the order defined in
// "external/qemu/android/hw-sensors.h"
enum PhysicalType {
POSITION = 0;
// All values are angles in degrees.
// values = [x,y,z]
ROTATION = 1;
MAGNETIC_FIELD = 2;
// Temperature in °C
TEMPERATURE = 3;
// Proximity sensor distance measured in centimeters
PROXIMITY = 4;
// Ambient light level in SI lux units
LIGHT = 5;
// Atmospheric pressure in hPa (millibar)
PRESSURE = 6;
// Relative ambient air humidity in percent
HUMIDITY = 7;
VELOCITY = 8;
AMBIENT_MOTION = 9;
// Describing a hinge angle sensor in degrees.
HINGE_ANGLE0 = 10;
HINGE_ANGLE1 = 11;
HINGE_ANGLE2 = 12;
ROLLABLE0 = 13;
ROLLABLE1 = 14;
ROLLABLE2 = 15;
}
PhysicalType target = 1;
// [Output Only]
State status = 2;
// Value interpretation depends on sensor, will contain at most 3 values.
ParameterValue value = 3;
}
// A single sensor value.
message SensorValue {
enum State {
OK = 0;
NO_SERVICE = -3; // qemud service is not available/initiated.
DISABLED = -2; // Sensor is disabled.
UNKNOWN = -1; // Unknown sensor (should not happen)
}
// These are the various sensors that can be available in an emulated
// devices.
enum SensorType {
// Measures the acceleration force in m/s2 that is applied to a device
// on all three physical axes (x, y, and z), including the force of
// gravity.
ACCELERATION = 0;
// Measures a device's rate of rotation in rad/s around each of the
// three physical axes (x, y, and z).
GYROSCOPE = 1;
// Measures the ambient geomagnetic field for all three physical axes
// (x, y, z) in μT.
MAGNETIC_FIELD = 2;
// Measures degrees of rotation that a device makes around all three
// physical axes (x, y, z)
ORIENTATION = 3;
// Measures the temperature of the device in degrees Celsius (°C).
TEMPERATURE = 4;
// Measures the proximity of an object in cm relative to the view screen
// of a device. This sensor is typically used to determine whether a
// handset is being held up to a person's ear.
PROXIMITY = 5;
// Measures the ambient light level (illumination) in lx.
LIGHT = 6;
// Measures the ambient air pressure in hPa or mbar.
PRESSURE = 7;
// Measures the relative ambient humidity in percent (%).
HUMIDITY = 8;
MAGNETIC_FIELD_UNCALIBRATED = 9;
GYROSCOPE_UNCALIBRATED = 10;
}
// Type of sensor
SensorType target = 1;
// [Output Only]
State status = 2;
// Value interpretation depends on sensor enum, will contain at most 3
// values.
ParameterValue value = 3;
}
message LogMessage {
// [Output Only] The contents of the log output.
string contents = 1;
// The starting byte position of the output that was returned. This
// should match the start parameter sent with the request. If the serial
// console output exceeds the size of the buffer, older output will be
// overwritten by newer content and the start values will be mismatched.
int64 start = 2;
//[Output Only] The position of the next byte of content from the serial
// console output. Use this value in the next request as the start
// parameter.
int64 next = 3;
// Set the sort of response you are interested it in.
// It the type is "Parsed" the entries field will contain the parsed
// results. otherwise the contents field will be set.
LogType sort = 4;
// [Output Only] The parsed logcat entries so far. Only set if sort is
// set to Parsed
repeated LogcatEntry entries = 5;
enum LogType {
Text = 0;
Parsed = 1;
}
}
// A parsed logcat entry.
message LogcatEntry {
// The possible log levels.
enum LogLevel {
UNKNOWN = 0;
DEFAULT = 1;
VERBOSE = 2;
DEBUG = 3;
INFO = 4;
WARN = 5;
ERR = 6;
FATAL = 7;
SILENT = 8;
}
// A Unix timestamps in milliseconds (The number of milliseconds that
// have elapsed since January 1, 1970 (midnight UTC/GMT), not counting
// leap seconds)
uint64 timestamp = 1;
// Process id.
uint32 pid = 2;
// Thread id.
uint32 tid = 3;
LogLevel level = 4;
string tag = 5;
string msg = 6;
}
// Information about the hypervisor that is currently in use.
message VmConfiguration {
enum VmHypervisorType {
// An unknown hypervisor
UNKNOWN = 0;
// No hypervisor is in use. This usually means that the guest is
// running on a different CPU than the host, or you are using a
// platform where no hypervisor is available.
NONE = 1;
// The Kernel based Virtual Machine
// (https://www.linux-kvm.org/page/Main_Page)
KVM = 2;
// Intel® Hardware Accelerated Execution Manager (Intel® HAXM)
// https://github.com/intel/haxm
HAXM = 3;
// Hypervisor Framework.
// https://developer.apple.com/documentation/hypervisor
HVF = 4;
// Window Hypervisor Platform
// https://docs.microsoft.com/en-us/virtualization/api/
WHPX = 5;
GVM = 6;
}
VmHypervisorType hypervisorType = 1;
int32 numberOfCpuCores = 2;
int64 ramSizeBytes = 3;
}
// Representation of a clipped data object on the clipboard.
message ClipData {
// UTF-8 Encoded text.
string text = 1;
}
// The Touch interface represents a single contact point on a
// touch-sensitive device. The contact point is commonly a finger or stylus
// and the device may be a touchscreen or trackpad.
message Touch {
// The horizontal coordinate. This is the physical location on the
// screen For example 0 indicates the leftmost coordinate.
int32 x = 1;
// The vertical coordinate. This is the physical location on the screen
// For example 0 indicates the top left coordinate.
int32 y = 2;
// The identifier is an arbitrary non-negative integer that is used to
// identify and track each tool independently when multiple tools are
// active. For example, when multiple fingers are touching the device,
// each finger should be assigned a distinct tracking id that is used as
// long as the finger remains in contact. Tracking ids may be reused
// when their associated tools move out of range.
//
// The emulator currently supports up to 10 concurrent touch events. The
// identifier can be any uninque value and will be mapped to the next
// available internal identifier.
int32 identifier = 3;
// Reports the physical pressure applied to the tip of the tool or the
// signal strength of the touch contact.
//
// The values reported must be non-zero when the tool is touching the
// device and zero otherwise to indicate that the touch event is
// completed.
//
// Make sure to deliver a pressure of 0 for the given identifier when
// the touch event is completed, otherwise the touch identifier will not
// be unregistered!
int32 pressure = 4;
// Optionally reports the cross-sectional area of the touch contact, or
// the length of the longer dimension of the touch contact.
int32 touch_major = 5;
// Optionally reports the length of the shorter dimension of the touch
// contact. This axis will be ignored if touch_major is reporting an
// area measurement greater than 0.
int32 touch_minor = 6;
enum EventExpiration {
// The system will use the default time of 120s to track
// the touch event with the given identifier. If no update happens
// within this timeframe the identifier is considered expired
// and can be made available for re-use. This means that a touch event
// with pressure 0 for this identifier will be send to the emulator.
EVENT_EXPIRATION_UNSPECIFIED = 0;
// Never expire the given slot. You must *ALWAYS* close the identifier
// by sending a touch event with 0 pressure.
NEVER_EXPIRE = 1;
}
EventExpiration expiration = 7;
}
// A TouchEvent contains a list of Touch objects that are in contact with
// the touch surface.
//
// Touch events are delivered in sequence as specified in the touchList.
//
// TouchEvents are delivered to the emulated devices using ["Protocol
// B"](https://www.kernel.org/doc/Documentation/input/multi-touch-protocol.txt)
message TouchEvent {
// The list of Touch objects, note that these do not need to be unique
repeated Touch touches = 1;
// The display device where the touch event occurred.
// Omitting or using the value 0 indicates the main display.
//
// Touch events cannot be send to displays other than 0, due to
// https://issuetracker.google.com/issues/150699691
int32 display = 2;
}
// The MouseEvent interface represents events that occur due to the user
// interacting with a pointing device (such as a mouse).
message MouseEvent {
// The horizontal coordinate. This is the physical location on the
// screen For example 0 indicates the leftmost coordinate.
int32 x = 1;
// The vertical coordinate. This is the physical location on the screen
// For example 0 indicates the top left coordinate.
int32 y = 2;
// Indicates which buttons are pressed.
// 0: No button was pressed
// 1: Primary button (left)
// 2: Secondary button (right)
int32 buttons = 3;
// The display device where the mouse event occurred.
// Omitting or using the value 0 indicates the main display.
int32 display = 4;
}
// KeyboardEvent objects describe a user interaction with the keyboard; each
// event describes a single interaction between the user and a key (or
// combination of a key with modifier keys) on the keyboard.
// This follows the pattern as set by
// (javascript)[https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent]
//
// Note: that only keyCode, key, or text can be set and that the semantics
// will slightly vary.
message KeyboardEvent {
// Code types that the emulator can receive. Note that the emulator
// will do its best to translate the code to an evdev value that
// will be send to the emulator. This translation is based on
// the chromium translation tables. See
// (this)[https://android.googlesource.com/platform/external/qemu/+/refs/heads/emu-master-dev/android/android-grpc/android/emulation/control/keyboard/keycode_converter_data.inc]
// for details on the translation.
enum KeyCodeType {
Usb = 0;
Evdev = 1;
XKB = 2;
Win = 3;
Mac = 4;
}
enum KeyEventType {
// Indicates that this keyevent should be send to the emulator
// as a key down event. Meaning that the key event will be
// translated to an EvDev event type and bit 11 (0x400) will be
// set before it is sent to the emulator.
keydown = 0;
// Indicates that the keyevent should be send to the emulator
// as a key up event. Meaning that the key event will be
// translated to an EvDev event type and
// sent to the emulator.
keyup = 1;
// Indicates that the keyevent will be send to the emulator
// as e key down event and immediately followed by a keyup event.
keypress = 2;
}
// Type of keycode contained in the keyCode field.
KeyCodeType codeType = 1;
// The type of keyboard event that should be sent to the emulator
KeyEventType eventType = 2;
// This property represents a physical key on the keyboard (as opposed
// to the character generated by pressing the key). In other words, this
// property is a value which isn't altered by keyboard layout or the
// state of the modifier keys. This value will be interpreted by the
// emulator depending on the KeyCodeType. The incoming key code will be
// translated to an evdev code type and send to the emulator.
// The values in key and text will be ignored.
int32 keyCode = 3;
// The value of the key pressed by the user, taking into consideration
// the state of modifier keys such as Shift as well as the keyboard
// locale and layout. This follows the w3c standard used in browsers.
// You can find an accurate description of valid values
// [here](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values)
//
// Note that some keys can result in multiple evdev events that are
// delivered to the emulator. for example the Key "A" will result in a
// sequence:
// ["Shift", "a"] -> [0x2a, 0x1e] whereas "a" results in ["a"] -> [0x1e].
//
// Not all documented keys are understood by android, and only printable
// ASCII [32-127) characters are properly translated.
//
// Keep in mind that there are a set of key values that result in android
// specific behavior
// [see](https://developer.mozilla.org/en-US/docs/Web/API/KeyboardEvent/key/Key_Values#Phone_keys):
//
// - "AppSwitch": Behaves as the "Overview" button in android.
// - "GoBack": The Back button.
// - "GoHome": The Home button, which takes the user to the phone's main
// screen (usually an application launcher).
// - "Power": The Power button.
string key = 4;
// Series of utf8 encoded characters to send to the emulator. An attempt
// will be made to translate every character will an EvDev event type and
// send to the emulator as a keypress event. The values in keyCode,
// eventType, codeType and key will be ignored.
//
// Note that most printable ASCII characters (range [32-127) can be send
// individually with the "key" param. Do not expect arbitrary UTF symbols to
// arrive in the emulator (most will be ignored).
//
// Note that it is possible to overrun the keyboard buffer by slamming this
// endpoint with large quantities of text (>1kb). The clipboard api is better
// suited for transferring large quantities of text.
string text = 5;
}
message Fingerprint {
// True when the fingprint is touched.
bool isTouching = 1;
// The identifier of the registered fingerprint.
int32 touchId = 2;
}
message GpsState {
// Setting this to false will disable auto updating from the LocationUI,
// otherwise the location UI will override the location at a frequency of 1hz.
//
// - This is unused if the emulator is launched with -no-window, or when he
// location ui is disabled.
// - This will BREAK the location ui experience if it is set to false. For
// example routing will no longer function.
bool passiveUpdate = 1;
// The latitude, in degrees.
double latitude = 2;
// The longitude, in degrees.
double longitude = 3;
// The speed if it is available, in meters/second over ground
double speed = 4;
// gets the horizontal direction of travel of this device, and is not
// related to the device orientation. It is guaranteed to be in the
// range [0.0, 360.0] if the device has a bearing. 0=North, 90=East,
// 180=South, etc..
double bearing = 5;
// The altitude if available, in meters above the WGS 84 reference
// ellipsoid.
double altitude = 6;
// The number of satellites used to derive the fix
int32 satellites = 7;
}
message BatteryState {
enum BatteryStatus {
UNKNOWN = 0;
CHARGING = 1;
DISCHARGING = 2;
NOT_CHARGING = 3;
FULL = 4;
}
enum BatteryCharger {
NONE = 0;
AC = 1;
USB = 2;
WIRELESS = 3;
}
enum BatteryHealth {
GOOD = 0;
FAILED = 1;
DEAD = 2;
OVERVOLTAGE = 3;
OVERHEATED = 4;
}
bool hasBattery = 1;
bool isPresent = 2;
BatteryCharger charger = 3;
int32 chargeLevel = 4;
BatteryHealth health = 5;
BatteryStatus status = 6;
}
// An ImageTransport allows for specifying a side channel for
// delivering image frames versus using the standard bytes array that is
// returned with the gRPC request.
message ImageTransport {
enum TransportChannel {
// Return full frames over the gRPC transport
TRANSPORT_CHANNEL_UNSPECIFIED = 0;
// Write images to the a file/shared memory handle.
MMAP = 1;
}
// The desired transport channel used for delivering image frames. Only
// relevant when streaming screenshots.
TransportChannel channel = 1;
// Handle used for writing image frames if transport is mmap. The client sets
// and owns this handle. It can be either a shm region, or a mmap. A mmap
// should be a url that starts with `file:///`
// Note: the mmap can result in tearing.
string handle = 2;
}
// The aspect ratio (width/height) will be different from the one
// where the device is unfolded.
message FoldedDisplay {
uint32 width = 1;
uint32 height = 2;
// It is possible for the screen to be folded in different ways depending
// on which surface is shown to the user. So xOffset and yOffset indicate
// the top left corner of the folded screen within the original unfolded
// screen.
uint32 xOffset = 3;
uint32 yOffset = 4;
}
message ImageFormat {
enum ImgFormat {
// Portable Network Graphics format
// (https://en.wikipedia.org/wiki/Portable_Network_Graphics)
PNG = 0;
// Three-channel RGB color model supplemented with a fourth alpha
// channel. https://en.wikipedia.org/wiki/RGBA_color_model
// Each pixel consists of 4 bytes.
RGBA8888 = 1;
// Three-channel RGB color model, each pixel consists of 3 bytes
RGB888 = 2;
}
// The (desired) format of the resulting bytes.
ImgFormat format = 1;
// [Output Only] The rotation of the image. The image will be rotated
// based upon the coarse grained orientation of the device.
Rotation rotation = 2;
// The (desired) width of the image. When passed as input
// the image will be scaled to match the given
// width, while maintaining the aspect ratio of the device.
// The returned image will never exceed the given width, but can be less.
// Omitting this value (or passing in 0) will result in no scaling,
// and the width of the actual device will be used.
uint32 width = 3;
// The (desired) height of the image. When passed as input
// the image will be scaled to match the given
// height, while maintaining the aspect ratio of the device.
// The returned image will never exceed the given height, but can be less.
// Omitting this value (or passing in 0) will result in no scaling,
// and the height of the actual device will be used.
uint32 height = 4;
// The (desired) display id of the device. Setting this to 0 (or omitting)
// indicates the main display.
uint32 display = 5;
// Set this if you wish to use a different transport channel to deliver image
// frames.
ImageTransport transport = 6;
// [Output Only] Display configuration when screen is folded. The value is the
// original configuration before scaling.
FoldedDisplay foldedDisplay = 7;
}
message Image {
ImageFormat format = 1;
uint32 width = 2 [deprecated = true]; // width is contained in format.
uint32 height = 3 [deprecated = true]; // height is contained in format.
// The organization of the pixels in the image buffer is from left to
// right and bottom up. This will be empty if an alternative image transport
// is requested in the image format. In that case the side channel should
// be used to obtain the image data.
bytes image = 4;
// [Output Only] Monotonically increasing sequence number in a stream of
// screenshots. The first screenshot will have a sequence of 0. A single
// screenshot will always have a sequence number of 0. The sequence is not
// necessarily contiguous, and can be used to detect how many frames were
// dropped. An example sequence could be: [0, 3, 5, 7, 9, 11].
uint32 seq = 5;
// [Output Only] Unix timestamp in microseconds when the emulator estimates
// the frame was generated. The timestamp is before the actual frame is
// copied and transformed. This can be used to calculate variance between
// frame production time, and frame depiction time.
uint64 timestampUs = 6;
}
message Rotation {
enum SkinRotation {
PORTRAIT = 0; // 0 degrees
LANDSCAPE = 1; // 90 degrees
REVERSE_PORTRAIT = 2; // -180 degrees
REVERSE_LANDSCAPE = 3; // -90 degrees
}
// The rotation of the device, derived from the sensor state
// of the emulator. The derivation reflects how android observes
// the rotation state.
SkinRotation rotation = 1;
// Specifies the angle of rotation, in degrees [-180, 180]
double xAxis = 2;
double yAxis = 3;
double zAxis = 4;
}
message PhoneCall {
enum Operation {
InitCall = 0;
AcceptCall = 1;
RejectCallExplicit = 2;
RejectCallBusy = 3;
DisconnectCall = 4;
PlaceCallOnHold = 5;
TakeCallOffHold = 6;
}
Operation operation = 1;
string number = 2;
}
message PhoneResponse {
enum Response {
OK = 0;
BadOperation = 1; // Enum out of range
BadNumber = 2; // Mal-formed telephone number
InvalidAction = 3; // E.g., disconnect when no call is in progress
ActionFailed = 4; // Internal error
RadioOff = 5; // Radio power off
}
Response response = 1;
}
message Entry {
string key = 1;
string value = 2;
}
message EntryList {
repeated Entry entry = 1;
}
message EmulatorStatus {
// The emulator version string.
string version = 1;
// The time the emulator has been active in .ms
uint64 uptime = 2;
// True if the device has completed booting.
// For P and later this information will accurate,
// for older images we rely on adb.
bool booted = 3;
// The current vm configuration
VmConfiguration vmConfig = 4;
// The hardware configuration of the running emulator as
// key valure pairs.
EntryList hardwareConfig = 5;
}
message AudioFormat {
enum SampleFormat {
AUD_FMT_U8 = 0; // Unsigned 8 bit
AUD_FMT_S16 = 1; // Signed 16 bit (little endian)
}
enum Channels {
Mono = 0;
Stereo = 1;
}
// Sampling rate to use, defaulting to 44100 if this is not set.
// Note, that android devices typically will not use a sampling
// rate higher than 48kHz. See https://developer.android.com/ndk/guides/audio.
uint64 samplingRate = 1;
Channels channels = 2;
SampleFormat format = 3;
}
message AudioPacket {
AudioFormat format = 1;
// Unix epoch in us when this frame was captured.
uint64 timestamp = 2;
// Contains a sample in the given audio format.
bytes audio = 3;
}
message SmsMessage {
// The source address where this message came from.
//
// The address should be a valid GSM-formatted address as specified by
// 3GPP 23.040 Sec 9.1.2.5.
//
// For example: +3106225412 or (650) 555-1221
string srcAddress = 1;
// A utf8 encoded text message that should be delivered.
string text = 2;
}
// A DisplayConfiguration describes a primary or secondary
// display available to the emulator. The screen aspect ratio
// cannot be longer (or wider) than 21:9 (or 9:21). Screen sizes
// larger than 4k will be rejected.
//
// Common configurations (w x h) are:
// - 480p (480x720) 142 dpi
// - 720p (720x1280) 213 dpi
// - 1080p (1080x1920) 320 dpi
// - 4K (2160x3840) 320 dpi
// - 4K (2160x3840) 640 dpi (upscaled)
//
// The behavior of the virtual display depends on the flags that are provided to
// this method. By default, virtual displays are created to be private,
// non-presentation and unsecure.
message DisplayConfiguration {
// These are the set of known android flags and their respective values.
// you can combine the int values to (de)construct the flags field below.
enum DisplayFlags {
DISPLAYFLAGS_UNSPECIFIED = 0;
// When this flag is set, the virtual display is public.
// A public virtual display behaves just like most any other display
// that is connected to the system such as an external or wireless
// display. Applications can open windows on the display and the system
// may mirror the contents of other displays onto it. see:
// https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PUBLIC
VIRTUAL_DISPLAY_FLAG_PUBLIC = 1;
// When this flag is set, the virtual display is registered as a
// presentation display in the presentation display category.
// Applications may automatically project their content to presentation
// displays to provide richer second screen experiences.
// https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_PRESENTATION
VIRTUAL_DISPLAY_FLAG_PRESENTATION = 2;
// When this flag is set, the virtual display is considered secure as
// defined by the Display#FLAG_SECURE display flag. The caller promises
// to take reasonable measures, such as over-the-air encryption, to
// prevent the contents of the display from being intercepted or
// recorded on a persistent medium.
// see:
// https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_SECURE
VIRTUAL_DISPLAY_FLAG_SECURE = 4;
// This flag is used in conjunction with VIRTUAL_DISPLAY_FLAG_PUBLIC.
// Ordinarily public virtual displays will automatically mirror the
// content of the default display if they have no windows of their own.
// When this flag is specified, the virtual display will only ever show
// its own content and will be blanked instead if it has no windows. See
// https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY
VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY = 8;
// Allows content to be mirrored on private displays when no content is
// being shown.
// This flag is mutually exclusive with
// VIRTUAL_DISPLAY_FLAG_OWN_CONTENT_ONLY. If both flags are specified
// then the own-content only behavior will be applied.
// see:
// https://developer.android.com/reference/android/hardware/display/DisplayManager#VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR)
VIRTUAL_DISPLAY_FLAG_AUTO_MIRROR = 16;
}
// The width of the display, restricted to:
// 320 * (dpi / 160) <= width
uint32 width = 1;
// The heigh of the display, restricted to:
// * 320 * (dpi / 160) <= height
uint32 height = 2;
// The pixel density (dpi).
// See https://developer.android.com/training/multiscreen/screendensities
// for details. This value should be in the range [120, ..., 640]
uint32 dpi = 3;
// A combination of virtual display flags. These flags can be constructed
// by combining the DisplayFlags enum described above.
//
// The behavior of the virtual display depends on the flags. By default
// virtual displays are created to be private, non-presentation and
// unsecure.
uint32 flags = 4;
// The id of the display.
// The primary (default) display has the display ID of 0.
// A secondary display has a display ID not 0.
//
// The id can be used to get or stream a screenshot.
uint32 display = 5;
}
message DisplayConfigurations {
repeated DisplayConfiguration displays = 1;
}
message Notification {
enum EventType {
VIRTUAL_SCENE_CAMERA_INACTIVE = 0;
VIRTUAL_SCENE_CAMERA_ACTIVE = 1;
// Fired when an update to a display event has been fired through
// the extended ui. This does not fire events when the display
// is changed through the console or gRPC endpoint.
DISPLAY_CONFIGURATIONS_CHANGED_UI = 2;
// Keep adding more for other event types
}
EventType event = 1;
}
message RotationRadian {
float x = 1; // x axis is horizontal and orthogonal to the view direction.
float y = 2; // y axis points up and is perpendicular to the floor.
float z = 3; // z axis is the view direction and is set to 0.0 in
// rotateVirtualSceneCamera call.
}
message Velocity {
float x = 1; // x axis is horizontal and orthogonal to the view direction.
float y = 2; // y axis points up and is perpendicular to the floor.
float z = 3; // z axis is the view direction
}
// must follow the definition in "external/qemu/android/hw-sensors.h"
message Posture {
enum PostureValue {
POSTURE_UNKNOWN = 0;
POSTURE_CLOSED = 1;
POSTURE_HALF_OPENED = 2;
POSTURE_OPENED = 3;
POSTURE_FLIPPED = 4;
POSTURE_TENT = 5;
POSTURE_MAX = 6;
}
PostureValue value = 3;
}
================================================
FILE: android_env/proto/snapshot.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (C) 2018 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto2";
// This file must be synchronized between
// Emulator (branch aosp/emu-master-dev):
// external/qemu/android/android-emu/android/snapshot/proto/snapshot.proto
//
// Android Studio (branch goog/studio-master-dev):
// tools/adt/idea/android/src/com/android/emulator/snapshot.proto
//
// If you modify one, please modify the other.
package emulator_snapshot;
option java_package = "com.android.emulator.snapshot";
message Image {
enum Type {
IMAGE_TYPE_UNKNOWN = 0;
IMAGE_TYPE_KERNEL = 1;
IMAGE_TYPE_KERNEL_RANCHU = 2;
IMAGE_TYPE_SYSTEM = 3;
IMAGE_TYPE_SYSTEM_COPY = 4;
IMAGE_TYPE_DATA = 5;
IMAGE_TYPE_DATA_COPY = 6;
IMAGE_TYPE_RAMDISK = 7;
IMAGE_TYPE_SDCARD = 8;
IMAGE_TYPE_CACHE = 9;
IMAGE_TYPE_VENDOR = 10;
IMAGE_TYPE_ENCRYPTION_KEY = 11;
}
optional Type type = 1;
optional string path = 2;
optional bool present = 3;
optional int64 size = 4;
optional int64 modification_time = 5;
}
message Host {
optional string gpu_driver = 4;
optional int32 hypervisor = 5;
}
message Config {
// Features are int32, not enums here to make sure we don't have to update
// one more protobuf definition with every single new feature flag, even
// when the code doesn't really care about the actual meaning for them,
// only for the values.
repeated int32 enabled_features = 1;
// This holds the renderer; int32 for the same reason as |enabled_features|.
optional int32 selected_renderer = 2;
optional int32 cpu_core_count = 3;
optional int64 ram_size_bytes = 4;
}
message SaveStats {
// Type of save
// 0: non-incremental
// 1: incremental
optional uint32 incremental = 1;
// Time taken to save.
optional uint64 duration = 2;
// How many changed bytes in RAM.
optional uint64 ram_changed_bytes = 3;
}
message Snapshot {
// Update every time when introducing some breaking changes that make the
// previous loading code break when trying to load the new snapshot.
// NOTE: if the old code is fine with just skipping the new fields or not
// getting the meaning of new values, |version| should remain
// unchanged.
optional int32 version = 1;
// Purely informative: when this snapshot was created, Unix timestamp.
optional int64 creation_time = 2;
// list of mounted disk images used during the snapshot creation.
repeated Image images = 3;
// Description of the host machine properties needed to load this snapshot.
optional Host host = 4;
// Description of the emulator configuration needed for this snapshot.
// NOTE: try not to duplicate the configuration that's already in
// hardware-qemu.ini; only add what's either not there or what
// could've been overridden during process initialization.
optional Config config = 5;
// Set if the snapshot failed to load during the last attempt.
// Code is up to the application to define, with 0 meaning 'not failed' just
// in case.
optional int64 failed_to_load_reason_code = 7;
// Set if data image is mounted.
// User build and userdebug build mount data partition at different time.
// But it should be done before boot finished, so this field is very likely
// to be true.
// We snapshot it here just in case someday we support snapshot during
// booting.
optional bool guest_data_partition_mounted = 8;
// Emulator rotation angle, in right angles (e.g. 1 is 90 degrees, 2 is 180
// etc).
optional int32 rotation = 9;
// Number of invalid loads / crashes that happened under this snapshot.
optional int32 invalid_loads = 10;
// Number of successful loads.
optional int32 successful_loads = 11;
// The name given to the snapshot by the user. Independent of the
// file name.
optional string logical_name = 12;
// The file name of this snapshot's parent. The parent is the
// snapshot that was loaded into the AVD prior to this snapshot
// being taken
optional string parent = 13;
// Arbitrary description added by the user
optional string description = 14;
// Record of save stats.
repeated SaveStats save_stats = 15;
// Folded state.
optional bool folded = 16;
// Emulator boot parameters
repeated string launch_parameters = 17;
// Emulator build ID
optional string emulator_build_id = 18;
// System image build ID
optional string system_image_build_id = 19;
}
================================================
FILE: android_env/proto/snapshot_service.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright (C) 2018 The Android Open Source Project
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Note that if you add/remove methods in this file you must update
// the metrics sql as well by running ./android/scripts/gen-grpc-sql.py
//
// Please group deleted methods in a block including the date (MM/DD/YY)
// it was removed. This enables us to easily keep metrics around after removal
//
// list of deleted methods
// rpc iWasDeleted (03/12/12)
// ...
syntax = "proto3";
package android.emulation.control;
import "android_env/proto/snapshot.proto";
option java_multiple_files = true;
option java_package = "com.android.emulator.control";
option objc_class_prefix = "AEC";
// The SnapshotService enables you to list, insert, store, and retrieve
// snapshots.
//
// Currently there are two types of snapshots:
//
// - Local (default): These are snapshots that are created locally. They are
// stored internally inside qcow2 files and are very efficient. These are
// the snapshots usually created by interacting with the UI.
//
// - Remote: These are snapshots that have been exported at a certain point.
// an exported snapshot is normalized (completely self contained) and
// can be imported into an emulator with a similar hardware configuration.
//
// Currently the emulator has limited support for importing snapshots:
// - Once an imported snapshot has been loaded into an emulator it is no longer
// possible to create new snapshots.
// - The hardware configuration of the emulator your are pushing a snapshot to
// must match (or be very similar) to the one you pulled the snapshot from.
//
// For example do not expect to be able to restore a snapshot on created on an
// Intel cpu on an AMD cpu.
service SnapshotService {
// Lists all the snapshots, filtered by the given query, that are stored
// locally for the currently running avd. This includes all the snapshots that
// were imported (pushed) into this emulator.
//
// Returns a list of snapshot_id's and associated details that describes
// the hardware configuration, logical name, etc of the snapshot.
rpc ListSnapshots(SnapshotFilter) returns (SnapshotList) {}
// Pulls down the snapshot stored inside the AVD as a tar.gz/tar stream
// This will normalize the snapshot, all relevant data to push a snapshot
// into a similar emulator will be placed inside the tar file.
//
// Pulling down a snapshot will pause the emulator until the snapshots
// are rebased and ready for exporting. Once the snapshot is rebased
// the emulator will continue and downloading should commence.
//
// Note that pulling .gz stream is slow.
//
// You must provide the snapshot_id and (desired) format.
//
// If SnapshotPackage.path is set, the gRPC service will directly write the
// exported snapshot to SnapshotPackage.path without streaming, which is
// usually significantly faster. It would require emulator to have direct
// access to SnapshotPackage.path, which usually means it can only be used
// when pulling from a local emulator.
rpc PullSnapshot(SnapshotPackage) returns (stream SnapshotPackage) {}
// Push a tar.gz stream contain the snapshot. The tar file should
// be a snapshot that was exported through the PullSnapshot in the past.
// The emulator will try to import the snapshot. The hardware configuration
// of the current emulator should match the one used for pulling.
//
// A detailed description of the snapshot (emulator_snapshot.Snapshot)
// is stored in the snapshot.pb file inside the tar.
//
// You must provide the snapshot_id and format in the first message.
// Will return success and a possible error message when a failure occurs.
//
// If SnapshotPackage.path is set, the gRPC service will directly unzip the
// exported snapshot from SnapshotPackage.path without streaming, which is
// usually significantly faster. It would require emulator to have direct
// access to SnapshotPackage.path, which usually means it can only be used
// when pushing to a local emulator.
rpc PushSnapshot(stream SnapshotPackage) returns (SnapshotPackage) {}
// Loads the given snapshot inside the emulator and activates it.
// The device will be in the state as it was when the snapshot was created.
//
// You will no longer be able to call Save if this was an imported
// snapshot that was pushed into this emulator.
//
// You must provide the snapshot_id to indicate which snapshot to load
// Will return success and a possible error message when a failure occurs.
rpc LoadSnapshot(SnapshotPackage) returns (SnapshotPackage) {}
// Creates as a snapshot of the current state of the emulator.
// You can only save a snapshot if you never activated (Load) an imported
// snapshot (Push).
//
// For example:
// - PushSnapshot("some_snap.tar.gz");
// - LoadSnapshot("some_snap");
// - SaveSnapshot("same_newer_snap"); // <--- Will currently fail.
//
// You can provide the snapshot_id to indicate the name used for storing.
// Will return success and a possible error message when a failure occurs.
rpc SaveSnapshot(SnapshotPackage) returns (SnapshotPackage) {}
// Deletes the snapshot with the given snapshot_id from the avd.
//
// You must provide the snapshot_id to indicate which snapshot to delete.
// Will return success and a possible error message when a failure occurs.
rpc DeleteSnapshot(SnapshotPackage) returns (SnapshotPackage) {}
// Tracks the given process for automated snapshot creation in case of
// assert failures.
//
// Will return success and a possible error message when a failure occurs.
// The snapshot_id field will contain the name of the snapshot that
// will be created. The pid field will contain the process id that is
// being tracked.
rpc TrackProcess(IceboxTarget) returns (IceboxTarget) {}
}
// Sets options for SnapshotService. Used for both request and response
// messages.
message SnapshotPackage {
enum Format {
TARGZ = 0;
TAR = 1;
DIRECTORY = 2;
}
// The identifier to the snapshot, only required for request messages. For
// streaming service, only used in the first stream message of a gRPC call
// (would be ignored in consequent stream messages of the same call).
string snapshot_id = 1;
// A stream of bytes. Encoded as a tar (possibly gzipped) file pendinf on the
// value of format.
bytes payload = 2;
// [response only] status fields, usually indicates end of transmission.
bool success = 3;
bytes err = 4;
// [request only] Format of the payload. Only used in request messages. For
// streaming service, only used in the first stream message of a gRPC call
// (would be ignored in consequent stream messages of the same call).
Format format = 5;
// [request only] Path to the snapshot package file. Only used in request
// messages.
//
// When set in a request, the PullSnapshot/PushSnapshot operation will
// directly write/read the exported snapshot in path without streaming, which
// is usually significantly faster. It would require emulator to have direct
// access to path, which usually means it can only be used with a local
// emulator.
string path = 6;
}
// A snapshot filter can be used to filter the results produced by ListSnapshots
message SnapshotFilter {
enum LoadStatus {
// Only return compatible snapshots
CompatibleOnly = 0;
// Return all snapshots.
All = 1;
}
// Filter snapshots by load status.
LoadStatus statusFilter = 1;
}
// Provides detailed information regarding the snapshot.
message SnapshotDetails {
enum LoadStatus {
// The emulator believes that the snapshot is compatible with the emulator
// that provided this information. The emulator will attempt to load this
// snapshot when requested.
//
// A snapshot is usually compatible when the following statements are true:
// - The snapshot was taken by the current emulator version. i.e.
// emulator_build_id in the details field matches the build_id of the
// emulator that provided this information.
//
// - The snapshot was taken on the current running machine, and no hardware
// changes have taken place between taking and loading the snapshot.
//
// - The avd configuration has not changed between when this snapshot was
// taken and when the snapshot was loaded.
//
// - The system images on which the avd is based have not changed.
Compatible = 0;
// The emulator will not allow loading of the snapshot, as it deems the
// snapshot to be incompatible. Loading of snapshots can be forced by
// launching the emulator with the feature "AllowSnapshotMigration" enabled.
Incompatible = 1;
// This snapshot was successfully loaded in the emulator, and was used at
// the starting point of the current running emulator. The following holds:
//
// A loaded snapshot is a compatible snapshot
// There is at most one snapshot_id that is in the "Loaded" state
Loaded = 2;
}
// The id of this snapshot. Use this id to load/delete/pull the
// snapshot.
string snapshot_id = 1;
// Detailed information about this snapshot. This contains a detailed
// hardware description of the snapshot. These details are the same
// as the "snapshot.pb" file found in an exported snapshot.
// Look at the import file for a detailed description of the available
// fields.
emulator_snapshot.Snapshot details = 2;
// Provides information about the ability to restore this snapshot.
LoadStatus status = 3;
// The size of the folder that stores required information to load a snapshot.
uint64 size = 4;
}
// A list of on snapshot details.
message SnapshotList {
repeated SnapshotDetails snapshots = 1;
}
message IceboxTarget {
// This is the process id to attach to, if this value is not set (0)
// The process name will be used instead.
int64 pid = 1;
// The process name to attach to if any, if this is not set the pid will
// be used. This is usually the application name of your application under
// test, that is passed in to the am instrument command. It is likely
// what you will find in your AndroidManifest.xml
string package_name = 2;
// The name of the snapshot that icebox will create if a snapshot is
// generated.
string snapshot_id = 3;
// [Output Only] True if icebox failed to track the given target.
bool failed = 4;
// [Output Only] Detailed error message that might provide more information.
string err = 5;
// Maximum number of snapshots the emulator can take during one Icebox run.
// Set to -1 for unlimited number of snapshots.
int32 max_snapshot_number = 6;
}
// list of deleted methods:
//
================================================
FILE: android_env/proto/state.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
option java_multiple_files = true;
message SaveStateRequest {
map args = 1;
}
message LoadStateRequest {
map args = 1;
}
message SaveStateResponse {
enum Status {
// Reserved value for unset statuses.
UNDEFINED = 0;
// Returned when everything goes well.
OK = 1;
// Returned when something internal did not work as expected.
ERROR = 2;
}
Status status = 1;
// `error_message` is only populated in case of errors.
string error_message = 2;
// Any additional info returned during the request; e.g., file paths or sizes.
map additional_info = 3;
}
message LoadStateResponse {
enum Status {
// Reserved value for unset statuses.
UNDEFINED = 0;
// Returned when everything goes well.
OK = 1;
// Returned when there is no state to load.
NOT_FOUND = 2;
// Returned when something internal did not work as expected.
ERROR = 3;
}
Status status = 1;
// `error_message` is only populated in case of errors.
string error_message = 2;
// Any additional info returned during the request; e.g., file paths or sizes.
map additional_info = 3;
}
================================================
FILE: android_env/proto/task.proto
================================================
// Copyright 2026 DeepMind Technologies Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
syntax = "proto3";
package android_env;
import "android_env/proto/adb.proto";
// An AppScreen identifies a unique configuration that we can observe on the
// screen of a device.
message AppScreen {
// Fully-qualified name of the activity.
string activity = 1;
// A list of regexes to match at each level of the current view hierarchy.
// The environment uses this list to determine whether the agent has "exited"
// this current task.
// Example: [
// "^DecorView@.*\[MainActivity\]$",
// "^android.widget.LinearLayout\{.*\}$",
// "^android.widget.FrameLayout\{.*android\:id\/content\}",
// "^android.widget.RelativeLayout\{.*\}",
// "^android.widget.FrameLayout\{.*app\:id\/fragment_holder\}",
// "^android.widget.RelativeLayout\{.*\}",
// "^com.google.example.games.nostalgicracer.views.RaceView3D\{.*app\:id\/gameplay_screen_3d\}",
// ],
repeated string view_hierarchy_path = 2;
}
// Waits for `app_screen` to be the current app screen shown to the user.
message WaitForAppScreen {
AppScreen app_screen = 1;
// Maximum time in seconds to wait for the activity to become the current one.
float timeout_sec = 2;
}
message CheckInstall {
string package_name = 1;
// Maximum time in seconds to wait.
float timeout_sec = 2;
}
message Sleep {
float time_sec = 1;
}
message SuccessCondition {
int32 num_retries = 1;
oneof check {
WaitForAppScreen wait_for_app_screen = 2;
CheckInstall check_install = 3;
}
}
message SetupStep {
SuccessCondition success_condition = 1;
oneof step {
AdbRequest adb_request = 2;
Sleep sleep = 3;
}
}
// A specification of structured observations
// Analogous to dm_env.specs.Array()
message ArraySpec {
// An identifier for this ArraySpec.
string name = 1;
// The shape of the multi-dimensional values associated with this ArraySpec,
repeated int32 shape = 2;
enum DataType {
INVALID_DATA_TYPE = 0;
FLOAT = 1;
DOUBLE = 2;
INT8 = 3;
INT16 = 4;
INT32 = 5;
INT64 = 6;
UINT8 = 7;
UINT16 = 8;
UINT32 = 9;
UINT64 = 10;
BOOL = 11;
STRING_U1 = 12;
STRING_U16 = 13;
STRING_U25 = 14;
STRING_U250 = 15;
STRING = 16; // String without max length
OBJECT = 17;
}
// Data type of elements we expect to see in an array of this spec.
DataType dtype = 3;
}
message LogParsingConfig {
// `filters` are tags used by the app's logging system so that we can
// identify them in logcat's output. It's the first argument to logging calls
// such as Log.e("ActivityManager", "My message").
// Example: "ActivityManager"
repeated string filters = 1;
// Regular expressions that define how we can extract RL information such as
// score, extras and episode end from raw logcat messages.
message LogRegexps {
// Regexp expected to match:
// ...a floating point value which gets accumulated over time.
// A delta in 'score' corresponds to the reward.
string score = 1;
// Regexp expected to match:
// ...a floating point value directly forwarded by the environment.
repeated string reward = 2;
// Regexp expected to match:
// ...a signal marking the end of an episode.
repeated string episode_end = 3;
// Regexp expected to match:
// ...a string representing pairs of extra names and values.
repeated string extra = 4;
// Regexp expected to match:
// ...a dict of extra names and values in json format.
repeated string json_extra = 5;
// Attaches rewards to arbitrary log messages, for example:
// {event: "coin_collected" reward: 2.3}
// {event: "car_crashed" reward: -1.4}
message RewardEvent {
// If `event` is matched, the environment will give `reward`.
string event = 1;
// Numerical value to give as reward if `event` is matched.
float reward = 2;
}
repeated RewardEvent reward_event = 6;
}
LogRegexps log_regexps = 2;
}
// Description of a reinforcement learning task to be solved by an agent.
message Task {
// A globally unique identifier for this task.
string id = 1;
// A human readable name for this task.
string name = 2;
// A description of the task.
string description = 3;
repeated SetupStep setup_steps = 4;
repeated SetupStep reset_steps = 5;
AppScreen expected_app_screen = 6;
// AndroidEnv resets the episode after `max_episode_sec` is passed since the
// last reset(). Recommended for time sensitive tasks (e.g. reactive games).
// Note that this is real time as measured by AndroidEnv and is independent of
// the speed of simulation of Android.
// If <= 0.0, this logic is disabled.
float max_episode_sec = 7;
// The maximum number of interactions in a single episode between the
// environment and an agent.
// This setting is appropriate for tasks that are not time-dependent or when
// the performance of the simulation varies dramatically between runs.
// If <= 0, this logic is disabled.
int32 max_episode_steps = 8;
// Defines parameters for parsing messages from logcat.
LogParsingConfig log_parsing_config = 9;
// NOTE: This field is deprecated and will be removed from this Task
// definition soon.
//
// (Optional): The task may also define extras to help the RL agent.
// An Extra in AndroidEnv is any information that apps may send to aid the
// understanding of the task. The type of information sent through this
// channel is usually something difficult to obtain from raw pixels and may
// include things such as:
//
// - The current board configuration (e.g. of a chess game or a tetris game)
// - The position of the avatar in a map
// - Events (e.g. whether a button was pressed or a checkpoint was achieved)
//
// Notice that these are entirely optional and may not be available at all.
// This specification ensures that only extras specified in the Task
// definition will be passed to the agent, everything else is excluded.
// The name of an extra must be unique across all extras.
repeated ArraySpec extras_spec = 10;
}
================================================
FILE: android_env/wrappers/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/wrappers/a11y/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: android_env/wrappers/a11y/a11y_events.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for accessing accessibility events."""
from collections.abc import Mapping
from typing import Any
from absl import logging
from android_env.proto.a11y import a11y_pb2
import numpy as np
from google.protobuf import any_pb2
_A11Y_EVENT_KEY = 'full_event'
def package_events_to_task_extras(
events: list[a11y_pb2.EventRequest],
) -> Mapping[str, np.ndarray]:
if not events:
return {}
events = np.stack(events, axis=0)
return {_A11Y_EVENT_KEY: events}
def extract_events_from_task_extras(
task_extras: Mapping[str, Any] | None = None,
) -> list[Mapping[str, str]]:
"""Inspects task_extras and extracts all accessibility events detected.
Args:
task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a
key in task_extras, then this function returns an empty string. Otherwise,
full_event is expected to be list to be a numpy array with one dimension,
and contains a list of dictionary describing accessibility events that are
present in the given task extras. e.g. 'event_type:
TYPE_WINDOW_CONTENT_CHANGED // event_package_name:
com.google.android.deskclock // source_class_name:
android.widget.ImageView'.
Returns:
List of all events detected
"""
if task_extras is None or _A11Y_EVENT_KEY not in task_extras:
return []
if (
not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray)
or task_extras[_A11Y_EVENT_KEY].ndim != 1
):
raise ValueError(
f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one'
' dimension.'
)
if task_extras[_A11Y_EVENT_KEY].size == 0:
return []
events = []
for e in task_extras[_A11Y_EVENT_KEY]:
if isinstance(e, a11y_pb2.EventRequest):
events.append(dict(e.event))
elif isinstance(e, dict):
events.append(e)
logging.warning(
'The event should come only from the a11y_grpc_wrapper. '
'Please verify that the upacking operation has not been '
'called twice. See here for full task_extras: %s',
task_extras,
)
elif isinstance(e, any_pb2.Any):
ev = a11y_pb2.EventRequest()
new_any = any_pb2.Any()
new_any.CopyFrom(e)
new_any.Unpack(ev)
events.append(dict(ev.event))
else:
raise TypeError(
f'Unexpected event type: {type(e)}. See here for full '
f'task_extras: {task_extras}.'
)
return events
def keep_latest_event_only(task_extras: dict[str, Any]):
"""Removes all a11y events except the last one observed."""
if task_extras is None or 'full_event' not in task_extras:
return
if (
not isinstance(task_extras[_A11Y_EVENT_KEY], np.ndarray)
or task_extras[_A11Y_EVENT_KEY].ndim != 1
):
raise ValueError(
f'{_A11Y_EVENT_KEY} task extra should be a numpy array with one'
' dimension.'
)
if task_extras[_A11Y_EVENT_KEY].size == 0:
return []
task_extras[_A11Y_EVENT_KEY] = task_extras[_A11Y_EVENT_KEY][-1:]
================================================
FILE: android_env/wrappers/a11y/a11y_events_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for a11y_events."""
from absl.testing import absltest
from absl.testing import parameterized
from android_env.proto.a11y import a11y_pb2
from android_env.wrappers.a11y import a11y_events
import numpy as np
from google.protobuf import any_pb2
def _event_request(d: dict[str, str]) -> a11y_pb2.EventRequest:
event_request = a11y_pb2.EventRequest()
for k, v in d.items():
event_request.event[k] = v
return event_request
def _event_request_as_any(d: dict[str, str]) -> any_pb2.Any:
event_request = _event_request(d)
response = any_pb2.Any()
response.Pack(event_request)
return response
class A11yEventsTest(parameterized.TestCase):
@parameterized.parameters(
dict(task_extras={}),
dict(
task_extras={'no_full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}]},
),
dict(
task_extras={'full_event': np.array([])},
),
dict(
task_extras={},
),
)
def test_no_events_in_task_extras(self, task_extras):
events = a11y_events.extract_events_from_task_extras(task_extras)
self.assertEmpty(events)
@parameterized.parameters(
dict(
task_extras={'full_event': [{'1': '1'}, {'2': '2'}]},
expected_events=[{'1': '1'}, {'2': '2'}],
),
dict(
task_extras={'full_event': [{}]},
expected_events=[{}],
),
dict(
task_extras={
'full_event_wrong_key': [1, 2, 3],
'full_event': [{'1': '1'}, {'2': '2'}, {'3': '3'}],
},
expected_events=[{'1': '1'}, {'2': '2'}, {'3': '3'}],
),
)
def test_task_extras(self, task_extras, expected_events):
event_requests = [_event_request(e) for e in task_extras['full_event']]
task_extras['full_event'] = np.stack(event_requests, axis=0)
events = a11y_events.extract_events_from_task_extras(task_extras)
self.assertEqual(len(events), len(expected_events))
for i, event in enumerate(expected_events):
self.assertEqual(len(event), len(expected_events[i]))
for k, v in event.items():
self.assertIn(k, expected_events[i])
self.assertEqual(v, expected_events[i][k])
def test_events_key_has_dict_event_requrests(self):
event_requests = [
_event_request({'1': '1'}),
{'2': '2'},
_event_request({'3': '3'}),
]
expected_events = [
{'1': '1'},
{'2': '2'},
{'3': '3'},
]
task_extras = {'full_event': np.stack(event_requests, axis=0)}
events = a11y_events.extract_events_from_task_extras(task_extras)
self.assertEqual(len(events), len(expected_events))
for i, event in enumerate(expected_events):
self.assertEqual(len(event), len(expected_events[i]))
for k, v in event.items():
self.assertIn(k, expected_events[i])
self.assertEqual(v, expected_events[i][k])
def test_events_key_has__event_requrests_packed_as_any(self):
event_requests = [
_event_request_as_any({'1': '1'}),
{'2': '2'},
_event_request_as_any({'3': '3'}),
]
expected_events = [
{'1': '1'},
{'2': '2'},
{'3': '3'},
]
task_extras = {'full_event': np.stack(event_requests, axis=0)}
events = a11y_events.extract_events_from_task_extras(task_extras)
self.assertEqual(len(events), len(expected_events))
for i, event in enumerate(expected_events):
self.assertEqual(len(event), len(expected_events[i]))
for k, v in event.items():
self.assertIn(k, expected_events[i])
self.assertEqual(v, expected_events[i][k])
def test_events_key_has_non_event_requrests(self):
event_requests = [
_event_request({'1': '1'}),
3, # Not an even and not a dict.
_event_request({'3': '3'}),
]
task_extras = {'full_event': np.stack(event_requests, axis=0)}
with self.assertRaises(TypeError):
_ = a11y_events.extract_events_from_task_extras(task_extras)
@parameterized.parameters(
dict(task_extras={}, expected_extras={}),
dict(
task_extras={
'no_full_event': 42,
},
expected_extras={
'no_full_event': 42,
},
),
dict(
task_extras={'full_event': np.array([1, 2]), 'no_full_event': 43},
expected_extras={'full_event': np.array([2]), 'no_full_event': 43},
),
dict(
task_extras={'full_event': np.array([1, 2, 3])},
expected_extras={'full_event': np.array([3])},
),
dict(
task_extras={'full_event': np.array([]), 'no_full_event': 44},
expected_extras={'full_event': np.array([]), 'no_full_event': 44},
),
)
def test_keep_latest_only(self, task_extras, expected_extras):
a11y_events.keep_latest_event_only(task_extras)
self.assertEqual(len(task_extras), len(expected_extras))
for k, v in task_extras.items():
self.assertIn(k, expected_extras)
if k == 'full_event':
np.testing.assert_array_equal(v, expected_extras['full_event'])
else:
self.assertEqual(v, expected_extras[k])
pass
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/a11y/a11y_forests.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tools for accessing accessibility events."""
from collections.abc import Mapping
from typing import Any
from android_env.proto.a11y import android_accessibility_forest_pb2
import numpy as np
from google.protobuf import any_pb2
_A11Y_FORESTS_KEY = 'accessibility_tree'
def package_forests_to_task_extras(
forests: list[android_accessibility_forest_pb2.AndroidAccessibilityForest],
) -> Mapping[str, np.ndarray]:
if not forests:
return {}
forests = np.stack(forests, axis=0)
return {_A11Y_FORESTS_KEY: forests}
def task_extras_has_forests(task_extras: Mapping[str, Any]) -> bool:
"""Checks that the task_extras has any a11y forest information."""
if _A11Y_FORESTS_KEY not in task_extras:
return False
payload = task_extras[_A11Y_FORESTS_KEY]
if not isinstance(payload, np.ndarray) or payload.ndim != 1:
raise ValueError(
f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one'
f' dimension. payload: {payload}'
)
if payload.size == 0:
return False
if any(isinstance(f, any_pb2.Any) for f in payload):
# Forests were packed as Any.
return True
return any(
isinstance(f, android_accessibility_forest_pb2.AndroidAccessibilityForest)
for f in payload
)
def convert_to_forest(
forest: android_accessibility_forest_pb2.AndroidAccessibilityForest
| any_pb2.Any
| None,
) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None:
"""Takes an object and attempts to convert it to a forest."""
if forest is None:
return None
if isinstance(forest, any_pb2.Any):
output = android_accessibility_forest_pb2.AndroidAccessibilityForest()
new_any = any_pb2.Any()
new_any.CopyFrom(forest)
new_any.Unpack(output)
return output
elif isinstance(
forest, android_accessibility_forest_pb2.AndroidAccessibilityForest
):
return forest
else:
return None
def extract_forests_from_task_extras(
task_extras: Mapping[str, Any] | None = None,
) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]:
"""Inspects task_extras and extracts all accessibility forests detected.
Args:
task_extras: Task extras forwarded by AndroidEnv. If 'full_event' is not a
key in task_extras, then this function returns an empty string. Otherwise,
full_event is expected to be list to be a numpy array with one dimension,
and contains a list of dictionary describing accessibility forests that
are present in the given task extras.
Returns:
List of all forests detected
"""
if task_extras is None or not task_extras_has_forests(task_extras):
return []
forests = []
for f in task_extras[_A11Y_FORESTS_KEY]:
f = convert_to_forest(f)
if f is not None:
forests.append(f)
return forests
def keep_latest_forest_only(task_extras: dict[str, Any]):
"""Removes all a11y forests except the last one observed."""
if _A11Y_FORESTS_KEY not in task_extras.keys():
return
payload = task_extras[_A11Y_FORESTS_KEY]
if not isinstance(payload, np.ndarray) or payload.ndim != 1:
raise ValueError(
f'{_A11Y_FORESTS_KEY} task extra should be a numpy array with one'
f' dimension. payload: {payload}'
)
if payload.size == 0:
return
task_extras[_A11Y_FORESTS_KEY] = payload[-1:]
================================================
FILE: android_env/wrappers/a11y/a11y_forests_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for a11y_forests."""
from absl.testing import absltest
from absl.testing import parameterized
from android_env.proto.a11y import android_accessibility_forest_pb2
from android_env.wrappers.a11y import a11y_forests
import numpy as np
from google.protobuf import any_pb2
def _pack_any(proto_message) -> any_pb2.Any:
response = any_pb2.Any()
response.Pack(proto_message)
return response
def _empty_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
return android_accessibility_forest_pb2.AndroidAccessibilityForest()
def _one_empty_window_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
forest.windows.add()
return forest
def _two_window_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
window = forest.windows.add()
window.tree.nodes.add(
class_name='foo', is_clickable=True, hint_text='Foo hint'
)
forest.windows.add()
return forest
class A11YForestsTest(parameterized.TestCase):
@parameterized.parameters(
dict(task_extras={}, expected_forests=[], convert_to_np=[]),
dict(
task_extras={'accessibility_tree': []},
convert_to_np=['accessibility_tree'],
expected_forests=[],
),
dict(
task_extras={
'not_accessibility_tree': [
_empty_forest(),
_one_empty_window_forest(),
_two_window_forest(),
],
},
convert_to_np=['not_accessibility_tree'],
expected_forests=[],
),
dict(
task_extras={
'accessibility_tree': [
_empty_forest(),
{'not_a_forest_key': 'nor_a_forest_value'},
_two_window_forest(),
]
},
convert_to_np=['accessibility_tree'],
expected_forests=[_empty_forest(), _two_window_forest()],
),
dict(
task_extras={
'accessibility_tree': [
{'not_a_forest_key': 'nor_a_forest_value'},
3,
4,
{'not_a_forest_key': _empty_forest()},
],
},
convert_to_np=['accessibility_tree'],
expected_forests=[],
),
dict(
task_extras={'accessibility_tree': []},
convert_to_np=['accessibility_tree'],
expected_forests=[],
),
dict(
task_extras={
'accessibility_tree_wrong_key': [1, 2, 3],
'accessibility_tree': [
_empty_forest(),
None,
None,
_one_empty_window_forest(),
_two_window_forest(),
],
},
convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'],
expected_forests=[
_empty_forest(),
_one_empty_window_forest(),
_two_window_forest(),
],
),
dict(
task_extras={
'accessibility_tree_wrong_key': [1, 2, 3],
'accessibility_tree': [
None,
_pack_any(_empty_forest()),
_pack_any(_one_empty_window_forest()),
_pack_any(_two_window_forest()),
],
},
convert_to_np=['accessibility_tree', 'accessibility_tree_wrong_key'],
expected_forests=[
_empty_forest(),
_one_empty_window_forest(),
_two_window_forest(),
],
),
dict(
task_extras={
'accessibility_tree': [
_pack_any(_empty_forest()),
{'not_a_forest_key': 'nor_a_forest_value'},
None,
_two_window_forest(),
None,
]
},
convert_to_np=['accessibility_tree'],
expected_forests=[_empty_forest(), _two_window_forest()],
),
)
def test_task_extras(self, task_extras, expected_forests, convert_to_np):
for k in convert_to_np:
if task_extras[k]:
task_extras[k] = np.stack(task_extras[k], axis=0)
else:
task_extras[k] = np.array([])
forests = a11y_forests.extract_forests_from_task_extras(task_extras)
self.assertEqual(len(forests), len(expected_forests))
for idx, f in enumerate(forests):
self.assertEqual(f, expected_forests[idx])
@parameterized.parameters(
dict(task_extras={}, expected_extras={}),
dict(
task_extras={
'no_accessibility_tree': 42,
},
expected_extras={
'no_accessibility_tree': 42,
},
),
dict(
task_extras={'accessibility_tree': []},
expected_extras={'accessibility_tree': []},
),
dict(
task_extras={
'accessibility_tree': [
_empty_forest(),
_one_empty_window_forest(),
],
'no_accessibility_tree': 43,
},
expected_extras={
'accessibility_tree': [_one_empty_window_forest()],
'no_accessibility_tree': 43,
},
),
dict(
task_extras={
'accessibility_tree': [
_empty_forest(),
_one_empty_window_forest(),
_two_window_forest(),
]
},
expected_extras={'accessibility_tree': [_two_window_forest()]},
),
dict(
task_extras={
'accessibility_tree': [],
'no_accessibility_tree': 44,
},
expected_extras={
'accessibility_tree': [],
'no_accessibility_tree': 44,
},
),
)
def test_keep_latest_only(self, task_extras, expected_extras):
if 'accessibility_tree' in task_extras:
if task_extras['accessibility_tree']:
task_extras['accessibility_tree'] = np.stack(
task_extras['accessibility_tree'], axis=0
)
else:
task_extras['accessibility_tree'] = np.array([])
a11y_forests.keep_latest_forest_only(task_extras)
self.assertSameElements(task_extras.keys(), expected_extras.keys())
for k in task_extras.keys():
if k == 'accessibility_tree':
self.assertEqual(len(task_extras[k]), len(expected_extras[k]))
for idx, f in enumerate(task_extras[k]):
self.assertEqual(f, expected_extras[k][idx])
else:
self.assertEqual(task_extras[k], expected_extras[k])
pass
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/a11y/a11y_servicer.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Accessibility Servicer implementation."""
import asyncio
from collections.abc import AsyncIterator, Generator, Iterable
import threading
from absl import logging
from android_env.proto.a11y import a11y_pb2
from android_env.proto.a11y import a11y_pb2_grpc
from android_env.proto.a11y import android_accessibility_forest_pb2
import grpc
class A11yServicer(a11y_pb2_grpc.A11yServiceServicer):
"""Services the A11yService requests."""
def __init__(self, latest_forest_only: bool = False):
self._received_forests: list[
android_accessibility_forest_pb2.AndroidAccessibilityForest
] = []
self._received_events: list[a11y_pb2.EventRequest] = []
self._lock_forests = threading.Lock()
self._lock_events = threading.Lock()
self._latest_forest_only = latest_forest_only
self._paused = True
# A11y Forest bookkeeping.
self._get_forest = asyncio.Event() # Whether to request a forest.
self._forest_ready = asyncio.Event() # Whether the forest is ready.
self._latest_forest: (
android_accessibility_forest_pb2.AndroidAccessibilityForest | None
) = None
def SendForest(
self,
request: android_accessibility_forest_pb2.AndroidAccessibilityForest,
context: grpc.ServicerContext,
) -> a11y_pb2.ForestResponse:
self._process_forest(request)
return a11y_pb2.ForestResponse()
def SendEvent(
self,
request: a11y_pb2.EventRequest,
context: grpc.ServicerContext,
) -> a11y_pb2.EventResponse:
self._process_event(request)
return a11y_pb2.EventResponse()
async def Bidi(
self,
request_iterator: AsyncIterator[a11y_pb2.ClientToServer],
context: grpc.aio.ServicerContext,
) -> AsyncIterator[a11y_pb2.ServerToClient]:
"""Processes incoming ClientToServer requests."""
logging.info('Starting A11yServicer.Bidi()')
# Send a dummy message to unblock clients in their loop.
yield a11y_pb2.ServerToClient()
# This block defines two coroutines:
#
# * `read_client_requests()`
# * `check_forest()`
#
# They cooperate with each other and both populate a queue `q` which is
# consumed in a loop below, which actually yields requests which are sent to
# the client. The processing finishes when the clients "closes" the
# connection, which causes `read_client_requests()` to put a special value,
# `STOP_ITERATION`, in the queue.
# Queue for communicating from coroutines to `Bidi()`.
q = asyncio.Queue()
should_run = True
async def read_client_requests():
"""Coroutine for reading client requests."""
nonlocal should_run
async for request in request_iterator:
field_name = request.WhichOneof('payload')
match field_name:
case 'event':
self._process_event(request.event)
case 'forest':
self._latest_forest = request.forest
self._forest_ready.set()
self._get_forest.clear() # Reset the `Event`.
case _:
logging.error('Unknown field %r', field_name)
await q.put(a11y_pb2.ServerToClient())
# Send a special value to stop processing this `Bidi` connection.
await q.put('STOP_ITERATION')
should_run = False
async def check_forest():
"""Coroutine for sending "get forest" requests."""
nonlocal should_run
while should_run:
await self._get_forest.wait()
await q.put(a11y_pb2.ServerToClient(get_forest={}))
tasks = asyncio.gather(read_client_requests(), check_forest())
while should_run:
v = await q.get()
if v == 'STOP_ITERATION':
break
else:
yield v
await tasks
logging.info('Finishing A11yServicer.Bidi()')
async def get_forest(
self,
) -> android_accessibility_forest_pb2.AndroidAccessibilityForest | None:
"""Issues a request to get the a11y forest from the client."""
self._get_forest.set() # Unblocks coroutine to send a request.
await self._forest_ready.wait() # Wait for forest to be ready.
self._forest_ready.clear() # Reset the `Event`.
return self._latest_forest
def gather_forests(
self,
) -> list[android_accessibility_forest_pb2.AndroidAccessibilityForest]:
forests = []
with self._lock_forests:
forests = self._received_forests
self._received_forests = []
return forests
def gather_events(self) -> list[a11y_pb2.EventRequest]:
events = []
with self._lock_events:
events = self._received_events
self._received_events = []
return events
def pause_and_clear(self) -> None:
"""Temporarily stop receiving events/forests and clear the queue.
Used when resetting the environment; in this case:
- all events/forests that have been received since last timestep are things
that happened in the last episode after its `LAST` timestep (so we should
ignore them, done by clearing the lists).
- we're about to receive a bunch of events/forests just as a result of
resetting the environment. We don't want to count these either; thus we
temporarily stop receiving new ones.
"""
self._paused = True
with self._lock_forests:
self._received_forests = []
with self._lock_events:
self._received_events = []
def resume(self) -> None:
"""Start receiving events/forests (e.g., after a reset)."""
self._paused = False
def _process_event(self, event: a11y_pb2.EventRequest) -> None:
"""Adds the given event to the internal buffer of events."""
if not self._paused:
with self._lock_events:
self._received_events.append(event)
def _process_forest(
self, forest: android_accessibility_forest_pb2.AndroidAccessibilityForest
) -> None:
"""Adds the given forest to the internal buffer of forests."""
if not self._paused:
with self._lock_forests:
if self._latest_forest_only:
self._received_forests = [forest]
else:
self._received_forests.append(forest)
================================================
FILE: android_env/wrappers/a11y/a11y_servicer_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for a11y_servicer."""
import asyncio
from collections.abc import AsyncIterator, Iterable
from typing import TypeVar
from unittest import IsolatedAsyncioTestCase, mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env.proto.a11y import a11y_pb2
from android_env.proto.a11y import android_accessibility_forest_pb2
from android_env.wrappers.a11y import a11y_servicer
import grpc
_T = TypeVar('_T')
async def _aiter(xs: Iterable[_T]) -> AsyncIterator[_T]:
"""Utility to make an AsyncIterator from Iterable."""
for x in xs:
yield x
def one_window_one_node_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
window = forest.windows.add()
node = window.tree.nodes.add()
node.class_name = 'foo'
node.is_clickable = True
node.hint_text = 'Foo hint'
return forest
def one_window_two_nodes_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
window = forest.windows.add()
node = window.tree.nodes.add()
node.class_name = 'bar'
node.is_clickable = True
node.hint_text = 'Bar hint'
node = window.tree.nodes.add()
node.class_name = 'bar'
node.is_clickable = False
node.hint_text = 'Bar hint 2'
return forest
def empty_dict() -> dict[str, str]:
return {}
def single_item_dict_with_special_chars() -> dict[str, str]:
return {'foo': 'bar\r\t\nbaz'}
class A11yServicerTest(parameterized.TestCase, IsolatedAsyncioTestCase):
def test_servicer_sendforest(self):
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
servicer = a11y_servicer.A11yServicer()
servicer.resume()
response = servicer.SendForest(one_window_one_node_forest(), mock_context)
self.assertEqual(response.error, '')
response = servicer.SendForest(one_window_two_nodes_forest(), mock_context)
self.assertEqual(response.error, '')
forests = servicer.gather_forests()
self.assertLen(forests, 2)
self.assertEqual(forests[0], one_window_one_node_forest())
self.assertEqual(forests[1], one_window_two_nodes_forest())
async def test_servicer_bidi_forests(self):
"""Checks that the bidirectional interface accepts forests."""
# Arrange.
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
servicer = a11y_servicer.A11yServicer()
# Act.
servicer.resume()
responses = [
x
async for x in servicer.Bidi(
_aiter([
a11y_pb2.ClientToServer(
event=a11y_pb2.EventRequest(
event=single_item_dict_with_special_chars()
)
),
a11y_pb2.ClientToServer(forest=one_window_two_nodes_forest()),
]),
mock_context,
)
]
forest = await servicer.get_forest()
# Assert.
self.assertEqual(responses[0], a11y_pb2.ServerToClient())
self.assertEqual(responses[1], a11y_pb2.ServerToClient())
self.assertIsNotNone(forest)
self.assertEqual(forest, one_window_two_nodes_forest())
def test_servicer_sendforest_latest_only(self):
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
servicer = a11y_servicer.A11yServicer(latest_forest_only=True)
servicer.resume()
response = servicer.SendForest(one_window_one_node_forest(), mock_context)
self.assertEqual(response.error, '')
response = servicer.SendForest(one_window_two_nodes_forest(), mock_context)
self.assertEqual(response.error, '')
forests = servicer.gather_forests()
self.assertLen(forests, 1)
self.assertEqual(forests[0], one_window_two_nodes_forest())
def test_servicer_sendevent(self):
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
servicer = a11y_servicer.A11yServicer()
servicer.resume()
response = servicer.SendEvent(
a11y_pb2.EventRequest(event=empty_dict()), mock_context
)
self.assertEqual(response.error, '')
response = servicer.SendEvent(
a11y_pb2.EventRequest(event=single_item_dict_with_special_chars()),
mock_context,
)
self.assertEqual(response.error, '')
events = servicer.gather_events()
self.assertLen(events, 2)
self.assertEqual(events[0].event, empty_dict())
self.assertEqual(events[1].event, single_item_dict_with_special_chars())
async def test_servicer_bidi_events(self):
"""Checks that the bidirectional interface accepts events."""
# Arrange.
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
servicer = a11y_servicer.A11yServicer()
# Act.
servicer.resume()
responses = [
x
async for x in servicer.Bidi(
_aiter([
a11y_pb2.ClientToServer(
event=a11y_pb2.EventRequest(event=empty_dict())
),
a11y_pb2.ClientToServer(
event=a11y_pb2.EventRequest(
event=single_item_dict_with_special_chars()
)
),
]),
mock_context,
)
]
events = servicer.gather_events()
# Assert.
self.assertEqual(responses[0], a11y_pb2.ServerToClient())
self.assertEqual(responses[1], a11y_pb2.ServerToClient())
self.assertLen(events, 2)
self.assertEqual(events[0].event, empty_dict())
self.assertEqual(events[1].event, single_item_dict_with_special_chars())
def test_servicer_pause_and_clear_pauses(self):
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
servicer = a11y_servicer.A11yServicer()
servicer.resume()
servicer.pause_and_clear()
response = servicer.SendEvent(
a11y_pb2.EventRequest(event=empty_dict()), mock_context
)
self.assertEqual(response.error, '')
response = servicer.SendForest(one_window_one_node_forest(), mock_context)
self.assertEqual(response.error, '')
events = servicer.gather_events()
self.assertEmpty(events)
forests = servicer.gather_forests()
self.assertEmpty(forests)
def test_servicer_pause_and_clear_clears(self):
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
servicer = a11y_servicer.A11yServicer()
servicer.resume()
response = servicer.SendEvent(
a11y_pb2.EventRequest(event=empty_dict()), mock_context
)
self.assertEqual(response.error, '')
response = servicer.SendForest(one_window_one_node_forest(), mock_context)
self.assertEqual(
response.error,
'',
)
servicer.pause_and_clear()
events = servicer.gather_events()
self.assertEmpty(events)
forests = servicer.gather_forests()
self.assertEmpty(forests)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/a11y_grpc_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wraps AndroidEnv to retrieve accessibility messages from gRPC."""
from concurrent import futures
import time
from typing import Any
import urllib.request
from absl import logging
from android_env import env_interface
from android_env.components import action_type as android_action_type_lib
from android_env.proto import adb_pb2
from android_env.proto.a11y import a11y_pb2_grpc
from android_env.wrappers import base_wrapper
from android_env.wrappers.a11y import a11y_events
from android_env.wrappers.a11y import a11y_forests
from android_env.wrappers.a11y import a11y_servicer
import dm_env
import grpc
import numpy as np
import portpicker
def _get_accessibility_forwarder_apk() -> bytes:
logging.info('Downloading accessibility forwarder apk....')
with urllib.request.urlopen(
'https://storage.googleapis.com/android_env-tasks/2024.05.13-accessibility_forwarder.apk'
) as response:
return response.read()
class EnableNetworkingError(ValueError):
pass
class A11yGrpcWrapper(base_wrapper.BaseWrapper):
"""Wrapper which receives A11y events and forests over gRPC.
A11y forest protobufs and event dicts are sent from the Android emulator via
gRPC from the `AccessibilityForwarder` (for use in developing reward
functions, etc). This wrapper constructs a server which receives these
messages and channels them into `task_extras`.
The downside of forwarding this information through gRPC is that no messages
will be sent if networking is turned off (e.g., if the AVD is in airplane
mode). To mitigate this problem, the `AccessibilityForwarder` logs an error
message if it fails to contact the server. This wrapper monitors the logs for
such error messages, and attempts (in another thread, to not block environment
transitions) to reconnect the AVD to the network. If this fails to fix the
problem, this wrapper ends the episode.
This wrapper is implemented to be robust to multiple upstream callers of
`task_extras`, and to ensure they each receive the same extras at every
timestep. Thus, the logic is the following:
* New a11y events/forests are fetched during `reset` and `step`, *not* during
`task_extras()` calls.
* If no one has called `task_extras()` since the last `step` or `reset`, the
extras are accumulated (so that no extras are missed because someone called
`step()` twice without calling `task_extras()`).
* If someone *has* called `task_extras()` since last step, the newly fetched
extras replace the old extras.
"""
def __init__(
self,
env: env_interface.AndroidEnvInterface,
disable_other_network_traffic: bool = False,
install_a11y_forwarding: bool = False,
start_a11y_service: bool = True,
enable_a11y_tree_info: bool = False,
add_latest_a11y_info_to_obs: bool = False,
a11y_info_timeout: float | None = None,
max_enable_networking_attempts: int = 10,
latest_a11y_info_only: bool = False,
grpc_server_ip: str = '10.0.2.2',
):
"""Initializes wrapper.
Args:
env: Environment to wrap.
disable_other_network_traffic: When True, all network traffic, other than
the connection to the servicer, is disabled. NOTE: This requires root
access on the device (i.e. it uses the `su` command). An
`AdbControllerError` exception will be raised if the underlying command
fails.
install_a11y_forwarding: If True, the wrapper handles the installation of
all packages required for the servicer to collect a11y information.
start_a11y_service: If True, starts the a11y forwarding services. NOTE:
The packages must be installed beforehand, e.g., using the
install_a11y_forwarding flag.
enable_a11y_tree_info: When False, this wrapper collects only a11y events
and not a11y tree.
add_latest_a11y_info_to_obs: When True, the latest observed a11y forest is
added to the observation.
a11y_info_timeout: When larger than zero and add_latest_a11y_info_to_obs
is set to True, the wrapper will wait the corresponding amount of time,
measured in seconds, to collect the latest a11y forest.
max_enable_networking_attempts: When the a11y gRPC service fails to
provide a11y information, we attempt this many times to re-enable the
networking. If all these attempts fail, fetching task_extras will raise
an EnableNetworkingError.
latest_a11y_info_only: When True, the a11y servicer is setup to save only
the latest tree it has received from the Android app.
grpc_server_ip: The IP address of the gRPC server which will be
broadcasted to the AccessibilityForwarder app where it should log the
a11y info. By default, this is set to the IP address of the AVD's host
machine which is 10.0.2.2: See
https://developer.android.com/studio/run/emulator-networking#networkaddresses.
"""
self._env = env
self._grpc_server_ip = grpc_server_ip
if install_a11y_forwarding:
self._install_a11y_forwarding_apk()
time.sleep(10.0)
if start_a11y_service:
self._start_a11y_services()
time.sleep(3.0)
if enable_a11y_tree_info:
self._enable_a11y_tree_logs()
self._relaunch_count = 0
self._server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
self._servicer = a11y_servicer.A11yServicer(
latest_forest_only=latest_a11y_info_only
)
a11y_pb2_grpc.add_A11yServiceServicer_to_server(
self._servicer, self._server
)
server_credentials = grpc.local_server_credentials()
self._port = portpicker.pick_unused_port()
logging.info('Using port %s', self._port)
uri_address = f'[::]:{self._port}'
self._server.add_secure_port(uri_address, server_credentials)
logging.info('Starting server')
self._server.start()
logging.info('Server now running.')
self._max_enable_networking_attempts = max_enable_networking_attempts
self._reset_enable_networking_attempts()
self._disable_other_network_traffic = disable_other_network_traffic
self._should_accumulate = False
self._accumulated_extras = None
self._add_latest_a11y_info_to_obs = add_latest_a11y_info_to_obs
self._a11y_info_timeout = a11y_info_timeout
self._parent_action_spec = self._env.action_spec()
if self._a11y_info_timeout is not None and self._a11y_info_timeout > 0.0:
if 'action_type' not in self._parent_action_spec.keys():
raise ValueError(
'action_type not in the parent action spec: '
f'{self._parent_action_spec}. This is a strong requirement when '
f'a11y_info_timeout = {a11y_info_timeout} > 0'
)
def _start_a11y_services(self) -> None:
"""Starts the accessibility forwarder services.
Raises:
RuntimeError: If accessibility service is not started.
"""
start_service_request = adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.SECURE,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='enabled_accessibility_services',
value=(
'com.google.androidenv.accessibilityforwarder/com.google.'
'androidenv.accessibilityforwarder.AccessibilityForwarder'
),
),
)
)
start_service_response = self._env.execute_adb_call(start_service_request)
if start_service_response.status != adb_pb2.AdbResponse.Status.OK:
raise RuntimeError(
'Could not start accessibility forwarder '
'service: '
f'{start_service_response}.'
)
def _install_a11y_forwarding_apk(self) -> None:
"""Enables accessibility information forwarding."""
a11y_fwd_apk = _get_accessibility_forwarder_apk()
# Install and setup the Accesssibility Forwarder.
install_request = adb_pb2.AdbRequest(
install_apk=adb_pb2.AdbRequest.InstallApk(
blob=adb_pb2.AdbRequest.InstallApk.Blob(contents=a11y_fwd_apk),
)
)
install_response = self._env.execute_adb_call(install_request)
if install_response.status != adb_pb2.AdbResponse.Status.OK:
raise ValueError(
f'Could not install accessibility_forwarder.apk: {install_response}.'
)
def _enable_a11y_tree_logs(self) -> None:
enable_tree_logs_request = adb_pb2.AdbRequest(
send_broadcast=adb_pb2.AdbRequest.SendBroadcast(
action=(
'accessibility_forwarder.intent.action.'
'ENABLE_ACCESSIBILITY_TREE_LOGS'
),
component=(
'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver'
),
)
)
enable_tree_logs_response = self._env.execute_adb_call(
enable_tree_logs_request
)
if enable_tree_logs_response.status != adb_pb2.AdbResponse.Status.OK:
raise ValueError(
'Could not enable accessibility tree logging: '
f'{enable_tree_logs_response}.'
)
def _reset_enable_networking_attempts(self) -> None:
self._enable_networking_attempts_left = self._max_enable_networking_attempts
self._enabling_networking_future = None
self._a11y_exception = None
def get_port(self):
return self._port
def close(self):
self._server.stop(None)
logging.info('gRPC server stopped')
self._env.close()
def attempt_enable_networking(self) -> None:
"""Attempts to turn on networking within the Android device.
Attempt to turn on the networking in the Android device, by:
- turning off airplane mode;
- turning on the wifi connection.
"""
self.execute_adb_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='airplane_mode_on', value='0'
),
)
)
)
time.sleep(1.0)
self.execute_adb_call(
adb_pb2.AdbRequest(
generic=adb_pb2.AdbRequest.GenericRequest(
args=[
'shell',
'svc',
'wifi',
'enable',
]
)
)
)
time.sleep(1.0)
def _configure_grpc(self) -> None:
"""Configure networking and set the gRPC ip and port on AVD or device."""
if self._disable_other_network_traffic:
self.execute_adb_call(
adb_pb2.AdbRequest(
generic=adb_pb2.AdbRequest.GenericRequest(
args=[
'shell',
'su',
'0',
'iptables',
'-A',
'OUTPUT',
'-p',
'tcp',
'-d',
self._grpc_server_ip,
'--dport',
str(self._port),
'-j',
'ACCEPT',
]
)
)
)
time.sleep(3.0)
self.execute_adb_call(
adb_pb2.AdbRequest(
generic=adb_pb2.AdbRequest.GenericRequest(
args=[
'shell',
'su',
'0',
'iptables',
'-A',
'OUTPUT',
'-j',
'DROP',
]
)
)
)
time.sleep(3.0)
self.execute_adb_call(
adb_pb2.AdbRequest(
settings=adb_pb2.AdbRequest.SettingsRequest(
name_space=adb_pb2.AdbRequest.SettingsRequest.Namespace.GLOBAL,
put=adb_pb2.AdbRequest.SettingsRequest.Put(
key='no_proxy', value=f'{self._grpc_server_ip}:{self._port}'
),
)
)
)
self.attempt_enable_networking()
self.execute_adb_call(
adb_pb2.AdbRequest(
send_broadcast=adb_pb2.AdbRequest.SendBroadcast(
action=(
'accessibility_forwarder.intent.action.SET_GRPC --ei'
f' "port" {self._port} --es "host" {self._grpc_server_ip}'
),
component=(
'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver'
),
)
)
)
def _accumulate_and_return_a11y_info(
self, timer: float | None = None, get_env_observation: bool = True
) -> dict[str, Any]:
"""Accumulates and returns the latest a11y tree info and observation.
Args:
timer: If larger than 0, the system will wait this long for a11y info to
accumulate before it returns a value.
get_env_observation: If False, the corresponding observation is not
introduced here.
Returns:
a dict with a11y forest under key 'a11y_forest'. All other fields will
provide the observation, if requested.
"""
timer = timer or 0.0
if timer > 0.0:
time.sleep(timer)
if get_env_observation:
# Fetch observation.
new_ts = self._env.step({
'action_type': np.array(
android_action_type_lib.ActionType.REPEAT,
dtype=self._parent_action_spec['action_type'].dtype,
),
})
observation = new_ts.observation
else:
observation = {}
extras = self.accumulate_new_extras()
forests = a11y_forests.extract_forests_from_task_extras(extras)
if forests:
observation['a11y_forest'] = forests[-1]
else:
observation['a11y_forest'] = None
return observation
def _fetch_task_extras_and_update_observation(
self, observation: dict[str, Any], timeout: float = 0.0
) -> dict[str, Any]:
if timeout > 0.0:
observation = self._accumulate_and_return_a11y_info(
timeout, get_env_observation=True
)
if not self._add_latest_a11y_info_to_obs:
observation.pop('a11y_forest')
else:
new_obs = self._accumulate_and_return_a11y_info(get_env_observation=False)
if self._add_latest_a11y_info_to_obs:
observation.update(new_obs)
return observation
def reset(self) -> dm_env.TimeStep:
self._reset_enable_networking_attempts()
self._servicer.pause_and_clear()
timestep = self._env.reset()
self._servicer.resume()
if self._env.stats()['relaunch_count'] > self._relaunch_count:
self._configure_grpc()
self._relaunch_count = self._env.stats()['relaunch_count']
self._accumulated_extras = {}
timeout = self._a11y_info_timeout or 0.0
new_observation = self._fetch_task_extras_and_update_observation(
timestep.observation, timeout
)
timestep = timestep._replace(observation=new_observation)
return timestep
def step(self, action: Any) -> dm_env.TimeStep:
timeout = float(action.pop('wait_time', self._a11y_info_timeout or 0.0))
timestep = self._env.step(action)
new_observation = self._fetch_task_extras_and_update_observation(
timestep.observation, timeout=timeout
)
timestep = timestep._replace(observation=new_observation)
return timestep
def accumulate_new_extras(self) -> dict[str, Any]:
new_extras = self._fetch_task_extras()
if self._should_accumulate:
for key in new_extras:
if key in self._accumulated_extras:
self._accumulated_extras[key] = np.concatenate(
(self._accumulated_extras[key], new_extras[key]), axis=0
)
else:
self._accumulated_extras[key] = new_extras[key]
else:
self._accumulated_extras = new_extras
self._should_accumulate = True
return self._accumulated_extras
def _fetch_task_extras(self) -> dict[str, Any]:
"""Fetches task_extras from the services.
NOTE: If you want to access the latest a11y information, please use
accumulate_and_return_a11y_info instead. This function has the side effect
of clearing the content from the servicer, hence all the a11y info returned
here won't be accumulated.
Returns:
A dict with the corresponding task_extras.
Raises:
EnableNetworkingError: after a fixed number of attempts to revive the a11y
services by re-enabling the network connection.
"""
base_extras = self._env.task_extras(latest_only=False).copy()
# If the previous future is done, reset it to the initial state.
if (
self._enabling_networking_future is not None
and self._enabling_networking_future.done()
):
self._enabling_networking_future = None
self._enable_networking_attempts_left -= 1
logging.info('Finished enabling networking.')
if (
self._enabling_networking_future is None
and 'exception' in base_extras
and base_extras['exception'].shape[0]
):
self._a11y_exception = base_extras['exception']
logging.warning(
'AccessibilityForwarder logged exceptions: %s', self._a11y_exception
)
if self._enable_networking_attempts_left > 0:
logging.warning(
'Attempting to enable networking. %s attempts left.',
self._enable_networking_attempts_left - 1,
)
executor = futures.ThreadPoolExecutor(max_workers=1)
self._enabling_networking_future = executor.submit(
self.attempt_enable_networking
)
else:
raise EnableNetworkingError(
'A11y service failed multiple times with'
f' exception.{self._a11y_exception}.'
)
forests = self._servicer.gather_forests()
if forests:
base_extras.update(a11y_forests.package_forests_to_task_extras(forests))
self._reset_enable_networking_attempts()
events = self._servicer.gather_events()
if events:
base_extras.update(a11y_events.package_events_to_task_extras(events))
self._reset_enable_networking_attempts()
return base_extras
def task_extras(self, latest_only: bool = False) -> dict[str, Any]:
if self._accumulated_extras is None:
raise RuntimeError('You must call .reset() before calling .task_extras()')
self._should_accumulate = False
extras = self._accumulated_extras.copy()
if latest_only:
a11y_events.keep_latest_event_only(extras)
a11y_forests.keep_latest_forest_only(extras)
return extras
================================================
FILE: android_env/wrappers/a11y_grpc_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for a11y_grpc_wrapper."""
import time
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env import env_interface
from android_env.proto import adb_pb2
from android_env.proto.a11y import a11y_pb2
from android_env.proto.a11y import a11y_pb2_grpc
from android_env.proto.a11y import android_accessibility_forest_pb2
from android_env.wrappers import a11y_grpc_wrapper
import dm_env
import grpc
import numpy as np
def empty_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
return android_accessibility_forest_pb2.AndroidAccessibilityForest()
def one_empty_window_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
_ = forest.windows.add()
return forest
def one_window_one_node_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
window = forest.windows.add()
node = window.tree.nodes.add()
node.class_name = 'foo'
node.is_clickable = True
node.hint_text = 'Foo hint'
return forest
def one_window_two_nodes_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
window = forest.windows.add()
node = window.tree.nodes.add()
node.class_name = 'bar'
node.is_clickable = True
node.hint_text = 'Bar hint'
node = window.tree.nodes.add()
node.class_name = 'bar'
node.is_clickable = False
node.hint_text = 'Bar hint 2'
return forest
def three_windows_forest() -> (
android_accessibility_forest_pb2.AndroidAccessibilityForest
):
forest = android_accessibility_forest_pb2.AndroidAccessibilityForest()
_ = forest.windows.add()
window = forest.windows.add()
node = window.tree.nodes.add()
node.class_name = 'foo'
node.is_clickable = True
node.hint_text = 'hint'
window = forest.windows.add()
node = window.tree.nodes.add()
node.class_name = 'baz'
node.is_clickable = True
node.hint_text = 'hint'
node = window.tree.nodes.add()
node.class_name = 'foobar'
node.is_clickable = False
node.hint_text = 'hint'
return forest
def empty_dict() -> dict[str, str]:
return {}
def single_item_dict() -> dict[str, str]:
return {'foo': 'bar'}
def several_long_items_dict() -> dict[str, str]:
return {
'first_key': 'Lorem ipsum ' * 100,
'second_key': 'the beginning is the end is' * 100,
}
def single_item_dict_with_special_chars() -> dict[str, str]:
return {'foo': 'bar\r\t\nbaz'}
def _ok_response():
return adb_pb2.AdbResponse(status=adb_pb2.AdbResponse.Status.OK)
class A11yGrpcWrapperTest(parameterized.TestCase):
def test_server(self):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.task_extras.return_value = {}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
wrapped_env.reset()
channel_creds = grpc.local_channel_credentials()
with grpc.secure_channel(
f'[::]:{wrapped_env.get_port()}', channel_creds
) as channel:
grpc.channel_ready_future(channel).result()
stub = a11y_pb2_grpc.A11yServiceStub(channel)
stub.SendForest(one_window_one_node_forest())
stub.SendForest(one_window_two_nodes_forest())
wrapped_env.step({})
extras = wrapped_env.task_extras(latest_only=False)
self.assertIn('accessibility_tree', extras)
self.assertEqual(extras['accessibility_tree'].shape[0], 2)
# tests of fetch_task_extras:
# exception occurs (ensure attempt to enable networking) and recovers
# exception occurs and enable networking doesn't help
# exception occurs twice but with a forest sent between
@parameterized.named_parameters(
('no_events_or_forests', [], []),
(
'no_events',
[],
[one_window_one_node_forest(), one_window_two_nodes_forest()],
),
('no_forests', [empty_dict(), single_item_dict()], []),
(
'events_and_forests',
[empty_dict(), single_item_dict()],
[one_window_one_node_forest(), one_window_two_nodes_forest()],
),
)
@mock.patch.object(time, 'sleep', autospec=True)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_fetch_task_extras(
self,
received_events,
received_forests,
mock_server,
mock_add_servicer,
mock_sleep,
):
del mock_server, mock_add_servicer, mock_sleep
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.task_extras.return_value = {
'foo': np.array(['bar', 'baz'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
wrapped_env.reset()
for forest in received_forests:
wrapped_env._servicer.SendForest(forest, mock_context)
for event in received_events:
wrapped_env._servicer.SendEvent(
a11y_pb2.EventRequest(event=event), mock_context
)
with mock.patch.object(
wrapped_env, 'attempt_enable_networking'
) as mock_attempt_enable_networking:
extras = wrapped_env._fetch_task_extras()
mock_attempt_enable_networking.assert_not_called()
self.assertIn('foo', extras)
np.testing.assert_array_equal(extras['foo'], ['bar', 'baz'])
self.assertIn('some_key', extras)
np.testing.assert_array_equal(extras['some_key'], ['some_value'])
if received_events:
self.assertIn('full_event', extras)
self.assertLen(extras['full_event'], len(received_events))
for i, event in enumerate(received_events):
event = a11y_pb2.EventRequest(event=event)
np.testing.assert_array_equal(extras['full_event'][i], event)
else:
self.assertNotIn('full_event', extras)
if received_forests:
self.assertIn('accessibility_tree', extras)
self.assertLen(extras['accessibility_tree'], len(received_forests))
for i, forest in enumerate(received_forests):
np.testing.assert_array_equal(extras['accessibility_tree'][i], forest)
else:
self.assertNotIn('accessibility_tree', extras)
@mock.patch.object(time, 'sleep', autospec=True)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_fetch_task_extras_enable_networking(
self,
mock_server,
mock_add_servicer,
mock_sleep,
):
del mock_server, mock_add_servicer, mock_sleep
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
'exception': np.array(['fake exception'], dtype='U'),
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
with mock.patch.object(
wrapped_env, 'attempt_enable_networking'
) as mock_attempt_enable_networking:
extras = wrapped_env._fetch_task_extras()
self.assertNotIn('accessibility_tree', extras)
self.assertNotIn('full_event', extras)
future = wrapped_env._enabling_networking_future
if future is not None:
future.result()
mock_attempt_enable_networking.assert_called_once()
@mock.patch.object(time, 'sleep', autospec=True)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_fetch_task_extras_enable_networking_twice(
self,
mock_server,
mock_add_servicer,
mock_sleep,
):
del mock_server, mock_add_servicer, mock_sleep
mock_context = mock.create_autospec(grpc.ServicerContext, instance=True)
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
wrapped_env.reset()
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
'exception': np.array(['fake exception'], dtype='U'),
}
with mock.patch.object(
wrapped_env, 'attempt_enable_networking'
) as mock_attempt_enable_networking:
extras = wrapped_env._fetch_task_extras()
self.assertNotIn('accessibility_tree', extras)
self.assertNotIn('full_event', extras)
future = wrapped_env._enabling_networking_future
if future is not None:
future.result()
mock_attempt_enable_networking.assert_called_once()
# Fixed networking; send a forest so the wrapper knows it worked.
wrapped_env._servicer.SendForest(one_window_one_node_forest(), mock_context)
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
}
extras = wrapped_env._fetch_task_extras()
self.assertIn('accessibility_tree', extras)
self.assertEqual(extras['accessibility_tree'].shape[0], 1)
self.assertEqual(
extras['accessibility_tree'][0], one_window_one_node_forest()
)
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
'exception': np.array(['fake exception'], dtype='U'),
}
with mock.patch.object(
wrapped_env, 'attempt_enable_networking'
) as mock_attempt_enable_networking:
extras = wrapped_env._fetch_task_extras()
self.assertNotIn('accessibility_tree', extras)
self.assertNotIn('full_event', extras)
future = wrapped_env._enabling_networking_future
if future is not None:
future.result()
mock_attempt_enable_networking.assert_called_once()
@mock.patch.object(time, 'sleep', autospec=True)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_task_extras_raises_a11y_info_exception(
self, mock_sleep, mock_add_servicer, mock_server
):
del mock_server, mock_add_servicer, mock_sleep
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
base_env.step.return_value = dm_env.transition(
observation={'dummy': 42}, reward=0.0
)
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env,
add_latest_a11y_info_to_obs=True,
max_enable_networking_attempts=1,
)
wrapped_env.reset()
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
'exception': np.array(['fake exception'], dtype='U'),
}
with mock.patch.object(
wrapped_env, 'attempt_enable_networking'
) as mock_attempt_enable_networking:
extras = wrapped_env._fetch_task_extras()
self.assertNotIn('accessibility_tree', extras)
self.assertNotIn('full_event', extras)
# Wait for the the attempt to finish.
future = wrapped_env._enabling_networking_future
if future is not None:
future.result()
mock_attempt_enable_networking.assert_called_once()
# The _fetch_task_extras() call inside the next step will force a restart
self.assertRaises(
a11y_grpc_wrapper.EnableNetworkingError, wrapped_env.step, {}
)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_configure_grpc(
self,
mock_server,
mock_add_servicer,
):
del mock_server, mock_add_servicer
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.task_extras.return_value = {
'foo': np.array(['bar'], dtype='U'),
'some_key': np.array(['some_value'], dtype='U'),
}
base_env.stats.return_value = {'relaunch_count': 1}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
with mock.patch.object(
wrapped_env, '_configure_grpc'
) as mock_configure_grpc:
wrapped_env.reset()
mock_configure_grpc.assert_called_once()
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_task_extras_raises_before_reset(
self, unused_mock_server, unused_mock_add_servicer
):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
with self.assertRaisesRegex(
RuntimeError,
r'You must call \.reset\(\) before calling \.task_extras\(\)',
):
wrapped_env.task_extras(latest_only=False)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_extras_accumulate_between_steps(
self, mock_server, mock_add_servicer
):
del mock_server, mock_add_servicer
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
base_env.step.return_value = dm_env.transition(
observation={'dummy': 42}, reward=0.0
)
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env, add_latest_a11y_info_to_obs=True
)
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
wrapped_env._fetch_task_extras.return_value = {
'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
}
timestep = wrapped_env.reset()
self.assertIn('a11y_forest', timestep.observation)
self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
wrapped_env._fetch_task_extras.return_value = {
'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
'accessibility_tree': np.array(
one_window_two_nodes_forest(), ndmin=1, dtype=object
),
}
timestep = wrapped_env.step({})
self.assertIn('a11y_forest', timestep.observation)
self.assertEqual(
timestep.observation['a11y_forest'], one_window_two_nodes_forest()
)
timestep = wrapped_env.step({})
self.assertIn('a11y_forest', timestep.observation)
self.assertEqual(
timestep.observation['a11y_forest'], one_window_two_nodes_forest()
)
wrapped_env._fetch_task_extras.return_value = {
'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
}
timestep = wrapped_env.step({})
self.assertIn('a11y_forest', timestep.observation)
self.assertEqual(
timestep.observation['a11y_forest'], one_window_two_nodes_forest()
)
expected_task_extras = {
'full_event': np.array(
[
single_item_dict(),
empty_dict(),
empty_dict(),
single_item_dict(),
],
dtype=object,
),
'accessibility_tree': np.array(
[
empty_forest(),
one_window_two_nodes_forest(),
one_window_two_nodes_forest(),
],
dtype=object,
),
}
expected_task_extras_latest = {
'full_event': np.array([single_item_dict()], dtype=object),
'accessibility_tree': np.array(
[one_window_two_nodes_forest()], dtype=object
),
}
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
task_extras = wrapped_env.task_extras(latest_only=True)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras_latest['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras_latest['accessibility_tree'],
)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_a11y_info_disabled(
self,
unused_mock_server,
unused_mock_add_servicer,
):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.action_spec.return_value = {
'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
base_env.step.return_value = dm_env.transition(
observation={'dummy': 42}, reward=0.0
)
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env, add_latest_a11y_info_to_obs=False, a11y_info_timeout=1.0
)
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
wrapped_env._fetch_task_extras.return_value = {
'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
}
timestep = wrapped_env.reset()
self.assertNotIn('a11y_forest', timestep.observation)
timestep = wrapped_env.step({})
self.assertNotIn('a11y_forest', timestep.observation)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_a11y_info_with_timer_info_present(
self,
unused_mock_server,
unused_mock_add_servicer,
):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.action_spec.return_value = {
'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
base_env.step.return_value = dm_env.transition(
observation={'dummy': 42}, reward=0.0
)
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0
)
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
wrapped_env._fetch_task_extras.side_effect = [{
'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
}]
timestep = wrapped_env.reset()
self.assertIn('a11y_forest', timestep.observation)
self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
@mock.patch.object(time, 'sleep', autospec=True)
def test_a11y_info_with_timer_task_extra_returned(
self, unused_mock_server, unused_mock_add_servicer, unused_mock_sleep
):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.action_spec.return_value = {
'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
base_env.step.return_value = dm_env.transition(
observation={'dummy': 42}, reward=0.0
)
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=1.0
)
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
wrapped_env._fetch_task_extras.side_effect = [
{
'accessibility_tree': np.array(
empty_forest(), ndmin=1, dtype=object
),
},
]
timestep = wrapped_env.reset()
self.assertIn('a11y_forest', timestep.observation)
self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
@mock.patch.object(time, 'sleep', autospec=True)
def test_a11y_info_with_timer_from_action(
self, unused_mock_server, unused_mock_add_servicer, mock_sleep
):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.action_spec.return_value = {
'action_type': dm_env.specs.Array(shape=(), dtype=np.int32)
}
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
base_env.reset.return_value = dm_env.restart(observation={'dummy': 42})
base_env.step.return_value = dm_env.transition(
observation={'dummy': 42}, reward=0.0
)
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env, add_latest_a11y_info_to_obs=True, a11y_info_timeout=0.0
)
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
wrapped_env._fetch_task_extras.side_effect = [
{
'accessibility_tree': np.array(
empty_forest(), ndmin=1, dtype=object
),
},
]
timestep = wrapped_env.step(action={'wait_time': 1.0})
self.assertIn('a11y_forest', timestep.observation)
mock_sleep.assert_called_once()
self.assertEqual(timestep.observation['a11y_forest'], empty_forest())
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_task_extras_same_between_calls(self, mock_server, mock_add_servicer):
del mock_server, mock_add_servicer
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
expected_task_extras = {
'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
}
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
wrapped_env._fetch_task_extras.return_value = expected_task_extras
wrapped_env.reset()
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
expected_task_extras = {
'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
'accessibility_tree': np.array(
one_window_two_nodes_forest(), ndmin=1, dtype=object
),
}
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
wrapped_env._fetch_task_extras.return_value = expected_task_extras
wrapped_env.step({})
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
@mock.patch.object(
a11y_pb2_grpc, 'add_A11yServiceServicer_to_server', autospec=True
)
@mock.patch.object(grpc, 'server', autospec=True)
def test_task_extras_clear_if_called_between_step(
self, mock_server, mock_add_servicer
):
del mock_server, mock_add_servicer
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.stats.return_value = {'relaunch_count': 0}
base_env.execute_adb_call.return_value = _ok_response()
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(base_env)
with mock.patch.object(wrapped_env, '_fetch_task_extras'):
expected_task_extras = {
'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
}
wrapped_env._fetch_task_extras.return_value = expected_task_extras
wrapped_env.reset()
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
expected_task_extras = {
'full_event': np.array(single_item_dict(), ndmin=1, dtype=object),
'accessibility_tree': np.array(empty_forest(), ndmin=1, dtype=object),
}
wrapped_env._fetch_task_extras.return_value = expected_task_extras
wrapped_env.step({})
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
expected_task_extras = {
'full_event': np.array(empty_dict(), ndmin=1, dtype=object),
'accessibility_tree': np.array(
one_window_two_nodes_forest(), ndmin=1, dtype=object
),
}
wrapped_env._fetch_task_extras.return_value = expected_task_extras
wrapped_env.step({})
task_extras = wrapped_env.task_extras(latest_only=False)
np.testing.assert_equal(
task_extras['full_event'], expected_task_extras['full_event']
)
np.testing.assert_equal(
task_extras['accessibility_tree'],
expected_task_extras['accessibility_tree'],
)
@parameterized.named_parameters(
('none_true', False, False, False, 0),
('only_install', True, False, False, 1),
('only_start', False, True, False, 1),
('only_enable_a11y_tree', False, False, True, 1),
('install_and_start_no_a11y_tree', True, True, False, 2),
('install_and_a11y_tree', True, False, True, 2),
('start_and_a11y_tree', False, True, True, 2),
('all_true', True, True, True, 3),
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_apk_install_and_start(
self,
install_a11y_forwarding: bool,
start_a11y_service: bool,
enable_a11y_tree_logs: bool,
expected_adb_calls: int,
unused_mock_sleep,
):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
side_effects = []
if install_a11y_forwarding:
side_effects.append(_ok_response()) # install response
if start_a11y_service:
side_effects.append(_ok_response()) # start service response
if enable_a11y_tree_logs:
side_effects.append(_ok_response()) # enable_tree_request
base_env.execute_adb_call.side_effect = side_effects
_ = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env,
install_a11y_forwarding=install_a11y_forwarding,
start_a11y_service=start_a11y_service,
enable_a11y_tree_info=enable_a11y_tree_logs,
)
self.assertEqual(base_env.execute_adb_call.call_count, expected_adb_calls)
@mock.patch.object(time, 'sleep', autospec=True)
def test_component_and_start(self, unused_mock_sleep):
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
side_effects = []
side_effects.append(_ok_response()) # install response
side_effects.append(_ok_response()) # start service response
side_effects.append(_ok_response()) # enable_tree_request
base_env.execute_adb_call.side_effect = side_effects
_ = a11y_grpc_wrapper.A11yGrpcWrapper(
base_env,
install_a11y_forwarding=True,
start_a11y_service=True,
enable_a11y_tree_info=True,
)
# call_args returns a tuple of which the first member is a tuple containing
# the most recent args the mock was called with, and execute_adb_call only
# has one arg (so [0][0] to access the AdbRequest).
self.assertEqual(
base_env.execute_adb_call.call_args[0][0].send_broadcast.component,
'com.google.androidenv.accessibilityforwarder/com.google.androidenv.accessibilityforwarder.FlagsBroadcastReceiver',
)
def test_broadcast_sent_default_grpc_server_ip(self):
"""Tests that the broadcast sets the default grpc server ip."""
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.execute_adb_call.return_value = _ok_response()
base_env.stats.return_value = {'relaunch_count': 1}
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
env=base_env,
disable_other_network_traffic=False,
install_a11y_forwarding=False,
start_a11y_service=False,
enable_a11y_tree_info=False,
)
wrapped_env.reset()
self.assertStartsWith(
base_env.execute_adb_call.call_args[0][0].send_broadcast.action,
'accessibility_forwarder.intent.action.SET_GRPC',
)
self.assertIn(
'--es "host" 10.0.2.2',
base_env.execute_adb_call.call_args[0][0].send_broadcast.action,
)
@parameterized.parameters(('127.0.0.1',), ('1.2.3.4',), 'localhost')
def test_broadcast_sent_custom_grpc_server_ip(self, grpc_server_ip):
"""Tests that the broadcast sets the custom grpc server ip."""
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.execute_adb_call.return_value = _ok_response()
base_env.stats.return_value = {'relaunch_count': 1}
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
env=base_env,
disable_other_network_traffic=False,
install_a11y_forwarding=False,
start_a11y_service=False,
enable_a11y_tree_info=False,
grpc_server_ip=grpc_server_ip,
)
wrapped_env.reset()
self.assertStartsWith(
base_env.execute_adb_call.call_args[0][0].send_broadcast.action,
'accessibility_forwarder.intent.action.SET_GRPC',
)
self.assertIn(
f'--es "host" {grpc_server_ip}',
base_env.execute_adb_call.call_args[0][0].send_broadcast.action,
)
def test_broadcast_sent_port(self):
"""Tests that the broadcast sets the correct port."""
base_env = mock.create_autospec(
env_interface.AndroidEnvInterface, instance=True
)
base_env.execute_adb_call.return_value = _ok_response()
base_env.stats.return_value = {'relaunch_count': 1}
wrapped_env = a11y_grpc_wrapper.A11yGrpcWrapper(
env=base_env,
disable_other_network_traffic=False,
install_a11y_forwarding=False,
start_a11y_service=False,
enable_a11y_tree_info=False,
)
wrapped_env.reset()
self.assertStartsWith(
base_env.execute_adb_call.call_args[0][0].send_broadcast.action,
'accessibility_forwarder.intent.action.SET_GRPC',
)
self.assertIn(
f'--ei "port" {wrapped_env.get_port()}',
base_env.execute_adb_call.call_args[0][0].send_broadcast.action,
)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/base_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base class for AndroidEnv wrappers."""
from typing import Any
from absl import logging
from android_env import env_interface
from android_env.proto import adb_pb2
from android_env.proto import state_pb2
import dm_env
from dm_env import specs
import numpy as np
class BaseWrapper(env_interface.AndroidEnvInterface):
"""AndroidEnv wrapper."""
def __init__(self, env: env_interface.AndroidEnvInterface) -> None:
self._env = env
logging.info('Wrapping with %s', self.__class__.__name__)
def reset(self) -> dm_env.TimeStep:
self._reset_state()
timestep = self._process_timestep(self._env.reset())
return timestep
def step(self, action: Any) -> dm_env.TimeStep:
action = self._process_action(action)
return self._process_timestep(self._env.step(action))
def task_extras(self, latest_only: bool = True) -> dict[str, np.ndarray]:
return self._env.task_extras(latest_only=latest_only)
def _reset_state(self):
pass
def _process_action(self, action: Any) -> Any:
return action
def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
return timestep
def observation_spec(self) -> dict[str, specs.Array]:
return self._env.observation_spec()
def action_spec(self) -> dict[str, specs.Array]:
return self._env.action_spec()
def reward_spec(self) -> specs.Array:
return self._env.reward_spec()
def discount_spec(self) -> specs.Array:
return self._env.discount_spec()
def _wrapper_stats(self) -> dict[str, Any]:
"""Add wrapper specific logging here."""
return {}
def stats(self) -> dict[str, Any]:
info = self._env.stats()
info.update(self._wrapper_stats())
return info
def load_state(
self, request: state_pb2.LoadStateRequest
) -> state_pb2.LoadStateResponse:
"""Loads a state."""
return self._env.load_state(request)
def save_state(
self, request: state_pb2.SaveStateRequest
) -> state_pb2.SaveStateResponse:
"""Saves a state.
Args:
request: A `SaveStateRequest` containing any parameters necessary to
specify how/what state to save.
Returns:
A `SaveStateResponse` containing the status, error message (if
applicable), and any other relevant information.
"""
return self._env.save_state(request)
def execute_adb_call(
self, adb_call: adb_pb2.AdbRequest
) -> adb_pb2.AdbResponse:
return self._env.execute_adb_call(adb_call)
@property
def raw_action(self) -> Any:
return self._env.raw_action
@property
def raw_observation(self) -> Any:
return self._env.raw_observation
@property
def raw_env(self) -> env_interface.AndroidEnvInterface:
"""Recursively unwrap until we reach the true 'raw' env."""
wrapped = self._env
if hasattr(wrapped, 'raw_env'):
return wrapped.raw_env
return wrapped
def __getattr__(self, attr) -> Any:
"""Delegate attribute access to underlying environment."""
return getattr(self._env, attr)
def close(self) -> None:
self._env.close()
================================================
FILE: android_env/wrappers/base_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.wrappers.base_wrapper."""
from unittest import mock
from absl import logging
from absl.testing import absltest
from android_env import env_interface
from android_env.proto import state_pb2
from android_env.wrappers import base_wrapper
class BaseWrapperTest(absltest.TestCase):
@mock.patch.object(logging, 'info')
def test_base_function_forwarding(self, mock_info):
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
wrapped_env = base_wrapper.BaseWrapper(base_env)
mock_info.assert_called_with('Wrapping with %s', 'BaseWrapper')
fake_ts = 'fake_ts'
base_env.reset.return_value = fake_ts
self.assertEqual(fake_ts, wrapped_env.reset())
base_env.reset.assert_called_once()
fake_ts = 'fake_ts'
fake_action = 'fake_action'
base_env.step.return_value = fake_ts
self.assertEqual(fake_ts, wrapped_env.step(fake_action))
base_env.step.assert_called_once_with(fake_action)
fake_extras = 'fake_task_extras'
base_env.task_extras.return_value = fake_extras
self.assertEqual(fake_extras, wrapped_env.task_extras(latest_only=True))
base_env.task_extras.assert_called_once_with(latest_only=True)
fake_obs_spec = 'fake_obs_spec'
base_env.observation_spec.return_value = fake_obs_spec
self.assertEqual(fake_obs_spec, wrapped_env.observation_spec())
base_env.observation_spec.assert_called_once()
fake_action_spec = 'fake_action_spec'
base_env.action_spec.return_value = fake_action_spec
self.assertEqual(fake_action_spec, wrapped_env.action_spec())
base_env.action_spec.assert_called_once()
fake_raw_action = 'fake_raw_action'
type(base_env).raw_action = mock.PropertyMock(return_value=fake_raw_action)
self.assertEqual(fake_raw_action, wrapped_env.raw_action)
fake_raw_observation = 'fake_raw_observation'
type(base_env).raw_observation = mock.PropertyMock(
return_value=fake_raw_observation)
self.assertEqual(fake_raw_observation, wrapped_env.raw_observation)
load_request = state_pb2.LoadStateRequest(args={})
expected_response = state_pb2.LoadStateResponse(
status=state_pb2.LoadStateResponse.Status.OK
)
base_env.load_state.return_value = expected_response
self.assertEqual(wrapped_env.load_state(load_request), expected_response)
base_env.load_state.assert_called_once_with(load_request)
save_request = state_pb2.SaveStateRequest(args={})
expected_response = state_pb2.SaveStateResponse(
status=state_pb2.SaveStateResponse.Status.OK
)
base_env.save_state.return_value = expected_response
self.assertEqual(wrapped_env.save_state(save_request), expected_response)
base_env.save_state.assert_called_once_with(save_request)
wrapped_env.close()
base_env.close.assert_called_once()
fake_return_value = 'fake'
# AndroidEnv::some_random_function() does not exist and calling it should
# raise an AttributeError.
with self.assertRaises(AttributeError):
base_env.some_random_function.return_value = fake_return_value
def test_multiple_wrappers(self):
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
wrapped_env_1 = base_wrapper.BaseWrapper(base_env)
wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1)
wrapped_env_2.close()
base_env.close.assert_called_once()
def test_raw_env(self):
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
wrapped_env_1 = base_wrapper.BaseWrapper(base_env)
wrapped_env_2 = base_wrapper.BaseWrapper(wrapped_env_1)
self.assertEqual(base_env, wrapped_env_2.raw_env)
def test_stats(self):
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
wrapped_env = base_wrapper.BaseWrapper(base_env)
base_stats = {'base': 'stats'}
base_env.stats.return_value = base_stats
self.assertEqual(base_stats, wrapped_env.stats())
@mock.patch.object(logging, 'info')
def test_wrapped_stats(self, mock_info):
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
class LoggingWrapper1(base_wrapper.BaseWrapper):
def _wrapper_stats(self):
return {
'wrapper1': 'stats',
'shared': 1,
}
class LoggingWrapper2(base_wrapper.BaseWrapper):
def _wrapper_stats(self):
return {
'wrapper2': 'stats',
'shared': 2,
}
wrapped_env = LoggingWrapper2(LoggingWrapper1(base_env))
mock_info.assert_has_calls([
mock.call('Wrapping with %s', 'LoggingWrapper1'),
mock.call('Wrapping with %s', 'LoggingWrapper2'),
])
base_stats = {'base': 'stats'}
base_env.stats.return_value = base_stats
expected_stats = {
'base': 'stats',
'wrapper1': 'stats',
'wrapper2': 'stats',
'shared': 2,
}
self.assertEqual(expected_stats, wrapped_env.stats())
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/discrete_action_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wraps the AndroidEnv environment to provide discrete actions."""
from collections.abc import Sequence
from typing import cast
from android_env import env_interface
from android_env.components import action_type
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
import numpy as np
_NOISE_CLIP_VALUE = 0.4999
class DiscreteActionWrapper(base_wrapper.BaseWrapper):
"""AndroidEnv with discrete actions."""
def __init__(
self,
env: env_interface.AndroidEnvInterface,
action_grid: Sequence[int] = (10, 10),
redundant_actions: bool = True,
noise: float = 0.1,
) -> None:
super().__init__(env)
self._parent_action_spec = self._env.action_spec()
self._assert_base_env()
self._action_grid = action_grid # [height, width]
self._grid_size = np.prod(self._action_grid)
action_types = cast(
specs.DiscreteArray, self._parent_action_spec['action_type']
)
self._num_action_types = action_types.num_values
self._redundant_actions = redundant_actions
self._noise = noise
def _assert_base_env(self) -> None:
"""Checks that the wrapped env has the right action spec format."""
assert len(self._parent_action_spec) == 2
assert not self._parent_action_spec['action_type'].shape
assert self._parent_action_spec['touch_position'].shape == (2,)
@property
def num_actions(self) -> int:
"""Number of discrete actions."""
if self._redundant_actions:
return self._grid_size * self._num_action_types
else:
return self._grid_size + self._num_action_types - 1
def step(self, action: dict[str, int]) -> dm_env.TimeStep:
"""Take a step in the base environment."""
return self._env.step(self._process_action(action))
def _process_action(self, action: dict[str, int]) -> dict[str, np.ndarray]:
"""Transforms action so that it agrees with AndroidEnv's action spec."""
return {
'action_type':
np.array(self._get_action_type(action['action_id']),
dtype=self._parent_action_spec['action_type'].dtype),
'touch_position':
np.array(self._get_touch_position(action['action_id']),
dtype=self._parent_action_spec['touch_position'].dtype)
}
def _get_action_type(self, action_id: int) -> action_type.ActionType:
"""Compute action type corresponding to the given action_id.
When `self._redundant_actions` == True the `grid_size` is "broadcast" over
all the possible actions so you end up with `grid_size` discrete actions
of type 0, `grid_size` discrete actions of type 1, etc. for all action
types.
When `self._redundant_actions` == False the first `grid_size` actions are
reserved for "touch" and the rest are just added (NOT multiplied) to the
total number of discrete actions (exactly one of LIFT and REPEAT).
Args:
action_id: A discrete action.
Returns:
action_type: The action_type of the action.
"""
if self._redundant_actions:
assert action_id < self._num_action_types * self._grid_size
return action_id // self._grid_size
else:
assert action_id <= self._grid_size + 1
if action_id < self._grid_size:
return action_type.ActionType.TOUCH
elif action_id == self._grid_size:
return action_type.ActionType.LIFT
else:
return action_type.ActionType.REPEAT
def _get_touch_position(self, action_id: int) -> Sequence[float]:
"""Compute the position corresponding to the given action_id.
Note: in the touch_position (x, y) of an action, x corresponds to the
horizontal axis (width), and y corresponds to the vertical axis (height)
of the screen. BUT, the screen has dimensions (height, width), i.e. the
first coordinate corresponds to y, and the second coordinate corresponds
to x. Pay attention to this mismatch in the calculations below.
Args:
action_id: A discrete action.
Returns:
touch_position: The [0,1]x[0,1] coordinate of the action.
"""
position_idx = action_id % self._grid_size
x_pos_grid = position_idx % self._action_grid[1] # WIDTH
y_pos_grid = position_idx // self._action_grid[1] # HEIGHT
noise_x = np.random.normal(loc=0.0, scale=self._noise)
noise_y = np.random.normal(loc=0.0, scale=self._noise)
# Noise is clipped so that the action will strictly stay in the cell.
noise_x = max(min(noise_x, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
noise_y = max(min(noise_y, _NOISE_CLIP_VALUE), -_NOISE_CLIP_VALUE)
x_pos = (x_pos_grid + 0.5 + noise_x) / self._action_grid[1] # WIDTH
y_pos = (y_pos_grid + 0.5 + noise_y) / self._action_grid[0] # HEIGHT
# Project action space to action_spec ranges. For the default case of
# minimum = [0, 0] and maximum = [1, 1], this will not do anything.
x_min, y_min = cast(
specs.BoundedArray, self._parent_action_spec['touch_position']
).minimum
x_max, y_max = cast(
specs.BoundedArray, self._parent_action_spec['touch_position']
).maximum
x_pos = x_min + x_pos * (x_max - x_min)
y_pos = y_min + y_pos * (y_max - y_min)
return [x_pos, y_pos]
def action_spec(self) -> dict[str, specs.Array]:
"""Action spec of the wrapped environment."""
return {
'action_id':
specs.DiscreteArray(
num_values=self.num_actions,
name='action_id')
}
================================================
FILE: android_env/wrappers/discrete_action_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.wrappers.discrete_action_wrapper."""
from unittest import mock
from absl.testing import absltest
from android_env import env_interface
from android_env.components import action_type as action_type_lib
from android_env.wrappers import discrete_action_wrapper
from dm_env import specs
import numpy as np
ActionType = action_type_lib.ActionType
def _make_array_spec(shape, dtype, name):
assert len(shape) == 1
return specs.BoundedArray(
name=name,
shape=shape,
dtype=dtype,
minimum=np.zeros(shape),
maximum=(shape[0] - 1) * np.ones(shape), # maximum is inclusive.
)
def _valid_shape(action):
assert len(action) == 2, action
assert not action['action_type'].shape, (
'action: %r, shape: %r' %
(action['action_type'], action['action_type'].shape))
assert action['touch_position'].shape == (
2,), ('action: %r, shape: %r' %
(action['touch_position'], action['touch_position'].shape))
def _valid_types(action, types):
for a, t in zip(action.values(), types):
assert a.dtype == t, '%r is not of dtype %r' % (a, t)
class DiscreteActionWrapperTest(absltest.TestCase):
def setUp(self):
super().setUp()
self._num_action_types = 3 # Only TOUCH, LIFT, REPEAT.
self._base_action_spec = {
'action_type': specs.DiscreteArray(
num_values=self._num_action_types, name='action_type'),
'touch_position': _make_array_spec(
shape=(2,), dtype=np.float32, name='touch_position'),
}
self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
self.base_env.action_spec.return_value = self._base_action_spec
def test_num_actions(self):
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env, action_grid=(3, 3), redundant_actions=True)
# 27 = 3 * 3 * 2 (H * W * self._num_action_types).
self.assertEqual(27, wrapped_env.num_actions)
def test_num_actions_non_redundant(self):
# Check that with `redundant_actions`==False we get an additive term instead
# of a multiplier in the number of actions.
non_redudant_wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env, action_grid=(3, 3), redundant_actions=False)
# 11 = 3 * 3 + 2 (H * W + (self._num_action_types - 1)).
self.assertEqual(11, non_redudant_wrapped_env.num_actions)
def test_reset(self):
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env, redundant_actions=True)
fake_timestep = 'ts'
self.base_env.reset.return_value = fake_timestep
ts = wrapped_env.reset()
self.base_env.reset.assert_called_once()
self.assertEqual(fake_timestep, ts)
def test_step_no_noise(self):
height = 4
width = 3
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env,
action_grid=(height, width),
noise=0.0,
redundant_actions=True)
self.assertEqual(height * width * self._num_action_types,
wrapped_env.num_actions)
vertical_half_step = 1. / float(height) / 2.
horizontal_half_step = 1. / float(width) / 2.
delta = 0.0001
# Testing the four corners with each finger position
def get_verifier(expected_action_type, lower_x, lower_y):
def verifier(x):
_valid_shape(x)
_valid_types(x, [np.int32, np.float32])
self.assertEqual(
expected_action_type, x['action_type'])
if lower_y:
self.assertAlmostEqual(
vertical_half_step, x['touch_position'][1], delta=delta)
else:
self.assertAlmostEqual(
1 - vertical_half_step, x['touch_position'][1], delta=delta)
if lower_x:
self.assertAlmostEqual(
horizontal_half_step, x['touch_position'][0], delta=delta)
else:
self.assertAlmostEqual(
1 - horizontal_half_step, x['touch_position'][0], delta=delta)
return True
return verifier
action_tests = {
0: get_verifier(0, lower_x=True, lower_y=True),
2: get_verifier(0, lower_x=False, lower_y=True),
9: get_verifier(0, lower_x=True, lower_y=False),
11: get_verifier(0, lower_x=False, lower_y=False),
12: get_verifier(1, lower_x=True, lower_y=True),
14: get_verifier(1, lower_x=False, lower_y=True),
21: get_verifier(1, lower_x=True, lower_y=False),
23: get_verifier(1, lower_x=False, lower_y=False),
24: get_verifier(2, lower_x=True, lower_y=True),
26: get_verifier(2, lower_x=False, lower_y=True),
33: get_verifier(2, lower_x=True, lower_y=False),
35: get_verifier(2, lower_x=False, lower_y=False),
}
fake_timestep = 'ts'
self.base_env.step.return_value = fake_timestep
for action_id, verifier in action_tests.items():
ts = wrapped_env.step({'action_id': action_id})
verifier(self.base_env.step.call_args[0][0])
self.assertEqual(fake_timestep, ts)
def test_step_redundant_actions_invalid_action_id(self):
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env,
action_grid=(4, 3),
noise=0.0,
redundant_actions=True)
with self.assertRaises(AssertionError):
_ = wrapped_env.step({'action_id': 36})
def test_step_no_noise_no_redudant_actions(self):
height = 4
width = 3
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env,
action_grid=(height, width),
noise=0.0,
redundant_actions=False)
self.assertEqual(height * width + (self._num_action_types - 1),
wrapped_env.num_actions)
vertical_half_step = 1. / float(height) / 2.
horizontal_half_step = 1. / float(width) / 2.
delta = 0.0001
# Testing the four corners with each finger position
def get_verifier(expected_action_type, lower_x, lower_y):
def verifier(x):
_valid_shape(x)
_valid_types(x, [np.int32, np.float32])
self.assertEqual(expected_action_type, x['action_type'])
# If the action type == TOUCH, then check the coordinate values.
if x['action_type'] == ActionType.TOUCH:
if lower_y:
self.assertAlmostEqual(
vertical_half_step, x['touch_position'][1], delta=delta)
else:
self.assertAlmostEqual(
1 - vertical_half_step, x['touch_position'][1], delta=delta)
if lower_x:
self.assertAlmostEqual(
horizontal_half_step, x['touch_position'][0], delta=delta)
else:
self.assertAlmostEqual(
1 - horizontal_half_step, x['touch_position'][0], delta=delta)
return True
return verifier
action_tests = {
# Touch type actions
0: get_verifier(0, lower_x=True, lower_y=True),
2: get_verifier(0, lower_x=False, lower_y=True),
9: get_verifier(0, lower_x=True, lower_y=False),
11: get_verifier(0, lower_x=False, lower_y=False),
# Actions > grid_size return non-touch actions with (0,0) coordinates.
12: get_verifier(1, lower_x=False, lower_y=False),
13: get_verifier(2, lower_x=False, lower_y=False),
}
fake_timestep = 'ts'
self.base_env.step.return_value = fake_timestep
for action_id, verifier in action_tests.items():
ts = wrapped_env.step({'action_id': action_id})
verifier(self.base_env.step.call_args[0][0])
self.assertEqual(fake_timestep, ts)
def test_step_no_redundant_actions_invalid_action_id(self):
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env,
action_grid=(4, 3),
noise=0.0,
redundant_actions=False)
with self.assertRaises(AssertionError):
_ = wrapped_env.step({'action_id': 14})
def test_step_with_noise(self):
height = 4
width = 3
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env, action_grid=(height, width), noise=1.0)
self.assertEqual(height * width * self._num_action_types,
wrapped_env.num_actions)
vertical_grid_step = 1. / float(height)
horizontal_grid_step = 1. / float(width)
# Testing the four corners with each finger position
def get_verifier(expected_up_down, lower_x, lower_y):
def verifier(x):
_valid_shape(x)
_valid_types(x, [np.int32, np.float32])
self.assertEqual(expected_up_down, x['action_type'])
if lower_y:
self.assertGreater(vertical_grid_step, x['touch_position'][1])
else:
self.assertLess(1 - vertical_grid_step, x['touch_position'][1])
if lower_x:
self.assertGreater(horizontal_grid_step, x['touch_position'][0])
else:
self.assertLess(1 - horizontal_grid_step, x['touch_position'][0])
return True
return verifier
action_tests = {
0: get_verifier(0, lower_x=True, lower_y=True),
2: get_verifier(0, lower_x=False, lower_y=True),
9: get_verifier(0, lower_x=True, lower_y=False),
11: get_verifier(0, lower_x=False, lower_y=False),
12: get_verifier(1, lower_x=True, lower_y=True),
14: get_verifier(1, lower_x=False, lower_y=True),
21: get_verifier(1, lower_x=True, lower_y=False),
23: get_verifier(1, lower_x=False, lower_y=False),
24: get_verifier(2, lower_x=True, lower_y=True),
26: get_verifier(2, lower_x=False, lower_y=True),
33: get_verifier(2, lower_x=True, lower_y=False),
35: get_verifier(2, lower_x=False, lower_y=False),
}
fake_timestep = 'ts'
self.base_env.step.return_value = fake_timestep
for action_id, verifier in action_tests.items():
ts = wrapped_env.step({'action_id': action_id})
verifier(self.base_env.step.call_args[0][0])
self.assertEqual(fake_timestep, ts)
def test_parent_spec_type(self):
base_action_spec = {
'action_type': specs.DiscreteArray(
num_values=self._num_action_types, name='action_type'),
'touch_position': _make_array_spec(
shape=(2,), dtype=np.float64, name='touch_position'),
}
base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
base_env.action_spec.return_value = base_action_spec
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
base_env, noise=0.0)
fake_timestep = 'ts'
base_env.step.return_value = fake_timestep
def verifier(x):
_valid_types(x, [np.int32, np.float64])
return True
ts = wrapped_env.step({'action_id': 1})
verifier(base_env.step.call_args[0][0])
self.assertEqual(fake_timestep, ts)
def test_observation_spec(self):
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env)
fake_obs_spec = 'fake_obs_spec'
self.base_env.observation_spec.return_value = fake_obs_spec
observation_spec = wrapped_env.observation_spec()
self.base_env.observation_spec.assert_called_once()
self.assertEqual(fake_obs_spec, observation_spec)
def test_action_spec(self):
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env, action_grid=(4, 5), redundant_actions=True)
expected_action_spec = {
'action_id':
specs.DiscreteArray(
num_values=4 * 5 * self._num_action_types, name='action_type')
}
self.assertEqual(expected_action_spec, wrapped_env.action_spec())
def test_action_spec_non_redundant(self):
wrapped_env = discrete_action_wrapper.DiscreteActionWrapper(
self.base_env, action_grid=(4, 5), redundant_actions=False)
num_non_touch_actions = self._num_action_types - 1
expected_action_spec = {
'action_id':
specs.DiscreteArray(
num_values=4 * 5 + num_non_touch_actions, name='action_type')
}
self.assertEqual(expected_action_spec, wrapped_env.action_spec())
def test_assert_base_env_action_spec_too_short(self):
self.base_env.action_spec.return_value = {
'action_type': specs.DiscreteArray(
num_values=self._num_action_types, name='action_type'),
}
with self.assertRaises(AssertionError):
_ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
def test_assert_base_env_action_spec_too_long(self):
self.base_env.action_spec.return_value = {
'action_type': specs.DiscreteArray(
num_values=self._num_action_types, name='action_type'),
'touch_position': _make_array_spec(
shape=(2,), dtype=np.float32, name='touch_position'),
'too_long': _make_array_spec(
shape=(1,), dtype=np.float32, name='too_long'),
}
with self.assertRaises(AssertionError):
_ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
def test_assert_base_env_action_spec_wrong_shapes(self):
self.base_env.action_spec.return_value = {
'action_type': _make_array_spec(
shape=(2,), dtype=np.float32, name='action_type'),
'touch_position': _make_array_spec(
shape=(1,), dtype=np.float32, name='touch_position')
}
with self.assertRaises(AssertionError):
_ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
def test_assert_base_env_ok(self):
self.base_env.action_spec.return_value = {
'action_type': specs.DiscreteArray(
num_values=self._num_action_types, name='action_type'),
'touch_position': _make_array_spec(
shape=(2,), dtype=np.float32, name='touch_position'),
}
_ = discrete_action_wrapper.DiscreteActionWrapper(self.base_env)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/flat_interface_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wraps the AndroidEnv environment to make its interface flat."""
from typing import Any, cast
from android_env import env_interface
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
import numpy as np
RGB_CHANNELS = (0, 1, 2)
def _extract_screen_pixels(obs: np.ndarray) -> np.ndarray:
"""Get only screen pixels by removing previous action layer."""
is_grayscale_image = obs.shape[-1] == 2
if is_grayscale_image:
return np.expand_dims(obs[..., 0], -1)
return obs[..., RGB_CHANNELS]
def _get_no_action_observation_spec(
obs_spec: specs.BoundedArray,
) -> specs.BoundedArray:
"""Create an observation spec without the action layer."""
shape = np.array(obs_spec.shape)
shape[2] -= 1
minimum = obs_spec.minimum
maximum = obs_spec.maximum
is_scalar = lambda x: np.isscalar(x) or np.ndim(x) == 0
if not is_scalar(minimum):
minimum = _extract_screen_pixels(minimum)
if not is_scalar(maximum):
maximum = _extract_screen_pixels(maximum)
return obs_spec.replace(shape=shape, minimum=minimum, maximum=maximum)
class FlatInterfaceWrapper(base_wrapper.BaseWrapper):
"""Simple interface for AndroidEnv.
Removes the structure from observations and actions, keeping only the pixel
observations. Also exposes action as an int32 scalar, making it easier to use
with conventional discrete agents. This wrapper expects a discretized action
space.
"""
def __init__(
self,
env: env_interface.AndroidEnvInterface,
flat_actions: bool = True,
flat_observations: bool = True,
keep_action_layer: bool = True,
) -> None:
super().__init__(env)
self._flat_actions = flat_actions
self._flat_observations = flat_observations
self._keep_action_layer = keep_action_layer
self._action_name = list(self._env.action_spec())[0]
self._assert_base_env()
def _assert_base_env(self) -> None:
base_action_spec = self._env.action_spec()
assert len(base_action_spec) == 1, self._env.action_spec()
assert isinstance(base_action_spec, dict)
assert isinstance(base_action_spec[self._action_name], specs.BoundedArray)
def _process_action(
self, action: int | np.ndarray | dict[str, Any]
) -> int | np.ndarray | dict[str, Any]:
if self._flat_actions:
return {self._action_name: action}
else:
return action
def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
if self._flat_observations:
step_type, reward, discount, observation = timestep
# Keep only the pixels.
pixels = observation['pixels']
pixels = (
pixels if self._keep_action_layer else _extract_screen_pixels(pixels)
)
return dm_env.TimeStep(
step_type=step_type,
reward=reward,
discount=discount,
observation=pixels,
)
else:
return timestep
def reset(self) -> dm_env.TimeStep:
timestep = self._env.reset()
return self._process_timestep(timestep)
def step(self, action: int) -> dm_env.TimeStep:
timestep = self._env.step(self._process_action(action))
return self._process_timestep(timestep)
def observation_spec(self) -> specs.Array | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
if self._flat_observations:
pixels_spec = cast(
specs.BoundedArray, self._env.observation_spec()['pixels']
)
if not self._keep_action_layer:
return _get_no_action_observation_spec(pixels_spec)
return pixels_spec
else:
return self._env.observation_spec()
def action_spec(self) -> specs.BoundedArray | dict[str, specs.Array]: # pytype: disable=signature-mismatch # overriding-return-type-checks
if self._flat_actions:
return self._env.action_spec()[self._action_name] # pytype: disable=bad-return-type
else:
return self._env.action_spec()
================================================
FILE: android_env/wrappers/flat_interface_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.wrappers.flat_interface_wrapper."""
from typing import cast
from unittest import mock
from absl.testing import absltest
from android_env.wrappers import flat_interface_wrapper
import dm_env
from dm_env import specs
import numpy as np
def _make_array_spec(shape, dtype=np.float32, name=None, maximum=3, minimum=0):
return specs.BoundedArray(
shape=shape,
dtype=dtype,
name=name,
maximum=np.ones(shape) * maximum,
minimum=np.ones(shape) * minimum)
def _make_timestep(observation):
return dm_env.TimeStep(
step_type='fake_step_type',
reward='fake_reward',
discount='fake_discount',
observation=observation,
)
class FlatInterfaceWrapperTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.action_shape = (1,)
self.base_action_spec: dict[str, specs.DiscreteArray] = {
'action_id': specs.DiscreteArray(name='action_id', num_values=4)
}
self.int_obs_shape = (3, 4, 2)
self.float_obs_shape = (2,)
self.base_observation_spec = {
'pixels': _make_array_spec(
shape=self.int_obs_shape, dtype=np.uint8, name='pixels'),
'obs1': _make_array_spec(
shape=self.float_obs_shape, dtype=np.float32, name='obs1'),
}
# Expected.
self.expected_observation_spec = _make_array_spec(
shape=self.int_obs_shape, dtype=np.uint8, name='pixels')
self.image_obs = np.ones(self.int_obs_shape, dtype=np.uint8)
self.expected_timestep = _make_timestep(self.image_obs)
# Expected for no new action layer shape.
expected_new_shape_no_action_layer = (3, 4, 1)
self.expected_observation_spec_no_action_layer = _make_array_spec(
shape=expected_new_shape_no_action_layer, dtype=np.uint8, name='pixels')
self.expected_timestep_no_action_layer = _make_timestep(
np.ones(expected_new_shape_no_action_layer, dtype=np.uint8))
# Base environment.
self.other_obs = np.ones(self.float_obs_shape, dtype=np.float32)
self.base_timestep = _make_timestep({
'pixels': self.image_obs,
'obs1': self.other_obs})
self.base_env = mock.create_autospec(dm_env.Environment)
self.base_env.action_spec.return_value = self.base_action_spec
self.base_env.observation_spec.return_value = self.base_observation_spec
self.base_env.reset.return_value = self.base_timestep
self.base_env.step.return_value = self.base_timestep
def test_reset(self):
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
ts = wrapped_env.reset()
self.base_env.reset.assert_called_once()
self.assertEqual(self.expected_timestep, ts)
def test_reset_no_action_layer(self):
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(
self.base_env, keep_action_layer=False)
ts = wrapped_env.reset()
self.base_env.reset.assert_called_once()
self.assertEqual(
self.expected_timestep_no_action_layer.observation.tolist(),
ts.observation.tolist())
def test_step(self):
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
action = 2
ts = wrapped_env.step(action)
def verifier(x):
self.assertIsInstance(x, dict)
self.assertIsInstance(x['action_id'], int)
self.assertEqual(x['action_id'], action)
return True
verifier(self.base_env.step.call_args[0][0])
self.assertEqual(self.expected_timestep, ts)
def test_step_no_action_layer(self):
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(
self.base_env, keep_action_layer=False)
action = 2
ts = wrapped_env.step(action)
def verifier(x):
self.assertIsInstance(x, dict)
self.assertIsInstance(x['action_id'], int)
self.assertEqual(x['action_id'], action)
return True
verifier(self.base_env.step.call_args[0][0])
self.assertEqual(
self.expected_timestep_no_action_layer.observation.tolist(),
ts.observation.tolist())
def test_observation_spec(self):
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
observation_spec = wrapped_env.observation_spec()
self.base_env.observation_spec.assert_called_once()
self.assertEqual(self.expected_observation_spec, observation_spec)
def test_observation_spec_no_action_layer(self):
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(
self.base_env, keep_action_layer=False)
observation_spec = wrapped_env.observation_spec()
self.base_env.observation_spec.assert_called_once()
self.assertEqual(self.expected_observation_spec_no_action_layer,
observation_spec)
def test_action_spec(self):
wrapped_env = flat_interface_wrapper.FlatInterfaceWrapper(self.base_env)
action_spec = cast(specs.BoundedArray, wrapped_env.action_spec())
parent_action_spec = self.base_action_spec['action_id']
self.assertEqual(parent_action_spec.name, action_spec.name)
self.assertEqual((), action_spec.shape)
self.assertEqual(np.int32, action_spec.dtype)
self.assertEqual(0, action_spec.minimum)
def test_bad_action_spec_structured_action(self):
bad_base_env = mock.create_autospec(dm_env.Environment)
bad_base_env.action_spec.return_value = {
'action_id': _make_array_spec((1,)),
'too_many': _make_array_spec((1,))
}
with self.assertRaises(AssertionError):
_ = flat_interface_wrapper.FlatInterfaceWrapper(bad_base_env)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/float_pixels_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converts pixel observation to from int to float32 between 0.0 and 1.0."""
from android_env import env_interface
from android_env.components import pixel_fns
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
import numpy as np
class FloatPixelsWrapper(base_wrapper.BaseWrapper):
"""Wraps AndroidEnv for Panultimate agent."""
def __init__(self, env: env_interface.AndroidEnvInterface) -> None:
super().__init__(env)
self._input_spec = self._env.observation_spec()['pixels']
self._should_convert_int_to_float = np.issubdtype(self._input_spec.dtype,
np.integer)
def _process_observation(
self, observation: dict[str, np.ndarray]
) -> dict[str, np.ndarray]:
if self._should_convert_int_to_float:
float_pixels = pixel_fns.convert_int_to_float(
observation['pixels'], self._input_spec
)
observation['pixels'] = float_pixels
return observation
def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
step_type, reward, discount, observation = timestep
return dm_env.TimeStep(
step_type=step_type,
reward=reward,
discount=discount,
observation=self._process_observation(observation))
def reset(self) -> dm_env.TimeStep:
return self._process_timestep(self._env.reset())
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
return self._process_timestep(self._env.step(action))
def observation_spec(self) -> dict[str, specs.Array]:
if self._should_convert_int_to_float:
observation_spec = self._env.observation_spec()
observation_spec['pixels'] = specs.BoundedArray(
shape=self._env.observation_spec()['pixels'].shape,
dtype=np.float32,
minimum=0.0,
maximum=1.0,
name=self._env.observation_spec()['pixels'].name)
return observation_spec
return self._env.observation_spec()
================================================
FILE: android_env/wrappers/float_pixels_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.wrappers.float_pixels_wrapper."""
from unittest import mock
from absl.testing import absltest
from android_env.wrappers import float_pixels_wrapper
import dm_env
from dm_env import specs
import numpy as np
def _make_array_spec(shape, dtype=np.float32, name=None):
return specs.Array(
shape=shape,
dtype=dtype,
name=name,
)
def _make_bounded_array_spec(
shape, dtype=np.float32, name=None, maximum=1.0, minimum=0.0):
return specs.BoundedArray(
shape=shape,
dtype=dtype,
name=name,
maximum=maximum,
minimum=minimum,
)
def _simple_timestep(obs_shape, obs_type):
return dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=3.14,
discount=0.9,
observation=(np.ones(shape=obs_shape, dtype=obs_type),),
)
class FloatPixelsWrapperTest(absltest.TestCase):
def setUp(self):
super().setUp()
self.pixels_shape = (3, 4)
base_pixel_spec = _make_array_spec(
shape=self.pixels_shape, dtype=np.uint8, name='pixels')
self.other_obs_spec = _make_array_spec(
shape=(1,), dtype=np.float32, name='other_obs')
base_observation_spec = {
'pixels': base_pixel_spec,
'other_obs': self.other_obs_spec
}
self.base_env = mock.create_autospec(dm_env.Environment)
self.base_env.observation_spec.return_value = base_observation_spec
self.base_timestep = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=3.14,
discount=0.9,
observation={
'pixels': np.ones(shape=self.pixels_shape, dtype=np.uint8),
'other_obs': [42.2]})
self.base_env.step.return_value = self.base_timestep
self.base_env.reset.return_value = self.base_timestep
def test_float_pixels_wrapper_spec(self):
expected_pixel_spec = _make_bounded_array_spec(
shape=self.pixels_shape,
dtype=np.float32,
name='pixels',
minimum=0.0,
maximum=1.0)
wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
self.assertLen(wrapped_env.observation_spec(), 2)
self.assertEqual(expected_pixel_spec,
wrapped_env.observation_spec()['pixels'])
self.assertEqual(self.other_obs_spec,
wrapped_env.observation_spec()['other_obs'])
def test_float_pixels_wrapper_step(self):
wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
self.assertEqual(self.base_timestep.step_type, ts.step_type)
self.assertEqual(self.base_timestep.reward, ts.reward)
self.assertEqual(self.base_timestep.discount, ts.discount)
self.assertEqual(self.base_timestep.observation['other_obs'],
ts.observation['other_obs'])
expected_pixel_value = 1. / 255. # original values are unit8
expected_pixels = np.ones(
self.pixels_shape, dtype=np.float32) * expected_pixel_value
np.testing.assert_equal(expected_pixels, ts.observation['pixels'])
def test_float_pixels_wrapper_reset(self):
wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(self.base_env)
ts = wrapped_env.reset()
self.assertEqual(self.base_timestep.step_type, ts.step_type)
self.assertEqual(self.base_timestep.reward, ts.reward)
self.assertEqual(self.base_timestep.discount, ts.discount)
self.assertEqual(self.base_timestep.observation['other_obs'],
ts.observation['other_obs'])
expected_pixel_value = 1. / 255. # original values are unit8
expected_pixels = np.ones(
self.pixels_shape, dtype=np.float32) * expected_pixel_value
np.testing.assert_equal(expected_pixels, ts.observation['pixels'])
def test_float_pixels_wrapper_already_float(self):
base_pixel_spec = _make_array_spec(
shape=self.pixels_shape, dtype=np.float64, name='pixels')
base_observation_spec = {
'pixels': base_pixel_spec,
'other_obs': self.other_obs_spec
}
base_env = mock.create_autospec(dm_env.Environment)
base_env.observation_spec.return_value = base_observation_spec
wrapped_env = float_pixels_wrapper.FloatPixelsWrapper(base_env)
# If the pixels are already float values, then obs_spec does not change.
self.assertEqual(base_env.observation_spec(),
wrapped_env.observation_spec())
# The wrapper should not touch the timestep in this case.
fake_timestep = ('step_type', 'reward', 'discount', 'obs')
base_env.step.return_value = fake_timestep
ts = wrapped_env.step({'fake_action': np.array([1, 2, 3])})
self.assertEqual(fake_timestep, ts)
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/gym_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wraps the AndroidEnv to expose an OpenAI Gym interface."""
from typing import Any
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
import gym
from gym import spaces
import numpy as np
class GymInterfaceWrapper(gym.Env):
"""AndroidEnv with OpenAI Gym interface."""
def __init__(self, env: dm_env.Environment):
self._env = env
self.spec = None
self.action_space = self._spec_to_space(self._env.action_spec())
self.observation_space = self._spec_to_space(self._env.observation_spec())
self.metadata = {'render.modes': ['rgb_array']}
self._latest_observation = None
def _spec_to_space(self, spec: specs.Array) -> spaces.Space:
"""Converts dm_env specs to OpenAI Gym spaces."""
if isinstance(spec, list):
return spaces.Tuple([self._spec_to_space(s) for s in spec])
if isinstance(spec, dict):
return spaces.Dict(
{name: self._spec_to_space(s) for name, s in spec.items()}
)
if isinstance(spec, specs.DiscreteArray):
return spaces.Box(
shape=(),
dtype=spec.dtype,
low=0,
high=spec.num_values-1)
if isinstance(spec, specs.BoundedArray):
return spaces.Box(
shape=spec.shape,
dtype=spec.dtype,
low=spec.minimum,
high=spec.maximum)
if isinstance(spec, specs.Array):
if spec.dtype == np.uint8:
low = 0
high = 255
else:
low = -np.inf
high = np.inf
return spaces.Box(shape=spec.shape, dtype=spec.dtype, low=low, high=high)
raise ValueError('Unknown type for specs: {}'.format(spec))
def render(self, mode='rgb_array'):
"""Renders the environment."""
if mode == 'rgb_array':
if self._latest_observation is None:
return
return self._latest_observation['pixels']
else:
raise ValueError('Only supported render mode is rgb_array.')
def reset(self) -> np.ndarray:
self._latest_observation = None
timestep = self._env.reset()
return timestep.observation
def step(self, action: dict[str, int]) -> tuple[Any, ...]:
"""Take a step in the base environment."""
timestep = self._env.step(action)
observation = timestep.observation
self._latest_observation = observation
reward = timestep.reward
done = timestep.step_type == dm_env.StepType.LAST
info = {'discount': timestep.discount}
return observation, reward, done, info
================================================
FILE: android_env/wrappers/gym_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.wrappers.gym_wrapper."""
from unittest import mock
from absl.testing import absltest
from android_env import env_interface
from android_env.wrappers import gym_wrapper
import dm_env
from dm_env import specs
from gym import spaces
import numpy as np
class GymInterfaceWrapperTest(absltest.TestCase):
def setUp(self):
super().setUp()
self._base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
self._base_env.action_spec.return_value = {
'action_type':
specs.DiscreteArray(
num_values=3,
name='action_type'),
'touch_position':
specs.BoundedArray(
shape=(2,),
dtype=np.float32,
minimum=[0.0, 0.0],
maximum=[1.0, 1.0],
name='touch_position'),
}
self._base_env.observation_spec.return_value = {
'pixels':
specs.Array(
shape=(480, 320, 3),
dtype=np.uint8,
name='pixels'),
'timedelta':
specs.Array(shape=(), dtype=np.int64, name='timedelta'),
'orientation':
specs.Array(
shape=np.array([4]),
dtype=np.uint8,
name='orientation'),
}
self._wrapped_env = gym_wrapper.GymInterfaceWrapper(self._base_env)
self._fake_ts = dm_env.TimeStep(
step_type=dm_env.StepType.MID,
observation={'pixels': np.ones(shape=(2, 3))},
reward=10.0,
discount=1.0)
def test_render(self):
self._base_env.step.return_value = self._fake_ts
_ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
image = self._wrapped_env.render(mode='rgb_array')
self.assertTrue(np.array_equal(image, np.ones(shape=(2, 3))))
def test_render_error(self):
with self.assertRaises(ValueError):
_ = self._wrapped_env.render(mode='human')
def test_reset(self):
self._base_env.reset.return_value = dm_env.TimeStep(
step_type=dm_env.StepType.FIRST,
observation={'pixels': np.ones(shape=(2, 3))},
reward=10.0,
discount=1.0)
obs = self._wrapped_env.reset()
self._base_env.reset.assert_called_once()
self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3))))
def test_step(self):
self._base_env.step.return_value = self._fake_ts
obs, _, _, _ = self._wrapped_env.step(action=np.zeros(shape=(1,)))
self._base_env.step.assert_called_once()
self.assertTrue(np.array_equal(obs['pixels'], np.ones(shape=(2, 3))))
def test_spec_to_space(self):
spec = specs.Array(
shape=(2, 3),
dtype=np.float32)
space = self._wrapped_env._spec_to_space(spec)
self.assertEqual(space, spaces.Box(
low=-np.inf, high=np.inf, shape=spec.shape, dtype=spec.dtype))
spec = specs.BoundedArray(
shape=(),
dtype=np.float32,
minimum=4,
maximum=5)
space = self._wrapped_env._spec_to_space(spec)
self.assertEqual(space, spaces.Box(
low=4, high=5, shape=spec.shape, dtype=spec.dtype))
spec = specs.DiscreteArray(num_values=4)
space = self._wrapped_env._spec_to_space(spec)
self.assertEqual(space, spaces.Box(
low=0, high=3, shape=(), dtype=np.int32))
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/image_rescale_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wraps the AndroidEnv environment to rescale the observations."""
from collections.abc import Sequence
from typing import cast
from android_env import env_interface
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
import numpy as np
from PIL import Image
# Taken from https://pillow.readthedocs.io/en/3.2.x/reference/Image.html#PIL.Image.Image.convert
#
# This array maps an RGB image to a grayscale image using the ITU-R 709
# specification which is good for computer displays and HDTV.
RGB_TO_GRAYSCALE_COEFFICIENTS = [0.2126, 0.7152, 0.0722]
class ImageRescaleWrapper(base_wrapper.BaseWrapper):
"""AndroidEnv with rescaled observations."""
def __init__(
self,
env: env_interface.AndroidEnvInterface,
zoom_factors: Sequence[float] | None = (0.5, 0.5),
grayscale: bool = False,
) -> None:
super().__init__(env)
assert 'pixels' in self._env.observation_spec()
assert self._env.observation_spec()['pixels'].shape[-1] in [
1,
3,
], 'Number of pixel channels should be 1 or 3.'
self._grayscale = grayscale
if zoom_factors is None:
zoom_factors = (1.0, 1.0)
# We only zoom the width and height of each layer, and we explicitly do not
# want to zoom the number of channels so we just multiply it by 1.0.
self._zoom_factors = tuple(zoom_factors) + (1.0,)
def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
observation = timestep.observation
processed_observation = observation.copy()
processed_observation['pixels'] = self._process_pixels(
observation['pixels']
)
return timestep._replace(observation=processed_observation)
def _process_pixels(self, raw_observation: np.ndarray) -> np.ndarray:
# We expect `raw_observation` to have shape (W, H, 3) - 3 for RGB
new_shape = np.array(
self._zoom_factors[0:2] * np.array(raw_observation.shape[0:2]),
dtype=np.int32,
)[::-1]
if self._grayscale:
# When self._grayscale == True, we squash the RGB into a single layer
image = np.dot(raw_observation, RGB_TO_GRAYSCALE_COEFFICIENTS)
else:
image = raw_observation
return self._resize_image_array(image, new_shape)
def _resize_image_array(
self, grayscale_or_rbg_array: np.ndarray, new_shape: np.ndarray
) -> np.ndarray:
"""Resize color or grayscale/action_layer array to new_shape."""
assert new_shape.ndim == 1
assert len(new_shape) == 2
resized_array = np.array(
Image.fromarray(grayscale_or_rbg_array.astype('uint8')).resize(
tuple(new_shape)
)
)
if resized_array.ndim == 2:
return np.expand_dims(resized_array, axis=-1)
return resized_array
def reset(self) -> dm_env.TimeStep:
timestep = self._env.reset()
return self._process_timestep(timestep)
def step(self, action) -> dm_env.TimeStep:
timestep = self._env.step(action)
return self._process_timestep(timestep)
def observation_spec(self) -> dict[str, specs.Array]:
parent_spec = self._env.observation_spec().copy()
parent_pixels = cast(specs.BoundedArray, parent_spec['pixels'])
out_shape = np.multiply(parent_pixels.shape, self._zoom_factors).astype(
np.int32
)
if self._grayscale:
# In grayscale mode we want the output shape to be [W, H, 1]
out_shape[-1] = 1
parent_spec['pixels'] = specs.BoundedArray(
shape=out_shape,
dtype=parent_pixels.dtype,
name=parent_pixels.name,
minimum=parent_pixels.minimum,
maximum=parent_pixels.maximum,
)
return parent_spec
================================================
FILE: android_env/wrappers/image_rescale_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.wrappers.image_rescale_wrapper."""
from typing import Any
from unittest import mock
from absl.testing import absltest
from android_env import env_interface
from android_env.wrappers import image_rescale_wrapper
import dm_env
from dm_env import specs
import numpy as np
def _simple_spec():
return specs.BoundedArray(
shape=np.array([300, 300, 3]),
dtype=np.uint8,
name='pixels',
minimum=0,
maximum=255)
def _simple_timestep():
observation = np.ones(shape=[300, 300, 3])
return dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=3.14,
discount=0.9,
observation={'pixels': observation})
class ImageRescaleWrapperTest(absltest.TestCase):
def test_100x50_grayscale(self):
fake_timestep = _simple_timestep()
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
fake_env.reset.return_value = fake_timestep
fake_env.step.return_value = fake_timestep
wrapper = image_rescale_wrapper.ImageRescaleWrapper(
fake_env, zoom_factors=(1.0 / 3, 1.0 / 6.0), grayscale=True)
self.assertIsNotNone(wrapper)
self.assertEqual(wrapper.observation_spec()['pixels'].shape, (100, 50, 1))
reset_timestep = wrapper.reset()
reset_image = reset_timestep.observation['pixels']
self.assertEqual(reset_image.shape, (100, 50, 1))
step_timestep = wrapper.step(action='fake_action')
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (100, 50, 1))
def test_150x60_full_channels(self):
fake_timestep = _simple_timestep()
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
fake_env.reset.return_value = fake_timestep
fake_env.step.return_value = fake_timestep
wrapper = image_rescale_wrapper.ImageRescaleWrapper(
fake_env, zoom_factors=(1.0 / 2.0, 1.0 / 5.0))
self.assertIsNotNone(wrapper)
self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3))
reset_timestep = wrapper.reset()
reset_image = reset_timestep.observation['pixels']
self.assertEqual(reset_image.shape, (150, 60, 3))
step_timestep = wrapper.step(action='fake_action')
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (150, 60, 3))
def test_list_zoom_factor(self):
fake_timestep = _simple_timestep()
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
fake_env.reset.return_value = fake_timestep
fake_env.step.return_value = fake_timestep
wrapper = image_rescale_wrapper.ImageRescaleWrapper(
fake_env, zoom_factors=[0.5, 0.2])
self.assertIsNotNone(wrapper)
self.assertEqual(wrapper.observation_spec()['pixels'].shape, (150, 60, 3))
reset_timestep = wrapper.reset()
reset_image = reset_timestep.observation['pixels']
self.assertEqual(reset_image.shape, (150, 60, 3))
step_timestep = wrapper.step(action='fake_action')
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (150, 60, 3))
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/last_action_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Extends Android observation with the latest action taken."""
from typing import cast
from android_env import env_interface
from android_env.components import action_type
from android_env.components import pixel_fns
from android_env.wrappers import base_wrapper
import dm_env
from dm_env import specs
import numpy as np
class LastActionWrapper(base_wrapper.BaseWrapper):
"""Extends Android observations with information about the last action taken.
The position of the last action is denoted by a single white pixel (with a
value of 255) in a channel of all black pixels (with a value of 0).
As this wrapper makes use of temporarily stored information about the
last action taken, it is important to apply on the environment side rather
than the agent side. Recommended not to apply before an ImageRescaleWrapper,
to avoid distortion of the single pixel denoting the action position.
"""
def __init__(self,
env: env_interface.AndroidEnvInterface,
concat_to_pixels: bool = True) -> None:
"""Initializes the internal state of this wrapper.
Args:
env: the environment to wrap.
concat_to_pixels: If True, will add a channel to the pixel observation.
If False, will pass the action as an extra observation.
"""
super().__init__(env)
self._concat_to_pixels = concat_to_pixels
self._screen_dimensions = self._env.observation_spec()['pixels'].shape[:2]
def _process_timestep(self, timestep: dm_env.TimeStep) -> dm_env.TimeStep:
observation = timestep.observation.copy()
processed_observation = self._process_observation(observation)
return timestep._replace(observation=processed_observation)
def _process_observation(
self, observation: dict[str, np.ndarray]
) -> dict[str, np.ndarray]:
"""Extends observation with last_action data."""
processed_observation = observation.copy()
last_action_layer = self._get_last_action_layer(observation['pixels'])
if self._concat_to_pixels:
pixels = observation['pixels'].copy()
processed_pixels = np.dstack((pixels, last_action_layer))
processed_observation['pixels'] = processed_pixels
else:
processed_observation['last_action'] = last_action_layer
return processed_observation
def _get_last_action_layer(self, pixels: np.ndarray) -> np.ndarray:
"""Makes sure the rescaling doesn't distort the last_action layer."""
last_action = self._env.raw_action
last_action_layer = np.zeros(self._screen_dimensions, dtype=pixels.dtype)
if ('action_type' in last_action and
last_action['action_type'] == action_type.ActionType.TOUCH):
touch_position = last_action['touch_position']
x, y = pixel_fns.touch_position_to_pixel_position(
touch_position, width_height=self._screen_dimensions[::-1]
)
last_action_layer[y, x] = 255
return last_action_layer
def reset(self) -> dm_env.TimeStep:
timestep = self._env.reset()
return self._process_timestep(timestep)
def step(self, action) -> dm_env.TimeStep:
timestep = self._env.step(action)
return self._process_timestep(timestep)
def observation_spec(self) -> dict[str, specs.Array]:
parent_spec = self._env.observation_spec().copy()
parent_pixels = cast(specs.BoundedArray, parent_spec['pixels'])
shape = parent_pixels.shape
if self._concat_to_pixels:
parent_spec['pixels'] = specs.BoundedArray(
shape=(shape[0], shape[1], shape[2] + 1),
dtype=parent_pixels.dtype,
name=parent_pixels.name,
minimum=parent_pixels.minimum,
maximum=parent_pixels.maximum)
else:
parent_spec.update({
'last_action':
specs.BoundedArray(
shape=(shape[0], shape[1]),
dtype=parent_pixels.dtype,
name='last_action',
minimum=parent_pixels.minimum,
maximum=parent_pixels.maximum)
})
return parent_spec
================================================
FILE: android_env/wrappers/last_action_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for android_env.wrappers.last_action_wrapper."""
from typing import Any
from unittest import mock
from absl.testing import absltest
from android_env import env_interface
from android_env.components import action_type
from android_env.wrappers import last_action_wrapper
import dm_env
from dm_env import specs
import numpy as np
def _simple_spec():
return specs.BoundedArray(
shape=np.array([120, 80, 3]),
dtype=np.uint8,
name='pixels',
minimum=0,
maximum=255)
def _simple_timestep():
observation = np.ones(shape=[120, 80, 3])
return dm_env.TimeStep(
step_type=dm_env.StepType.MID,
reward=3.14,
discount=0.9,
observation={'pixels': observation})
class LastActionWrapperTest(absltest.TestCase):
def test_concat_to_pixels(self):
fake_timestep = _simple_timestep()
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
fake_env.reset.return_value = fake_timestep
fake_env.step.return_value = fake_timestep
wrapper = last_action_wrapper.LastActionWrapper(
fake_env, concat_to_pixels=True)
self.assertIsNotNone(wrapper)
self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 4))
reset_timestep = wrapper.reset()
reset_image = reset_timestep.observation['pixels']
self.assertEqual(reset_image.shape, (120, 80, 4))
last_action_layer = reset_image[:, :, -1]
self.assertEqual(np.sum(last_action_layer), 0)
action1 = {
'action_type': action_type.ActionType.TOUCH,
'touch_position': np.array([0.25, 0.75], dtype=np.float32), # (W x H)
}
type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
step_timestep = wrapper.step(action=action1)
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (120, 80, 4)) # (H x W)
last_action_layer = step_image[:, :, -1]
self.assertEqual(np.sum(last_action_layer), 255)
y, x = np.where(last_action_layer == 255)
self.assertEqual((y.item(), x.item()), (90, 20))
action2 = {
'action_type': action_type.ActionType.LIFT,
'touch_position': np.array([0.25, 0.75], dtype=np.float32),
}
type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
step_timestep = wrapper.step(action=action2)
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (120, 80, 4))
last_action_layer = step_image[:, :, -1]
self.assertEqual(np.sum(last_action_layer), 0)
action3 = {
'action_type': action_type.ActionType.TOUCH,
'touch_position': np.array([0.25, 1.0], dtype=np.float32),
}
type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
step_timestep = wrapper.step(action=action3)
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (120, 80, 4))
last_action_layer = step_image[:, :, -1]
self.assertEqual(np.sum(last_action_layer), 255)
y, x = np.where(last_action_layer == 255)
self.assertEqual((y.item(), x.item()), (119, 20))
def test_no_concat_to_pixels(self):
fake_timestep = _simple_timestep()
fake_env = mock.create_autospec(env_interface.AndroidEnvInterface)
fake_env.observation_spec.return_value = {'pixels': _simple_spec()}
fake_env.reset.return_value = fake_timestep
fake_env.step.return_value = fake_timestep
wrapper = last_action_wrapper.LastActionWrapper(
fake_env, concat_to_pixels=False)
self.assertIsNotNone(wrapper)
self.assertEqual(wrapper.observation_spec()['pixels'].shape, (120, 80, 3))
self.assertEqual(wrapper.observation_spec()['last_action'].shape, (120, 80))
reset_timestep = wrapper.reset()
reset_image = reset_timestep.observation['pixels']
self.assertEqual(reset_image.shape, (120, 80, 3))
last_action_layer = reset_timestep.observation['last_action']
self.assertEqual(np.sum(last_action_layer), 0)
action1 = {
'action_type': action_type.ActionType.TOUCH,
'touch_position': np.array([0.25, 0.75], dtype=np.float32),
}
type(fake_env).raw_action = mock.PropertyMock(return_value=action1)
step_timestep = wrapper.step(action=action1)
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (120, 80, 3))
last_action_layer = step_timestep.observation['last_action']
self.assertEqual(np.sum(last_action_layer), 255)
y, x = np.where(last_action_layer == 255)
self.assertEqual((y.item(), x.item()), (90, 20))
action2 = {
'action_type': action_type.ActionType.LIFT,
'touch_position': np.array([0.25, 0.75], dtype=np.float32),
}
type(fake_env).raw_action = mock.PropertyMock(return_value=action2)
step_timestep = wrapper.step(action=action2)
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (120, 80, 3))
last_action_layer = step_timestep.observation['last_action']
self.assertEqual(np.sum(last_action_layer), 0)
action3 = {
'action_type': action_type.ActionType.TOUCH,
'touch_position': np.array([1.0, 0.75], dtype=np.float32),
}
type(fake_env).raw_action = mock.PropertyMock(return_value=action3)
step_timestep = wrapper.step(action=action3)
step_image = step_timestep.observation['pixels']
self.assertEqual(step_image.shape, (120, 80, 3))
last_action_layer = step_timestep.observation['last_action']
self.assertEqual(np.sum(last_action_layer), 255)
y, x = np.where(last_action_layer == 255)
self.assertEqual((y.item(), x.item()), (90, 79))
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/rate_limit_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Limits interactions with the environment to a given rate."""
import enum
import time
from android_env import env_interface
from android_env.components import action_type
from android_env.wrappers import base_wrapper
import dm_env
import numpy as np
class RateLimitWrapper(base_wrapper.BaseWrapper):
"""Limits interactions with the environment to a given rate."""
class SleepType(enum.IntEnum):
"""Determines how the wrapper interacts with the underlying environment."""
# The wrapper sleeps before calling `step()` on the underlying environment.
BEFORE = 0
# The wrapper sleeps after calling `step()` on the underlying environment.
AFTER = 1
# The wrapper first calls `step()`, obtaining a TimeStep which is ignored,
# then it sleeps, and then it calls `step(REPEAT)` to obtain a TimeStep
# that's as fresh as possible.
#
# Note that for both BEFORE and AFTER_WITH_REPEAT, the _total_ amount of
# time inside this wrapper may go beyond the rate specified in `rate`
# because the sleep does not account for the time taken by step().
AFTER_WITH_REPEAT = 2
def __init__(self,
env: env_interface.AndroidEnvInterface,
rate: float,
sleep_type: SleepType = SleepType.AFTER_WITH_REPEAT):
"""Initializes this wrapper.
Args:
env: The underlying environment to which this wrapper is applied.
rate: The desired rate in Hz to interact with the environment. If <=0.0,
this wrapper will be disabled.
sleep_type: This determines how the wrapper will interact with the
underlying AndroidEnv environment.
"""
super().__init__(env)
self._assert_base_env()
self._last_step_time = None
self._max_wait = 1.0 / rate if rate > 0.0 else 0.0
self._sleep_type = sleep_type
def _assert_base_env(self):
"""Checks that the wrapped env has the right action spec format."""
parent_action_spec = self._env.action_spec()
assert len(parent_action_spec) == 2
assert not parent_action_spec['action_type'].shape
assert parent_action_spec['touch_position'].shape == (2,)
def reset(self):
timestep = self._env.reset()
self._last_step_time = time.time()
return timestep
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
"""Takes a step while maintaining a steady interaction rate."""
# If max_wait is non-positive, the wrapper has no effect.
if self._max_wait <= 0.0:
return self._env.step(action)
if self._sleep_type == RateLimitWrapper.SleepType.BEFORE:
self._wait()
timestep = self._env.step(action)
if timestep.last():
return timestep
if self._sleep_type == RateLimitWrapper.SleepType.AFTER_WITH_REPEAT:
for k in action.keys():
if k.startswith('action_type'):
action[k] = np.array(action_type.ActionType.REPEAT, dtype=np.uint8)
self._wait()
first_reward = timestep.reward or 0.0
timestep = self._env.step(action)
second_reward = timestep.reward or 0.0
# Accumulate rewards over the two steps taken.
timestep = timestep._replace(reward=first_reward + second_reward)
elif self._sleep_type == RateLimitWrapper.SleepType.AFTER:
self._wait()
self._last_step_time = time.time()
return timestep
def _wait(self) -> None:
if self._max_wait > 0.0 and self._last_step_time is not None:
time_since_step = time.time() - self._last_step_time
sec_to_wait = self._max_wait - time_since_step
if sec_to_wait > 0.0:
time.sleep(sec_to_wait)
================================================
FILE: android_env/wrappers/rate_limit_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for rate_limit_wrapper."""
import time
from typing import Any, Protocol
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from android_env import env_interface
from android_env.components import action_type
from android_env.wrappers import rate_limit_wrapper
import dm_env
from dm_env import specs
import numpy as np
def _get_base_env():
env = mock.create_autospec(env_interface.AndroidEnvInterface)
env.action_spec.return_value = {
'action_type':
specs.DiscreteArray(
num_values=len(action_type.ActionType),
name='action_type'),
'touch_position':
specs.BoundedArray(
shape=(2,),
dtype=np.float32,
minimum=[0.0, 0.0],
maximum=[1.0, 1.0],
name='touch_position'),
}
return env
class _FnWithTimestamps(Protocol):
"""A function with `timestamp` and `timestamps` attributes."""
timestamp: float
timestamps: list[float]
def _with_timestamp(fn: Any) -> _FnWithTimestamps:
return fn
class RateLimitWrapperTest(parameterized.TestCase):
@parameterized.named_parameters(
('zero_rate', 0),
('negative_rate', -50),
)
@mock.patch.object(time, 'sleep', autospec=True)
def test_disabled(self, rate, mock_sleep):
"""With a non-positive rate, this wrapper should do nothing."""
env = _get_base_env()
wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=rate)
_ = wrapper.reset()
mock_sleep.assert_not_called()
_ = wrapper.step({
'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
'touch_position': np.array([0.123, 0.456])
})
mock_sleep.assert_not_called()
# When the wrapper is disabled, base step should only be called once.
env.step.assert_called_once()
@mock.patch.object(time, 'sleep', autospec=True)
def test_enabled(self, mock_sleep):
"""When enabled, the wrapper should sleep for a period in [0, 1/rate]."""
env = _get_base_env()
env.step.return_value = dm_env.transition(reward=None, observation=None)
wrapper = rate_limit_wrapper.RateLimitWrapper(env, rate=1/33.33)
_ = wrapper.reset()
mock_sleep.assert_not_called() # It should never sleep during reset().
# Step for 100 steps.
for _ in range(100):
_ = wrapper.step({
'action_type':
np.array(action_type.ActionType.LIFT, dtype=np.uint8),
'touch_position':
np.array([0.123, 0.456])
})
# Check that there are 100 calls and that they're all within [0, 1/rate].
self.assertLen(mock_sleep.call_args_list, 100)
for call in mock_sleep.call_args_list:
args, unused_kwargs = call
sleep_time = args[0]
self.assertBetween(sleep_time, 0.0, 33.33)
@mock.patch.object(time, 'sleep', autospec=True)
def test_enabled_sleep_type_before(self, mock_sleep):
"""When sleep_type==BEFORE, sleep should come before step()."""
env = _get_base_env()
wrapper = rate_limit_wrapper.RateLimitWrapper(
env,
rate=1/33.33,
sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.BEFORE)
_ = wrapper.reset()
mock_sleep.assert_not_called() # It should never sleep during reset().
@_with_timestamp
def _sleep_fn(sleep_time):
_sleep_fn.timestamp = time.time()
self.assertBetween(sleep_time, 0.0, 33.33)
mock_sleep.side_effect = _sleep_fn
def _step_fn(action):
self.assertEqual(
action['action_type'],
np.array(action_type.ActionType.LIFT, dtype=np.uint8))
_step_fn.timestamps.append(time.time())
return dm_env.transition(reward=None, observation=None)
_step_fn.timestamps = []
env.step.side_effect = _step_fn
_ = wrapper.step({
'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
'touch_position': np.array([0.123, 0.456])
})
self.assertLen(_step_fn.timestamps, 1)
# We expect sleep to have been executed BEFORE a single `step()`.
self.assertGreaterEqual(_step_fn.timestamps[0], _sleep_fn.timestamp)
@mock.patch.object(time, 'sleep', autospec=True)
def test_enabled_sleep_type_after(self, mock_sleep):
"""When sleep_type==AFTER, sleep should come after step()."""
env = _get_base_env()
wrapper = rate_limit_wrapper.RateLimitWrapper(
env,
rate=1/33.33,
sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType.AFTER)
_ = wrapper.reset()
mock_sleep.assert_not_called() # It should never sleep during reset().
@_with_timestamp
def _sleep_fn(sleep_time):
_sleep_fn.timestamp = time.time()
self.assertBetween(sleep_time, 0.0, 33.33)
mock_sleep.side_effect = _sleep_fn
def _step_fn(action):
self.assertEqual(
action['action_type'],
np.array(action_type.ActionType.LIFT, dtype=np.uint8))
_step_fn.timestamps.append(time.time())
return dm_env.transition(reward=None, observation=None)
_step_fn.timestamps = []
env.step.side_effect = _step_fn
_ = wrapper.step({
'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
'touch_position': np.array([0.123, 0.456])
})
# We expect sleep to have been executed AFTER a single `step()`.
self.assertLen(_step_fn.timestamps, 1)
self.assertLessEqual(_step_fn.timestamps[0], _sleep_fn.timestamp)
@mock.patch.object(time, 'sleep', autospec=True)
def test_enabled_sleep_type_after_with_repeat(self, mock_sleep):
"""When sleep_type==AFTER_WITH_REPEAT, sleep should be between 2 steps()."""
env = _get_base_env()
wrapper = rate_limit_wrapper.RateLimitWrapper(
env,
rate=1/33.33,
sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType
.AFTER_WITH_REPEAT)
_ = wrapper.reset()
mock_sleep.assert_not_called() # It should never sleep during reset().
@_with_timestamp
def _sleep_fn(sleep_time):
_sleep_fn.timestamp = time.time()
self.assertBetween(sleep_time, 0.0, 33.33)
mock_sleep.side_effect = _sleep_fn
@_with_timestamp
def _step_fn(action):
# On even calls the action should be the actual agent action, but on odd
# calls they should be REPEATs.
if len(_step_fn.timestamps) % 2 == 0:
self.assertEqual(
action['action_type'],
np.array(action_type.ActionType.LIFT, dtype=np.uint8))
else:
self.assertEqual(
action['action_type'],
np.array(action_type.ActionType.REPEAT, dtype=np.uint8))
_step_fn.timestamps.append(time.time())
return dm_env.transition(reward=1.0, observation=None)
_step_fn.timestamps = []
env.step.side_effect = _step_fn
timestep = wrapper.step({
'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
'touch_position': np.array([0.123, 0.456])
})
# When the wrapper is enabled, base step should be called twice.
self.assertEqual(env.step.call_count, 2)
# `step()` should be called twice: before `sleep()` and after it.
self.assertLen(_step_fn.timestamps, 2)
self.assertGreaterEqual(_sleep_fn.timestamp, _step_fn.timestamps[0])
self.assertLessEqual(_sleep_fn.timestamp, _step_fn.timestamps[1])
# Rewards should accumulate over the two step() calls
self.assertEqual(timestep.reward, 2.0)
@mock.patch.object(time, 'sleep', autospec=True)
def test_enabled_sleep_type_after_with_repeat_last(self, mock_sleep):
"""If the first step is a LAST, second step should not be taken."""
env = _get_base_env()
wrapper = rate_limit_wrapper.RateLimitWrapper(
env,
rate=1/33.33,
sleep_type=rate_limit_wrapper.RateLimitWrapper.SleepType
.AFTER_WITH_REPEAT)
_ = wrapper.reset()
mock_sleep.assert_not_called() # It should never sleep during reset().
env.step.return_value = dm_env.termination(reward=None, observation=None)
_ = wrapper.step({
'action_type': np.array(action_type.ActionType.LIFT, dtype=np.uint8),
'touch_position': np.array([0.123, 0.456])
})
# Second step call should be skipped.
env.step.assert_called_once()
mock_sleep.assert_not_called()
if __name__ == '__main__':
absltest.main()
================================================
FILE: android_env/wrappers/tap_action_wrapper.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wraps the AndroidEnv environment to provide tap actions of a given duration."""
from collections.abc import Sequence
from typing import Any
from android_env import env_interface
from android_env.components import action_type
from android_env.wrappers import base_wrapper
import dm_env
import numpy as np
class TapActionWrapper(base_wrapper.BaseWrapper):
"""AndroidEnv with tap actions."""
def __init__(self,
env: env_interface.AndroidEnvInterface,
num_frames: int = 5,
touch_only: bool = False) -> None:
super().__init__(env)
assert 'action_type' in env.action_spec()
self._touch_only = touch_only
self._num_frames = num_frames
self._env_steps = 0
def stats(self) -> dict[str, Any]:
"""Returns a dictionary of metrics logged by the environment."""
logs = self._env.stats()
logs.update({'env_steps': self._env_steps})
return logs
def _process_action(
self, action: dict[str, np.ndarray]
) -> Sequence[dict[str, np.ndarray]]:
if self._touch_only:
assert action['action_type'] == 0
touch_action = action.copy()
touch_action['action_type'] = np.array(
action_type.ActionType.TOUCH
).astype(self.action_spec()['action_type'].dtype)
actions = [touch_action] * self._num_frames
lift_action = action.copy()
lift_action['action_type'] = np.array(action_type.ActionType.LIFT).astype(
self.action_spec()['action_type'].dtype
)
actions.append(lift_action)
else:
if action['action_type'] == action_type.ActionType.TOUCH:
actions = [action] * self._num_frames
lift_action = action.copy()
lift_action['action_type'] = np.array(
action_type.ActionType.LIFT
).astype(self.action_spec()['action_type'].dtype)
actions.append(lift_action)
else:
actions = [action] * (self._num_frames + 1)
return actions
def step(self, action: dict[str, np.ndarray]) -> dm_env.TimeStep:
"""Takes a step in the environment."""
self._env_steps += self._num_frames + 1
actions = self._process_action(action)
total_reward = 0.0
for idx in range(len(actions)):
step_type, reward, discount, observation = self._env.step(actions[idx])
if reward:
total_reward += reward
if step_type == dm_env.StepType.LAST:
return dm_env.TimeStep(
step_type=step_type,
reward=total_reward,
discount=discount,
observation=observation)
return dm_env.TimeStep(
step_type=step_type,
reward=total_reward,
discount=discount,
observation=observation)
def action_spec(self) -> dict[str, dm_env.specs.Array]:
if self._touch_only:
return {
'action_type':
dm_env.specs.DiscreteArray(num_values=1, name='action_type'),
'touch_position':
self._env.action_spec()['touch_position'],
}
else:
return self._env.action_spec()
================================================
FILE: android_env/wrappers/tap_action_wrapper_test.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tap_action_wrapper."""
from unittest import mock
from absl.testing import absltest
from android_env import env_interface
from android_env.components import action_type
from android_env.wrappers import tap_action_wrapper
import dm_env
from dm_env import specs
import numpy as np
def _make_array_spec(shape, dtype, name):
return specs.BoundedArray(
name=name,
shape=shape,
dtype=dtype,
minimum=np.zeros(shape),
maximum=np.ones(shape), # maximum is inclusive.
)
class TapActionWrapperTest(absltest.TestCase):
def setUp(self):
super().setUp()
self._base_action_spec = {
'action_type': specs.DiscreteArray(
num_values=3, name='action_type'),
'touch_position': _make_array_spec(
shape=(2,), dtype=np.float32, name='touch_position'),
}
self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface)
self.base_env.action_spec.return_value = self._base_action_spec
def test_process_action_repeat(self):
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=3)
action = {
'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32),
'touch_position': np.array([0.5, 0.5], dtype=np.float32),
}
actions = wrapped_env._process_action(action)
self.assertLen(actions, wrapped_env._num_frames + 1)
self.assertEqual(action, actions[-1])
def test_process_action_lift(self):
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=3)
action = {
'action_type': np.array(action_type.ActionType.LIFT, dtype=np.int32),
'touch_position': np.array([0.5, 0.5], dtype=np.float32),
}
actions = wrapped_env._process_action(action)
self.assertLen(actions, wrapped_env._num_frames + 1)
self.assertEqual(action, actions[-1])
def test_process_action_touch(self):
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=3)
action = {
'action_type': np.array(action_type.ActionType.TOUCH, dtype=np.int32),
'touch_position': np.array([0.5, 0.5], dtype=np.float32),
}
actions = wrapped_env._process_action(action)
self.assertLen(actions, wrapped_env._num_frames + 1)
self.assertEqual(
actions[-1]['action_type'], np.array(action_type.ActionType.LIFT)
)
def test_reset(self):
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=5)
fake_timestep = 'ts'
self.base_env.reset.return_value = fake_timestep
ts = wrapped_env.reset()
self.base_env.reset.assert_called_once()
self.assertEqual(fake_timestep, ts)
def test_step(self):
# Arrange.
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=5)
fake_timestep = dm_env.TimeStep(
step_type='fake_type',
reward=0.0,
discount=1.0,
observation='fake_obs')
self.base_env.step.return_value = fake_timestep
self.base_env.stats.return_value = {}
# Act.
ts = wrapped_env.step({
'action_type': np.array(action_type.ActionType.REPEAT, dtype=np.int32),
'touch_position': np.array([0.5, 0.5], dtype=np.float32),
})
stats = wrapped_env.stats()
# Assert.
self.assertEqual(wrapped_env._num_frames+1, self.base_env.step.call_count)
self.assertIsInstance(ts, dm_env.TimeStep)
self.assertIsInstance(stats, dict)
self.assertIn('env_steps', stats)
self.assertEqual(stats['env_steps'], 6)
def test_observation_spec(self):
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=5)
fake_obs_spec = 'fake_obs_spec'
self.base_env.observation_spec.return_value = fake_obs_spec
observation_spec = wrapped_env.observation_spec()
self.base_env.observation_spec.assert_called_once()
self.assertEqual(fake_obs_spec, observation_spec)
def test_action_spec(self):
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=5)
self.base_env.action_spec.return_value = self._base_action_spec
action_spec = wrapped_env.action_spec()
self.base_env.action_spec.assert_called()
self.assertEqual(self.base_env.action_spec(),
action_spec)
def test_stats(self):
"""Checks that returned stats have expected properties."""
# Arrange.
self.base_env.stats.return_value = {
'some_key': 12345,
'another_key': 5.4321,
}
wrapped_env = tap_action_wrapper.TapActionWrapper(
self.base_env, num_frames=5
)
# Act.
stats = wrapped_env.stats()
# Assert.
self.assertIsInstance(stats, dict)
# Original entries should still be present.
self.assertIn('some_key', stats)
self.assertEqual(stats['some_key'], 12345)
self.assertIn('another_key', stats)
self.assertEqual(stats['another_key'], 5.4321)
# TapActionWrapper inserts its own `env_steps`.
self.assertIn('env_steps', stats)
self.assertEqual(stats['env_steps'], 0)
if __name__ == '__main__':
absltest.main()
================================================
FILE: docs/emulator_guide.md
================================================
# AndroidEnv - Emulator Setup Guide
In this document we provide a step-by-step guide for creating a virtual Android
device with Android Studio. After creating an AVD
([Android Virtual Device](https://developer.android.com/studio/run/managing-avds))
you will be able to connect it to an AndroidEnv instance and you're ready to go.
To get started, you will need to download
[Android Studio](https://developer.android.com/studio) - an IDE widely used by
Android developers.
## Install an SDK Platform Package
Android Studio comes with the Android Software Development Toolkit (SDK) which,
among others, allows you to install different versions of Android. Click on
**Tools** > **SDK Manager** and select the SDK version that you would like to
use.

We recommend that you set the `Android SDK Location` to be in your home
directory (for example, on Linux the default one is `~/Android/Sdk`, while on
macOS - `~/Library/Android/sdk`). You can always find the SDK location in
Android Studio under **Preferences** > **Appearance & Behavior** > **System
Settings** > **Android SDKs** > _Android SDK Location_.
If you set the the custom `Android SDK Location`, make note of it - you will
need it for connecting AndroidEnv to your AVD.

## Create an AVD
Now it is time to create a virtual device (AVD). Go to **Tools** > **AVD
Manager**.

In the pop-up window you will find an option to **Create Virtual Device**.

Configure the virtual device. You can select the model or choose from more
advanced settings (refer to the
[Android docs](https://developer.android.com/studio/run/managing-avds) for
step-by-step instructions).

Name your AVD and take note of this value. It will be neccessary for connecting
AndroidEnv to this virtual device.

Once you are done, you will see the new AVD show up in the **AVD Manager**.
Click on **View details** to inspect some of its properties.

Take note of the `AVD Path`. This value will be neccessary for connecting
AndroidEnv to this device. We recommend that you set this to be your home
directory (for instance, on Linux or macOS it may be `~/.android/avd`).

## Ready to use
With SDK and AVD both set up, you are now ready to use this emulated device with
AndroidEnv. Don't forget to take note of the following three values: your AVD
name, the AVD path, and the SDK path. For example, on Linux they may be:
```
--avd_name=my_avd
--avd_package_path=~/.android/avd
--android_sdk_root=~/Android/Sdk
```
Next, once you have set up the AVD, follow the
[Task steps](instructions.md#the-task) in the
[Running the environment guide](instructions.md) to finish setting up
AndroidEnv.
However, if you want to just interact with the newly created device, click on
the run button next to your AVD in the **AVD Manager** in Android Studio (this
step is optional).

You will see an emulator window pop up. You can interact with it by clicking on
the screen.

There are many other features in Android Studio that let you customize your
device. For example, you can create custom images with pre-installed
applications or configured settings.
================================================
FILE: docs/environment.md
================================================
# AndroidEnv - Environment features
AndroidEnv is a complex environment that, while offering an almost endless range
of possibilites for RL research and investigation, poses multiple kinds of
challenges simultaneously. In this document we outline AndroidEnv's main
features that render it such a unique learning environment.
## Real-time environment
AndroidEnv is built on top of an emulated Android device, allowing the agent to
communicate with the emulator through touch actions. Android emulators are
created independently from our environment implementation and simulate real
Android devices in the most realistic manner. This simulation runs real-time,
independently of the agent, meaning that the simulaton will not wait for agent
input between frames. This aspect of the environment renders the RL setup
similar to a robotics problem, where the challenges of real-time interaction and
consequent noise in observations have to be overcome. Please note there is
currently no straightforward way to slow down the simulation either.
## Action space
Perhaps one of the most interesting features of AndroidEnv is its large and
complex action interface. The raw action space of the environment consists of a
tuple `(x,y) in [0,1]x[0,1]` determining the location of the action on the
Android screen, and a discrete value `ActionType in {LIFT, TOUCH, REPEAT}`
indicating whether the agent wants to touch the screen at this chosen location
or not. This action space is uniform across all tasks/apps.
**Gestures.** The complexity of the interface arises from the fact that
individual raw actions on their own do not neccessarily trigger a meaningful
change in the environment. Most Android applications are designed such that they
can be controlled/navigated through common touchscreen gestures such as pressing
buttons, swiping, scrolling, pinching, drag and drop etc. Each of these can be
thought of as particular sequences of raw actions: for example, *touching* the
screen at a particular location, then immediately *lifting* the imaginary finger
might be interpreted as a *press of a button*; while a sequence of *touches*
aligned in a vertical line might be interpreted as *scrolling*. We note that
AndroidEnv does not support multitouch actions at the moment, but it is a
possible feature to add.
It is important to point out that it is out of the environment's control to
determine how particular sequences of raw actions get interpreted by the Android
simulator - much like when humans interact with physical devices, a certain
gesture on the screen might be interpreted differently if it is performed at a
slightly different angle or speed.
Tap | Double Tap | Touch & Hold | Flick Left | Flick Right | Scroll (H) | Scroll (V) | Drag & Drop
-------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------------------------------------------------------ | ---------------------------------------------------------------------- | ------------------------------------------------------------------------ | ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------- | -----------
 |  |  |  |  |  |  | 
**Wrappers.** It is possible to alter the raw action space of the environment by
applying [wrappers](#wrappers). For example one might discretize the action
space by splitting the screen up into a grid of a desired size; restrict the
ActionType to *touch* only; or fix certain gesture skills. We note here that
these wrappers, again, will not alter how the particular sequence of performed
raw actions gets interpreted by the Android simulator.
## Observation space
The observation space of AndroidEnv consists of three main components:
(`pixels`, `timedelta`, `orientation`), the most notable of these being
`pixels`. The original screen size will depend on the type of emulator used, but
given that it will correspond to real device screen sizes, this will usually be
quite large (of course, this can be scaled down, e.g. with wrappers). The
`timedelta` component captures the amount of time passed since the last
observation was fetched. The `orientation`, even though it does not affect the
layout of the RGB image in the observation, might carry relevant information for
the agent. For example, if there is text on the screen, it is important to know
how it is oriented. Again, a benefit of this observation space is that it is
uniform across all tasks. As mentioned above, observations often carry spatial
cues and are suggestive of the kind of actions/gestures that are meaningful to
perform in a given state.
## Task extras
On top of the default observations (`pixels`, `timedelta`, `orientation`), some
tasks might expose additional structured observations after each step. An
*extra* in AndroidEnv is any information that an app may send to aid the
understanding of the task. The type of information sent through this channel is
usually something difficult to obtain from raw pixels and may include meaningful
information such as:
* The current board configuration (e.g. of a chess game or of a tetris game)
in matrix or string form.
* The position of the avatar in a map.
* Events such as whether a button was pressed or whether a checkpoint was
achieved.
Note that these are entirely optional and may not be available at all.
To request extras from the environment, you can call `env.task_extras()` after
each `env.step()`, which will return a dictionary of all the extra observations
observed during the previous step (or an empty dict is there's none available).
For example:
```python
for _ in range(FLAGS.n_steps):
action = agent.select_action(timestep.observation)
timestep = env.step(action)
logging.info('observation: %s', timestep.observation)
logging.info('extra observations: %s', env.task_extras())
```
Please note however that the env might not return extras at every timestep, only
when something meaningful happened (e.g. only when a button was pressed, or when
the state of the board has changed).
When integrating your own APK as a new task for the environment, you can define
your own extras by following the instructions
[here](tasks_guide.md#log-messages-and-custom-apks).
## Wrappers
AndroidEnv's action- and observation spaces can be altered by applying suitable
wrappers. While custom wrappers can be built easily, we have provided a number
of useful wrappers that demonstrate their usage:
* `discrete_action_wrapper`: Discretizes the action space into an `nxk` grid.
* `flat_interface_wrapper`: Removes the dictionary structure from the
observation and action specs.
* `float_pixels_wrapper`: Projects the pixel RGB values from the integer range
`[0, 255]` to the float range `[0, 1]`.
* `image_rescale_wrapper`: Resizes the pixel observations by the selected
ratio.
* `gym_wrapper`: Changes the environment interface from
[dm_env](https://github.com/deepmind/dm_env) to
[OpenAI](https://gym.openai.com/) gym interface.
* `last_action_wrapper`: Extends the observation with a one-hot encoded
location of the previously taken action, in order to aid agents without
built-in memory.
## Internal structure of AndroidEnv
The chart below gives an overview of the internal workings of the system,
illustrating how different classes interact with each other and what their
individual roles are. See the source code for more details.

================================================
FILE: docs/example_tasks.md
================================================
# AndroidEnv - Available tasks
This page gives a detailed overview of the example tasks provided with
AndroidEnv. The purpose is to give researchers an idea of the different kinds of
challenges that AndroidEnv poses.
To use any of these tasks in your own experiments, click on **Download** to
download a ZIP file containing textprotos and the corresponding APKs. After
downloading, move the `.apk` and `.textproto` files to a directory of your
choice and take note of their path. This information is needed for
[running](instructions.md#create-the-env) an AndroidEnv instance with the given
task.
| App / Game | Interface | Time reactive | Multi-level | Rewards | Extras | Download |
| --------------------------------------------------- | -------------------- | -------------- | ----------------------------- | ---------- | ------------ | -------- |
| [Vokram (MDP)](#vokram) | Tapping (buttons) | No | Yes (4 variants) | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/mdp.tar.gz) |
| [Apple Flinger](#apple-flinger) | Drag & drop | No | Yes (6 variants) | Dense | No | [Download](https://storage.googleapis.com/android_env-tasks/apple_flinger.tar.gz) |
| [Blockinger](#blockinger) | Tapping (buttons) | Yes | No | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/blockinger.tar.gz) |
| [Catch](#catch) | Touch | Yes | No | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/catch_the_ball.tar.gz) |
| [Classic 2048](#classic-2048) | Swiping | No | No | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/classic_2048.tar.gz) |
| [Dodge](#dodge) | Tapping | Yes | No | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/dodge.tar.gz) |
| [DroidFish (Chess)](#droidfish) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/droidfish.tar.gz) |
| [FlappyDroid](#flappydroid) | Tapping | Yes | Yes (2 levels) | Dense | No | [Download](https://storage.googleapis.com/android_env-tasks/systemui_egg_land.tar.gz) |
| [FloodIt](#floodit) | Tapping (buttons) | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/floodit.tar.gz) |
| [Frozen Bubble](#frozen-bubble) | Dragging, tapping | No | No | Sparse | No | [Download](https://storage.googleapis.com/android_env-tasks/frozen_bubble.tar.gz) |
| [Memory Game](#memory-game) | Tapping | No | Yes (6 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/memory_game.tar.gz) |
| [Minesweeper](#minesweeper) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/minesweeper.tar.gz) |
| [Nostalgic Racer](#nostalgic-racer) | Touch | Yes | Yes (2 variants) | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/nostalgic_racer.tar.gz) |
| [Open Sudoku](#open-sudoku) | Tapping (buttons) | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/open_sudoku.tar.gz) |
| [Perfection](#perfection) | Drag & drop | No | Yes (3 game types) | Dense | Yes | [Download](https://storage.googleapis.com/android_env-tasks/perfection.tar.gz) |
| [Rocket Sleigh](#rocket-sleigh) | Tapping | Yes | No | Dense | No | [Download](https://storage.googleapis.com/android_env-tasks/rocket_sleigh.tar.gz) |
| [Pong](#pong) | Drag | Yes | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/pong.tar.gz) |
| [SGT Puzzles - Blackbox](#sgt-puzzles-blackbox) | Tapping | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Bridge](#sgt-puzzles-bridge) | Drag & drop | No | Yes (5 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Cube](#sgt-puzzles-cube) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Dominosa](#sgt-puzzles-dominosa) | Tapping | No | Yes (5 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Fifteen](#sgt-puzzles-fifteen) | Tapping | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Flip](#sgt-puzzles-flip) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/apple_flinger.tar.gz) |
| [SGT Puzzles - Flood](#sgt-puzzles-flood) | Tapping | No | Yes (3 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Galaxies](#sgt-puzzles-galaxies) | Tapping | No | Yes (6 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Guess](#sgt-puzzles-guess) | Tapping | No | Yes (4 levels) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Inertia](#sgt-puzzles-inertia) | Tapping | No | Yes (2 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Light Up](#sgt-puzzles-light-up) | Tapping | No | Yes (5 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Loopy](#sgt-puzzles-loopy) | Tapping | No | Yes (3 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [SGT Puzzles - Net](#sgt-puzzles-net) | Tapping | No | Yes (5 sizes) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/sgtpuzzles.tar.gz) |
| [Shattered Pixel Dungeon](#shattered-pixel-dungeon) | Tapping | Yes | Yes (4 variants) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/shattered_pixel_dungeon.tar.gz) |
| [Simple Solitaire](#simple-solitaire) | Drag & drop | No | Yes (19 tasks) | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/simple_solitaire.tar.gz) |
| [Snake](#snake) | Tapping (buttons) | Yes | No | Sparse | Yes | [Download](https://storage.googleapis.com/android_env-tasks/aosp_samples_snake.tar.gz) |
| [Vector Pinball](#vector-pinball) | Tapping | Yes | Yes (5 variants) | Sparse | No | [Download](https://storage.googleapis.com/android_env-tasks/vector_pinball.tar.gz) |
## Vokram
Vokram is our in-house implementation of an Android app that displays a
Markov-Decision-Process (MDP) graph as buttons on the screen which the agent
must use to select its actions. The observation is simply the color of the
background, and the actions are the buttons themselves which are presented in
different colors.
* **mdp_0000**: This is a task that presents the agent with two colored, but
unlabeled buttons on the screen. Pressing one of the buttons gives the agent
a reward of `-1` and redraws the buttons on the screen. The other button
gives a reward of `+1` and terminates the episode. The color of the buttons
is the same throughout the episode. The sizes of the buttons are randomized
at each screen draw. Pressing anywhere else on the screen gives a reward of
zero. The task lasts up to 60 seconds, at which point the episode is
restarted. The underlying dynamics governing the buttons is a simple 2-state
2-action Markov Decision Process (MDP). The MDP is an intentionally simple
environment that can be used to debug agents.
* **mdp_0001**: This is similar to `mdp_0000` but it's even simpler. It
presents the agent with a single button which gives a reward of `+1` and
terminates the episode when pressed. This task can be used for example to
train agents to click buttons.
* **mdp_0002**: In this task there are two buttons, pressing either of which
will terminate the episode with a return of `+1`.
* **mdp_0003**: An equivalent of `mdp_0000` with rewards reversed: the episode
ends when the wrong button is clicked, and carries on with a new set of
buttons when the correct one is clicked.
Extras returned
* `actions`:
- Set of all buttons present, e.g. `['A', 'B']`.
- Returned when any button is pressed.
- Has `shape=[2], dtype=STRING_U1`.
* `clicks`:
- Character representing the button pressed.
- Returned when any button is pressed.
- Has `shape=[1], dtype=STRING_U1`.
* `buttons`:
- Coordinates of the top left and bottom right corners of each button, e.g
`[[x_a_0, y_a_0, x_a_1, y_a_1], [x_b_0, y_a_0, x_b_1, y_b_1]]`.
- Returned when any button is pressed.
- Has `shape=[2, 4], dtype=INT32`.
**mdp_0000** | **mdp_0001** | **mdp_0002** | **mdp_0003**
------------------------------------------------ | ------------------------------------------------ | ------------------------------------------------ | ------------
 |  |  | 
## Apple Flinger
A clone of Angry Birds. Even though the game offers many levels, we currently
expose six levels. See the original github repo for more info:
https://gitlab.com/ar-/apple-flinger.
Extras returned
Returns no extras.
**apple_flinger_M_1_1** | **apple_flinger_M_1_2** | **apple_flinger_M_1_18**
---------------------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------
 |  | 
**apple_flinger_M_2_1** | **apple_flinger_M_2_2** | **apple_flinger_M_2_18**
----------------------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------
) |  | 
## Blockinger
This is a Tetris clone implemented with on-screen controls. See the original
github repo for more info: https://github.com/tasioleiva/Blockinger.git.
Extras returned
* `down_pressed`, `left_pressed`, `right_pressed`, `rotate_right_pressed`,
`drop_pressed`:
- Indicates that said button has been pressed.
- Returned when said button has been pressed.
- Has `shape=[1], dtype=INT32`.
* `current_board`:
- One-hot encoded state of the board.
- Has `shape=[18, 10], dtype=INT32`.
* `current_line`, `cleared_lines`:
- Index of the relevant line.
- Has `shape=[1], dtype=INT32`.
* `current_piece`, `next_piece`:
- Index representing the type of piece.
- Has `shape=[1], dtype=INT32`.

## Catch
Classic Catch game.
Extras returned
* `ball`:
- `x, y` coordinates of the ball.
- Returned every timestep.
- Has `shape=[2], dtype=INT32`.
* `paddle`:
- `x, y` coordinates of the paddle.
- Returned every timestep.
- Has `shape=[2], dtype=INT32`.
* `paddle_width`:
- Width of the paddle.
- Returned every timestep.
- Has `shape=[1], dtype=INT32`.
* `lives`:
- Number of lives left.
- Returned every timestep.
- Has `shape=[1], dtype=INT32`.

## Classic 2048
This is an Android implementation of a popular game in the 2010s. See the
original github repo for more info: https://github.com/tpcstld/2048.git.
Extras returned
* `grid`:
- State of the board.
- Returned when the board changes.
- Has `shape=[4, 4], dtype=INT32`.
* `direction`:
- Index representing the direction of the last swipe (between 0-3).
- Returned when the swipe prompted a board change.
- Has `shape=[1], dtype=INT32`.

## Dodge
Guide the ball from the red line to the green line without getting hit by the
floating dots.
Extras returned
* `lives`:
- Number of lives left.
- Returned when its value changes.
- Has `shape=[1], dtype=INT32`.
* `level`:
- Current level.
- Returned when its value changes.
- Has `shape=[1], dtype=INT32`.

## DroidFish
Standard chess game. You can choose whether to play as a specific player
(black/white), or have the player colour randomly assigned at the beginning of
each episode. The numbers 1, 10 and 100 indicate the level of difficulty. Take a
look at a few sample moves below to get an idea of roughly how well the bot
plays for each level of difficulty. You can see that the 1% and 10% bots often
make very obvious mistakes. See the original github repo for more info:
https://github.com/peterosterlund2/droidfish.
Extras returned
* `board`:
- State of the board, representing pieces by indices. No piece - 0
- White pieces - 1: king, 2: queen, 3: rook, 4: bishop, 5: knight, 6: pawn
- Black pieces - 7: king, 8: queen, 9: rook, 10: bishop, 11: knight, 12:
pawn
- Returned when the board changes.
- Has `shape=[8, 8], dtype=INT32`.
* `selection`:
- Coordinate of selected piece (between 0-64, -1 if selection is removed)
- Returned when a piece is selected (or unselected).
- Has `shape=[1], dtype=INT32`.
* `moved`:
- Coordinates "from" and "to" cells when a piece is moved (between 0-64)
- Returned when a piece is moved.
- Has `shape=[2], dtype=INT32`.
* `invalid`:
- Coordinates "from" and "to" cells of an invalid move attempt (between
0-64)
- Returned upon invalid move request
- Has `shape=[2], dtype=INT32`.
**droidfish_black_1** | **droidfish_black_10** | **droidfish_black_100**
------------------------------------------------------------------ | -------------------------------------------------------------------- | -----------------------
 |  | 
**droidfish_white_1** | **droidfish_white_10** | **droidfish_white_100**
------------------------------------------------------------------ | -------------------------------------------------------------------- | -----------------------
 |  | 
**droidfish_random_1** | **droidfish_random_10** | **droidfish_random_100**
-------------------------------------------------------------------- | ---------------------------------------------------------------------- | ------------------------
 |  | 
## FlappyDroid
A clone of the well-known game Flappy Birds.
Extras returned
Returns no extras.
**systemui_egg_land_default** | **systemui_egg_land_half_speed**
---------------------------------------------------------------------------------- | --------------------------------
 | 
## FloodIt
FloodIt is a game where the player needs to fill the board with a single color.
The dynamics of the game are driven by a few colorful buttons at the bottom of
the screen, which when pressed cause the currently active region to change its
color to the color of the pressed button. When this active region changes color
it absorbs neighboring squares that have the same color, thus expanding the
active region. The active region starts as a single square at the top-left
corner of the board. The game gives a single reward at the end of the game if
the player manages to fill the entire board with the same color within the
maximum number steps, otherwise the reward is just zero.
This is a very hard-exploration game because the game does not give intermediate
rewards until the very end of the game, and the number of possible moves is
incredibly big.
You can find another implementation of this game in the task
[SGT Puzzles - Flood](#sgt-puzzles-flood).
Extras returned
* `board`:
- State of the board, representing colours by their indices.
- 0: purple, 1: blue, 2: green, 3: yellow, 4: red, 5: pink
- Returned when the board changes.
- Has `shape=[board_size, board_size], dtype=INT32`.
* `clicked`:
- Index of the colour clicked (between 0-5)
- Returned when said colour is clicked.
- Has `shape=[1], dtype=INT32`.
* `flipped`:
- The number of new cells that just got merged into the big blob (0 or
more)
- Returned when the board state changes
- Has `shape=[1], dtype=INT32`.
**floodit_easy** | **floodit_medium** | **floodit_hard**
-------------------------------------------------------- | ------------------------------------------------------------ | ----------------
 |  | 
### Task `mdp_flood_it`
Custom task created for pretraining agents to locate and press FloodIt buttons
on the screen.

## Frozen Bubble
Shoot the coloured bubbles in a direction of your choice. Groups of bubbles with
the same colour will drop. Remove all bubbles from the board before the time
runs out. See the original github repo for more info:
https://github.com/robinst/frozen-bubble-android.git.
Extras returned
Returns no extras.

## Memory Game
Classic memory game. Find the pairs of images. See the original github repo for
more info: https://github.com/sromku/memory-game/.
Extras returned
* `flip`:
- Index of the card flipped
- Returned when a card is clicked.
- Has `shape=[1], dtype=INT32`.
* `cards`:
- Number of cards still on the board.
- Returned upon finding a pair.
- Has `shape=[1], dtype=INT32`.
* `remained`:
- Number of cards remaining at the end of the episode.
- Returned when an episode is over.
- Has `shape=[1], dtype=INT32`.
* `stars`:
- Number of stars achieved at the end of the game.
- Returned upon finishing the game.
- Has `shape=[1], dtype=INT32`.
* `achieved`:
- Score obtained by the end of the episode.
- Returned when an episode is over.
- Has `shape=[1], dtype=INT32`.
**memory_game_animals_beginner** | **memory_game_animals_easy** | **memory_game_monsters_medium**
---------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------- | -------------------------------
 |  | 
**memory_game_monsters_hard** | **memory_game_emojis_hardest** | **memory_game_emojis_master**
---------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | -----------------------------
 |  | 
## Minesweeper
This is an Android implementation of a popular game on Desktop in the 1990s. See
the original github repo for more info: https://gitlab.com/ar-/apple-flinger.
Extras returned
* `hidden`:
- Number of hidden cells.
- Returned whenever the board changes.
- Has `shape=[1], dtype=INT32`.
* `revealed`:
- Number of revealed cells.
- Returned whenever the board changes.
- Has `shape=[1], dtype=INT32`.
* `bombs`:
- Number of bombs in the game.
- Returned whenever the board changes.
- Has `shape=[1], dtype=INT32`.
* `click`:
- Coordinates of the cell clicked (row, column).
- Returned whenever the board changes.
- Has `shape=[2], dtype=INT32`.
* `grid`:
- State of the board.
- -1 = hidden, -2 = marked, 9 = bomb, 0-8 = number of nearby bombs
- Returned whenever the board changes.
- Has `shape=[grid_height, grid_width], dtype=INT32`.
**minesweeper_easy** | **minesweeper_medium** | **minesweeper_hard**
---------------------------------------------------------------- | -------------------------------------------------------------------- | --------------------
 |  | 
## Nostalgic Racer
NostalgicRacer is a racing game that offers Atari-like graphics and controls.
The objective is to maximize the score which increases as the car moves forward
and by collecting coins and speed-ups.
Extras returned
Returns no extras.
### Task `nostalgic_racer`
The player can only control whether the car should move left, move right or stay
put by touching on the screen. If the touch is on the right pane the car moves
right, if the touch is on the left pane the car moves left and no touches leaves
the car in the same position. Pressing for too little time moves the car by
miniscule amounts, with effects similar to staying put.
### Task `nostalgic_racer_2d`
This is the same underlying game as NostalgicRacer with the same objective.
However, the interface is very different. The observation is given as a 2D view
from the top with no perspective and the touchscreen determines the position
that the car should move to (sideways).
**nostaligc_racer** | **nostalgic_racer_2d**
-------------------------------------------------------------- | ----------------------
 | 
## Open Sudoku
Classic Sudoku game with different levels of difficulty. The board is randomised
over a set of 30 boards for each level. See the original github repo for more
info: https://github.com/ogarcia/opensudoku.
Extras returned
* `value`:
- Number pressed (between 1-9, 0 if the "delete" button is pressed).
- Returned upon clicking said button.
- Has `shape=[1], dtype=INT32`.
**open_sudoku_easy** | **open_sudoku_medium** | **open_sudoku_hard**
---------------------------------------------------------------- | -------------------------------------------------------------------- | --------------------
 |  | 
## Perfection
Drag the items corresponding to the targets with the same shape.
Extras returned
* `moving`:
- The ID of the piece being dragged on the screen or 0.
- Returned when its value changes.
- Has `shape=[1], dtype=INT32`.
* `todo`:
- Number of pieces yet to be moved to a hole.
- Returned when its value changes.
- Has `shape=[1], dtype=INT32`.
* `done`:
- Number of pieces correctly moved to a hole.
- Returned when its value changes.
- Has `shape=[1], dtype=INT32`.
**perfection_1_circle_static** | **perfection_1_cube_static** | **perfection_1_plus_static** | **perfection_1_triangle_static**
------------------------------------------------------------------------------------ | ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------- | --------------------------------
 |  |  | 
**perfection_default** | **perfection_4_colors_square_static** | **perfection_4_pieces_static**
-------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------- | ------------------------------
 |  | 
## Rocket Sleigh
A Flappy Bird-like game where you have to collect christmas presents while
avoiding trees. The sleigh is powered by a rocket that needs to recharge over
time after you use up its fuel.
Extras returned
Returns no extras.

## Pong
Classic Pong game.
Extras returned
* `ball`:
- The ball coordinates: [left, top, right, bottom].
- Returned when its value changes.
- Has `shape=[4], dtype=INT32`.
* `computer`:
- The computer paddle coordinates: [left, top, right, bottom].
- Returned when its value changes.
- Has `shape=[4], dtype=INT32`.
* `human`:
- The human paddle coordinates: [left, top, right, bottom].
- Returned when its value changes.
- Has `shape=[4], dtype=INT32`.
* `collision`:
- Indicates collision of paddle and ball: (0=no collision, 1=collision).
- Returned when its value changes.
- Has `shape=[1], dtype=INT32`.
* `state`:
- The current state of the game: (0=pause, 1=ready, 2=running, 3=lose,
4=win).
- Returned when its value changes.
- Has `shape=[1], dtype=INT32`.
**pong_easy** | **pong_default** | **pong_hard**
-------------------------------------------------- | -------------------------------------------------------- | -------------
 |  | 
## SGT Puzzles - Blackbox
There's an invisible laser beam originating from each of the cells at the edge
of the grid. There are also a given number of balls inside the grid, hidden from
the player whose aim is to guess where those balls are. The player can figure
out where those balls might be by looking at how they *deflect* the laser beams.
Clicking on an edge cell the player can reveal information about how that
particular laser beam travels. Click on a cell; if the cell reveals an `H`, it
means the straight laser beam leaving this cell hits a ball frontally. If the
cell reveals a *number* along with another cell with the same number, that means
the laser beam originating in the first cell ends up getting absorbed in the
corresponding pair cell. If the cell reveals an `R`, that means the laser beam
was *reflected*: either its origin and the cell it gets absorbed in is the
*same*, or the beam gets bent before entering the grid. See the description
below.
The balls affect the travel of the laser beam in the following way:
* If a laser beam hits it straight, it gets absorbed. This is denoted by the
letter `H`.
* If a laser beam hits *its corner*, the beam gets deflected by 90 degrees.
* If a laser beam hits a balls corner right at the edge of the grid, i.e.
before it enters the grid, it is considered *reflected*.
* If a laser beam enters the same cell that it originally left, it is
considered *reflected* too.
Once the player has placed the given number of balls on the screen, a green dot
appears that allows the player to check if their solution was correct. See the
original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `balls`:
- The number of balls in the game arena.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `guesses`:
- The number of guessed balls made by the agent.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `wrong`:
- 1 if the guesses are wrong, 0 otherwise.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `lasers`:
- The number of lasers in the grid
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- Representation of the grid cells:
- In the arena: `G`=guessed ball `' '`=empty
- In the range: `[0-9]`=the number of lasers, `H`=beam hit, `R`=beam
reflected,
- `?` unknown, `' '` for the corners
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**blackbox_3x3_1_ball** | **blackbox_5x5_3_balls** | **blackbox_8x8_5_balls** | **blackbox_10x10_5_balls**
--------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- | ----------------------------------------------------------------------------------- | --------------------------
 |  |  | 
## SGT Puzzles - Bridge
Connect nodes on the board so that each number denotes the degree of the given
vertex. Edges are not allowed to cross each other and the graph has to be
connected. Edges have to be horizontal or vertical, and there can be at most two
parallel bridges between any pair of nodes. See the original github repo for
more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `islands`:
- Number of nodes on the board.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- Representation of the current state of the board.
- `[0-9]=island, ' '=empty`
- `'|'=vertical line, '"'=double vertical line, '!'=wrong vertical line`
- `'-'=horizontal line, '='=double horizontal line, '~'=wrong horizontal
line`
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**bridge_7x7_easy** | **bridge_7x7_medium** | **bridge_7x7_hard** | **bridge_10x10_medium** | **bridge_15x15_medium**
------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | -------------------------------------------------------------------------------------------- | -----------------------
 |  |  |  | 
## SGT Puzzles - Cube
There are six coloured squares that you have to collect with a moving cube. If
the cube rolls on top of a coloured cell, the colour gets attached to that side
of the cube; and if a coloured side of the cube rolls on an empty cell, the
colour is removed from the cube. The goal is to have all six sides of the cube
coloured. See the original github repo for more info:
https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `current`:
- Index of the current grid cell the cube is on.
- Returned whenever the cube moves.
- Has `shape=[1], dtype=INT32`.
* `previous`:
- Index of the previous grid cell the cube was on.
- Returned whenever the cube moves.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- The grid state (0 = dark, 1 = blue)
- Returned whenever the cube moves.
- Has `shape=[grid_size, grid_size], dtype=INT32`.
* `face_count`:
- The number of dark faces on the cube.
- Returned whenever the cube moves.
- Has `shape=[1], dtype=INT32`.
* `face_colour_count`:
- The number of blue faces on the cube.
- Returned whenever the cube moves.
- Has `shape=[1], dtype=INT32`.
* `faces`:
- The cube faces (0 = dark, 1 = blue)
- Returned whenever the cube moves.
- Has `shape=[6], dtype=INT32`.
**cube_c3x3** | **cube_c4x4** | **cube_c8x8**
------------------------------------------------------------------------ | ------------------------------------------------------------------------ | -------------
 |  | 
## SGT Puzzles - Dominosa
Place 2x1 size dominoes on the board such that the full board is covered, making
sure that no two dominoes have the same pair of numbers on them. There needs to
be exactly one of (0, 0), (0, 1) (0, 2), ... (1, 1), (1, 2) etc. See the
original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `numbers`:
- Numbers as they appear in the grid.
- Returned whenever the grid changes.
- Has `shape=[height, width], dtype=INT32`.
* `grid`:
- Letters representing the dominoes currently placed on the board.
- 'R=right, L=left, T=top, B=bottom'
- Returned whenever the grid changes.
- Has `shape=[height, with], dtype=INT32`.
* `clash`:
- Represents clashes on the board (i.e. if two dominoes have the same
pair)
- '1=clash, 0=no clash'
- Returned whenever the grid changes.
- Has `shape=[height, width], dtype=INT32`.
**dominosa_1** | **dominosa_3** | **dominosa_3a** | **dominosa_6** | **dominosa_9**
-------------------------------------------------------------------------- | -------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | -------------------------------------------------------------------------- | --------------
 |  |  |  | 
## SGT Puzzles - Fifteen
Order the tiles in increasing order, starting from the top left corner. See the
original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `grid`:
- Current state of the grid.
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=INT32`.
* `empty`:
- Index of the single empty cell in the grid.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `movecount`:
- Number of moves made so far.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
**fifteen_2x2** | **fifteen_3x3** | **fifteen_4x4** | **fifteen_6x6**
---------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | ---------------
 |  |  | 
## SGT Puzzles - Flip
Clicking on a cell will flip the colour of some of its neighbours, which are
determined by the symbol in the cell. The goal is to make all the cells have the
same colour. See the original github repo for more info:
https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `light`:
- The number of light cells.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `dark`:
- The number of dark cells
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `moves`:
- The number of moves made by the player.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- State of the board (0 = dark, 1 = light).
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=INT32`.
* `gridMatrix`:
- The grid matrix of square neighbours (-1 = outside, 1 = neighbour, 0 =
not neighbour)
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size, 3, 3s], dtype=INT32`.
**flip_3x3c** | **flip_4x4c** | **flip_5x5r**
------------------------------------------------------------------------ | ------------------------------------------------------------------------ | -------------
 |  | 
## SGT Puzzles - Flood
FloodIt is a game where the player needs to fill the board with a single color.
The dynamics of the game are driven by colored areas of the board, which when
pressed cause the currently active region to change its color to the color of
the pressed button. When this active region changes color it absorbs neighboring
squares that have the same color, thus expanding the active region. The active
region starts as a single square at the top-left corner of the board. The game
gives a single reward at the end of the game if the player manages to fill the
entire board with the same color within the maximum number steps, otherwise the
reward is just zero. See the original github repo for more info:
https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `board`:
- State of the board, representing colours by their indices.
- 0: red, 1: yellow, 2: green, 3: blue, 4: orange, 5: purple,
- 6: brown, 7: light blue, 8: light green, 9: pink
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=INT32`.
**sgtpuzzles_flood_3x3_easy** | **sgtpuzzles_flood_12x12_medium** | **sgtpuzzles_flood_16x16_hard**
---------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | -------------------------------
 |  | 
## SGT Puzzles - Galaxies
Split the grid up into centrally symmetric areas. The centre of symmetry for
each area is denoted by a dot on the grid. See the original github repo for more
info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `dot`:
- Number of dots on the board.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- String representation of the board:
- `o`=dot, `' '`=empty, `+`, `-`, `|` = cell corners.
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**galaxies_3x3_normal** | **galaxies_5x5_normal** | **galaxies_7x7_normal**
--------------------------------------------------------------------------------- | --------------------------------------------------------------------------------- | -----------------------
 |  | 
**galaxies_7x7_unreasonable** | **galaxies_10x10_normal** | **galaxies_15x15_normal**
--------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- | -------------------------
 |  | 
## SGT Puzzles - Guess
The computer has thought of a sequence of colours that you have to guess. Fill
the top row with colours of your choice, and wait for the computer to give you
feedback about your sequence. It will show a black dot for each colour that is
placed in the correct position, and a white dot for each that is present in the
hidden sequence, but not at the position your guess. Try to figure out the
hidden sequence before you run out of guesses! See the original github repo for
more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `peg`:
- Indices representing the colours selected in the latest row.
- Returned after the row is completed and evaluated.
- Has `shape=[row_length], dtype=INT32`.
* `feedback`:
- Evaluation of the latest guess (0: incorrect, 1: correct place, 2:
correct colour)
- Returned after the row is completed and evaluated.
- Has `shape=[row_length], dtype=INT32`.
**guess_basic** | **guess_quick** | **guess_standard** | **guess_super**
---------------------------------------------------------------------------- | ----------------------------------------------------------------- | ----------------------------------------------------------------------- | ---------------
 |  |  | 
## SGT Puzzles - Inertia
Collect all the blue diamonds on the board without colliding into a bomb. You
can move the ball in the 8 main directions (including the diagonals). The ball
will keep on moving in that direction until it hits a wall, a bomb, a diamond or
a circle. Circles and diamonds have grip, i.e. it will stop the ball from
continuing to move in the direction it was going towards. See the original
github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `gems`:
- Current number of gems still on the board.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `distancemoved`:
- Number of cells just moved.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- Symbols of grid cells (b=blank, g=gem, m=mine, s=stop, w=wall)
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=INT32`.
**inertia_5x5** | **inertia_10x10**
----------------------------------------------------------------- | -----------------
 | 
## SGT Puzzles - Light Up
You have a grid of squares. Some are empty (black) and some are *walls* (grey);
some of the walls are numbered. Your aim is to *light up* all the empty squares
by placing light bulbs in some of them. The numbers denote how many bulbs' light
hits these directly in a straight sight. Meanwhile, no two bulbs should light up
each other (i.e. only one bulb allowed in straight sight). See the original
github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `grid`:
- String representation of the board:
- '#' = blocked, ' ' = empty/black, 'L' = light, 'l' = illuminated,
- 'X' = impossible, 'number' = number of bulbs hitting this cell.
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**light_up_3x3_easy** | **light_up_5x5_easy** | **light_up_7x7_easy** | **light_up_10x10_tricky** | **light_up_14x14_easy**
---------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------- | -----------------------
 |  |  |  | 
## SGT Puzzles - Loopy
Draw a closed loop along the edges of the grid. A number in a cell denotes the
number of edges adjacent to that cell. The loop cannot intersect itself. See the
original github repo for more info: https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `grid`:
- String representation of the board:
- The grid lines and cells:
- `.` = dots (cell corners)
- `0-9` = number on cell face or ` ` for empty face
- `?` = unknown (default),`x` = no line,`-` = `|` = line,`~` = `/` = error
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=STRING_U1`.
**loopy_3x3_easy** | **loopy_5x5_easy** | **loopy_7x7_easy** | **loopy_7x7_normal** | **loopy_7x7_hard**
---------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------- | ------------------
 |  |  |  | 
## SGT Puzzles - Net
There are a number of light bulbs, wires and a single light source in the
middle. Connect the wires such that all bulbs are lit up, without loose ends or
loops in the wiring. You can rotate the tiles by clicking on them. If the task
name has a *w* suffix in it then it is allowed to connect wires on opposing
edges of the grid. See the original github repo for more info:
https://github.com/chrisboyle/sgtpuzzles.
Extras returned
* `active`:
- The number of active/completed cells.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `total`:
- The total number of cells
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- The grid cells represented by numbers.
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=INT32`.
* `gridCompleted`:
- The grid cells' active/completed status (0 = false, 1 = true).
- Returned whenever the grid changes.
- Has `shape=[grid_size, grid_size], dtype=INT32`.
**net_3x3** | **net_5x5** | **net_7x7w** | **net_9x9** | **net_11x11w**
-------------------------------------------------------------------- | -------------------------------------------------------------------- | ---------------------------------------------------------------------- | -------------------------------------------------------------------- | --------------
 |  |  |  | 
## Shattered Pixel Dungeon
Shattered Pixel Dungeon is a Roguelike RPG, with pixel art graphics and lots of
variety and replayability. Every game is unique, with four different playable
characters, randomized levels and enemies, and over 150 items to collect and
use. The game is simple to get into, but has lots of depth. Strategy is required
if you want to win! See the original github repo for more info:
https://github.com/00-Evan/shattered-pixel-dungeon.git.
Extras returned
* `action`:
- The action just completed by the hero.
- Returned whenever an action is taken.
- Has `shape=[1], dtype=STRING_U25`.
* `dst`:
- TThe destination of the action.
- Returned whenever an action is taken.
- Has `shape=[1], dtype=INT32`.
* `level`:
- The level reached.
- Returned whenever a new level is reached.
- Has `shape=[1], dtype=INT32`.
* `depth`:
- The depth of the level reached.
- Returned whenever a new floor is reached.
- Has `shape=[1], dtype=INT32`.
* `deepestFloor`:
- The deepest reached floor statistic.
- Returned whenever a new floor is reached.
- Has `shape=[1], dtype=INT32`.
* `gold`:
- The gold level reached.
- Returned whenever gold is acquired.
- Has `shape=[1], dtype=INT32`.
* `totalGold`:
- The gold collected statistic.
- Returned whenever gold is acquired.
- Has `shape=[1], dtype=INT32`.
* `addedGold`:
- The gold acquired at the most recent step.
- Returned whenever gold is acquired.
- Has `shape=[1], dtype=INT32`.
* `newlyVisited`:
- Number of new squares uncovered.
- Returned whenever new squares are uncovered.
- Has `shape=[1], dtype=INT32`.
* `damageDealt`:
- Damage dealt by hero.
- Returned whenever the hero deals damage.
- Has `shape=[1], dtype=INT32`.
* `damageTaken`:
- Damage taken by hero.
- Returned whenever the hero takes damage.
- Has `shape=[1], dtype=INT32`.
* `HP`:
- Hit Points (Health) of hero.
- Returned whenever its value changes.
- Has `shape=[1], dtype=INT32`.
* `exp`:
- Experience gained by hero.
- Returned whenever the hero gains experience points.
- Has `shape=[1], dtype=INT32`.
* `dew`:
- How many dew (an immediately consumed item) the user picks up.
- Returned whenever a dew is picked up.
- Has `shape=[1], dtype=INT32`.
* `heal`:
- Amount the hero is healed.
- Returned whenever the hero is healed.
- Has `shape=[1], dtype=INT32`.
* `itemPickup`:
- The name of an item that is picked up.
- Returned whenever an item is picked up.
- Has `shape=[1], dtype=STRING_U25`.
* `itemDrop`:
- The name of an item that is dropped.
- Returned whenever an item is dropped.
- Has `shape=[1], dtype=STRING_U25`.
* `destroy`:
- The name of the enemy that is killed.
- Returned whenever an enemy is killed.
- Has `shape=[1], dtype=STRING_U25`.
* `search`:
- Whether the user clicked the "Search" button.
- Returned whenever the button is clicked.
- Has `shape=[1], dtype=INT32`.
* `wait`:
- Whether the user clicked the "Wait" button.
- Returned whenever the button is clicked.
- Has `shape=[1], dtype=INT32`.
* `openedInventory`:
- Whether the user clicked the "Inventory" button.
- Returned whenever the button is clicked.
- Has `shape=[1], dtype=INT32`.
huntress | mage | rogue | warrior
------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | -------
 |  |  | 
## Simple Solitaire
This is an Android implementation of
[Solitaire](https://en.wikipedia.org/wiki/Solitaire) card games. We currently
support 19 variants listed here in alphabetical order. Note that the full
task_IDs take the form `simple_solitaire_aces_up`,
`simple_solitaire_calculation` etc. See the original github repo for more info:
https://github.com/TobiasBielefeld/Simple-Solitaire.git.
Extras returned
* `card`:
- The new visible card `[kind, suit]`:
- kind: `a=ace, k=king, q=queen, j=jack, x=10, 2-9=digit`.
- suit: `c=clubs, d=diamonds, h=hearts, s=spades`.
- Returned when a card is moved.
- Has `shape=[2], dtype=STRING_U1`.
* `stack_i`:
- A non-empty stack of visible cards `[kind, suit]`.
- `i` different extras (`stack_0`, `stack_1`, `...`), one corresponding to
each stack.
- Returned when a card is moved.
- Has `shape=[52, 2], dtype=STRING_U1`.
**aces_up** | **calculation** | **canfield** | **forty_eight** | **freecell**
-------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------- | ------------
 |  |  |  | 
**golf** | **grandfathers_clock** | **gypsy** | **klondike** | **maze**
-------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | --------
 |  |  |  | 
**mod3** | **napoleons_tomb** | **pyramid** | **simple_simon** | **spider**
-------------------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | ----------
 |  |  |  | 
**spiderette** | **tri_peaks** | **vegas** | **yukon**
-------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------ | ---------------------------------------------------------------------------- | ---------
 |  |  | 
## Snake
Classic Snake game.
Extras returned
* `move`:
- The desired direction of movement: `0`=left, `1`=up, `2`=down,
`3`=right.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `direction`:
- The direction of the snake: `1`=north, `2`=south, `3`=east, `4`=west.
- Returned whenever the grid changes.
- Has `shape=[1], dtype=INT32`.
* `grid`:
- The grid cells: `x`=border, `' '`=empty, `s`=snake, `a`=apple.
- Returned whenever the grid changes.
- Has `shape=[13, 19], dtype=STRING_U1`.

## Vector Pinball
A simple vector-based Pinball game with realistic physics. See the original
github repo for more info: https://github.com/dozingcat/Vector-Pinball.
Extras returned
Returns no extras.
**vector_pinball_table_1** | **vector_pinball_table_2** | **vector_pinball_table_3**
---------------------------------------------------------------------------- | ---------------------------------------------------------------------------- | --------------------------
 |  | 
**vector_pinball_table_4** | **vector_pinball_table_5**
---------------------------------------------------------------------------- | --------------------------
 | 
================================================
FILE: docs/instructions.md
================================================
# AndroidEnv - Running the environment
In order to create an AndroidEnv instance you will need to provide two main
components: a [simulator](#the-simulator) and a [task](#the-task). In the
following sections you will learn how you can create them.
### The simulator
First, you will need to provide your Android virtual device (AVD) that the
environment (and through it, the agent) can communicate with. While this can
also be a physical device, in most cases you will need a virtual emulated
device. There are many ways to emulate an AVD - in our examples, we will use
[Android Studio](https://developer.android.com/studio) to create one.
1. In Android Studio, create a virtual device by following this step-by-step
[guide](emulator_guide.md).
2. Follow the steps below to attach the AVD to your environment.
### The task and examples with games and other apps
A `task` is a particular definition of an RL problem that the agent will be
interacting with. A `task` may include critical RL information, such as what the
rewards are, when the episodes are supposed to terminate, and what the reset
procedures are that the environment should perform upon episode termination
(e.g. start or relaunch an app, clear cache, etc.). This information is packaged
into a `Task()` proto message, which gets passed passed to AndroidEnv.
* For ready-made example tasks provided with AndroidEnv, check out the
[Available tasks](example_tasks.md), featuring Vokram (with Markov Decision
Process (MDP)), Pong, DroidFish (a chess clone), Blockinger (a tetris
clone), and more.
* See the [Tasks guide](tasks_guide.md) for details on features and
capabilities of tasks, as well as how to create custom ones.
### Create the environment
After setting up the simulator and creating a task, you may find the
[`loader.load()`](https://github.com/deepmind/android_env/blob/main/android_env/loader.py)
function handy for creating an environment instance by providing relevant
arguments, such as:
* `task_path`: the path pointing to the `.textproto` file describing the
desired task.
* `avd_name`: the name of the AVD specified when your created it in Android
Studio.
* `android_avd_home` (Optional): the path to where the AVD is installed.
(default value: `~/.android/avd`).
* `android_sdk_root` (Optional): the root directory of the Android SDK.
(default value: `~/Android/Sdk`).
* `emulator_path` (Optional): the path to the emulator binary. (default:
`~/Android/Sdk/emulator/emulator`).
* `adb_path` (Optional): the path to the ADB
([Android Debug Bridge](https://developer.android.com/studio/command-line/adb)).
(default value: `~/Android/Sdk/platform-tools/adb`).
For the AVD name and path, in Android Studio, go to **Tools** > **AVD Manager**,
right click on your virtual device, and select **View Details**, where you will
find the `avd_name` and its path.
For the Android SDK location, in Android Studio, go to **Preferences** >
**Appearance & Behavior** > **System Settings** > **Android SDKs** and note the
_Android SDK Location_. In the SDK folder, you will find `/emulator/emulator` as
well as the ADB path (`/platform-tools/adb`).
Your example configuration may look like this, depending on how you set up your
emulator:
```python
from android_env import loader
env = loader.load(
avd_name='my_avd',
android_avd_home='/Users/username/.android/avd',
android_sdk_root='/Users/username/Library/Android/sdk',
emulator_path='/Users/username/Library/Android/sdk/emulator/emulator',
adb_path='/Users/username/Library/Android/sdk/platform-tools/adb',
task_path='/Users/username/android_env/my_tasks/my_task.textproto',
)
```
## Example RL agent scripts
The `examples` directory contains a few simple example agent setups, such as:
* [`run_random_agent.py`](https://github.com/deepmind/android_env/blob/main/examples/run_random_agent.py):
Runs a simple loop performing randomly selected actions in the environment.
* [`run_acme_agent.py`](https://github.com/deepmind/android_env/blob/main/examples/run_acme_agent.py):
Runs a training loop with an
[Acme](https://deepmind.com/research/publications/Acme) DQN agent,
implemented in the popular DeepMind RL framework. This will require to
install the [`acme`](https://github.com/deepmind/acme) dependency.
* [`run_human_agent.py`](https://github.com/deepmind/android_env/blob/main/examples/run_human_agent.py):
Creates a [`pygame`](https://www.pygame.org) instance that lets a human user
interact with the environment and observe environment mechanics, such as
rewards or task extras. You will need to install the [PyGame] dependency.
For instance, here is how you can run
[`run_random_agent.py`](https://github.com/8bitmp3/android_env/blob/main/examples/run_random_agent.py)
in a folder where you have your APK file, such as
[Apple Flinger](https://github.com/deepmind/android_env/blob/main/docs/example_tasks.md#apple-flinger)
from
[Example tasks](https://github.com/deepmind/android_env/blob/main/docs/example_tasks.md).
(The downloaded TAR file contains the APK file and `.textproto` definitions.)
```shell
python3 run_random_agent.py \
--avd_name='my_avd' \
--android_avd_home=/Users/username/.android/avd \
--android_sdk_root=/Users/username/Library/Android/sdk \
--emulator_path=/Users/username/Library/Android/sdk/emulator/emulator \
--adb_path=/Users/username/Library/Android/sdk/platform-tools/adb \
--num_steps=1000 \
--task_path=/Users/username//apple_flinger_M_1_1.textproto
```
================================================
FILE: docs/tasks_guide.md
================================================
# AndroidEnv - Tasks
With AndroidEnv we provide a mechanism for easily defining RL tasks for the
agent to learn. This includes various types of information such as what app/game
it should train on, what rewards the environment returns, or the start state
distribution and the episode end criteria.
## Task structure
A *task* definition is captured in the form of a `Task()` proto message. These
are most easily created by writing a `.textproto` file, then parsing it into a
proto message. In this section you can find a detailed description about the
types of information that make up a task, and an example demonstrating exactly
how to put these into code.
Expand this tab to view the main types of information captured in these messages:
* `id`: An ID used to identify the task.
* `setup_steps`: These are steps the environment will perform right after
launching the simulator. Possible steps include:
* `install_apk`: Installs an application from a specified path to the APK
file.
* `start_activity`: Launches the requested app/activity.
* `rotate`: Sets the orientation of the device (landscape/portrait).
* `reset_steps`: These are steps the environment will perform right at the
beginning of a new RL episode. Possible steps include:
* `force_stop`: Stops a given app.
* `start_activity`: Launches the requested app/activity.
* `start_screen_pinning`: Restricts the agent's interaction to a
particular activity through
[screen pinning](https://support.google.com/android/answer/9455138?hl=en),
meaning the agent will not be able to quit the given app.
* `clear_cache`: Clears the cache of a given app.
* `success_conditions`: For each success condition defined, the environment
will make sure that these conditions were met after finishing `setup_steps`
and `reset_steps`. They might include conditions such as:
* `check_install`: Makes sure that the request app was successfully
installed.
* `wait_for_app_screen`: Waits until the request app was successfully
launched.
* `expected_app_screen`: If this value is set to a particular activity, the
environment will periodically check if the agent is still interacting with
said activity, making sure it has not accidentally quit the application we
want it to be training on.
* `max_episode_sec`: Puts a time limit on the episodes, triggering an episode
reset if the current episode has lasted too long.
* `max_duration_steps`: Puts a step limit on the episodes, triggering an
episode reset once the agent has reached the specified limit.
* `log_parsing_config`: If the environment is parsing logcat messages, this
field will determine what information it should listen for using regular
expressions.
* `filters`: The environment filters log messages for these labels which
signify that such messages were meant to be parsed by AndroidEnv.
* `log_regexps`: Once a log message was identified as relevant using the
filters, the environment parses its contents using these regular
expressions. For example, an application might be sending log messages
of the form `reward: 1.0`, then the task will capture this info using
the regexp `^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$`.
Expand this tab to see what an example `.textproto` file might look like in practice:
```python
id: "classic_2048"
name: "Classic 2048 - Default"
description: "Slide numbered tiles on a grid to combine them to create a tile with the number 2048"
package_name: "com.tpcstld.twozerogame"
full_activity_name: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity"
# Perform these upon launching the environment
setup_steps: [
{
# Install the 2048 app
adb_call: {
install_apk: {
filesystem: {
path: path/to/classic_2048.apk
}
}
}
# Check if it was installed correctly
success_condition: {
check_install: {
package_name: "com.tpcstld.twozerogame"
timeout_sec: 10.0
}
}
},
# Orient the screen in portait mode
{ adb_call: { rotate: { orientation: PORTRAIT_0 } } }
]
# Perform these upon episode resets
reset_steps: [
# Stop the 2048 app
{ adb_call: { force_stop: { package_name: "com.tpcstld.twozerogame" } } },
{ adb_call: { clear_cache: { package_name: "com.tpcstld.twozerogame" } } },
# Start the 2048 app
{
adb_call: {
start_activity: {
full_activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity"
extra_args: [
"--ez", '"RL_TASK_ENABLED"', '"true"',
"--es", '"RL_TASK_GAME_CONFIG"', '"{}"'
]
}
}
# Wait until the app has launched successfully
success_condition: {
wait_for_app_screen: {
app_screen: {
activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity"
view_hierarchy_path: [
]
}
timeout_sec: 10.0
}
num_retries: 10
}
},
# Make sure the agent cannot quit the 2048 app
{
adb_call: {
start_screen_pinning: {
full_activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity"
}
}
}
]
# Periodically check if the agent has accidentally quit the app
expected_app_screen: {
activity: "com.tpcstld.twozerogame/com.tpcstld.twozerogame.MainActivity"
view_hierarchy_path: []
}
max_episode_steps: 500
# Capture expected format of log messages
log_parsing_config: {
filters: ["AndroidRLTask:V"]
log_regexps: {
score: "^[Ss]core: ([-+]?[0-9]*\\.?[0-9]*)$"
reward: "^[Rr]eward: ([-+]?[0-9]*\\.?[0-9]*)$"
episode_end: "^episode[ _]end$"
extra: "^extra: (?P[^ ]*)[ ]?(?P.*)$"
json_extra: "^json_extra: (?P.*)$"
}
}
```
## Log messages and custom APKs
You might have noticed that tasks often rely on log messages exposed by the
Android system, which AndroidEnv can intercept and translate into items such as
rewards, episode end signals or task extras.
One way to define rewards is by using
`log_parsing_config.LogRegexps.RewardEvent` messages in the task proto. These
consist of a regular expression and a numeric value indicating the intended
reward. If the regexp is matched in any of the lines of the logcat stream, the
agent will receive the given reward. It is also possible to have multiple of
these RewardEvents, allowing us to give rewards for different log messages. The
same applies for episode end signals: logcat messages that match the regexps
defined in `log_parsing_config.LogRegexps.episode_end` will trigger an episode
reset.
Of course, applications might not send suitable messages by default, so in order
to have access to such messages, we often add them to the apps' source code to
match our expectations. For example, in the case of the 2048 app, we find in the
game's source code the exact lines where the score is computed, and add a line
to log this value in the format that is expected by the textproto (or
conversely, make sure the textproto matches the format you specified here). For
example:
```java
// Make sure thet LOG_FILTER matches 'filters' in the textproto
public static final String LOG_FILTER = "AndroidRLTask";
// Make sure that the corresponding part of 'log_regexps' will match this string
Log.i(LOG_FILTER, String.format(Locale.ENGLISH, "reward: %r", reward_value))
```
You can take a look at example APKs extended with log messages in the example
tasks (see the section below).
## Example tasks
Along with the environment implementation we provide a set of example task
definitions. These were chosen so that they would demonstrate the large variety
of different challenges (e.g. app navigtion, puzzle games, time-reactive games,
adventure games, card games...) and corresponding interfaces (e.g. button
pressing, swiping, drag-and-drop...) available in AndroidEnv. You can find a
list and detailed description of each of these tasks in
[example_tasks.md](example_tasks.md).
================================================
FILE: examples/__init__.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
================================================
FILE: examples/run_acme_agent.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Acme DQN agent interacting with AndroidEnv."""
from absl import app
from absl import flags
from absl import logging
import acme
from acme import specs
from acme import wrappers as acme_wrappers
from acme.agents.tf import dqn
from acme.tf import networks
from android_env import loader
from android_env.components import config_classes
from android_env.wrappers import discrete_action_wrapper
from android_env.wrappers import float_pixels_wrapper
from android_env.wrappers import image_rescale_wrapper
# Simulator args
flags.DEFINE_string('avd_name', None, 'Name of AVD to use.')
flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.')
flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.')
flags.DEFINE_string('emulator_path',
'~/Android/Sdk/emulator/emulator', 'Path to emulator.')
flags.DEFINE_string('adb_path',
'~/Android/Sdk/platform-tools/adb', 'Path to ADB.')
# Environment args
flags.DEFINE_string('task_path', None, 'Path to task textproto file.')
# Experiment args
flags.DEFINE_integer('num_episodes', 100, 'Number of episodes.')
FLAGS = flags.FLAGS
def apply_wrappers(env):
"""Applies a series of wrappers to the environment."""
env = discrete_action_wrapper.DiscreteActionWrapper(env, action_grid=(10, 10))
env = image_rescale_wrapper.ImageRescaleWrapper(
env, zoom_factors=(0.25, 0.25))
env = float_pixels_wrapper.FloatPixelsWrapper(env)
env = acme_wrappers.SinglePrecisionWrapper(env)
return env
def main(_):
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
emulator_path=FLAGS.emulator_path,
android_sdk_root=FLAGS.android_sdk_root,
android_avd_home=FLAGS.android_avd_home,
avd_name=FLAGS.avd_name,
run_headless=FLAGS.run_headless,
),
adb_controller=config_classes.AdbControllerConfig(
adb_path=FLAGS.adb_path
),
),
)
with loader.load(config) as env:
env = apply_wrappers(env)
env_spec = specs.make_environment_spec(env)
agent = dqn.DQN(
environment_spec=env_spec,
network=networks.DQNAtariNetwork(
num_actions=env_spec.actions.num_values),
batch_size=10,
samples_per_insert=2,
min_replay_size=10)
loop = acme.EnvironmentLoop(env, agent)
loop.run(num_episodes=FLAGS.num_episodes)
if __name__ == '__main__':
logging.set_verbosity('info')
logging.set_stderrthreshold('info')
flags.mark_flags_as_required(['task_path', 'avd_name'])
app.run(main)
================================================
FILE: examples/run_human_agent.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loads an interactive session where a human acts on behalf of an agent."""
import time
from typing import Any
from absl import app
from absl import flags
from absl import logging
from android_env import loader
from android_env.components import action_type
from android_env.components import config_classes
from android_env.components import pixel_fns
import dm_env
import numpy as np
import pygame
# Simulator args.
flags.DEFINE_string('avd_name', None, 'Name of AVD to use.')
flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.')
flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.')
flags.DEFINE_string('emulator_path',
'~/Android/Sdk/emulator/emulator', 'Path to emulator.')
flags.DEFINE_string('adb_path',
'~/Android/Sdk/platform-tools/adb', 'Path to ADB.')
flags.DEFINE_boolean('run_headless', True, 'Optionally turn off display.')
# Environment args.
flags.DEFINE_string('task_path', None, 'Path to task textproto file.')
# Pygame args.
flags.DEFINE_list('screen_size', '480,720', 'Screen width, height in pixels.')
flags.DEFINE_float('frame_rate', 1.0/30.0, 'Frame rate in seconds.')
FLAGS = flags.FLAGS
def _get_action_from_event(
event: pygame.event.Event, screen: pygame.Surface, orientation: int
) -> dict[str, Any]:
"""Returns the current action by reading data from a pygame Event object."""
act_type = action_type.ActionType.LIFT
if event.type == pygame.MOUSEBUTTONDOWN:
act_type = action_type.ActionType.TOUCH
return {
'action_type':
np.array(act_type, dtype=np.int32),
'touch_position':
_scale_position(event.pos, screen, orientation),
}
def _get_action_from_mouse(
screen: pygame.Surface, orientation: int
) -> dict[str, Any]:
"""Returns the current action by reading data from the mouse."""
act_type = action_type.ActionType.LIFT
if pygame.mouse.get_pressed()[0]:
act_type = action_type.ActionType.TOUCH
return {
'action_type':
np.array(act_type, dtype=np.int32),
'touch_position':
_scale_position(pygame.mouse.get_pos(), screen, orientation),
}
def _scale_position(position: np.ndarray, screen: pygame.Surface,
orientation: int) -> np.ndarray:
"""AndroidEnv accepts mouse inputs as floats so we need to scale it."""
scaled_pos = np.divide(position, screen.get_size(), dtype=np.float32)
if orientation == 1: # LANDSCAPE_90
scaled_pos = scaled_pos[::-1]
scaled_pos[0] = 1 - scaled_pos[0]
return scaled_pos
def _accumulate_reward(
timestep: dm_env.TimeStep,
episode_return: float) -> float:
"""Accumulates rewards collected over the course of an episode."""
if timestep.reward and timestep.reward != 0:
logging.info('Reward: %s', timestep.reward)
episode_return += timestep.reward
if timestep.first():
episode_return = 0
elif timestep.last():
logging.info('Episode return: %s', episode_return)
return episode_return
def _render_pygame_frame(surface: pygame.Surface, screen: pygame.Surface,
orientation: int, timestep: dm_env.TimeStep) -> None:
"""Displays latest observation on pygame surface."""
frame = timestep.observation['pixels'][:, :, :3] # (H x W x C) (RGB)
frame = pixel_fns.transpose_pixels(frame) # (W x H x C)
frame = pixel_fns.orient_pixels(frame, orientation)
pygame.surfarray.blit_array(surface, frame)
pygame.transform.smoothscale(surface, screen.get_size(), screen)
pygame.display.flip()
def main(_):
pygame.init()
pygame.display.set_caption('android_human_agent')
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
emulator_path=FLAGS.emulator_path,
android_sdk_root=FLAGS.android_sdk_root,
android_avd_home=FLAGS.android_avd_home,
avd_name=FLAGS.avd_name,
run_headless=FLAGS.run_headless,
),
adb_controller=config_classes.AdbControllerConfig(
adb_path=FLAGS.adb_path
),
),
)
with loader.load(config) as env:
# Reset environment.
first_timestep = env.reset()
orientation = np.argmax(first_timestep.observation['orientation'])
# Create pygame canvas.
screen_size = list(map(int, FLAGS.screen_size)) # (W x H)
obs_shape = env.observation_spec()['pixels'].shape[:2] # (H x W)
if (orientation == 1 or orientation == 3): # LANDSCAPE_90 | LANDSCAPE_270
screen_size = screen_size[::-1]
obs_shape = obs_shape[::-1]
screen = pygame.display.set_mode(screen_size) # takes (W x H)
surface = pygame.Surface(obs_shape[::-1]) # takes (W x H)
# Start game loop.
prev_frame = time.time()
episode_return = 0
while True:
if pygame.key.get_pressed()[pygame.K_ESCAPE]:
return
all_events = pygame.event.get()
for event in all_events:
if event.type == pygame.QUIT:
return
# Filter event queue for mouse click events.
mouse_click_events = [
event for event in all_events
if event.type in [pygame.MOUSEBUTTONDOWN, pygame.MOUSEBUTTONUP]
]
# Process all mouse click events.
for event in mouse_click_events:
action = _get_action_from_event(event, screen, orientation)
timestep = env.step(action)
episode_return = _accumulate_reward(timestep, episode_return)
_render_pygame_frame(surface, screen, orientation, timestep)
# Sample the current position of the mouse either way.
action = _get_action_from_mouse(screen, orientation)
timestep = env.step(action)
episode_return = _accumulate_reward(timestep, episode_return)
_render_pygame_frame(surface, screen, orientation, timestep)
# Limit framerate.
now = time.time()
frame_time = now - prev_frame
if frame_time < FLAGS.frame_rate:
time.sleep(FLAGS.frame_rate - frame_time)
prev_frame = now
if __name__ == '__main__':
logging.set_verbosity('info')
logging.set_stderrthreshold('info')
flags.mark_flags_as_required(['avd_name', 'task_path'])
app.run(main)
================================================
FILE: examples/run_random_agent.py
================================================
# coding=utf-8
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Example script demonstrating usage of AndroidEnv."""
from absl import app
from absl import flags
from absl import logging
from android_env import loader
from android_env.components import config_classes
from dm_env import specs
import numpy as np
FLAGS = flags.FLAGS
# Simulator args.
flags.DEFINE_string('avd_name', None, 'Name of AVD to use.')
flags.DEFINE_string('android_avd_home', '~/.android/avd', 'Path to AVD.')
flags.DEFINE_string('android_sdk_root', '~/Android/Sdk', 'Path to SDK.')
flags.DEFINE_string('emulator_path',
'~/Android/Sdk/emulator/emulator', 'Path to emulator.')
flags.DEFINE_string('adb_path',
'~/Android/Sdk/platform-tools/adb', 'Path to ADB.')
flags.DEFINE_bool('run_headless', False,
'Whether to display the emulator window.')
# Environment args.
flags.DEFINE_string('task_path', None, 'Path to task textproto file.')
# Experiment args.
flags.DEFINE_integer('num_steps', 1000, 'Number of steps to take.')
def main(_):
config = config_classes.AndroidEnvConfig(
task=config_classes.FilesystemTaskConfig(path=FLAGS.task_path),
simulator=config_classes.EmulatorConfig(
emulator_launcher=config_classes.EmulatorLauncherConfig(
emulator_path=FLAGS.emulator_path,
android_sdk_root=FLAGS.android_sdk_root,
android_avd_home=FLAGS.android_avd_home,
avd_name=FLAGS.avd_name,
run_headless=FLAGS.run_headless,
),
adb_controller=config_classes.AdbControllerConfig(
adb_path=FLAGS.adb_path
),
),
)
with loader.load(config) as env:
action_spec = env.action_spec()
def get_random_action() -> dict[str, np.ndarray]:
"""Returns a random AndroidEnv action."""
action = {}
for k, v in action_spec.items():
if isinstance(v, specs.DiscreteArray):
action[k] = np.random.randint(low=0, high=v.num_values, dtype=v.dtype)
else:
action[k] = np.random.random(size=v.shape).astype(v.dtype)
return action
_ = env.reset()
for step in range(FLAGS.num_steps):
action = get_random_action()
timestep = env.step(action=action)
reward = timestep.reward
logging.info('Step %r, action: %r, reward: %r', step, action, reward)
if __name__ == '__main__':
logging.set_verbosity('info')
logging.set_stderrthreshold('info')
flags.mark_flags_as_required(['avd_name', 'task_path'])
app.run(main)
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = [
"setuptools",
"wheel"
]
build-backend = "setuptools.build_meta"
[project]
name = "android-env"
version = "1.2.2"
description = "AndroidEnv environment and library for training agents."
authors = [{name = "DeepMind"}]
license = {file = "LICENSE"}
readme = {text = "Read the README at https://github.com/deepmind/android_env for more information.", content-type = "text/plain"}
keywords = ["Android", "OS", "reinforcement-learning"]
requires-python = ">=3.10"
dependencies = [
"absl-py>=0.1.0",
"dm_env",
"grpcio",
"numpy>=1.21",
"portpicker>=1.2.0",
"protobuf>=2.6",
"pygame",
]
[project.optional-dependencies]
testing = [
"gym",
"pillow",
"pytype",
]
acme = ["dm-acme"]
gym = ["gym"]
[project.urls]
repository = "https://github.com/deepmind/android_env"
deepmind = "https://www.deepmind.com/publications/androidenv-the-android-learning-environment"
arxiv = "https://arxiv.org/abs/2105.13231"
================================================
FILE: setup.py
================================================
# Copyright 2026 DeepMind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Simple package definition for using with `pip`."""
import importlib
import os
import setuptools
from setuptools import find_packages
from setuptools import setup
from setuptools.command.build_ext import build_ext
from setuptools.command.build_py import build_py
_ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
# Tuple of proto message definitions to build Python bindings for. Paths must
# be relative to root directory.
_ANDROID_ENV_PROTOS = (
'android_env/proto/adb.proto',
'android_env/proto/emulator_controller.proto',
'android_env/proto/snapshot.proto',
'android_env/proto/snapshot_service.proto',
'android_env/proto/state.proto',
'android_env/proto/task.proto',
'android_env/proto/a11y/a11y.proto',
'android_env/proto/a11y/android_accessibility_action.proto',
'android_env/proto/a11y/android_accessibility_forest.proto',
'android_env/proto/a11y/android_accessibility_node_info.proto',
'android_env/proto/a11y/android_accessibility_node_info_clickable_span.proto',
'android_env/proto/a11y/android_accessibility_tree.proto',
'android_env/proto/a11y/android_accessibility_window_info.proto',
'android_env/proto/a11y/rect.proto',
)
class _GenerateProtoFiles(setuptools.Command):
"""Command to generate protobuf bindings for AndroidEnv protos."""
descriptions = 'Generates Python protobuf bindings for AndroidEnv protos.'
user_options = []
def initialize_options(self):
pass
def finalize_options(self):
pass
def run(self):
# Import grpc_tools here, after setuptools has installed setup_requires
# dependencies.
from grpc_tools import protoc # pylint: disable=g-import-not-at-top
with importlib.resources.as_file(
importlib.resources.files('grpc_tools').joinpath('_proto')
) as path:
grpc_protos_include = str(path)
for proto_path in _ANDROID_ENV_PROTOS:
proto_args = [
'grpc_tools.protoc',
'--proto_path={}'.format(grpc_protos_include),
'--proto_path={}'.format(_ROOT_DIR),
'--python_out={}'.format(_ROOT_DIR),
'--pyi_out={}'.format(_ROOT_DIR),
'--grpc_python_out={}'.format(_ROOT_DIR),
os.path.join(_ROOT_DIR, proto_path),
]
if protoc.main(proto_args) != 0:
raise RuntimeError('ERROR: {}'.format(proto_args))
class _BuildExt(build_ext):
"""Generate protobuf bindings in build_ext stage."""
def run(self):
self.run_command('generate_protos')
build_ext.run(self)
class _BuildPy(build_py):
"""Generate protobuf bindings in build_py stage."""
def run(self):
self.run_command('generate_protos')
build_py.run(self)
setup(
packages=find_packages(exclude=['examples']),
package_data={'': ['proto/*.proto']}, # Copy protobuf files.
include_package_data=True,
setup_requires=['grpcio-tools'],
cmdclass={
'build_ext': _BuildExt,
'build_py': _BuildPy,
'generate_protos': _GenerateProtoFiles,
},
)