diff --git a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java index b0fae1d044..e43a9865d7 100644 --- a/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java +++ b/driver-core/src/main/com/mongodb/internal/connection/TlsChannelStreamFactoryFactory.java @@ -37,6 +37,7 @@ import javax.net.ssl.SSLParameters; import java.io.Closeable; import java.io.IOException; +import java.net.InetSocketAddress; import java.net.StandardSocketOptions; import java.nio.ByteBuffer; import java.nio.channels.CompletionHandler; @@ -209,35 +210,60 @@ private static class TlsChannelStream extends AsynchronousChannelStream { @Override public void openAsync(final OperationContext operationContext, final AsyncCompletionHandler handler) { isTrue("unopened", getChannel() == null); + SocketChannel socketChannel = null; + SelectorMonitor.SocketRegistration socketRegistration = null; + boolean registered = false; try { - SocketChannel socketChannel = SocketChannel.open(); - socketChannel.configureBlocking(false); + //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception. + int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs(); + InetSocketAddress socketAddress = getSocketAddresses(getServerAddress(), inetAddressResolver).get(0); + SocketChannel openedSocketChannel = SocketChannel.open(); + socketChannel = openedSocketChannel; + openedSocketChannel.configureBlocking(false); - socketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true); - socketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true); + openedSocketChannel.setOption(StandardSocketOptions.TCP_NODELAY, true); + openedSocketChannel.setOption(StandardSocketOptions.SO_KEEPALIVE, true); if (getSettings().getReceiveBufferSize() > 0) { - socketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize()); + openedSocketChannel.setOption(StandardSocketOptions.SO_RCVBUF, getSettings().getReceiveBufferSize()); } if (getSettings().getSendBufferSize() > 0) { - socketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize()); + openedSocketChannel.setOption(StandardSocketOptions.SO_SNDBUF, getSettings().getSendBufferSize()); } - //getConnectTimeoutMs MUST be called before connection attempt, as it might throw MongoOperationTimeout exception. - int connectTimeoutMs = operationContext.getTimeoutContext().getConnectTimeoutMs(); - socketChannel.connect(getSocketAddresses(getServerAddress(), inetAddressResolver).get(0)); - SelectorMonitor.SocketRegistration socketRegistration = new SelectorMonitor.SocketRegistration( - socketChannel, () -> initializeTslChannel(handler, socketChannel)); + openedSocketChannel.connect(socketAddress); + socketRegistration = new SelectorMonitor.SocketRegistration( + openedSocketChannel, () -> initializeTslChannel(handler, openedSocketChannel)); if (connectTimeoutMs > 0) { scheduleTimeoutInterruption(handler, socketRegistration, connectTimeoutMs); } selectorMonitor.register(socketRegistration); + registered = true; } catch (IOException e) { + closeUnregisteredSocketChannel(socketChannel, socketRegistration, registered, e); handler.failed(new MongoSocketOpenException("Exception opening socket", getServerAddress(), e)); } catch (Throwable t) { + closeUnregisteredSocketChannel(socketChannel, socketRegistration, registered, t); handler.failed(t); } } + private void closeUnregisteredSocketChannel(@Nullable final SocketChannel socketChannel, + @Nullable final SelectorMonitor.SocketRegistration socketRegistration, + final boolean registered, final Throwable failure) { + if (!registered) { + if (socketRegistration != null) { + socketRegistration.tryCancelPendingConnection(); + } + if (socketChannel != null) { + try { + socketChannel.close(); + } catch (IOException e) { + failure.addSuppressed(e); + } + } + } + } + private void scheduleTimeoutInterruption(final AsyncCompletionHandler handler, final SelectorMonitor.SocketRegistration socketRegistration, final int connectTimeoutMs) { @@ -384,4 +410,3 @@ public void close() throws IOException { } } } - diff --git a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java index 3af1eaa33e..3e904cf5c4 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/connection/TlsChannelStreamFunctionalTest.java @@ -17,18 +17,22 @@ package com.mongodb.internal.connection; import com.mongodb.ClusterFixture; +import com.mongodb.MongoSocketException; import com.mongodb.MongoSocketOpenException; import com.mongodb.ServerAddress; +import com.mongodb.connection.AsyncCompletionHandler; import com.mongodb.connection.SocketSettings; import com.mongodb.connection.SslSettings; import com.mongodb.internal.TimeoutContext; import com.mongodb.internal.TimeoutSettings; +import com.mongodb.spi.dns.InetAddressResolver; import org.bson.ByteBuf; import org.bson.ByteBufNIO; import org.junit.jupiter.api.DisplayName; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -37,11 +41,13 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import java.io.IOException; +import java.net.InetAddress; import java.net.ServerSocket; import java.nio.ByteBuffer; import java.nio.channels.InterruptedByTimeoutException; import java.nio.channels.SocketChannel; import java.util.Collections; +import java.util.List; import java.util.concurrent.TimeUnit; import static com.mongodb.ClusterFixture.getPrimaryServerDescription; @@ -52,10 +58,12 @@ import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.atLeast; @@ -68,6 +76,69 @@ class TlsChannelStreamFunctionalTest { private static final String UNREACHABLE_PRIVATE_IP_ADDRESS = "10.255.255.1"; private static final int UNREACHABLE_PORT = 65333; + @Test + void shouldFailAsyncCompletionHandlerWithoutOpeningSocketChannelIfNameResolutionFails() { + //given + ServerAddress serverAddress = new ServerAddress(); + MongoSocketException exception = new MongoSocketException("Temporary failure in name resolution", serverAddress); + InetAddressResolver inetAddressResolver = new InetAddressResolver() { + @Override + public List lookupByName(final String host) { + throw exception; + } + }; + + try (StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(100, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + Stream stream = streamFactory.create(serverAddress); + @SuppressWarnings("unchecked") + AsyncCompletionHandler handler = Mockito.mock(AsyncCompletionHandler.class); + + //when + stream.openAsync(createOperationContext(100), handler); + + //then + verify(handler).failed(exception); + verify(handler, times(0)).completed(null); + socketChannelMockedStatic.verify(SocketChannel::open, times(0)); + } + } + + @Test + void shouldCloseSocketChannelIfConnectFailsBeforeRegistration() throws IOException { + //given + ServerAddress serverAddress = new ServerAddress(); + IOException exception = new IOException("connect failed"); + InetAddressResolver inetAddressResolver = host -> Collections.singletonList(InetAddress.getLoopbackAddress()); + + try (SocketChannel socketChannel = Mockito.spy(SocketChannel.open()); + StreamFactoryFactory streamFactoryFactory = new TlsChannelStreamFactoryFactory(inetAddressResolver); + MockedStatic socketChannelMockedStatic = Mockito.mockStatic(SocketChannel.class)) { + socketChannelMockedStatic.when(SocketChannel::open).thenReturn(socketChannel); + Mockito.doThrow(exception).when(socketChannel).connect(any()); + StreamFactory streamFactory = streamFactoryFactory.create(SocketSettings.builder() + .connectTimeout(100, TimeUnit.MILLISECONDS) + .build(), SSL_SETTINGS); + Stream stream = streamFactory.create(serverAddress); + @SuppressWarnings("unchecked") + AsyncCompletionHandler handler = Mockito.mock(AsyncCompletionHandler.class); + ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(Throwable.class); + + //when + stream.openAsync(createOperationContext(100), handler); + + //then + verify(handler).failed(failureCaptor.capture()); + MongoSocketOpenException actualException = assertInstanceOf(MongoSocketOpenException.class, failureCaptor.getValue()); + assertSame(exception, actualException.getCause()); + verify(handler, times(0)).completed(null); + verify(socketChannel).close(); + } + } + @ParameterizedTest @ValueSource(ints = {500, 1000, 2000}) void shouldInterruptConnectionEstablishmentWhenConnectionTimeoutExpires(final int connectTimeoutMs) throws IOException {