diff --git a/android_env/wrappers/swipe_action_wrapper.py b/android_env/wrappers/swipe_action_wrapper.py new file mode 100644 index 0000000..1faefc1 --- /dev/null +++ b/android_env/wrappers/swipe_action_wrapper.py @@ -0,0 +1,146 @@ +# coding=utf-8 +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Wraps the AndroidEnv environment to provide swipe actions.""" + +from collections.abc import Mapping, Sequence +from typing import Any, cast + +from android_env import env_interface +from android_env.components import action_type +from android_env.wrappers import base_wrapper +import dm_env +from dm_env import specs +import numpy as np + + +class SwipeActionWrapper(base_wrapper.BaseWrapper): + """AndroidEnv with swipe actions. + + Converts a single swipe action (start position, end position) into a sequence + of TOUCH steps with linearly interpolated positions, followed by a LIFT at the + end position. + """ + + def __init__( + self, + env: env_interface.AndroidEnvInterface, + *, + num_steps: int = 10, + ) -> None: + """Initializes the instance. + + Args: + env: The underlying environment. + num_steps: The number of TOUCH steps used to interpolate between the start + and end positions. + """ + super().__init__(env) + self._assert_base_env() + if num_steps < 1: + raise ValueError(f'num_steps must be >= 1, got {num_steps}.') + self._num_steps = num_steps + self._env_steps = 0 + self._touch_position_spec = cast( + specs.BoundedArray, self._env.action_spec()['touch_position'] + ) + self._action_type_dtype = self._env.action_spec()['action_type'].dtype + + def _assert_base_env(self) -> None: + parent_action_spec = self._env.action_spec() + assert len(parent_action_spec) == 2 + assert not parent_action_spec['action_type'].shape + assert parent_action_spec['touch_position'].shape == (2,) + + def stats(self) -> dict[str, Any]: + """Returns a dictionary of metrics logged by the environment.""" + logs = self._env.stats() + logs.update({'env_steps': self._env_steps}) + return logs + + def _process_action( + self, action: Mapping[str, np.ndarray] + ) -> Sequence[dict[str, np.ndarray]]: + start = np.asarray(action['start_position'], dtype=np.float32) + end = np.asarray(action['end_position'], dtype=np.float32) + touch_dtype = self._touch_position_spec.dtype + + alphas = np.linspace(0.0, 1.0, self._num_steps, dtype=np.float32) + positions = start + alphas[:, np.newaxis] * (end - start) + + actions = [] + for position in positions: + actions.append({ + 'action_type': np.array(action_type.ActionType.TOUCH).astype( + self._action_type_dtype + ), + 'touch_position': position.astype(touch_dtype), + }) + + actions.append({ + 'action_type': np.array(action_type.ActionType.LIFT).astype( + self._action_type_dtype + ), + 'touch_position': end.astype(touch_dtype), + }) + return actions + + def step(self, action: Mapping[str, np.ndarray]) -> dm_env.TimeStep: + """Takes a step in the environment.""" + actions = self._process_action(action) + total_reward = 0.0 + reward_discount = 1.0 + step_type = dm_env.StepType.MID + discount = None + observation = None + for sub_action in actions: + step_type, reward, discount, observation = self._env.step(sub_action) + self._env_steps += 1 + if reward is not None: + total_reward += reward_discount * reward + if discount is not None: + reward_discount *= discount + if step_type == dm_env.StepType.LAST: + return dm_env.TimeStep( + step_type=step_type, + reward=total_reward, + discount=discount, + observation=observation, + ) + return dm_env.TimeStep( + step_type=step_type, + reward=total_reward, + discount=discount, + observation=observation, + ) + + def action_spec(self) -> dict[str, specs.Array]: + touch_spec = self._touch_position_spec + return { + 'start_position': specs.BoundedArray( + shape=(2,), + dtype=touch_spec.dtype, + minimum=touch_spec.minimum, + maximum=touch_spec.maximum, + name='start_position', + ), + 'end_position': specs.BoundedArray( + shape=(2,), + dtype=touch_spec.dtype, + minimum=touch_spec.minimum, + maximum=touch_spec.maximum, + name='end_position', + ), + } diff --git a/android_env/wrappers/swipe_action_wrapper_test.py b/android_env/wrappers/swipe_action_wrapper_test.py new file mode 100644 index 0000000..abe18b0 --- /dev/null +++ b/android_env/wrappers/swipe_action_wrapper_test.py @@ -0,0 +1,246 @@ +# coding=utf-8 +# Copyright 2026 DeepMind Technologies Limited. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest import mock + +from absl.testing import absltest +from absl.testing import parameterized +from android_env import env_interface +from android_env.components import action_type +from android_env.wrappers import swipe_action_wrapper +import dm_env +from dm_env import specs +import numpy as np + + +def _make_array_spec(shape, dtype, name): + return specs.BoundedArray( + name=name, + shape=shape, + dtype=dtype, + minimum=np.zeros(shape), + maximum=np.ones(shape), + ) + + +class SwipeActionWrapperTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self._base_action_spec = { + 'action_type': specs.DiscreteArray(num_values=3, name='action_type'), + 'touch_position': _make_array_spec( + shape=(2,), dtype=np.float32, name='touch_position' + ), + } + self.base_env = mock.create_autospec(env_interface.AndroidEnvInterface) + self.base_env.action_spec.return_value = self._base_action_spec + + def test_process_action_interpolation(self): + num_steps = 5 + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=num_steps + ) + start = np.array([0.0, 0.0], dtype=np.float32) + end = np.array([1.0, 1.0], dtype=np.float32) + action = {'start_position': start, 'end_position': end} + + actions = wrapped_env._process_action(action) + self.assertLen(actions, num_steps + 1) + + expected_alphas = [0.0, 0.25, 0.5, 0.75, 1.0] + for i, alpha in enumerate(expected_alphas): + expected_position = start * (1.0 - alpha) + end * alpha + np.testing.assert_allclose( + actions[i]['touch_position'], expected_position, rtol=1e-6 + ) + self.assertEqual( + actions[i]['action_type'], action_type.ActionType.TOUCH + ) + + self.assertEqual(actions[-1]['action_type'], action_type.ActionType.LIFT) + np.testing.assert_allclose(actions[-1]['touch_position'], end, rtol=1e-6) + + def test_process_action_single_step(self): + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=1 + ) + start = np.array([0.2, 0.3], dtype=np.float32) + end = np.array([0.8, 0.9], dtype=np.float32) + actions = wrapped_env._process_action({ + 'start_position': start, + 'end_position': end, + }) + + self.assertLen(actions, 2) + np.testing.assert_allclose(actions[0]['touch_position'], start, rtol=1e-6) + self.assertEqual(actions[0]['action_type'], action_type.ActionType.TOUCH) + np.testing.assert_allclose(actions[1]['touch_position'], end, rtol=1e-6) + self.assertEqual(actions[1]['action_type'], action_type.ActionType.LIFT) + + def test_invalid_num_steps(self): + with self.assertRaisesRegex(ValueError, 'num_steps must be >= 1'): + swipe_action_wrapper.SwipeActionWrapper(self.base_env, num_steps=0) + + def test_reset(self): + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=5 + ) + fake_timestep = 'ts' + self.base_env.reset.return_value = fake_timestep + ts = wrapped_env.reset() + self.base_env.reset.assert_called_once() + self.assertEqual(fake_timestep, ts) + + def test_step(self): + num_steps = 5 + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=num_steps + ) + fake_timestep = dm_env.TimeStep( + step_type='fake_type', reward=0.0, discount=1.0, observation='fake_obs' + ) + self.base_env.step.return_value = fake_timestep + self.base_env.stats.return_value = {} + + ts = wrapped_env.step({ + 'start_position': np.array([0.0, 0.0], dtype=np.float32), + 'end_position': np.array([1.0, 1.0], dtype=np.float32), + }) + stats = wrapped_env.stats() + + self.assertEqual(num_steps + 1, self.base_env.step.call_count) + self.assertIsInstance(ts, dm_env.TimeStep) + self.assertIsInstance(stats, dict) + self.assertIn('env_steps', stats) + self.assertEqual(stats['env_steps'], num_steps + 1) + + def test_step_terminal(self): + num_steps = 5 + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=num_steps + ) + normal_timestep = dm_env.TimeStep( + step_type=dm_env.StepType.MID, + reward=0.0, + discount=1.0, + observation='fake_obs', + ) + terminal_timestep = dm_env.TimeStep( + step_type=dm_env.StepType.LAST, + reward=1.0, + discount=0.0, + observation='final_obs', + ) + self.base_env.step.side_effect = [ + normal_timestep, + normal_timestep, + terminal_timestep, + ] + self.base_env.stats.return_value = {} + + ts = wrapped_env.step({ + 'start_position': np.array([0.0, 0.5], dtype=np.float32), + 'end_position': np.array([1.0, 0.5], dtype=np.float32), + }) + + self.assertEqual(3, self.base_env.step.call_count) + self.assertIs(ts.step_type, dm_env.StepType.LAST) + self.assertEqual(ts.reward, 1.0) + + def test_step_accumulates_reward(self): + num_steps = 3 + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=num_steps + ) + self.base_env.step.side_effect = [ + dm_env.TimeStep(dm_env.StepType.MID, 0.1, 1.0, 'obs'), + dm_env.TimeStep(dm_env.StepType.MID, 0.2, 1.0, 'obs'), + dm_env.TimeStep(dm_env.StepType.MID, 0.3, 1.0, 'obs'), + dm_env.TimeStep(dm_env.StepType.MID, 0.4, 1.0, 'obs'), + ] + self.base_env.stats.return_value = {} + + ts = wrapped_env.step({ + 'start_position': np.array([0.0, 0.0], dtype=np.float32), + 'end_position': np.array([1.0, 0.0], dtype=np.float32), + }) + + self.assertAlmostEqual(ts.reward, 1.0) + + def test_step_accumulates_discounted_reward(self): + num_steps = 3 + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=num_steps + ) + self.base_env.step.side_effect = [ + dm_env.TimeStep(dm_env.StepType.MID, 1.0, 0.5, 'obs'), + dm_env.TimeStep(dm_env.StepType.MID, 1.0, 0.5, 'obs'), + dm_env.TimeStep(dm_env.StepType.MID, 1.0, 0.5, 'obs'), + dm_env.TimeStep(dm_env.StepType.MID, 1.0, 0.5, 'obs'), + ] + self.base_env.stats.return_value = {} + + ts = wrapped_env.step({ + 'start_position': np.array([0.0, 0.0], dtype=np.float32), + 'end_position': np.array([1.0, 0.0], dtype=np.float32), + }) + + # 1.0 + 0.5 + 0.25 + 0.125 = 1.875 + self.assertAlmostEqual(ts.reward, 1.875) + + def test_observation_spec(self): + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=5 + ) + fake_obs_spec = 'fake_obs_spec' + self.base_env.observation_spec.return_value = fake_obs_spec + observation_spec = wrapped_env.observation_spec() + self.base_env.observation_spec.assert_called_once() + self.assertEqual(fake_obs_spec, observation_spec) + + def test_action_spec(self): + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=5 + ) + action_spec = wrapped_env.action_spec() + self.assertCountEqual( + action_spec.keys(), ['start_position', 'end_position'] + ) + for key in ('start_position', 'end_position'): + spec = action_spec[key] + self.assertIsInstance(spec, specs.BoundedArray) + self.assertEqual(spec.shape, (2,)) + self.assertEqual(spec.dtype, np.float32) + + def test_stats(self): + self.base_env.stats.return_value = { + 'some_key': 12345, + 'another_key': 5.4321, + } + wrapped_env = swipe_action_wrapper.SwipeActionWrapper( + self.base_env, num_steps=5 + ) + + stats = wrapped_env.stats() + + self.assertIn('some_key', stats) + self.assertEqual(stats['some_key'], 12345) + self.assertIn('env_steps', stats) + self.assertEqual(stats['env_steps'], 0) + + +if __name__ == '__main__': + absltest.main()