Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions customerio/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,52 @@
"""

import math
import socket
from datetime import datetime, timezone

from requests import Session
from requests.adapters import HTTPAdapter
from requests.adapters import DEFAULT_POOLBLOCK, HTTPAdapter
from urllib3.connection import HTTPConnection
from urllib3.util.retry import Retry

from .__version__ import __version__ as ClientVersion

TCP_KEEPALIVE_IDLE_TIMEOUT = 300
TCP_KEEPALIVE_INTERVAL = 60


class CustomerIOException(Exception):
pass


def _tcp_keepalive_socket_options():
tcp_protocol = getattr(socket, "SOL_TCP", socket.IPPROTO_TCP)
tcp_keepidle = getattr(socket, "TCP_KEEPIDLE", getattr(socket, "TCP_KEEPALIVE", None))

options = list(HTTPConnection.default_socket_options)
keepalive_options = [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)]
if tcp_keepidle is not None:
keepalive_options.append((tcp_protocol, tcp_keepidle, TCP_KEEPALIVE_IDLE_TIMEOUT))
if hasattr(socket, "TCP_KEEPINTVL"):
keepalive_options.append((tcp_protocol, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL))

for option in keepalive_options:
if option not in options:
options.append(option)

return options


class TCPKeepAliveHTTPAdapter(HTTPAdapter):
def init_poolmanager(self, connections, maxsize, block=DEFAULT_POOLBLOCK, **pool_kwargs):
pool_kwargs.setdefault("socket_options", _tcp_keepalive_socket_options())
super().init_poolmanager(connections, maxsize, block=block, **pool_kwargs)

def proxy_manager_for(self, proxy, **proxy_kwargs):
proxy_kwargs.setdefault("socket_options", _tcp_keepalive_socket_options())
return super().proxy_manager_for(proxy, **proxy_kwargs)


class ClientBase:
def __init__(self, retries=3, timeout=10, backoff_factor=0.02, use_connection_pooling=True):
self.timeout = timeout
Expand Down Expand Up @@ -95,7 +128,9 @@ def _build_session(self):

session.mount(
"https://",
HTTPAdapter(max_retries=Retry(total=self.retries, backoff_factor=self.backoff_factor)),
TCPKeepAliveHTTPAdapter(
max_retries=Retry(total=self.retries, backoff_factor=self.backoff_factor)
),
)

return session
23 changes: 23 additions & 0 deletions tests/test_customerio.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import json
import socket
import unittest
from datetime import datetime
from functools import partial

import urllib3
from requests.auth import _basic_auth_str
from urllib3.connection import HTTPConnection

from customerio import CustomerIO, CustomerIOException, Regions
from customerio.client_base import TCP_KEEPALIVE_IDLE_TIMEOUT, TCP_KEEPALIVE_INTERVAL
from customerio.constants import CIOID, EMAIL, ID
from tests.server import HTTPSTestCase

Expand Down Expand Up @@ -64,6 +67,26 @@ def test_client_setup(self):
with self.assertRaises(CustomerIOException):
CustomerIO(site_id="site_id", api_key="api_key", region="au")

def test_keepalive_socket_options_are_configured_on_adapter(self):
default_socket_options = list(HTTPConnection.default_socket_options)
client = CustomerIO(site_id="site_id", api_key="api_key")
socket_options = client.http.adapters["https://"].poolmanager.connection_pool_kw[
"socket_options"
]
tcp_protocol = getattr(socket, "SOL_TCP", socket.IPPROTO_TCP)
tcp_keepidle = getattr(socket, "TCP_KEEPIDLE", getattr(socket, "TCP_KEEPALIVE", None))

for option in default_socket_options:
self.assertIn(option, socket_options)
self.assertIn((socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1), socket_options)
if tcp_keepidle is not None:
self.assertIn((tcp_protocol, tcp_keepidle, TCP_KEEPALIVE_IDLE_TIMEOUT), socket_options)
if hasattr(socket, "TCP_KEEPINTVL"):
self.assertIn(
(tcp_protocol, socket.TCP_KEEPINTVL, TCP_KEEPALIVE_INTERVAL), socket_options
)
self.assertEqual(HTTPConnection.default_socket_options, default_socket_options)

def test_client_connection_handling(self):
retries = self.cio.retries
# should not raise exception as i should be less than retries and
Expand Down