diff --git a/src/executorlib/standalone/hdf.py b/src/executorlib/standalone/hdf.py index dffc5d4b2..daac1e467 100644 --- a/src/executorlib/standalone/hdf.py +++ b/src/executorlib/standalone/hdf.py @@ -1,3 +1,4 @@ +import contextlib import os from concurrent.futures import Future from time import sleep @@ -34,10 +35,11 @@ def dump(file_name: Optional[str], data_dict: dict) -> None: with h5py.File(file_name_abs, "a") as fname: for data_key, data_value in data_dict.items(): if data_key in group_dict: - fname.create_dataset( - name="/" + group_dict[data_key], - data=np.void(cloudpickle.dumps(data_value)), - ) + with contextlib.suppress(ValueError): + fname.create_dataset( + name="/" + group_dict[data_key], + data=np.void(cloudpickle.dumps(data_value)), + ) def load(file_name: str) -> dict: diff --git a/src/executorlib/task_scheduler/file/shared.py b/src/executorlib/task_scheduler/file/shared.py index e18a1bcca..ae7b43b71 100644 --- a/src/executorlib/task_scheduler/file/shared.py +++ b/src/executorlib/task_scheduler/file/shared.py @@ -91,6 +91,7 @@ def execute_tasks_h5( process_dict: dict = {} cache_dir_dict: dict = {} file_name_dict: dict = {} + duplicate_dict: dict = {} while True: task_dict = None with contextlib.suppress(queue.Empty): @@ -101,6 +102,7 @@ def execute_tasks_h5( cancel_futures=task_dict.get("cancel_futures", False), memory_dict=memory_dict, process_dict=process_dict, + duplicate_dict=duplicate_dict, cache_dir_dict=cache_dir_dict, terminate_function=terminate_function, pysqa_config_directory=pysqa_config_directory, @@ -175,12 +177,17 @@ def execute_tasks_h5( process_dict[task_key] = queue_id memory_dict[task_key] = task_dict["future"] cache_dir_dict[task_key] = cache_directory + elif memory_dict[task_key] != task_dict["future"]: + if task_key not in duplicate_dict: + duplicate_dict[task_key] = [] + duplicate_dict[task_key].append(task_dict["future"]) future_queue.task_done() else: memory_dict = _refresh_memory_dict( memory_dict=memory_dict, cache_dir_dict=cache_dir_dict, process_dict=process_dict, + duplicate_dict=duplicate_dict, terminate_function=terminate_function, pysqa_config_directory=pysqa_config_directory, backend=backend, @@ -189,7 +196,10 @@ def execute_tasks_h5( def _check_task_output( - task_key: str, future_obj: Future, cache_directory: str + task_key: str, + future_obj: Future, + cache_directory: str, + duplicate_dict: Optional[dict] = None, ) -> Future: """ Check the output of a task and set the result of the future object if available. @@ -198,7 +208,7 @@ def _check_task_output( task_key (str): The key of the task. future_obj (Future): The future object associated with the task. cache_directory (str): The directory where the HDF5 files are stored. - + duplicate_dict (dict): The dictionary mapping task keys to their associated duplicate future objects. Returns: Future: The updated future object. @@ -207,11 +217,40 @@ def _check_task_output( if not os.path.exists(file_name): return future_obj exec_flag, no_error_flag, result = get_output(file_name=file_name) + _update_future( + future_obj=future_obj, + exec_flag=exec_flag, + no_error_flag=no_error_flag, + result=result, + ) + if duplicate_dict is not None and task_key in duplicate_dict: + for duplicate_future in duplicate_dict[task_key]: + _update_future( + future_obj=duplicate_future, + exec_flag=exec_flag, + no_error_flag=no_error_flag, + result=result, + ) + del duplicate_dict[task_key] + return future_obj + + +def _update_future( + future_obj: Future, exec_flag: bool, no_error_flag: bool, result: Any +) -> None: + """ + Update the future object with the result of the task execution. + + Args: + future_obj (Future): The future object to be updated. + exec_flag (bool): Flag indicating whether the task has been executed. + no_error_flag (bool): Flag indicating whether the task execution resulted in an error. + result (Any): The result of the task execution. + """ if exec_flag and no_error_flag: future_obj.set_result(result) elif exec_flag: future_obj.set_exception(result) - return future_obj def _convert_args_and_kwargs( @@ -281,6 +320,7 @@ def _refresh_memory_dict( memory_dict: dict, cache_dir_dict: dict, process_dict: dict, + duplicate_dict: Optional[dict] = None, terminate_function: Optional[Callable] = None, pysqa_config_directory: Optional[str] = None, backend: Optional[str] = None, @@ -293,6 +333,7 @@ def _refresh_memory_dict( memory_dict (dict): dictionary with task keys and future objects cache_dir_dict (dict): dictionary with task keys and cache directories process_dict (dict): dictionary with task keys and process reference. + duplicate_dict (dict): dictionary with task keys and duplicate future objects. terminate_function (callable): The function to terminate the tasks. pysqa_config_directory (str): path to the pysqa config directory (only for pysqa based backend). backend (str): name of the backend used to spawn tasks. @@ -315,6 +356,7 @@ def _refresh_memory_dict( task_key=key, future_obj=value, cache_directory=cache_dir_dict[key], + duplicate_dict=duplicate_dict, ) for key, value in memory_dict.items() if not value.done() @@ -400,6 +442,7 @@ def _shutdown_executor( memory_dict: dict, process_dict: dict, cache_dir_dict: dict, + duplicate_dict: Optional[dict] = None, terminate_function: Optional[Callable] = None, pysqa_config_directory: Optional[str] = None, backend: Optional[str] = None, @@ -421,6 +464,7 @@ def _shutdown_executor( cancel_futures (bool): Whether to cancel futures that have not yet started. memory_dict (dict): Mapping of task keys to their Future objects. process_dict (dict): Mapping of task keys to process handles or queue IDs. + duplicate_dict (dict): Mapping of task keys to lists of duplicate Future objects. cache_dir_dict (dict): Mapping of task keys to the cache directory for each task. terminate_function (Callable, optional): Function used to terminate running processes. pysqa_config_directory (str, optional): Path to the pysqa config directory. @@ -433,6 +477,7 @@ def _shutdown_executor( memory_dict=memory_dict, cache_dir_dict=cache_dir_dict, process_dict=process_dict, + duplicate_dict=duplicate_dict, terminate_function=terminate_function, pysqa_config_directory=pysqa_config_directory, backend=backend, @@ -447,6 +492,7 @@ def _shutdown_executor( memory_dict=memory_dict, cache_dir_dict=cache_dir_dict, process_dict=process_dict, + duplicate_dict=duplicate_dict, terminate_function=terminate_function, pysqa_config_directory=pysqa_config_directory, backend=backend, @@ -465,6 +511,7 @@ def _shutdown_executor( memory_dict=memory_dict, cache_dir_dict=cache_dir_dict, process_dict=process_dict, + duplicate_dict=duplicate_dict, terminate_function=terminate_function, pysqa_config_directory=pysqa_config_directory, backend=backend, diff --git a/tests/unit/executor/test_api.py b/tests/unit/executor/test_api.py index 160c79f69..5fadc6ebd 100644 --- a/tests/unit/executor/test_api.py +++ b/tests/unit/executor/test_api.py @@ -156,6 +156,17 @@ def test_executor_dependency_plot(self): self.assertEqual(len(nodes), 4) self.assertEqual(len(edges), 4) + def test_duplicate_futures(self): + with TestClusterExecutor(cache_directory="cache_dir") as exe: + cloudpickle_register(ind=1) + future_1 = exe.submit(add_with_sleep, 1, parameter_2=2) + future_2 = exe.submit(add_with_sleep, 1, parameter_2=2) + self.assertFalse(future_1.done()) + self.assertFalse(future_2.done()) + self.assertEqual(future_1.result(), 3) + self.assertEqual(future_2.result(), 3) + self.assertEqual(len(os.listdir("cache_dir")), 1) + def test_shutdown_wait_false_cancel_futures_false(self): exe = TestClusterExecutor(cache_directory="shutdown_1_dir") cloudpickle_register(ind=1) @@ -202,6 +213,7 @@ def test_shutdown_executor_function(self): cancel_futures=True, memory_dict=memory_dict, process_dict={}, + duplicate_dict=None, cache_dir_dict={"a": "cache_dir"}, terminate_function=None, pysqa_config_directory=None, diff --git a/tests/unit/executor/test_single_cache.py b/tests/unit/executor/test_single_cache.py index f04bd4dc1..b80619b6e 100644 --- a/tests/unit/executor/test_single_cache.py +++ b/tests/unit/executor/test_single_cache.py @@ -1,6 +1,7 @@ import os import shutil import unittest +from time import sleep from executorlib import SingleNodeExecutor, get_cache_data from executorlib.standalone.serialize import cloudpickle_register @@ -17,6 +18,11 @@ def get_error(a): raise ValueError(a) +def sum_with_wait(a, b): + sleep((a + b) / 10) + return a + b + + class AddClass: def __call__(self, a, b): return a+b @@ -69,6 +75,13 @@ def test_cache_key(self): sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst) ) + def test_cache_duplicate_function(self): + cache_directory = os.path.abspath("cache_duplicate") + with SingleNodeExecutor(hostname_localhost=True, cache_directory=cache_directory) as exe: + f1 = exe.submit(sum_with_wait, 1, 1) + f2 = exe.submit(sum_with_wait, 1, 1) + self.assertEqual(f1.result(), f2.result()) + def test_cache_error(self): cache_directory = os.path.abspath("cache_error") with SingleNodeExecutor(cache_directory=cache_directory) as exe: @@ -93,3 +106,4 @@ def test_cache_error_file(self): def tearDown(self): shutil.rmtree("executorlib_cache", ignore_errors=True) shutil.rmtree("cache_error", ignore_errors=True) + shutil.rmtree("cache_duplicate", ignore_errors=True) diff --git a/tests/unit/task_scheduler/file/test_backend.py b/tests/unit/task_scheduler/file/test_backend.py index b000a159c..bf189f7e3 100644 --- a/tests/unit/task_scheduler/file/test_backend.py +++ b/tests/unit/task_scheduler/file/test_backend.py @@ -50,7 +50,7 @@ def test_execute_function_mixed(self): backend_execute_task_in_file(file_name=file_name) future_obj = Future() _check_task_output( - task_key=task_key, future_obj=future_obj, cache_directory=cache_directory + task_key=task_key, future_obj=future_obj, cache_directory=cache_directory, ) self.assertTrue(future_obj.done()) self.assertEqual(future_obj.result(), 3) @@ -101,7 +101,7 @@ def test_execute_function_mixed_selector_convert(self): backend_execute_task_in_file(file_name=file_name_1) f1 = Future() _check_task_output( - task_key=task_key_1, future_obj=f1, cache_directory=cache_directory + task_key=task_key_1, future_obj=f1, cache_directory=cache_directory, ) task_key_2, data_dict = serialize_funct( fn=return_list, @@ -113,7 +113,7 @@ def test_execute_function_mixed_selector_convert(self): backend_execute_task_in_file(file_name=file_name_2) f2 = Future() _check_task_output( - task_key=task_key_2, future_obj=f2, cache_directory=cache_directory + task_key=task_key_2, future_obj=f2, cache_directory=cache_directory, ) fs1 = FutureSelector(future=f1, selector="a") fs2 = FutureSelector(future=f2, selector=1) @@ -143,7 +143,7 @@ def test_execute_function_args(self): backend_execute_task_in_file(file_name=file_name) future_obj = Future() _check_task_output( - task_key=task_key, future_obj=future_obj, cache_directory=cache_directory + task_key=task_key, future_obj=future_obj, cache_directory=cache_directory, ) self.assertTrue(future_obj.done()) self.assertEqual(future_obj.result(), 3) @@ -170,7 +170,7 @@ def test_execute_function_kwargs(self): backend_execute_task_in_file(file_name=file_name) future_obj = Future() _check_task_output( - task_key=task_key, future_obj=future_obj, cache_directory=cache_directory + task_key=task_key, future_obj=future_obj, cache_directory=cache_directory, ) self.assertTrue(future_obj.done()) self.assertEqual(future_obj.result(), 3) @@ -198,7 +198,7 @@ def test_execute_function_error(self): backend_execute_task_in_file(file_name=file_name) future_obj = Future() _check_task_output( - task_key=task_key, future_obj=future_obj, cache_directory=cache_directory + task_key=task_key, future_obj=future_obj, cache_directory=cache_directory, ) self.assertTrue(future_obj.done()) with self.assertRaises(ValueError):