Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions customerio/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,19 @@ def __init__(self, retries=3, timeout=10, backoff_factor=0.02, use_connection_po
self.use_connection_pooling = use_connection_pooling
self._current_session = None

def __enter__(self):
return self

def __exit__(self, *args):
self.close()

def close(self):
if self._current_session is not None:
try:
self._current_session.close()
finally:
self._current_session = None

@property
def http(self):
if self._current_session is None:
Expand Down
40 changes: 40 additions & 0 deletions tests/test_client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,46 @@ def build_session(resp=response):
result = client.send_request("POST", "https://example.com", {})
self.assertEqual(result.status_code, status)

def test_context_manager_closes_session(self):
client = ClientBase(use_connection_pooling=True)
session = FakeSession()
client._build_session = lambda: session

with client:
client.send_request("POST", "https://example.com", {})
self.assertFalse(session.closed)

self.assertTrue(session.closed)
self.assertIsNone(client._current_session)

def test_close_without_session(self):
client = ClientBase()
client.close()
self.assertIsNone(client._current_session)

def test_close_resets_session(self):
client = ClientBase(use_connection_pooling=True)
session = FakeSession()
client._build_session = lambda: session

client.send_request("POST", "https://example.com", {})
self.assertIsNotNone(client._current_session)

client.close()
self.assertTrue(session.closed)
self.assertIsNone(client._current_session)

def test_close_resets_session_even_on_error(self):
client = ClientBase(use_connection_pooling=True)
session = FakeSession()
session.close = lambda: (_ for _ in ()).throw(RuntimeError("close failed"))
client._current_session = session

with self.assertRaises(RuntimeError):
client.close()

self.assertIsNone(client._current_session)


if __name__ == "__main__":
unittest.main()