From 451c87b062e0dfc181aff31e02b6db834fb78aa5 Mon Sep 17 00:00:00 2001 From: Andrej Simurka Date: Mon, 8 Jun 2026 08:53:11 +0200 Subject: [PATCH] Wired agent.run_stream_events in streaming_query --- src/app/endpoints/streaming_query.py | 32 +- .../llamastack/__init__.py | 3 +- src/utils/agents/streaming.py | 2 +- src/utils/pydantic_ai.py | 9 +- .../e2e/features/steps/llm_query_response.py | 5 - tests/integration/conftest.py | 122 +++++++- .../test_streaming_query_byok_integration.py | 283 +++++++----------- .../test_streaming_query_integration.py | 24 +- .../app/endpoints/test_streaming_query.py | 50 ++-- 9 files changed, 299 insertions(+), 231 deletions(-) diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index 913909bb7..0a11dea98 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -36,6 +36,7 @@ APIStatusError as LLSApiStatusError, ) from openai._exceptions import APIStatusError as OpenAIAPIStatusError +from typing_extensions import deprecated from authentication import get_auth_dependency from authentication.interface import AuthTuple @@ -74,6 +75,10 @@ from models.common.responses.types import ResponseInput from models.common.turn_summary import TurnSummary from models.config import Action +from utils.agents.streaming import ( + generate_agent_response, + retrieve_agent_response_generator, +) from utils.conversation_compaction import ( CompactionResult, CompactionStartedEvent, @@ -329,7 +334,7 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals media_type=response_media_type, ) - generator, turn_summary = await retrieve_response_generator( + generator, turn_summary = await retrieve_agent_response_generator( responses_params=responses_params, context=context, endpoint_path=endpoint_path, @@ -342,16 +347,21 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals ) return StreamingResponse( - generate_response( + generate_agent_response( generator=generator, context=context, responses_params=responses_params, turn_summary=turn_summary, + background_topic_summary_tasks=_background_topic_summary_tasks, ), media_type=response_media_type, ) +@deprecated( + "Deprecated in favor of utils.agents.streaming.retrieve_agent_response_generator.", + stacklevel=2, +) async def retrieve_response_generator( responses_params: ResponsesApiParams, context: ResponseGeneratorContext, @@ -474,7 +484,7 @@ async def generate_response_with_compaction( request_id=context.request_id, ) - compacted = False + _compacted = False compacted_original_input: Optional[ResponseInput] = None try: async for item in apply_compaction( @@ -491,10 +501,10 @@ async def generate_response_with_compaction( yield stream_compaction_event(context.conversation_id) elif isinstance(item, CompactionResult): responses_params = item.params - compacted = item.compacted + _compacted = item.compacted compacted_original_input = item.original_input - generator, turn_summary = await retrieve_response_generator( + generator, turn_summary = await retrieve_agent_response_generator( responses_params=responses_params, context=context, endpoint_path=endpoint_path, @@ -531,18 +541,22 @@ async def generate_response_with_compaction( # The start event was already emitted above; delegate the rest (re-yield, # finalization, compacted-turn storage) to the shared generator. - async for event in generate_response( + async for event in generate_agent_response( generator, context, responses_params, turn_summary, + background_topic_summary_tasks=_background_topic_summary_tasks, emit_start=False, - compacted=compacted, original_input=compacted_original_input, ): yield event +@deprecated( + "Deprecated in favor of utils.agents.streaming.generate_agent_response.", + stacklevel=2, +) async def generate_response( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-branches,too-many-statements generator: AsyncIterator[str], context: ResponseGeneratorContext, @@ -711,6 +725,10 @@ async def generate_response( # pylint: disable=too-many-arguments,too-many-posi ) +@deprecated( + "Deprecated in favor of utils.agents.streaming.agent_response_generator.", + stacklevel=2, +) async def response_generator( # pylint: disable=too-many-branches,too-many-statements,too-many-locals turn_response: AsyncIterator[OpenAIResponseObjectStream], context: ResponseGeneratorContext, diff --git a/src/pydantic_ai_lightspeed/llamastack/__init__.py b/src/pydantic_ai_lightspeed/llamastack/__init__.py index 47eda1e7d..fac9ee826 100644 --- a/src/pydantic_ai_lightspeed/llamastack/__init__.py +++ b/src/pydantic_ai_lightspeed/llamastack/__init__.py @@ -1,5 +1,6 @@ """Pydantic AI provider for Llama Stack.""" +from pydantic_ai_lightspeed.llamastack._model import LlamaStackResponsesModel from pydantic_ai_lightspeed.llamastack._provider import LlamaStackProvider -__all__ = ["LlamaStackProvider"] +__all__ = ["LlamaStackProvider", "LlamaStackResponsesModel"] diff --git a/src/utils/agents/streaming.py b/src/utils/agents/streaming.py index 07b15f7e1..138852bc2 100644 --- a/src/utils/agents/streaming.py +++ b/src/utils/agents/streaming.py @@ -24,7 +24,6 @@ TextPartDelta, ) -from app.endpoints.streaming_query import shield_violation_generator from configuration import configuration from constants import INTERRUPTED_RESPONSE_MESSAGE, MEDIA_TYPE_JSON from log import get_logger @@ -70,6 +69,7 @@ persist_interrupted_turn, register_interrupt_callback, ) +from utils.streaming_sse import shield_violation_generator AgentDispatchEvent: TypeAlias = AgentStreamEvent | AgentRunResultEvent diff --git a/src/utils/pydantic_ai.py b/src/utils/pydantic_ai.py index c655e67a5..f4a1cf18c 100644 --- a/src/utils/pydantic_ai.py +++ b/src/utils/pydantic_ai.py @@ -7,12 +7,15 @@ from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from llama_stack_client import AsyncLlamaStackClient from pydantic_ai import Agent, AgentCapability -from pydantic_ai.models.openai import OpenAIResponsesModel, OpenAIResponsesModelSettings +from pydantic_ai.models.openai import OpenAIResponsesModelSettings from pydantic_ai_skills import SkillsCapability from models.common.responses.responses_api_params import ResponsesApiParams from models.config import SkillsConfiguration -from pydantic_ai_lightspeed.llamastack import LlamaStackProvider +from pydantic_ai_lightspeed.llamastack import ( + LlamaStackProvider, + LlamaStackResponsesModel, +) _LLS_RESPONSES_EXTRA_FIELDS: Final[frozenset[str]] = frozenset( { @@ -132,7 +135,7 @@ def build_agent( provider = _llama_stack_provider_from_client(client) settings = _model_settings_from_responses_params(responses_params) - model = OpenAIResponsesModel( + model = LlamaStackResponsesModel( responses_params.model, provider=provider, settings=settings, diff --git a/tests/e2e/features/steps/llm_query_response.py b/tests/e2e/features/steps/llm_query_response.py index 18c76a4cf..b0f992861 100644 --- a/tests/e2e/features/steps/llm_query_response.py +++ b/tests/e2e/features/steps/llm_query_response.py @@ -366,7 +366,6 @@ def _parse_streaming_response(response_text: str) -> dict: full_response = "" full_response_split = [] finished = False - first_token = True stream_error = ( None # {"status_code": int, "response": str, "cause": str} if event "error" ) @@ -380,10 +379,6 @@ def _parse_streaming_response(response_text: str) -> dict: if event == "start": conversation_id = data["data"]["conversation_id"] elif event == "token": - # Skip the first token (shield status message) - if first_token: - first_token = False - continue full_response_split.append(data["data"]["token"]) elif event == "turn_complete": full_response = data["data"]["token"] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 7f3a263da..3833071e4 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,7 +1,7 @@ """Shared fixtures for integration tests.""" import os -from collections.abc import Generator +from collections.abc import AsyncIterator, Generator from pathlib import Path from typing import Any, Optional @@ -10,12 +10,15 @@ from fastapi.testclient import TestClient from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client.types import VersionInfo +from pydantic_ai import AgentRunResultEvent from pydantic_ai.messages import ( ModelMessage, ModelRequest, ModelResponse, NativeToolCallPart, NativeToolReturnPart, + PartEndEvent, + PartStartEvent, TextPart, ToolCallPart, ToolReturnPart, @@ -347,6 +350,96 @@ def set_query_agent_run( mock_query_agent.run.return_value = create_agent_run_result(mocker, **kwargs) +def mock_agent_run_stream(events: list[Any]) -> Any: + """Build an async context manager that yields pydantic-ai stream events.""" + + async def _event_stream() -> AsyncIterator[Any]: + for event in events: + yield event + + class _RunStreamCtx: + """Minimal async context manager matching agent.run_stream_events.""" + + async def __aenter__(self) -> AsyncIterator[Any]: + return _event_stream() + + async def __aexit__(self, *_args: object) -> None: + return None + + return _RunStreamCtx() + + +def create_text_agent_stream_events( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str = "This is a test response about Ansible.", + response_id: str = "response-123", + input_tokens: int = 10, + output_tokens: int = 5, +) -> list[Any]: + """Create pydantic-ai stream events for a simple text agent run.""" + run_result = create_agent_run_result( + mocker, + content=content, + response_id=response_id, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + return [ + PartStartEvent(index=0, part=TextPart(content=content)), + AgentRunResultEvent(result=run_result), + ] + + +def create_file_search_agent_stream_events( # pylint: disable=too-many-arguments,too-many-positional-arguments + mocker: MockerFixture, + *, + content: str, + response_id: str = "response-tool-rag", + queries: Optional[list[str]] = None, + results: Optional[list[dict[str, Any]]] = None, + input_tokens: int = 10, + output_tokens: int = 5, +) -> list[Any]: + """Create pydantic-ai stream events for a file_search tool agent run.""" + run_result = create_file_search_agent_run_result( + mocker, + content=content, + response_id=response_id, + queries=queries, + results=results, + input_tokens=input_tokens, + output_tokens=output_tokens, + ) + call = NativeToolCallPart( + tool_name=FileSearchTool.kind, + args={"queries": queries or ["test query"]}, + tool_call_id="call-fs-1", + ) + return_part = NativeToolReturnPart( + tool_name=FileSearchTool.kind, + tool_call_id="call-fs-1", + content={"status": "success", "results": results or []}, + ) + return [ + PartEndEvent(index=0, part=call), + PartStartEvent(index=1, part=return_part), + PartStartEvent(index=2, part=TextPart(content=content)), + AgentRunResultEvent(result=run_result), + ] + + +def set_streaming_query_agent_run( + mock_streaming_query_agent: Any, + mocker: MockerFixture, + **kwargs: Any, +) -> None: + """Configure mock agent.run_stream_events for /streaming_query integration tests.""" + mock_streaming_query_agent.run_stream_events.return_value = mock_agent_run_stream( + create_text_agent_stream_events(mocker, **kwargs) + ) + + # ========================================== # Fixtures # ========================================== @@ -707,7 +800,12 @@ def mock_llama_stack_client_fixture( @pytest.fixture(name="mock_query_agent") def mock_query_agent_fixture(mocker: MockerFixture) -> Any: - """Patch build_agent for /query and return the mock agent.""" + """Patch build_agent for /query and return the mock agent. + + Client inference is replaced by patching ``build_agent``. Configure + ``mock_query_agent.run.return_value`` in tests/fixtures to control the + agent response (see ``create_agent_run_result``). + """ mock_agent = mocker.AsyncMock() mock_agent.run = mocker.AsyncMock(return_value=create_agent_run_result(mocker)) mock_agent.build_agent_mock = mocker.patch( @@ -715,3 +813,23 @@ def mock_query_agent_fixture(mocker: MockerFixture) -> Any: return_value=mock_agent, ) return mock_agent + + +@pytest.fixture(name="mock_streaming_query_agent") +def mock_streaming_query_agent_fixture(mocker: MockerFixture) -> Any: + """Patch build_agent for /streaming_query and return the mock agent. + + Mirrors ``mock_query_agent``: client inference is replaced by patching + ``build_agent``. The only difference is the mocked method — ``run`` vs + ``run_stream_events`` — and its return value (``AgentRunResult`` vs a + stream of pydantic-ai events). + """ + mock_agent = mocker.Mock() + mock_agent.run_stream_events = mocker.Mock( + return_value=mock_agent_run_stream(create_text_agent_stream_events(mocker)) + ) + mock_agent.build_agent_mock = mocker.patch( + "utils.agents.streaming.build_agent", + return_value=mock_agent, + ) + return mock_agent diff --git a/tests/integration/endpoints/test_streaming_query_byok_integration.py b/tests/integration/endpoints/test_streaming_query_byok_integration.py index c539d4294..f9c480ccc 100644 --- a/tests/integration/endpoints/test_streaming_query_byok_integration.py +++ b/tests/integration/endpoints/test_streaming_query_byok_integration.py @@ -3,13 +3,12 @@ # pylint: disable=too-many-lines import json -from collections.abc import AsyncIterator, Generator +from collections.abc import Generator from typing import Any import pytest from fastapi import Request, status from fastapi.responses import StreamingResponse -from llama_stack_api.openai_responses import OpenAIResponseObject from pytest_mock import AsyncMockType, MockerFixture import constants @@ -17,6 +16,11 @@ from authentication.interface import AuthTuple from configuration import AppConfig from models.api.requests import QueryRequest +from tests.integration.conftest import ( + create_file_search_agent_stream_events, + create_text_agent_stream_events, + mock_agent_run_stream, +) from tests.integration.endpoints.test_query_byok_integration import ( _build_base_mock_client, _make_byok_vector_io_response, @@ -51,44 +55,19 @@ def _build_base_streaming_mock_client(mocker: MockerFixture) -> Any: """Build a base mock Llama Stack client configured for streaming responses. Extends the base query mock client with streaming-specific stubs: - conversations.items.create and a streaming responses.create. + conversations.items.create and a non-streaming responses.create stub for + topic summary generation. Agent inference is mocked separately via + ``mock_streaming_query_agent``. """ mock_client = _build_base_mock_client(mocker) - - # Streaming additions mock_client.conversations.items.create = mocker.AsyncMock() - async def _mock_stream() -> AsyncIterator[Any]: - chunk = mocker.MagicMock() - chunk.type = "response.output_text.done" - chunk.text = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." - ) - yield chunk - - # Emit response.completed so referenced_documents propagate to end event - completed_chunk = mocker.MagicMock() - completed_chunk.type = "response.completed" - mock_final = mocker.MagicMock(spec=OpenAIResponseObject) - mock_final.id = "response-inline-stream" - mock_final.error = None - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 50 - mock_usage.output_tokens = 20 - mock_final.usage = mock_usage - mock_final.output = [] - completed_chunk.response = mock_final - yield completed_chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_stream() + async def _responses_create(**_kwargs: Any) -> Any: mock_resp = mocker.MagicMock() mock_resp.output = [mocker.MagicMock(content="topic summary")] return mock_resp mock_client.responses.create = mocker.AsyncMock(side_effect=_responses_create) - return mock_client @@ -100,12 +79,25 @@ async def _responses_create(**kwargs: Any) -> Any: @pytest.fixture(name="mock_streaming_byok_client") def mock_streaming_byok_client_fixture( mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, ) -> Generator[Any, None, None]: """Mock Llama Stack client with BYOK inline RAG configured for streaming. Configures vector_io.query to return BYOK RAG chunks and sets vector_stores.list to empty (no tool-based vector stores). """ + mock_streaming_query_agent.run_stream_events.return_value = mock_agent_run_stream( + create_text_agent_stream_events( + mocker, + content=( + "Based on the documentation, OpenShift is a Kubernetes distribution." + ), + response_id="response-byok", + input_tokens=50, + output_tokens=20, + ) + ) + mock_holder_class = mocker.patch( "app.endpoints.streaming_query.AsyncLlamaStackClientHolder" ) @@ -128,12 +120,37 @@ def mock_streaming_byok_client_fixture( @pytest.fixture(name="mock_streaming_byok_tool_client") def mock_streaming_byok_tool_client_fixture( # pylint: disable=too-many-statements mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, ) -> Generator[Any, None, None]: """Mock Llama Stack client with BYOK tool RAG (file_search) for streaming. - Configures vector_stores.list with a BYOK store and responses.create - to stream file_search_call output items alongside the assistant message. + Configures vector_stores.list with a BYOK store and agent stream events + that include a file_search tool call alongside the assistant message. """ + mock_streaming_query_agent.run_stream_events.return_value = mock_agent_run_stream( + create_file_search_agent_stream_events( + mocker, + content=( + "Based on the documentation, OpenShift is a Kubernetes distribution." + ), + response_id="response-tool-stream", + queries=["What is OpenShift?"], + results=[ + { + "text": "OpenShift is a Kubernetes distribution by Red Hat.", + "score": 0.92, + "attributes": { + "doc_url": "https://docs.redhat.com/ocp/overview", + "title": "openshift-docs.txt", + "document_id": "doc-ocp-1", + }, + } + ], + input_tokens=60, + output_tokens=25, + ) + ) + mock_holder_class = mocker.patch( "app.endpoints.streaming_query.AsyncLlamaStackClientHolder" ) @@ -152,79 +169,6 @@ def mock_streaming_byok_tool_client_fixture( # pylint: disable=too-many-stateme mock_list_result.data = [mock_vector_store] mock_client.vector_stores.list.return_value = mock_list_result - # Build a streaming response with file_search and completion events - async def _mock_tool_stream() -> AsyncIterator[Any]: - # file_search output item done - item_done_chunk = mocker.MagicMock() - item_done_chunk.type = "response.output_item.done" - item_done_chunk.output_index = 0 - - mock_item = mocker.MagicMock() - mock_item.type = "file_search_call" - mock_item.id = "call-fs-stream-1" - mock_item.queries = ["What is OpenShift?"] - mock_item.status = "completed" - - mock_result = mocker.MagicMock() - mock_result.file_id = "doc-ocp-1" - mock_result.filename = "openshift-docs.txt" - mock_result.score = 0.92 - mock_result.text = "OpenShift is a Kubernetes distribution by Red Hat." - mock_result.attributes = { - "doc_url": "https://docs.redhat.com/ocp/overview", - } - mock_result.model_dump = mocker.Mock( - return_value={ - "file_id": "doc-ocp-1", - "filename": "openshift-docs.txt", - "score": 0.92, - "text": "OpenShift is a Kubernetes distribution.", - "attributes": {"doc_url": "https://docs.redhat.com/ocp/overview"}, - } - ) - mock_item.results = [mock_result] - item_done_chunk.item = mock_item - yield item_done_chunk - - # Text done - text_done_chunk = mocker.MagicMock() - text_done_chunk.type = "response.output_text.done" - text_done_chunk.text = ( - "Based on the documentation, OpenShift is a Kubernetes distribution." - ) - yield text_done_chunk - - # Response completed - completed_chunk = mocker.MagicMock() - completed_chunk.type = "response.completed" - mock_final_response = mocker.MagicMock(spec=OpenAIResponseObject) - mock_final_response.id = "response-tool-stream" - mock_final_response.error = None - - mock_usage = mocker.MagicMock() - mock_usage.input_tokens = 60 - mock_usage.output_tokens = 25 - mock_final_response.usage = mock_usage - - # file_search results in the final response output - mock_fs_output = mocker.MagicMock() - mock_fs_output.type = "file_search_call" - mock_fs_output.id = "call-fs-stream-1" - mock_fs_output.results = [mock_result] - mock_final_response.output = [mock_fs_output] - - completed_chunk.response = mock_final_response - yield completed_chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_tool_stream() - mock_resp = mocker.MagicMock() - mock_resp.output = [mocker.MagicMock(content="topic summary")] - return mock_resp - - mock_client.responses.create = mocker.AsyncMock(side_effect=_responses_create) - mock_holder_class.return_value.get_client.return_value = mock_client yield mock_client @@ -287,13 +231,14 @@ def byok_tool_config_fixture( async def test_streaming_query_byok_inline_rag_injects_context( byok_config: AppConfig, mock_streaming_byok_client: AsyncMockType, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: """Test that inline BYOK RAG context is injected into streaming query input. Verifies: - - RAG context from vector_io.query is injected into responses.create input + - RAG context from vector_io.query is injected into the agent prompt - Input contains formatted file_search results """ _ = byok_config @@ -309,20 +254,18 @@ async def test_streaming_query_byok_inline_rag_injects_context( assert isinstance(response, StreamingResponse) - # Verify RAG context was injected into responses.create input - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - call_kwargs = create_call.kwargs - input_text = call_kwargs["input"] - assert "file_search found" in input_text - assert "OpenShift is a Kubernetes distribution" in input_text + # Verify RAG context was injected into the agent prompt + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] + assert "file_search found" in prompt + assert "OpenShift is a Kubernetes distribution" in prompt @pytest.mark.asyncio async def test_streaming_query_byok_inline_rag_with_request_vector_store_ids( test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -386,6 +329,7 @@ async def test_streaming_query_byok_inline_rag_with_request_vector_store_ids( async def test_streaming_query_byok_request_vector_store_ids_filters_configured_stores( test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -448,18 +392,17 @@ async def test_streaming_query_byok_request_vector_store_ids_filters_configured_ call_kwargs = mock_client.vector_io.query.call_args.kwargs assert call_kwargs["vector_store_id"] == "vs-source-a" - # Verify source-a context was injected into the LLM input - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] - assert "file_search found" in input_text + # Verify source-a context was injected into the agent prompt + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] + assert "file_search found" in prompt @pytest.mark.asyncio async def test_streaming_query_byok_inline_rag_empty_vector_store_ids_no_context( byok_config: AppConfig, mock_streaming_byok_client: AsyncMockType, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -484,17 +427,16 @@ async def test_streaming_query_byok_inline_rag_empty_vector_store_ids_no_context assert isinstance(response, StreamingResponse) mock_streaming_byok_client.vector_io.query.assert_not_called() - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] - assert "file_search found" not in input_text + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] + assert "file_search found" not in prompt @pytest.mark.asyncio async def test_streaming_query_byok_inline_rag_error_handled_gracefully( byok_config: AppConfig, mock_streaming_byok_client: AsyncMockType, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -525,12 +467,9 @@ async def test_streaming_query_byok_inline_rag_error_handled_gracefully( assert isinstance(response, StreamingResponse) # No inline RAG context should be injected when the search fails. - # "file_search found" is the header added by _format_rag_context when chunks are present. - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_streaming_byok_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] - assert "file_search found" not in input_text + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] + assert "file_search found" not in prompt @pytest.mark.asyncio @@ -674,6 +613,7 @@ async def test_streaming_query_byok_tool_rag_emits_referenced_documents( async def test_streaming_query_byok_combined_inline_and_tool_rag( test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -725,17 +665,13 @@ async def test_streaming_query_byok_combined_inline_and_tool_rag( assert isinstance(response, StreamingResponse) assert response.status_code == status.HTTP_200_OK - # Verify inline RAG context was injected - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - call_kwargs = create_call.kwargs - input_text = call_kwargs["input"] - assert "file_search found" in input_text - - # Verify tool RAG file_search was passed - assert call_kwargs.get("tools") is not None - assert any(tool.get("type") == "file_search" for tool in call_kwargs["tools"]) + # Verify inline RAG context was injected and tool RAG was configured + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] + assert "file_search found" in prompt + responses_params = mock_streaming_query_agent.build_agent_mock.call_args[0][1] + assert responses_params.tools is not None + assert any(tool.type == "file_search" for tool in responses_params.tools) # ============================================================================== @@ -747,6 +683,7 @@ async def test_streaming_query_byok_combined_inline_and_tool_rag( async def test_streaming_query_byok_only_configured_rag_id_is_queried( test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -812,11 +749,9 @@ async def test_streaming_query_byok_only_configured_rag_id_is_queried( ] assert "vs-source-b" not in queried_stores - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] - assert "file_search found" in input_text + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] + assert "file_search found" in prompt # ============================================================================== @@ -828,6 +763,7 @@ async def test_streaming_query_byok_only_configured_rag_id_is_queried( async def test_streaming_query_byok_score_multiplier_shifts_priority( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -897,12 +833,10 @@ async def _side_effect(**kwargs: Any) -> Any: assert isinstance(response, StreamingResponse) # Verify Doc B (weighted 2.0) appears before Doc A (weighted 0.9) in context - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] - pos_b = input_text.find("Doc B low similarity boosted") - pos_a = input_text.find("Doc A high similarity") + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] + pos_b = prompt.find("Doc B low similarity boosted") + pos_a = prompt.find("Doc A high similarity") assert pos_b != -1 and pos_a != -1 assert pos_b < pos_a @@ -916,6 +850,7 @@ async def _side_effect(**kwargs: Any) -> Any: async def test_streaming_query_rag_content_limit_caps_context( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -969,23 +904,22 @@ async def test_streaming_query_rag_content_limit_caps_context( # pylint: disabl assert isinstance(response, StreamingResponse) # Verify the context header reports the capped count - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] expected_header = f"file_search found {constants.INLINE_RAG_MAX_CHUNKS} chunks:" - assert expected_header in input_text + assert expected_header in prompt # The lowest-scoring chunk should NOT be in the context - assert "Chunk content 0" not in input_text + assert "Chunk content 0" not in prompt # The highest-scoring chunk should be in the context - assert f"Chunk content {num_chunks - 1}" in input_text + assert f"Chunk content {num_chunks - 1}" in prompt @pytest.mark.asyncio async def test_streaming_query_rag_content_limit_caps_across_multiple_sources( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -1058,26 +992,25 @@ async def _side_effect(**kwargs: Any) -> Any: assert isinstance(response, StreamingResponse) - # responses.create is the mock for the OpenAI-compatible LLM API call. - # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] expected_header = f"file_search found {constants.INLINE_RAG_MAX_CHUNKS} chunks:" - assert expected_header in input_text + assert expected_header in prompt # Both sources must appear in the context (overlapping scores guarantee this) - assert "Source A chunk" in input_text - assert "Source B chunk" in input_text + assert "Source A chunk" in prompt + assert "Source B chunk" in prompt # Lowest-scoring chunks from each source must be dropped - assert "Source A chunk 0" not in input_text - assert "Source B chunk 0" not in input_text + assert "Source A chunk 0" not in prompt + assert "Source B chunk 0" not in prompt @pytest.mark.asyncio async def test_streaming_query_rag_content_limit_caps_inline_rag( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -1132,12 +1065,12 @@ async def test_streaming_query_rag_content_limit_caps_inline_rag( # pylint: dis assert isinstance(response, StreamingResponse) - create_call = mock_client.responses.create.call_args_list[0] - input_text = create_call.kwargs["input"] + await _collect_sse_events(response) + prompt = mock_streaming_query_agent.run_stream_events.call_args.args[0] expected_header = "file_search found 3 chunks:" - assert expected_header in input_text + assert expected_header in prompt # The highest-scoring chunk should be in the context - assert f"Chunk content {num_chunks - 1}" in input_text + assert f"Chunk content {num_chunks - 1}" in prompt # Low-scoring chunks should be excluded - assert "Chunk content 0" not in input_text + assert "Chunk content 0" not in prompt diff --git a/tests/integration/endpoints/test_streaming_query_integration.py b/tests/integration/endpoints/test_streaming_query_integration.py index 5a7e51620..efe09b72d 100644 --- a/tests/integration/endpoints/test_streaming_query_integration.py +++ b/tests/integration/endpoints/test_streaming_query_integration.py @@ -1,6 +1,6 @@ """Integration tests for the /streaming_query endpoint (using Responses API).""" -from collections.abc import AsyncIterator, Generator +from collections.abc import Generator from typing import Any import pytest @@ -19,13 +19,15 @@ @pytest.fixture(name="mock_streaming_llama_stack_client") def mock_llama_stack_streaming_fixture( mocker: MockerFixture, + mock_streaming_query_agent: AsyncMockType, ) -> Generator[Any, None, None]: """Mock only the Llama Stack client (holder + client). Configures the client so the real handler runs: models, vector_stores, - conversations, shields, vector_io, and responses.create returning a minimal - stream. No other code paths are patched. + conversations, shields, vector_io, and responses.create for topic summary. + Agent inference is mocked separately via ``mock_streaming_query_agent``. """ + _ = mock_streaming_query_agent mock_holder_class = mocker.patch( "app.endpoints.streaming_query.AsyncLlamaStackClientHolder" ) @@ -56,15 +58,7 @@ def mock_llama_stack_streaming_fixture( mock_vector_io_response.scores = [] mock_client.vector_io.query = mocker.AsyncMock(return_value=mock_vector_io_response) - async def _mock_stream() -> AsyncIterator[Any]: - chunk = mocker.MagicMock() - chunk.type = "response.output_text.done" - chunk.text = "test" - yield chunk - - async def _responses_create(**kwargs: Any) -> Any: - if kwargs.get("stream", True): - return _mock_stream() + async def _responses_create(**_kwargs: Any) -> Any: mock_resp = mocker.MagicMock() mock_resp.output = [mocker.MagicMock(content="topic summary")] return mock_resp @@ -153,6 +147,7 @@ async def test_streaming_query_v2_endpoint_attachment_handling( # pylint: disab test_case: dict, test_config: AppConfig, mock_streaming_llama_stack_client: AsyncMockType, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, ) -> None: @@ -170,11 +165,13 @@ async def test_streaming_query_v2_endpoint_attachment_handling( # pylint: disab expected_status, expected_error) test_config: Test configuration mock_streaming_llama_stack_client: Mocked Llama Stack client + mock_streaming_query_agent: Mocked Pydantic AI agent for build_agent test_request: FastAPI request test_auth: noop authentication tuple """ _ = test_config _ = mock_streaming_llama_stack_client + _ = mock_streaming_query_agent attachments = test_case["attachments"] expected_status = test_case["expected_status"] @@ -254,6 +251,7 @@ async def test_streaming_query_endpoint_returns_401_for_mcp_oauth( # pylint: di test_case: dict, test_config: AppConfig, mock_streaming_llama_stack_client: Any, + mock_streaming_query_agent: AsyncMockType, test_request: Request, test_auth: AuthTuple, mocker: MockerFixture, @@ -272,12 +270,14 @@ async def test_streaming_query_endpoint_returns_401_for_mcp_oauth( # pylint: di expect_www_authenticate) test_config: Test configuration mock_streaming_llama_stack_client: Mocked Llama Stack client + mock_streaming_query_agent: Mocked Pydantic AI agent for build_agent test_request: FastAPI request test_auth: noop authentication tuple mocker: pytest-mock fixture """ _ = test_config _ = mock_streaming_llama_stack_client + _ = mock_streaming_query_agent www_authenticate = test_case["www_authenticate"] expect_www_authenticate = test_case["expect_www_authenticate"] diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index 45762c03e..dd5efd227 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -204,19 +204,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -291,19 +291,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -389,19 +389,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -485,19 +485,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id", @@ -583,19 +583,19 @@ async def mock_generator() -> AsyncIterator[str]: mock_turn_summary = TurnSummary() mocker.patch( - "app.endpoints.streaming_query.retrieve_response_generator", - return_value=(mock_generator(), mock_turn_summary), + "app.endpoints.streaming_query.retrieve_agent_response_generator", + new=mocker.AsyncMock(return_value=(mock_generator(), mock_turn_summary)), ) - async def mock_generate_response( + async def mock_generate_agent_response( *_args: Any, **_kwargs: Any ) -> AsyncIterator[str]: async for item in mock_generator(): yield item mocker.patch( - "app.endpoints.streaming_query.generate_response", - side_effect=mock_generate_response, + "app.endpoints.streaming_query.generate_agent_response", + side_effect=mock_generate_agent_response, ) mocker.patch( "app.endpoints.streaming_query.normalize_conversation_id",