diff --git a/src/mcp/server/elicitation.py b/src/mcp/server/elicitation.py index a34708d38d..b09bee2160 100644 --- a/src/mcp/server/elicitation.py +++ b/src/mcp/server/elicitation.py @@ -9,6 +9,7 @@ from pydantic_core import core_schema from mcp.server.session import ServerSession +from mcp.shared.session import ProgressFnT from mcp.types import RequestId # Internal surface package; imported as the gate's source of truth for spec-valid property schemas. @@ -87,6 +88,7 @@ async def elicit_with_validation( message: str, schema: type[ElicitSchemaModelT], related_request_id: RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user with schema validation (form mode). @@ -105,6 +107,7 @@ async def elicit_with_validation( message=message, requested_schema=json_schema, related_request_id=related_request_id, + progress_callback=progress_callback, ) if result.action == "accept" and result.content is not None: @@ -126,6 +129,7 @@ async def elicit_url( url: str, elicitation_id: str, related_request_id: RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> UrlElicitationResult: """Elicit information from the user via out-of-band URL navigation (URL mode). @@ -155,6 +159,7 @@ async def elicit_url( url=url, elicitation_id=elicitation_id, related_request_id=related_request_id, + progress_callback=progress_callback, ) if result.action == "accept": diff --git a/src/mcp/server/mcpserver/context.py b/src/mcp/server/mcpserver/context.py index 92de074d34..2392f6579e 100644 --- a/src/mcp/server/mcpserver/context.py +++ b/src/mcp/server/mcpserver/context.py @@ -14,6 +14,7 @@ elicit_with_validation, ) from mcp.server.lowlevel.helper_types import ReadResourceContents +from mcp.shared.dispatcher import ProgressFnT from mcp.types import LoggingLevel if TYPE_CHECKING: @@ -121,6 +122,7 @@ async def elicit( self, message: str, schema: type[ElicitSchemaModelT], + progress_callback: ProgressFnT | None = None, ) -> ElicitationResult[ElicitSchemaModelT]: """Elicit information from the client/user. @@ -134,6 +136,7 @@ async def elicit( message: Message to present to the user schema: A Pydantic model class defining the expected response structure. According to the specification, only primitive types are allowed. + progress_callback: Optional callback for receiving progress notifications. Returns: An ElicitationResult containing the action taken and the data if accepted @@ -148,6 +151,7 @@ async def elicit( message=message, schema=schema, related_request_id=self.request_id, + progress_callback=progress_callback, ) async def elicit_url( @@ -155,6 +159,7 @@ async def elicit_url( message: str, url: str, elicitation_id: str, + progress_callback: ProgressFnT | None = None, ) -> UrlElicitationResult: """Request URL mode elicitation from the client. @@ -173,6 +178,7 @@ async def elicit_url( message: Human-readable explanation of why the interaction is needed url: The URL the user should navigate to elicitation_id: Unique identifier for tracking this elicitation + progress_callback: Optional callback for receiving progress notifications. Returns: UrlElicitationResult indicating accept, decline, or cancel @@ -183,6 +189,7 @@ async def elicit_url( url=url, elicitation_id=elicitation_id, related_request_id=self.request_id, + progress_callback=progress_callback, ) async def log( diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 6254a01ee5..41ea70fc1e 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -158,6 +158,7 @@ async def create_message( tools: None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult: """Overload: Without tools, returns single content.""" ... @@ -177,6 +178,7 @@ async def create_message( tools: list[types.Tool], tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResultWithTools: """Overload: With tools, returns array-capable content.""" ... @@ -195,6 +197,7 @@ async def create_message( tools: list[types.Tool] | None = None, tool_choice: types.ToolChoice | None = None, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.CreateMessageResult | types.CreateMessageResultWithTools: """Send a sampling/create_message request. @@ -214,6 +217,7 @@ async def create_message( tool_choice: Optional control over tool usage behavior. Requires client to have sampling.tools capability. related_request_id: Optional ID of a related request. + progress_callback: Optional callback for receiving progress notifications. Returns: The sampling result from the client. @@ -250,11 +254,13 @@ async def create_message( request=request, result_type=types.CreateMessageResultWithTools, metadata=metadata_obj, + progress_callback=progress_callback, ) return await self.send_request( request=request, result_type=types.CreateMessageResult, metadata=metadata_obj, + progress_callback=progress_callback, ) async def list_roots(self) -> types.ListRootsResult: @@ -271,6 +277,7 @@ async def elicit( message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -278,6 +285,7 @@ async def elicit( message: The message to present to the user. requested_schema: Schema defining the expected response structure. related_request_id: Optional ID of the request that triggered this elicitation. + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response. @@ -286,13 +294,14 @@ async def elicit( This method is deprecated in favor of elicit_form(). It remains for backward compatibility but new code should use elicit_form(). """ - return await self.elicit_form(message, requested_schema, related_request_id) + return await self.elicit_form(message, requested_schema, related_request_id, progress_callback) async def elicit_form( self, message: str, requested_schema: types.ElicitRequestedSchema, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a form mode elicitation/create request. @@ -300,6 +309,7 @@ async def elicit_form( message: The message to present to the user. requested_schema: Schema defining the expected response structure. related_request_id: Optional ID of the request that triggered this elicitation. + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response with form data. @@ -318,6 +328,7 @@ async def elicit_form( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def elicit_url( @@ -326,6 +337,7 @@ async def elicit_url( url: str, elicitation_id: str, related_request_id: types.RequestId | None = None, + progress_callback: ProgressFnT | None = None, ) -> types.ElicitResult: """Send a URL mode elicitation/create request. @@ -337,6 +349,7 @@ async def elicit_url( url: The URL the user should navigate to. elicitation_id: Unique identifier for tracking this elicitation. related_request_id: Optional ID of the request that triggered this elicitation. + progress_callback: Optional callback for receiving progress notifications. Returns: The client's response indicating acceptance, decline, or cancellation. @@ -356,6 +369,7 @@ async def elicit_url( ), types.ElicitResult, metadata=ServerMessageMetadata(related_request_id=related_request_id), + progress_callback=progress_callback, ) async def send_ping(self) -> types.EmptyResult: diff --git a/tests/shared/test_progress_notifications.py b/tests/shared/test_progress_notifications.py new file mode 100644 index 0000000000..fcf4550daa --- /dev/null +++ b/tests/shared/test_progress_notifications.py @@ -0,0 +1,121 @@ +from typing import Any + +import pytest + +from mcp import Client +from mcp.client import ClientRequestContext +from mcp.server.mcpserver import Context, MCPServer +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + ElicitRequestParams, + ElicitResult, + SamplingMessage, + TextContent, +) + + +@pytest.mark.anyio +async def test_server_create_message_progress_callback(): + """Test that ServerSession.create_message() accepts and passes through progress_callback.""" + server = MCPServer("test") + + # Track progress updates received by the server's progress callback + progress_updates: list[dict[str, Any]] = [] + + async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append({"progress": progress, "total": total, "message": message}) + + @server.tool("trigger_sampling") + async def trigger_sampling_tool(text: str, ctx: Context) -> str: + result = await ctx.session.create_message( + messages=[SamplingMessage(role="user", content=TextContent(type="text", text=text))], + max_tokens=100, + progress_callback=my_progress_callback, + ) + assert isinstance(result.content, TextContent) + return result.content.text + + async def sampling_callback( + context: ClientRequestContext, + params: CreateMessageRequestParams, + ) -> CreateMessageResult: + # Send progress notifications back to the server using the progress token + if context.meta and "progress_token" in context.meta: # pragma: no branch + token = context.meta["progress_token"] + await context.session.send_progress_notification( + progress_token=token, + progress=0.5, + total=1.0, + message="Halfway done", + ) + await context.session.send_progress_notification( + progress_token=token, + progress=1.0, + total=1.0, + message="Complete", + ) + + return CreateMessageResult( + role="assistant", + content=TextContent(type="text", text="LLM response"), + model="test-model", + stop_reason="endTurn", + ) + + async with Client(server, sampling_callback=sampling_callback) as client: + result = await client.call_tool("trigger_sampling", {"text": "Hello"}) + assert result.is_error is False + + # Verify the progress callback was invoked with correct values + assert len(progress_updates) == 2 + assert progress_updates[0] == {"progress": 0.5, "total": 1.0, "message": "Halfway done"} + assert progress_updates[1] == {"progress": 1.0, "total": 1.0, "message": "Complete"} + + +@pytest.mark.anyio +async def test_server_elicit_form_progress_callback(): + """Test that ServerSession.elicit_form() accepts and passes through progress_callback.""" + server = MCPServer("test") + + # Track progress updates received by the server's progress callback + progress_updates: list[dict[str, Any]] = [] + + async def my_progress_callback(progress: float, total: float | None, message: str | None) -> None: + progress_updates.append({"progress": progress, "total": total, "message": message}) + + @server.tool("trigger_elicitation") + async def trigger_elicitation_tool(text: str, ctx: Context) -> str: + result = await ctx.session.elicit_form( + message=text, + requested_schema={"type": "object", "properties": {"name": {"type": "string"}}}, + progress_callback=my_progress_callback, + ) + return result.action + + async def elicitation_callback( + context: ClientRequestContext, + params: ElicitRequestParams, + ) -> ElicitResult: + # Send progress notifications back to the server using the progress token + if context.meta and "progress_token" in context.meta: # pragma: no branch + token = context.meta["progress_token"] + await context.session.send_progress_notification( + progress_token=token, + progress=1.0, + total=1.0, + message="User responded", + ) + + return ElicitResult( + action="accept", + content={"name": "test"}, + ) + + async with Client(server, elicitation_callback=elicitation_callback) as client: + result = await client.call_tool("trigger_elicitation", {"text": "Enter name"}) + assert result.is_error is False + + # Verify the progress callback was invoked + assert len(progress_updates) == 1 + assert progress_updates[0] == {"progress": 1.0, "total": 1.0, "message": "User responded"}