Skip to content
10 changes: 6 additions & 4 deletions src/executorlib/standalone/hdf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import os
from concurrent.futures import Future
from time import sleep
Expand Down Expand Up @@ -34,10 +35,11 @@ def dump(file_name: Optional[str], data_dict: dict) -> None:
with h5py.File(file_name_abs, "a") as fname:
for data_key, data_value in data_dict.items():
if data_key in group_dict:
fname.create_dataset(
name="/" + group_dict[data_key],
data=np.void(cloudpickle.dumps(data_value)),
)
with contextlib.suppress(ValueError):
fname.create_dataset(
name="/" + group_dict[data_key],
data=np.void(cloudpickle.dumps(data_value)),
)


def load(file_name: str) -> dict:
Expand Down
53 changes: 50 additions & 3 deletions src/executorlib/task_scheduler/file/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def execute_tasks_h5(
process_dict: dict = {}
cache_dir_dict: dict = {}
file_name_dict: dict = {}
duplicate_dict: dict = {}
while True:
task_dict = None
with contextlib.suppress(queue.Empty):
Expand All @@ -101,6 +102,7 @@ def execute_tasks_h5(
cancel_futures=task_dict.get("cancel_futures", False),
memory_dict=memory_dict,
process_dict=process_dict,
duplicate_dict=duplicate_dict,
cache_dir_dict=cache_dir_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
Expand Down Expand Up @@ -175,12 +177,17 @@ def execute_tasks_h5(
process_dict[task_key] = queue_id
memory_dict[task_key] = task_dict["future"]
cache_dir_dict[task_key] = cache_directory
elif memory_dict[task_key] != task_dict["future"]:
if task_key not in duplicate_dict:
duplicate_dict[task_key] = []
duplicate_dict[task_key].append(task_dict["future"])
future_queue.task_done()
else:
memory_dict = _refresh_memory_dict(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
duplicate_dict=duplicate_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
Expand All @@ -189,7 +196,10 @@ def execute_tasks_h5(


def _check_task_output(
task_key: str, future_obj: Future, cache_directory: str
task_key: str,
future_obj: Future,
cache_directory: str,
duplicate_dict: Optional[dict] = None,
) -> Future:
"""
Check the output of a task and set the result of the future object if available.
Expand All @@ -198,7 +208,7 @@ def _check_task_output(
task_key (str): The key of the task.
future_obj (Future): The future object associated with the task.
cache_directory (str): The directory where the HDF5 files are stored.

duplicate_dict (dict): The dictionary mapping task keys to their associated duplicate future objects.
Returns:
Future: The updated future object.

Expand All @@ -207,11 +217,40 @@ def _check_task_output(
if not os.path.exists(file_name):
return future_obj
exec_flag, no_error_flag, result = get_output(file_name=file_name)
_update_future(
future_obj=future_obj,
exec_flag=exec_flag,
no_error_flag=no_error_flag,
result=result,
)
if duplicate_dict is not None and task_key in duplicate_dict:
for duplicate_future in duplicate_dict[task_key]:
_update_future(
future_obj=duplicate_future,
exec_flag=exec_flag,
no_error_flag=no_error_flag,
result=result,
)
del duplicate_dict[task_key]
return future_obj


def _update_future(
future_obj: Future, exec_flag: bool, no_error_flag: bool, result: Any
) -> None:
"""
Update the future object with the result of the task execution.

Args:
future_obj (Future): The future object to be updated.
exec_flag (bool): Flag indicating whether the task has been executed.
no_error_flag (bool): Flag indicating whether the task execution resulted in an error.
result (Any): The result of the task execution.
"""
if exec_flag and no_error_flag:
future_obj.set_result(result)
elif exec_flag:
future_obj.set_exception(result)
return future_obj


def _convert_args_and_kwargs(
Expand Down Expand Up @@ -281,6 +320,7 @@ def _refresh_memory_dict(
memory_dict: dict,
cache_dir_dict: dict,
process_dict: dict,
duplicate_dict: Optional[dict] = None,
terminate_function: Optional[Callable] = None,
pysqa_config_directory: Optional[str] = None,
backend: Optional[str] = None,
Expand All @@ -293,6 +333,7 @@ def _refresh_memory_dict(
memory_dict (dict): dictionary with task keys and future objects
cache_dir_dict (dict): dictionary with task keys and cache directories
process_dict (dict): dictionary with task keys and process reference.
duplicate_dict (dict): dictionary with task keys and duplicate future objects.
terminate_function (callable): The function to terminate the tasks.
pysqa_config_directory (str): path to the pysqa config directory (only for pysqa based backend).
backend (str): name of the backend used to spawn tasks.
Expand All @@ -315,6 +356,7 @@ def _refresh_memory_dict(
task_key=key,
future_obj=value,
cache_directory=cache_dir_dict[key],
duplicate_dict=duplicate_dict,
)
for key, value in memory_dict.items()
if not value.done()
Expand Down Expand Up @@ -400,6 +442,7 @@ def _shutdown_executor(
memory_dict: dict,
process_dict: dict,
cache_dir_dict: dict,
duplicate_dict: Optional[dict] = None,
terminate_function: Optional[Callable] = None,
pysqa_config_directory: Optional[str] = None,
backend: Optional[str] = None,
Expand All @@ -421,6 +464,7 @@ def _shutdown_executor(
cancel_futures (bool): Whether to cancel futures that have not yet started.
memory_dict (dict): Mapping of task keys to their Future objects.
process_dict (dict): Mapping of task keys to process handles or queue IDs.
duplicate_dict (dict): Mapping of task keys to lists of duplicate Future objects.
cache_dir_dict (dict): Mapping of task keys to the cache directory for each task.
terminate_function (Callable, optional): Function used to terminate running processes.
pysqa_config_directory (str, optional): Path to the pysqa config directory.
Expand All @@ -433,6 +477,7 @@ def _shutdown_executor(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
duplicate_dict=duplicate_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
Expand All @@ -447,6 +492,7 @@ def _shutdown_executor(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
duplicate_dict=duplicate_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
Expand All @@ -465,6 +511,7 @@ def _shutdown_executor(
memory_dict=memory_dict,
cache_dir_dict=cache_dir_dict,
process_dict=process_dict,
duplicate_dict=duplicate_dict,
terminate_function=terminate_function,
pysqa_config_directory=pysqa_config_directory,
backend=backend,
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/executor/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,17 @@ def test_executor_dependency_plot(self):
self.assertEqual(len(nodes), 4)
self.assertEqual(len(edges), 4)

def test_duplicate_futures(self):
with TestClusterExecutor(cache_directory="cache_dir") as exe:
cloudpickle_register(ind=1)
future_1 = exe.submit(add_with_sleep, 1, parameter_2=2)
future_2 = exe.submit(add_with_sleep, 1, parameter_2=2)
self.assertFalse(future_1.done())
self.assertFalse(future_2.done())
self.assertEqual(future_1.result(), 3)
self.assertEqual(future_2.result(), 3)
self.assertEqual(len(os.listdir("cache_dir")), 1)

def test_shutdown_wait_false_cancel_futures_false(self):
exe = TestClusterExecutor(cache_directory="shutdown_1_dir")
cloudpickle_register(ind=1)
Expand Down Expand Up @@ -202,6 +213,7 @@ def test_shutdown_executor_function(self):
cancel_futures=True,
memory_dict=memory_dict,
process_dict={},
duplicate_dict=None,
cache_dir_dict={"a": "cache_dir"},
terminate_function=None,
pysqa_config_directory=None,
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/executor/test_single_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import shutil
import unittest
from time import sleep

from executorlib import SingleNodeExecutor, get_cache_data
from executorlib.standalone.serialize import cloudpickle_register
Expand All @@ -17,6 +18,11 @@ def get_error(a):
raise ValueError(a)


def sum_with_wait(a, b):
sleep((a + b) / 10)
return a + b


class AddClass:
def __call__(self, a, b):
return a+b
Expand Down Expand Up @@ -69,6 +75,13 @@ def test_cache_key(self):
sum([sum(c["input_args"][0]) for c in cache_lst]), sum(result_lst)
)

def test_cache_duplicate_function(self):
cache_directory = os.path.abspath("cache_duplicate")
with SingleNodeExecutor(hostname_localhost=True, cache_directory=cache_directory) as exe:
f1 = exe.submit(sum_with_wait, 1, 1)
f2 = exe.submit(sum_with_wait, 1, 1)
self.assertEqual(f1.result(), f2.result())

def test_cache_error(self):
cache_directory = os.path.abspath("cache_error")
with SingleNodeExecutor(cache_directory=cache_directory) as exe:
Expand All @@ -93,3 +106,4 @@ def test_cache_error_file(self):
def tearDown(self):
shutil.rmtree("executorlib_cache", ignore_errors=True)
shutil.rmtree("cache_error", ignore_errors=True)
shutil.rmtree("cache_duplicate", ignore_errors=True)
12 changes: 6 additions & 6 deletions tests/unit/task_scheduler/file/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_execute_function_mixed(self):
backend_execute_task_in_file(file_name=file_name)
future_obj = Future()
_check_task_output(
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory,
)
self.assertTrue(future_obj.done())
self.assertEqual(future_obj.result(), 3)
Expand Down Expand Up @@ -101,7 +101,7 @@ def test_execute_function_mixed_selector_convert(self):
backend_execute_task_in_file(file_name=file_name_1)
f1 = Future()
_check_task_output(
task_key=task_key_1, future_obj=f1, cache_directory=cache_directory
task_key=task_key_1, future_obj=f1, cache_directory=cache_directory,
)
task_key_2, data_dict = serialize_funct(
fn=return_list,
Expand All @@ -113,7 +113,7 @@ def test_execute_function_mixed_selector_convert(self):
backend_execute_task_in_file(file_name=file_name_2)
f2 = Future()
_check_task_output(
task_key=task_key_2, future_obj=f2, cache_directory=cache_directory
task_key=task_key_2, future_obj=f2, cache_directory=cache_directory,
)
fs1 = FutureSelector(future=f1, selector="a")
fs2 = FutureSelector(future=f2, selector=1)
Expand Down Expand Up @@ -143,7 +143,7 @@ def test_execute_function_args(self):
backend_execute_task_in_file(file_name=file_name)
future_obj = Future()
_check_task_output(
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory,
)
self.assertTrue(future_obj.done())
self.assertEqual(future_obj.result(), 3)
Expand All @@ -170,7 +170,7 @@ def test_execute_function_kwargs(self):
backend_execute_task_in_file(file_name=file_name)
future_obj = Future()
_check_task_output(
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory,
)
self.assertTrue(future_obj.done())
self.assertEqual(future_obj.result(), 3)
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_execute_function_error(self):
backend_execute_task_in_file(file_name=file_name)
future_obj = Future()
_check_task_output(
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory
task_key=task_key, future_obj=future_obj, cache_directory=cache_directory,
)
self.assertTrue(future_obj.done())
with self.assertRaises(ValueError):
Expand Down
Loading