diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py index 211657e688..1c677ea928 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_workflow_run.py @@ -160,6 +160,19 @@ def _resume_to_workflow_responses(resume_payload: Any) -> dict[str, Any]: return responses +def _merge_workflow_response_sources( + resume_responses: dict[str, Any], + message_responses: dict[str, Any], +) -> dict[str, Any]: + """Merge workflow response sources with explicit resume payloads taking precedence.""" + if not resume_responses: + return dict(message_responses) + + responses = dict(message_responses) + responses.update(resume_responses) + return responses + + def _coerce_json_value(value: Any) -> Any: """Parse JSON strings when possible; otherwise return the original value.""" if not isinstance(value, str): @@ -541,8 +554,10 @@ async def run_workflow_stream( last_assistant_text: str | None = None resume_payload = _extract_resume_payload(input_data) - responses = _resume_to_workflow_responses(resume_payload) - responses.update(_extract_responses_from_messages(messages)) + responses = _merge_workflow_response_sources( + _resume_to_workflow_responses(resume_payload), + _extract_responses_from_messages(messages), + ) pending_before_run = await _pending_request_events(workflow) responses = _coerce_responses_for_pending_requests(responses, pending_before_run) pending_interrupts = _interrupts_from_pending_requests(pending_before_run) diff --git a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py index a52cc4dd2c..6c3cb9c94b 100644 --- a/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py +++ b/python/packages/ag-ui/tests/ag_ui/test_workflow_run.py @@ -1352,6 +1352,87 @@ async def handle_approval(self, original_request: Content, response: Content, ct assert not resumed_finished.get("interrupt") +async def test_workflow_run_explicit_resume_overrides_stale_message_approval() -> None: + """Explicit resume payloads should not be overwritten by stale function_approvals in messages.""" + + class ApprovalExecutor(Executor): + def __init__(self) -> None: + super().__init__(id="approval_executor") + + @handler + async def start(self, message: Any, ctx: WorkflowContext) -> None: + del message + function_call = Content.from_function_call( + call_id="refund-call", + name="submit_refund", + arguments={"order_id": "12345", "amount": "$89.99"}, + ) + approval_request = Content.from_function_approval_request(id="approval-1", function_call=function_call) + await ctx.request_info(approval_request, Content, request_id="approval-1") + + @response_handler + async def handle_approval(self, original_request: Content, response: Content, ctx: WorkflowContext) -> None: + del original_request + status = "approved" if bool(response.approved) else "rejected" + await ctx.yield_output(f"Refund {status}.") + + workflow = WorkflowBuilder(start_executor=ApprovalExecutor()).build() + first_events = [ + event async for event in run_workflow_stream({"messages": [{"role": "user", "content": "go"}]}, workflow) + ] + first_finished_events = [event for event in first_events if event.type == "RUN_FINISHED"] + assert len(first_finished_events) == 1 + first_finished = first_finished_events[0].model_dump() + interrupt_payload = first_finished.get("interrupt") + assert isinstance(interrupt_payload, list) + assert len(interrupt_payload) == 1 + interrupt_entry = interrupt_payload[0] + assert isinstance(interrupt_entry, dict) + interrupt_value = interrupt_entry.get("value") + assert isinstance(interrupt_value, dict) + + resumed_events = [ + event + async for event in run_workflow_stream( + { + "messages": [ + { + "role": "user", + "content": "", + "function_approvals": [ + { + "approved": True, + "id": "approval-1", + "call_id": "refund-call", + "name": "submit_refund", + "arguments": {"order_id": "12345", "amount": "$89.99"}, + } + ], + } + ], + "resume": { + "interrupts": [ + { + "id": "approval-1", + "value": { + "type": "function_approval_response", + "approved": False, + "id": interrupt_value.get("id", "approval-1"), + "function_call": interrupt_value.get("function_call"), + }, + } + ] + }, + }, + workflow, + ) + ] + + assistant_text = "".join(event.delta for event in resumed_events if event.type == "TEXT_MESSAGE_CONTENT") + assert "rejected" in assistant_text + assert "approved" not in assistant_text + + async def test_workflow_run_approval_via_messages_denied() -> None: """Denied approval response sent via messages (function_approvals) should satisfy the pending request."""