diff --git a/android_env/components/action_fns.py b/android_env/components/action_fns.py index 0a2e125..162cec0 100644 --- a/android_env/components/action_fns.py +++ b/android_env/components/action_fns.py @@ -88,31 +88,25 @@ def _prepare_touch_action( """ touch_events = [] - for i, finger_action in enumerate(_split_touch_action(action, num_fingers)): - is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH - touch_position = finger_action['touch_position'] + + # First finger + is_touch = action['action_type'] == action_type_lib.ActionType.TOUCH + touch_position = action['touch_position'] + touch_pixels = pixel_fns.touch_position_to_pixel_position( + touch_position, width_height=(screen_width, screen_height) + ) + touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, 0)) + + # Subsequent fingers + for i in range(2, num_fingers + 1): + is_touch = action[f'action_type_{i}'] == action_type_lib.ActionType.TOUCH + touch_position = action[f'touch_position_{i}'] touch_pixels = pixel_fns.touch_position_to_pixel_position( touch_position, width_height=(screen_width, screen_height) ) - touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i)) - return touch_events + touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i - 1)) - -def _split_touch_action( - action: dict[str, np.ndarray], num_fingers: int -) -> list[dict[str, np.ndarray]]: - """Splits a multitouch action into a list of single-touch actions.""" - - single_touch_actions = [{ - 'action_type': action['action_type'], - 'touch_position': action['touch_position'], - }] - for i in range(2, num_fingers + 1): - single_touch_actions.append({ - 'action_type': action[f'action_type_{i}'], - 'touch_position': action[f'touch_position_{i}'], - }) - return single_touch_actions + return touch_events def lift_all_fingers_action(num_fingers: int) -> dict[str, np.ndarray]: diff --git a/android_env/components/action_fns_test.py b/android_env/components/action_fns_test.py index 20b46b0..4f25ad7 100644 --- a/android_env/components/action_fns_test.py +++ b/android_env/components/action_fns_test.py @@ -13,13 +13,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import timeit from unittest import mock +from absl import flags from absl.testing import absltest from absl.testing import parameterized from android_env.components import action_fns from android_env.components import action_type as action_type_lib from android_env.components import errors +from android_env.components import pixel_fns from android_env.components.simulators import base_simulator import numpy as np @@ -183,5 +186,103 @@ def test_lift_all_fingers_action( np.testing.assert_array_equal(v, output[k]) +_RUN_BENCHMARKS = flags.DEFINE_bool( + 'run_benchmarks', False, 'Whether to run microbenchmarks.' +) + + +def _old_split_touch_action(action, num_fingers): + single_touch_actions = [{ + 'action_type': action['action_type'], + 'touch_position': action['touch_position'], + }] + for i in range(2, num_fingers + 1): + single_touch_actions.append({ + 'action_type': action[f'action_type_{i}'], + 'touch_position': action[f'touch_position_{i}'], + }) + return single_touch_actions + + +def _old_prepare_touch_action(action, screen_width, screen_height, num_fingers): + touch_events = [] + for i, finger_action in enumerate( + _old_split_touch_action(action, num_fingers) + ): + is_touch = finger_action['action_type'] == action_type_lib.ActionType.TOUCH + touch_position = finger_action['touch_position'] + touch_pixels = pixel_fns.touch_position_to_pixel_position( + touch_position, width_height=(screen_width, screen_height) + ) + touch_events.append((touch_pixels[0], touch_pixels[1], is_touch, i)) + return touch_events + + +_BENCHMARK_ACTION_1F = { + 'action_type': np.array(action_type_lib.ActionType.TOUCH), + 'touch_position': np.array([0.2, 0.5], np.float32), +} + +_BENCHMARK_ACTION_3F = { + 'action_type': np.array(action_type_lib.ActionType.TOUCH), + 'touch_position': np.array([0.2, 0.5], np.float32), + 'action_type_2': np.array(action_type_lib.ActionType.LIFT), + 'touch_position_2': np.array([0.1, 0.2], np.float32), + 'action_type_3': np.array(action_type_lib.ActionType.TOUCH), + 'touch_position_3': np.array([0.5, 0.2], np.float32), +} + + +class ActionFnsBenchmark(parameterized.TestCase): + + def test_prepare_touch_action(self): + if not _RUN_BENCHMARKS.value: + self.skipTest('Benchmark disabled') + + number = 100000 + + # 1 finger + t_old_1 = timeit.Timer( + '_old_prepare_touch_action(_BENCHMARK_ACTION_1F, 800, 600, 1)', + globals=globals(), + ) + res_old_1 = t_old_1.timeit(number=number) + print( + f'BenchmarkPrepareTouchAction_1f_Old {number}' + f' {res_old_1 / number * 1e9:.0f} ns/op' + ) + + t_new_1 = timeit.Timer( + 'action_fns._prepare_touch_action(_BENCHMARK_ACTION_1F, 800, 600, 1)', + globals=globals(), + ) + res_new_1 = t_new_1.timeit(number=number) + print( + f'BenchmarkPrepareTouchAction_1f_New {number}' + f' {res_new_1 / number * 1e9:.0f} ns/op' + ) + + # 3 fingers + t_old_3 = timeit.Timer( + '_old_prepare_touch_action(_BENCHMARK_ACTION_3F, 800, 600, 3)', + globals=globals(), + ) + res_old_3 = t_old_3.timeit(number=number) + print( + f'BenchmarkPrepareTouchAction_3f_Old {number}' + f' {res_old_3 / number * 1e9:.0f} ns/op' + ) + + t_new_3 = timeit.Timer( + 'action_fns._prepare_touch_action(_BENCHMARK_ACTION_3F, 800, 600, 3)', + globals=globals(), + ) + res_new_3 = t_new_3.timeit(number=number) + print( + f'BenchmarkPrepareTouchAction_3f_New {number}' + f' {res_new_3 / number * 1e9:.0f} ns/op' + ) + + if __name__ == '__main__': absltest.main()