diff --git a/customerio/client_base.py b/customerio/client_base.py index 6e95bad..6b33e95 100644 --- a/customerio/client_base.py +++ b/customerio/client_base.py @@ -3,19 +3,52 @@ """ import math +import socket from datetime import datetime, timezone from requests import Session -from requests.adapters import HTTPAdapter +from requests.adapters import DEFAULT_POOLBLOCK, HTTPAdapter +from urllib3.connection import HTTPConnection from urllib3.util.retry import Retry from .__version__ import __version__ as ClientVersion +TCP_KEEPALIVE_IDLE_TIMEOUT = 300 +TCP_KEEPALIVE_INTERVAL = 60 + class CustomerIOException(Exception): pass +def _tcp_keepalive_socket_options(): + tcp_protocol = getattr(socket, "SOL_TCP", socket.IPPROTO_TCP) + tcp_keepidle = getattr(socket, "TCP_KEEPIDLE", getattr(socket, "TCP_KEEPALIVE", None)) + + options = list(HTTPConnection.default_socket_options) + keepalive_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)] + if tcp_keepidle is not None: + keepalive_options.append((tcp_protocol, tcp_keepidle, TCP_KEEPALIVE_IDLE_TIMEOUT)) + if hasattr(socket, "TCP_KEEPINTVL"): + keepalive_options.append((tcp_protocol, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL)) + + for option in keepalive_options: + if option not in options: + options.append(option) + + return options + + +class TCPKeepAliveHTTPAdapter(HTTPAdapter): + def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs): + pool_kwargs.setdefault("socket_options", _tcp_keepalive_socket_options()) + super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs) + + def proxy_manager_for(self, proxy, **proxy_kwargs): + proxy_kwargs.setdefault("socket_options", _tcp_keepalive_socket_options()) + return super().proxy_manager_for(proxy, **proxy_kwargs) + + class ClientBase: def __init__(self, retries=3, timeout=10, backoff_factor=0.02, use_connection_pooling=True): self.timeout = timeout @@ -95,7 +128,9 @@ def _build_session(self): session.mount( "https://", - HTTPAdapter(max_retries=Retry(total=self.retries, backoff_factor=self.backoff_factor)), + TCPKeepAliveHTTPAdapter( + max_retries=Retry(total=self.retries, backoff_factor=self.backoff_factor) + ), ) return session diff --git a/tests/test_customerio.py b/tests/test_customerio.py index 37dd4b0..5ae2d65 100644 --- a/tests/test_customerio.py +++ b/tests/test_customerio.py @@ -1,12 +1,15 @@ import json +import socket import unittest from datetime import datetime from functools import partial import urllib3 from requests.auth import _basic_auth_str +from urllib3.connection import HTTPConnection from customerio import CustomerIO, CustomerIOException, Regions +from customerio.client_base import TCP_KEEPALIVE_IDLE_TIMEOUT, TCP_KEEPALIVE_INTERVAL from customerio.constants import CIOID, EMAIL, ID from tests.server import HTTPSTestCase @@ -64,6 +67,26 @@ def test_client_setup(self): with self.assertRaises(CustomerIOException): CustomerIO(site_id="site_id", api_key="api_key", region="au") + def test_keepalive_socket_options_are_configured_on_adapter(self): + default_socket_options = list(HTTPConnection.default_socket_options) + client = CustomerIO(site_id="site_id", api_key="api_key") + socket_options = client.http.adapters["https://"].poolmanager.connection_pool_kw[ + "socket_options" + ] + tcp_protocol = getattr(socket, "SOL_TCP", socket.IPPROTO_TCP) + tcp_keepidle = getattr(socket, "TCP_KEEPIDLE", getattr(socket, "TCP_KEEPALIVE", None)) + + for option in default_socket_options: + self.assertIn(option, socket_options) + self.assertIn((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), socket_options) + if tcp_keepidle is not None: + self.assertIn((tcp_protocol, tcp_keepidle, TCP_KEEPALIVE_IDLE_TIMEOUT), socket_options) + if hasattr(socket, "TCP_KEEPINTVL"): + self.assertIn( + (tcp_protocol, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL), socket_options + ) + self.assertEqual(HTTPConnection.default_socket_options, default_socket_options) + def test_client_connection_handling(self): retries = self.cio.retries # should not raise exception as i should be less than retries and