Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 15 additions & 21 deletions android_env/components/action_fns.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,31 +88,25 @@ def _prepare_touch_action(
"""

touch_events = []
for i, finger_action in enumerate(_split_touch_action(action, num_fingers)):
is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH
touch_position = finger_action['touch_position']

# First finger
is_touch = action['action_type'] == action_type_lib.ActionType.TOUCH
touch_position = action['touch_position']
touch_pixels = pixel_fns.touch_position_to_pixel_position(
touch_position, width_height=(screen_width, screen_height)
)
touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, 0))

# Subsequent fingers
for i in range(2, num_fingers + 1):
is_touch = action[f'action_type_{i}'] == action_type_lib.ActionType.TOUCH
touch_position = action[f'touch_position_{i}']
touch_pixels = pixel_fns.touch_position_to_pixel_position(
touch_position, width_height=(screen_width, screen_height)
)
touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i))
return touch_events
touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i - 1))


def _split_touch_action(
action: dict[str, np.ndarray], num_fingers: int
) -> list[dict[str, np.ndarray]]:
"""Splits a multitouch action into a list of single-touch actions."""

single_touch_actions = [{
'action_type': action['action_type'],
'touch_position': action['touch_position'],
}]
for i in range(2, num_fingers + 1):
single_touch_actions.append({
'action_type': action[f'action_type_{i}'],
'touch_position': action[f'touch_position_{i}'],
})
return single_touch_actions
return touch_events


def lift_all_fingers_action(num_fingers: int) -> dict[str, np.ndarray]:
Expand Down
101 changes: 101 additions & 0 deletions android_env/components/action_fns_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import timeit
from unittest import mock

from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from android_env.components import action_fns
from android_env.components import action_type as action_type_lib
from android_env.components import errors
from android_env.components import pixel_fns
from android_env.components.simulators import base_simulator
import numpy as np

Expand Down Expand Up @@ -183,5 +186,103 @@ def test_lift_all_fingers_action(
np.testing.assert_array_equal(v, output[k])


_RUN_BENCHMARKS = flags.DEFINE_bool(
'run_benchmarks', False, 'Whether to run microbenchmarks.'
)


def _old_split_touch_action(action, num_fingers):
single_touch_actions = [{
'action_type': action['action_type'],
'touch_position': action['touch_position'],
}]
for i in range(2, num_fingers + 1):
single_touch_actions.append({
'action_type': action[f'action_type_{i}'],
'touch_position': action[f'touch_position_{i}'],
})
return single_touch_actions


def _old_prepare_touch_action(action, screen_width, screen_height, num_fingers):
touch_events = []
for i, finger_action in enumerate(
_old_split_touch_action(action, num_fingers)
):
is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH
touch_position = finger_action['touch_position']
touch_pixels = pixel_fns.touch_position_to_pixel_position(
touch_position, width_height=(screen_width, screen_height)
)
touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i))
return touch_events


_BENCHMARK_ACTION_1F = {
'action_type': np.array(action_type_lib.ActionType.TOUCH),
'touch_position': np.array([0.2, 0.5], np.float32),
}

_BENCHMARK_ACTION_3F = {
'action_type': np.array(action_type_lib.ActionType.TOUCH),
'touch_position': np.array([0.2, 0.5], np.float32),
'action_type_2': np.array(action_type_lib.ActionType.LIFT),
'touch_position_2': np.array([0.1, 0.2], np.float32),
'action_type_3': np.array(action_type_lib.ActionType.TOUCH),
'touch_position_3': np.array([0.5, 0.2], np.float32),
}


class ActionFnsBenchmark(parameterized.TestCase):

def test_prepare_touch_action(self):
if not _RUN_BENCHMARKS.value:
self.skipTest('Benchmark disabled')

number = 100000

# 1 finger
t_old_1 = timeit.Timer(
'_old_prepare_touch_action(_BENCHMARK_ACTION_1F, 800, 600, 1)',
globals=globals(),
)
res_old_1 = t_old_1.timeit(number=number)
print(
f'BenchmarkPrepareTouchAction_1f_Old {number}'
f' {res_old_1 / number * 1e9:.0f} ns/op'
)

t_new_1 = timeit.Timer(
'action_fns._prepare_touch_action(_BENCHMARK_ACTION_1F, 800, 600, 1)',
globals=globals(),
)
res_new_1 = t_new_1.timeit(number=number)
print(
f'BenchmarkPrepareTouchAction_1f_New {number}'
f' {res_new_1 / number * 1e9:.0f} ns/op'
)

# 3 fingers
t_old_3 = timeit.Timer(
'_old_prepare_touch_action(_BENCHMARK_ACTION_3F, 800, 600, 3)',
globals=globals(),
)
res_old_3 = t_old_3.timeit(number=number)
print(
f'BenchmarkPrepareTouchAction_3f_Old {number}'
f' {res_old_3 / number * 1e9:.0f} ns/op'
)

t_new_3 = timeit.Timer(
'action_fns._prepare_touch_action(_BENCHMARK_ACTION_3F, 800, 600, 3)',
globals=globals(),
)
res_new_3 = t_new_3.timeit(number=number)
print(
f'BenchmarkPrepareTouchAction_3f_New {number}'
f' {res_new_3 / number * 1e9:.0f} ns/op'
)


if __name__ == '__main__':
absltest.main()
Loading