Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -227,16 +227,13 @@ private void handleIncomingErrors() {

@Override
public Mono<Void> sendMessage(JSONRPCMessage message) {
if (this.outboundSink.tryEmitNext(message).isSuccess()) {
// TODO: essentially we could reschedule ourselves in some time and make
// another attempt with the already read data but pause reading until
// success
// In this approach we delegate the retry and the backpressure onto the
// caller. This might be enough for most cases.
try {
// busyLooping retries FAIL_NON_SERIALIZED under concurrent senders
this.outboundSink.emitNext(message, Sinks.EmitFailureHandler.busyLooping(Duration.ofMillis(100)));
return Mono.empty();
}
else {
return Mono.error(new RuntimeException("Failed to enqueue message"));
catch (Sinks.EmissionException e) {
return Mono.error(new RuntimeException("Failed to enqueue message", e));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;

import io.modelcontextprotocol.json.TypeRef;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpSchema.JSONRPCMessage;
import io.modelcontextprotocol.spec.McpServerSession;
Expand Down Expand Up @@ -161,11 +161,12 @@ public StdioMcpSessionTransport() {
public Mono<Void> sendMessage(McpSchema.JSONRPCMessage message) {

return Mono.zip(inboundReady.asMono(), outboundReady.asMono()).then(Mono.defer(() -> {
if (outboundSink.tryEmitNext(message).isSuccess()) {
try {
outboundSink.emitNext(message, Sinks.EmitFailureHandler.busyLooping(Duration.ofMillis(100)));
return Mono.empty();
}
else {
return Mono.error(new RuntimeException("Failed to enqueue message"));
catch (Sinks.EmissionException e) {
return Mono.error(new RuntimeException("Failed to enqueue message", e));
}
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
import java.util.concurrent.atomic.AtomicReference;

import io.modelcontextprotocol.json.McpJsonDefaults;
import io.modelcontextprotocol.spec.McpError;
import io.modelcontextprotocol.spec.McpSchema;
import io.modelcontextprotocol.spec.McpServerSession;
import io.modelcontextprotocol.spec.McpServerTransport;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Disabled;
import org.junit.jupiter.api.Test;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.test.StepVerifier;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -98,7 +98,7 @@ void shouldCreateSessionWhenSessionFactoryIsSet() {
}

@Test
void shouldHandleIncomingMessages() throws Exception {
void shouldHandleIncomingMessages() {

String jsonMessage = "{\"jsonrpc\":\"2.0\",\"method\":\"test\",\"params\":{},\"id\":1}\n";
InputStream stream = new ByteArrayInputStream(jsonMessage.getBytes(StandardCharsets.UTF_8));
Expand Down Expand Up @@ -228,7 +228,7 @@ void shouldHandleNotificationBeforeSessionFactoryIsSet() {
}

@Test
void shouldHandleInvalidJsonMessage() throws Exception {
void shouldHandleInvalidJsonMessage() {

// Write an invalid JSON message to the input stream
String jsonMessage = "{invalid json}\n";
Expand All @@ -247,7 +247,7 @@ void shouldHandleInvalidJsonMessage() throws Exception {
}

@Test
void shouldHandleSessionClose() throws Exception {
void shouldHandleSessionClose() {
// Set session factory
transportProvider.setSessionFactory(sessionFactory);

Expand All @@ -258,4 +258,47 @@ void shouldHandleSessionClose() throws Exception {
verify(mockSession).closeGracefully();
}

@Test
void shouldHandleConcurrentSendMessage() {
// Redirect the transport output to a buffer so we can verify every message lands
ByteArrayOutputStream output = new ByteArrayOutputStream();
PrintStream outputPrintStream = new PrintStream(output, true, StandardCharsets.UTF_8);
transportProvider = new StdioServerTransportProvider(McpJsonDefaults.getMapper(), System.in, outputPrintStream);

// Capture the inner McpServerTransport handed to the session factory
AtomicReference<McpServerTransport> transportRef = new AtomicReference<>();
McpServerSession.Factory capturingFactory = transport -> {
transportRef.set(transport);
return mockSession;
};

// Set session factory
transportProvider.setSessionFactory(capturingFactory);

McpServerTransport transport = transportRef.get();
assertThat(transport).isNotNull();

// Fan sendMessage out across 16 parallel rails to race against the unicast sink
int messageCount = 500;
Flux<Integer> concurrentSends = Flux.range(0, messageCount)
.parallel(16)
.runOn(Schedulers.parallel())
.flatMap(i -> transport
.sendMessage(
new McpSchema.JSONRPCNotification(McpSchema.JSONRPC_VERSION, "test/notification", Map.of()))
.thenReturn(i))
.sequential();

// Every send should complete successfully (no FAIL_NON_SERIALIZED errors)
StepVerifier.create(concurrentSends).expectNextCount(messageCount).verifyComplete();

// Writes happen asynchronously on the outbound scheduler; wait briefly for drain
// and verify every message was written as its own newline-delimited JSON line
StepVerifier
.create(Mono.delay(java.time.Duration.ofMillis(500))
.then(Mono.fromCallable(() -> output.toString(StandardCharsets.UTF_8).lines().count())))
.assertNext(lineCount -> assertThat(lineCount).isEqualTo(messageCount))
.verifyComplete();
}

}