Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the "Problem" section of the PR description, the AI says:

The current implementation already has a broad catch block...So the literal "not passed to the handler" behavior is not obvious on current main or the 5.5.0 tag.

The reality is not that the bug ("behavior") is "not obvious", but that it does not exist (at least, the way it is described in the ticket JAVA-5855.

The AI should have clearly stated that the bug reported in the ticket does not exist, instead of claiming that the bug is not obvious, and then fixing a completely different bug.

Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import javax.net.ssl.SSLParameters;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.StandardSocketOptions;
import java.nio.ByteBuffer;
import java.nio.channels.CompletionHandler;
Expand Down Expand Up @@ -209,35 +210,60 @@ private static class TlsChannelStream extends AsynchronousChannelStream {
@Override
public void openAsync(final OperationContext operationContext, final AsyncCompletionHandler<Void> handler) {
isTrue("unopened", getChannel() == null);
SocketChannel socketChannel = null;
SelectorMonitor.SocketRegistration socketRegistration = null;
boolean registered = false;
try {
SocketChannel socketChannel = SocketChannel.open();
socketChannel.configureBlocking(false);
//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
InetSocketAddress socketAddress = getSocketAddresses(getServerAddress(), inetAddressResolver).get(0);
SocketChannel openedSocketChannel = SocketChannel.open();
socketChannel = openedSocketChannel;
Comment on lines +220 to +221
Copy link
Copy Markdown
Member

@vbabanin vbabanin May 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

socketChannel = openedSocketChannel looks redundant.

Right now the channel is assigned to openedSocketChannel and then immediately re-assigned to socketChannel, but there’s no semantic distinction: socketChannel only ever comes from SocketChannel.open(), and the code doesn’t use the two variables to represent different states (partially-opened, closed, etc.) or different sources.

Suggestion: keep a single variable (SocketChannel socketChannel) and remove openedSocketChannel to reduce cognitive load and avoid implying there’s a meaningful difference when there isn’t.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its used in lambdas so is needed to be final. Hence the reassignment.

openedSocketChannel.configureBlocking(false);

socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
openedSocketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true);
openedSocketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true);
if (getSettings().getReceiveBufferSize() > 0) {
socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
openedSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize());
}
if (getSettings().getSendBufferSize() > 0) {
socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
openedSocketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize());
}
//getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception.
int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs();
socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0));
SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration(
socketChannel, () -> initializeTslChannel(handler, socketChannel));
openedSocketChannel.connect(socketAddress);
socketRegistration = new SelectorMonitor.SocketRegistration(
openedSocketChannel, () -> initializeTslChannel(handler, openedSocketChannel));

Comment on lines +232 to 235
if (connectTimeoutMs > 0) {
scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs);
}
selectorMonitor.register(socketRegistration);
registered = true;
} catch (IOException e) {
closeUnregisteredSocketChannel(socketChannel, socketRegistration, registered, e);
handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e));
} catch (Throwable t) {
closeUnregisteredSocketChannel(socketChannel, socketRegistration, registered, t);
handler.failed(t);
Comment on lines +242 to 246
Copy link
Copy Markdown
Member

@stIncMale stIncMale May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before the PR, handler was guaranteed to be called if anything goes wrong (if Throwable happens), but that is not so anymore. Not completing a handler in asynchronous callback-based code is equivalent of a method never returning control in synchronous code. It's a serious bug, which causes any application code that was supposed to be executed, to never be executed. The latter may lead to, for example, resource leaks, dead locks caused by locks not being released.

Writing

try {
    closeUnregisteredSocketChannel(socketChannel, socketRegistration, registered, t);
} finally {
    handler.failed(t);
}

would have solved the problem.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed in a reimplementation #1981 (after this change was reverted). Technically a regression from the previous unconditional guarantee. But the practical risk is near-zero because:

  1. tryCancelPendingConnection() is AtomicReference.getAndSet(null) — cannot throw
  2. socketChannel.close() IOException is already caught
  3. failure.addSuppressed(e) can't self-suppress (different exception instances)

The risk would be an unchecked exception from SocketChannel.close() on a non-standard implementation, which doesn't apply to JDK's SocketChannelImpl.

}
Comment on lines 238 to 247
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The registered flag is redundant here and adds complexity in both this method and closeUnregisteredSocketChannel. Given the current control flow, registered can’t be true in any of the catch blocks, so the parameter doesn’t appear to affect behavior.

Suggestion: remove the registered parameter entirely.

}

private void closeUnregisteredSocketChannel(@Nullable final SocketChannel socketChannel,
@Nullable final SelectorMonitor.SocketRegistration socketRegistration,
final boolean registered, final Throwable failure) {
if (!registered) {
if (socketRegistration != null) {
socketRegistration.tryCancelPendingConnection();
}
if (socketChannel != null) {
try {
socketChannel.close();
} catch (IOException e) {
failure.addSuppressed(e);
}
}
}
}

private void scheduleTimeoutInterruption(final AsyncCompletionHandler<Void> handler,
final SelectorMonitor.SocketRegistration socketRegistration,
final int connectTimeoutMs) {
Expand Down Expand Up @@ -384,4 +410,3 @@ public void close() throws IOException {
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,22 @@
package com.mongodb.internal.connection;

import com.mongodb.ClusterFixture;
import com.mongodb.MongoSocketException;
import com.mongodb.MongoSocketOpenException;
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.TimeoutContext;
import com.mongodb.internal.TimeoutSettings;
import com.mongodb.spi.dns.InetAddressResolver;
import org.bson.ByteBuf;
import org.bson.ByteBufNIO;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import org.mockito.ArgumentCaptor;
import org.mockito.MockedStatic;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
Expand All @@ -37,11 +41,13 @@
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import java.io.IOException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.InterruptedByTimeoutException;
import java.nio.channels.SocketChannel;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;

import static com.mongodb.ClusterFixture.getPrimaryServerDescription;
Expand All @@ -52,10 +58,12 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
import static org.junit.jupiter.api.Assumptions.assumeTrue;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyInt;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.atLeast;
Expand All @@ -68,6 +76,69 @@ class TlsChannelStreamFunctionalTest {
private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1";
private static final int UNREACHABLE_PORT = 65333;

@Test
void shouldFailAsyncCompletionHandlerWithoutOpeningSocketChannelIfNameResolutionFails() {
//given
ServerAddress serverAddress = new ServerAddress();
MongoSocketException exception = new MongoSocketException("Temporary failure in name resolution", serverAddress);
InetAddressResolver inetAddressResolver = new InetAddressResolver() {
@Override
public List<InetAddress> lookupByName(final String host) {
throw exception;
}
};
Comment on lines +84 to +89
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no reason to declare and instantiate an anonymous class here, and this is against the code style we currently use. Instead, a lambda expression should have been used:

InetAddressResolver inetAddressResolver = host -> {
    throw exception;
};

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a nit that in reality is the same thing. I would make that an optional fix for a PR reviewer. I'll look into hardening our AGENTS.md to help steer AI agents.


try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(100, TimeUnit.MILLISECONDS)
.build(), SSL_SETTINGS);
Stream stream = streamFactory.create(serverAddress);
@SuppressWarnings("unchecked")
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);

//when
stream.openAsync(createOperationContext(100), handler);

//then
verify(handler).failed(exception);
verify(handler, times(0)).completed(null);
socketChannelMockedStatic.verify(SocketChannel::open, times(0));
}
}

@Test
void shouldCloseSocketChannelIfConnectFailsBeforeRegistration() throws IOException {
//given
ServerAddress serverAddress = new ServerAddress();
IOException exception = new IOException("connect failed");
InetAddressResolver inetAddressResolver = host -> Collections.singletonList(InetAddress.getLoopbackAddress());

try (SocketChannel socketChannel = Mockito.spy(SocketChannel.open());
StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver);
MockedStatic<SocketChannel> socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) {
socketChannelMockedStatic.when(SocketChannel::open).thenReturn(socketChannel);
Mockito.doThrow(exception).when(socketChannel).connect(any());
StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder()
.connectTimeout(100, TimeUnit.MILLISECONDS)
.build(), SSL_SETTINGS);
Stream stream = streamFactory.create(serverAddress);
@SuppressWarnings("unchecked")
AsyncCompletionHandler<Void> handler = Mockito.mock(AsyncCompletionHandler.class);
ArgumentCaptor<Throwable> failureCaptor = ArgumentCaptor.forClass(Throwable.class);

//when
stream.openAsync(createOperationContext(100), handler);

//then
verify(handler).failed(failureCaptor.capture());
MongoSocketOpenException actualException = assertInstanceOf(MongoSocketOpenException.class, failureCaptor.getValue());
assertSame(exception, actualException.getCause());
verify(handler, times(0)).completed(null);
verify(socketChannel).close();
}
}

@ParameterizedTest
@ValueSource(ints = {500, 1000, 2000})
void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException {
Expand Down