Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions temporalio/client/_nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,35 @@ async def start_operation(
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

# Overload for temporal_operation methods
@overload
@abstractmethod
async def start_operation(
self,
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
summary: str | None = None,
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> NexusOperationHandle[OutputT]: ...

@abstractmethod
async def start_operation(
self,
Expand Down Expand Up @@ -804,6 +833,35 @@ async def execute_operation(
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

# Overload for temporal_operation methods
@overload
@abstractmethod
async def execute_operation(
self,
operation: Callable[
[
NexusServiceType,
temporalio.nexus.TemporalNexusStartOperationContext,
temporalio.nexus.TemporalNexusClient,
InputT,
],
Awaitable[temporalio.nexus.TemporalOperationResult[OutputT]],
],
arg: InputT,
*,
id: str,
id_reuse_policy: temporalio.common.NexusOperationIDReusePolicy = temporalio.common.NexusOperationIDReusePolicy.ALLOW_DUPLICATE,
id_conflict_policy: temporalio.common.NexusOperationIDConflictPolicy = temporalio.common.NexusOperationIDConflictPolicy.FAIL,
schedule_to_close_timeout: timedelta | None = None,
schedule_to_start_timeout: timedelta | None = None,
start_to_close_timeout: timedelta | None = None,
search_attributes: temporalio.common.TypedSearchAttributes | None = None,
summary: str | None = None,
headers: Mapping[str, str] | None = None,
rpc_metadata: Mapping[str, str | bytes] = {},
rpc_timeout: timedelta | None = None,
) -> OutputT: ...

@abstractmethod
async def execute_operation(
self,
Expand Down
12 changes: 11 additions & 1 deletion temporalio/nexus/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,13 @@
See https://github.com/temporalio/sdk-python/tree/main#nexus
"""

from ._decorators import workflow_run_operation
from ._decorators import temporal_operation, workflow_run_operation
from ._operation_context import (
Info,
LoggerAdapter,
NexusCallback,
TemporalNexusCancelOperationContext,
TemporalNexusStartOperationContext,
WorkflowRunOperationContext,
client,
in_operation,
Expand All @@ -18,6 +20,8 @@
wait_for_worker_shutdown,
wait_for_worker_shutdown_sync,
)
from ._operation_handlers import TemporalNexusOperationHandler
from ._temporal_client import TemporalNexusClient, TemporalOperationResult
from ._token import WorkflowHandle

__all__ = (
Expand All @@ -26,6 +30,8 @@
"LoggerAdapter",
"NexusCallback",
"WorkflowRunOperationContext",
"TemporalNexusCancelOperationContext",
"TemporalNexusStartOperationContext",
"client",
"in_operation",
"info",
Expand All @@ -35,4 +41,8 @@
"wait_for_worker_shutdown",
"wait_for_worker_shutdown_sync",
"WorkflowHandle",
"TemporalNexusClient",
"TemporalNexusOperationHandler",
"TemporalOperationResult",
"temporal_operation",
)
201 changes: 185 additions & 16 deletions temporalio/nexus/_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from collections.abc import Awaitable, Callable
from typing import (
TypeVar,
overload,
)

Expand All @@ -13,26 +12,37 @@
StartOperationContext,
)

from ._operation_context import WorkflowRunOperationContext
from ._operation_handlers import WorkflowRunOperationHandler
from temporalio.nexus._temporal_client import (
TemporalNexusClient,
TemporalOperationResult,
)
from temporalio.types import NexusServiceType

from ._operation_context import (
TemporalNexusStartOperationContext,
WorkflowRunOperationContext,
)
from ._operation_handlers import (
WorkflowRunOperationHandler,
_TemporalNexusOperationHandler,
)
from ._token import WorkflowHandle
from ._util import (
get_callable_name,
get_temporal_operation_start_method_input_and_output_type_annotations,
get_workflow_run_start_method_input_and_output_type_annotations,
set_operation_factory,
)

ServiceHandlerT = TypeVar("ServiceHandlerT")


@overload
def workflow_run_operation(
start: Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
) -> Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]: ...

Expand All @@ -44,12 +54,12 @@ def workflow_run_operation(
) -> Callable[
[
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
],
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
]: ...
Expand All @@ -59,26 +69,26 @@ def workflow_run_operation(
start: None
| (
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
) = None,
*,
name: str | None = None,
) -> (
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
| Callable[
[
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]
],
Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
]
Expand All @@ -87,11 +97,11 @@ def workflow_run_operation(

def decorator(
start: Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
],
) -> Callable[
[ServiceHandlerT, WorkflowRunOperationContext, InputT],
[NexusServiceType, WorkflowRunOperationContext, InputT],
Awaitable[WorkflowHandle[OutputT]],
]:
(
Expand All @@ -100,7 +110,7 @@ def decorator(
) = get_workflow_run_start_method_input_and_output_type_annotations(start)

def operation_handler_factory(
self: ServiceHandlerT,
self: NexusServiceType,
) -> OperationHandler[InputT, OutputT]:
async def _start(
ctx: StartOperationContext, input: InputT
Expand Down Expand Up @@ -130,3 +140,162 @@ async def _start(
return decorator

return decorator(start)


@overload
def temporal_operation(
start: Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
) -> Callable[
[NexusServiceType, TemporalNexusStartOperationContext, TemporalNexusClient, InputT],
Awaitable[TemporalOperationResult[OutputT]],
]: ...


@overload
def temporal_operation(
*,
name: str | None = None,
) -> Callable[
[
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
],
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
]: ...


def temporal_operation(
start: None
| (
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
) = None,
*,
name: str | None = None,
) -> (
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
| Callable[
[
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]
],
Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
]
):
"""Decorator marking a method as the start method for an operation that interacts with Temporal.

.. warning::
This API is experimental and unstable.
"""

def decorator(
start: Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
],
) -> Callable[
[
NexusServiceType,
TemporalNexusStartOperationContext,
TemporalNexusClient,
InputT,
],
Awaitable[TemporalOperationResult[OutputT]],
]:
(
input_type,
output_type,
) = get_temporal_operation_start_method_input_and_output_type_annotations(start)

def operation_handler_factory(
self: NexusServiceType,
) -> OperationHandler[InputT, OutputT]:
async def _start(
ctx: TemporalNexusStartOperationContext,
client: TemporalNexusClient,
input: InputT,
) -> TemporalOperationResult[OutputT]:
return await start(
self,
ctx,
client,
input,
)

_start.__doc__ = start.__doc__
return _TemporalNexusOperationHandler(_start)

method_name = get_callable_name(start)
op = nexusrpc.Operation(
name=name or method_name,
input_type=input_type,
output_type=output_type,
)
op.method_name = method_name
nexusrpc.set_operation(operation_handler_factory, op)

set_operation_factory(start, operation_handler_factory)
return start

if start is None:
return decorator

return decorator(start)
Loading
Loading