diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java index 1b4eaca97..e00ea41a3 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/client/transport/StdioClientTransport.java @@ -227,16 +227,13 @@ private void handleIncomingErrors() { @Override public Mono sendMessage(JSONRPCMessage message) { - if (this.outboundSink.tryEmitNext(message).isSuccess()) { - // TODO: essentially we could reschedule ourselves in some time and make - // another attempt with the already read data but pause reading until - // success - // In this approach we delegate the retry and the backpressure onto the - // caller. This might be enough for most cases. + try { + // busyLooping retries FAIL_NON_SERIALIZED under concurrent senders + this.outboundSink.emitNext(message, Sinks.EmitFailureHandler.busyLooping(Duration.ofMillis(100))); return Mono.empty(); } - else { - return Mono.error(new RuntimeException("Failed to enqueue message")); + catch (Sinks.EmissionException e) { + return Mono.error(new RuntimeException("Failed to enqueue message", e)); } } diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java index 66cc304d6..2f25da7d6 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/StdioServerTransportProvider.java @@ -10,13 +10,13 @@ import java.io.InputStreamReader; import java.io.OutputStream; import java.nio.charset.StandardCharsets; +import java.time.Duration; import java.util.List; import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Function; import io.modelcontextprotocol.json.TypeRef; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage; import io.modelcontextprotocol.spec.McpServerSession; @@ -161,11 +161,12 @@ public StdioMcpSessionTransport() { public Mono sendMessage(McpSchema.JSONRPCMessage message) { return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> { - if (outboundSink.tryEmitNext(message).isSuccess()) { + try { + outboundSink.emitNext(message, Sinks.EmitFailureHandler.busyLooping(Duration.ofMillis(100))); return Mono.empty(); } - else { - return Mono.error(new RuntimeException("Failed to enqueue message")); + catch (Sinks.EmissionException e) { + return Mono.error(new RuntimeException("Failed to enqueue message", e)); } })); } diff --git a/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java index 6c2cc2bf4..198e1d89e 100644 --- a/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java +++ b/mcp-test/src/test/java/io/modelcontextprotocol/server/transport/StdioServerTransportProviderTests.java @@ -17,15 +17,15 @@ import java.util.concurrent.atomic.AtomicReference; import io.modelcontextprotocol.json.McpJsonDefaults; -import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpServerSession; import io.modelcontextprotocol.spec.McpServerTransport; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; import reactor.test.StepVerifier; import static org.assertj.core.api.Assertions.assertThat; @@ -98,7 +98,7 @@ void shouldCreateSessionWhenSessionFactoryIsSet() { } @Test - void shouldHandleIncomingMessages() throws Exception { + void shouldHandleIncomingMessages() { String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n"; InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8)); @@ -228,7 +228,7 @@ void shouldHandleNotificationBeforeSessionFactoryIsSet() { } @Test - void shouldHandleInvalidJsonMessage() throws Exception { + void shouldHandleInvalidJsonMessage() { // Write an invalid JSON message to the input stream String jsonMessage = "{invalid json}\n"; @@ -247,7 +247,7 @@ void shouldHandleInvalidJsonMessage() throws Exception { } @Test - void shouldHandleSessionClose() throws Exception { + void shouldHandleSessionClose() { // Set session factory transportProvider.setSessionFactory(sessionFactory); @@ -258,4 +258,47 @@ void shouldHandleSessionClose() throws Exception { verify(mockSession).closeGracefully(); } + @Test + void shouldHandleConcurrentSendMessage() { + // Redirect the transport output to a buffer so we can verify every message lands + ByteArrayOutputStream output = new ByteArrayOutputStream(); + PrintStream outputPrintStream = new PrintStream(output, true, StandardCharsets.UTF_8); + transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper(), System.in, outputPrintStream); + + // Capture the inner McpServerTransport handed to the session factory + AtomicReference transportRef = new AtomicReference<>(); + McpServerSession.Factory capturingFactory = transport -> { + transportRef.set(transport); + return mockSession; + }; + + // Set session factory + transportProvider.setSessionFactory(capturingFactory); + + McpServerTransport transport = transportRef.get(); + assertThat(transport).isNotNull(); + + // Fan sendMessage out across 16 parallel rails to race against the unicast sink + int messageCount = 500; + Flux concurrentSends = Flux.range(0, messageCount) + .parallel(16) + .runOn(Schedulers.parallel()) + .flatMap(i -> transport + .sendMessage( + new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, "test/notification", Map.of())) + .thenReturn(i)) + .sequential(); + + // Every send should complete successfully (no FAIL_NON_SERIALIZED errors) + StepVerifier.create(concurrentSends).expectNextCount(messageCount).verifyComplete(); + + // Writes happen asynchronously on the outbound scheduler; wait briefly for drain + // and verify every message was written as its own newline-delimited JSON line + StepVerifier + .create(Mono.delay(java.time.Duration.ofMillis(500)) + .then(Mono.fromCallable(() -> output.toString(StandardCharsets.UTF_8).lines().count()))) + .assertNext(lineCount -> assertThat(lineCount).isEqualTo(messageCount)) + .verifyComplete(); + } + }