diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java index e5f57bad8..f7cb4d619 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpAsyncServer.java @@ -999,12 +999,9 @@ private McpRequestHandler completionCompleteRequestHan .message("Prompt not found: " + promptReference.name()) .build()); } - if (!promptSpec.prompt() - .arguments() - .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { + List arguments = promptSpec.prompt().arguments(); + if (arguments == null + || !arguments.stream().filter(arg -> arg.name().equals(argumentName)).findFirst().isPresent()) { logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java index 18fc85786..46fdf0aab 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/McpStatelessAsyncServer.java @@ -744,12 +744,9 @@ private McpStatelessRequestHandler completionCompleteR .message("Prompt not found: " + promptReference.name()) .build()); } - if (!promptSpec.prompt() - .arguments() - .stream() - .filter(arg -> arg.name().equals(argumentName)) - .findFirst() - .isPresent()) { + List arguments = promptSpec.prompt().arguments(); + if (arguments == null + || !arguments.stream().filter(arg -> arg.name().equals(argumentName)).findFirst().isPresent()) { logger.warn("Argument not found: {} in prompt: {}", argumentName, promptReference.name()); diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java index 3d40453a3..4a18fa1cd 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/HttpServletStatelessIntegrationTests.java @@ -184,10 +184,10 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { true // hasMore )); - AtomicReference samplingRequest = new AtomicReference<>(); + AtomicReference completeRequest = new AtomicReference<>(); BiFunction completionHandler = (transportContext, request) -> { - samplingRequest.set(request); + completeRequest.set(request); return completionResponse; }; @@ -214,9 +214,9 @@ void testCompletionShouldReturnExpectedSuggestions(String clientType) { assertThat(result).isNotNull(); - assertThat(samplingRequest.get().argument().name()).isEqualTo("language"); - assertThat(samplingRequest.get().argument().value()).isEqualTo("py"); - assertThat(samplingRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); + assertThat(completeRequest.get().argument().name()).isEqualTo("language"); + assertThat(completeRequest.get().argument().value()).isEqualTo("py"); + assertThat(completeRequest.get().ref().type()).isEqualTo(PromptReference.TYPE); } finally { mcpServer.close(); diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java index 54fb80a78..5a26402c7 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/McpCompletionTests.java @@ -13,11 +13,9 @@ import org.apache.catalina.LifecycleState; import org.apache.catalina.startup.Tomcat; -import static org.assertj.core.api.Assertions.assertThat; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import io.modelcontextprotocol.client.McpClient; import io.modelcontextprotocol.client.transport.HttpClientSseClientTransport; @@ -37,6 +35,9 @@ import io.modelcontextprotocol.spec.McpSchema.ServerCapabilities; import io.modelcontextprotocol.spec.McpError; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + /** * Tests for completion functionality with context support. * @@ -324,4 +325,35 @@ void testCompletionErrorOnMissingContext() { mcpServer.close(); } + @Test + void testPromptWithoutArgumentsCompletionForArgument() { + BiFunction completionHandler = (exchange, + request) -> new CompleteResult(new CompleteResult.CompleteCompletion(List.of("test"), 1, false)); + + McpSchema.Prompt prompt = new Prompt("test-prompt", "this is a test prompt", null); + + var mcpServer = McpServer.sync(mcpServerTransportProvider) + .capabilities(ServerCapabilities.builder().completions().build()) + .prompts(new McpServerFeatures.SyncPromptSpecification(prompt, + (mcpSyncServerExchange, getPromptRequest) -> null)) + .completions(new McpServerFeatures.SyncCompletionSpecification( + new PromptReference(PromptReference.TYPE, "test-prompt"), completionHandler)) + .build(); + + try (var mcpClient = clientBuilder.clientInfo(new McpSchema.Implementation("Sample " + "client", "0.0.0")) + .build()) { + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + + // try completing an argument knowing that the prompt is not parameterized + CompleteRequest request = new CompleteRequest(new PromptReference(PromptReference.TYPE, "test-prompt"), + new CompleteRequest.CompleteArgument("arg", "val")); + + CompleteResult completeResult = mcpClient.completeCompletion(request); + assertThat(completeResult.completion().values()).isEmpty(); + } + + mcpServer.close(); + } + } \ No newline at end of file