Skip to content
72 changes: 42 additions & 30 deletions pymongo/pool_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
sock.setblocking(False)
await asyncio.get_running_loop().sock_connect(sock, host)
return sock
except OSError:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
Comment thread
NoahStapp marked this conversation as resolved.

Expand Down Expand Up @@ -238,6 +239,10 @@ async def _async_create_connection(address: _Address, options: PoolOptions) -> s
except OSError as e:
sock.close()
err = e # type: ignore[assignment]
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise

if err is not None:
raise err
Expand Down Expand Up @@ -289,19 +294,25 @@ async def _async_configured_socket(
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
except BaseException:
# Protect against cancellation or interruption where the raw socket would otherwise leak
sock.close()
raise
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(ssl_sock.getpeercert(), hostname=host) # type:ignore[attr-defined, unused-ignore]
except _CertificateError:
ssl_sock.close()
raise

ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
ssl_sock.settimeout(options.socket_timeout)
return ssl_sock
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the raw socket would otherwise leak.
ssl_sock.close()
raise


async def _configured_protocol_interface(
Expand Down Expand Up @@ -362,26 +373,27 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
# mismatch, will be turned into ServerSelectionTimeoutErrors later.
details = _get_timeout_details(options)
_raise_connection_failure(address, exc, "SSL handshake failed: ", timeout_details=details)

if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
try:
try:
if (
ssl_context.verify_mode
and not ssl_context.check_hostname
and not options.tls_allow_invalid_hostnames
):
ssl.match_hostname(transport.get_extra_info("peercert"), hostname=host) # type:ignore[attr-defined,unused-ignore]
except _CertificateError:
transport.abort()
raise

if ssl_session_cache is not None:
ssl_obj = transport.get_extra_info("ssl_object")
if ssl_obj is not None:
new_session = ssl_obj.session
if new_session is not None:
ssl_session_cache[0] = new_session

return AsyncNetworkingInterface((transport, protocol))
if ssl_session_cache is not None:
ssl_obj = transport.get_extra_info("ssl_object")
if ssl_obj is not None:
new_session = ssl_obj.session
if new_session is not None:
ssl_session_cache[0] = new_session

return AsyncNetworkingInterface((transport, protocol))
except BaseException:
# Protect against cancellation, _CertificateError, or interruption
# where the transport would otherwise leak.
transport.abort()
raise


def _create_connection(address: _Address, options: PoolOptions) -> socket.socket:
Expand Down
102 changes: 102 additions & 0 deletions test/asynchronous/test_async_cancellation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@
from __future__ import annotations

import asyncio
import functools
import socket
import ssl
import sys
from unittest.mock import patch

from test.asynchronous.utils import async_get_pool
from test.utils_shared import delay, one

sys.path[0:0] = [""]

from pymongo import pool_shared
from test.asynchronous import AsyncIntegrationTest, async_client_context, connected


Expand Down Expand Up @@ -129,3 +134,100 @@ async def task():
await task

self.assertTrue(change_stream._closed)

async def test_cancellation_closes_socket_during_create_connection(self):
address = (await async_client_context.host, await async_client_context.port)
options = (await async_get_pool(self.client)).opts

created_sockets: list[socket.socket] = []
real_socket_cls = socket.socket
target_task = None

def tracking_socket(*args, **kwargs):
s = real_socket_cls(*args, **kwargs)
if asyncio.current_task() is target_task:
created_sockets.append(s)
return s

loop = asyncio.get_running_loop()
real_sock_connect = loop.sock_connect
started = asyncio.Event()
block_forever = asyncio.Event()

async def slow_sock_connect(sock, addr):
if sock in created_sockets:
started.set()
await block_forever.wait()
return None
return await real_sock_connect(sock, addr)

with (
patch.object(socket, "socket", tracking_socket),
patch.object(loop, "sock_connect", slow_sock_connect),
):
task = asyncio.create_task(pool_shared._async_create_connection(address, options))
target_task = task
await asyncio.wait_for(started.wait(), timeout=5)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task
self.assertTrue(created_sockets, "expected at least one socket to be created")
for sock in created_sockets:
self.assertEqual(
sock.fileno(),
-1,
f"socket leaked across cancellation: {sock!r}",
)

async def test_cancellation_closes_socket_during_ssl_wrap_socket(self):
address = (await async_client_context.host, await async_client_context.port)
options = (await async_get_pool(self.client)).opts
fake_ssl_context = ssl.create_default_context()

created_sockets: list[socket.socket] = []
real_socket_cls = socket.socket
target_task = None

def tracking_socket(*args, **kwargs):
s = real_socket_cls(*args, **kwargs)
if asyncio.current_task() is target_task:
created_sockets.append(s)
return s

loop = asyncio.get_running_loop()
real_run_in_executor = loop.run_in_executor
started = asyncio.Event()

def slow_run_in_executor(executor, func, *args):
# Need to unwrap the SNI branch here if present
inner = func.func if isinstance(func, functools.partial) else func
# Each `ctx.wrap_socket` access returns a fresh bound-method
# object, so we check the bound instance (__self__) instead
if (
getattr(inner, "__self__", None) is fake_ssl_context
and asyncio.current_task() is target_task
):
started.set()
# Return a future that never completes for cancellation.
return asyncio.get_running_loop().create_future()
return real_run_in_executor(executor, func, *args)

with (
patch.object(socket, "socket", tracking_socket),
patch.object(loop, "run_in_executor", slow_run_in_executor),
patch.object(options, "_PoolOptions__ssl_context", fake_ssl_context),
):
task = asyncio.create_task(pool_shared._async_configured_socket(address, options))
target_task = task
await asyncio.wait_for(started.wait(), timeout=5)
task.cancel()
with self.assertRaises(asyncio.CancelledError):
await task

self.assertTrue(created_sockets, "expected at least one socket to be created")
for sock in created_sockets:
self.assertEqual(
sock.fileno(),
-1,
f"socket leaked across cancellation: {sock!r}",
)
Loading