diff --git a/customerio/client_base.py b/customerio/client_base.py index 3370665..0087faf 100644 --- a/customerio/client_base.py +++ b/customerio/client_base.py @@ -24,6 +24,19 @@ def __init__(self, retries=3, timeout=10, backoff_factor=0.02, use_connection_po self.use_connection_pooling = use_connection_pooling self._current_session = None + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + def close(self): + if self._current_session is not None: + try: + self._current_session.close() + finally: + self._current_session = None + @property def http(self): if self._current_session is None: diff --git a/tests/test_client_base.py b/tests/test_client_base.py index efaf624..7031519 100644 --- a/tests/test_client_base.py +++ b/tests/test_client_base.py @@ -134,6 +134,46 @@ def build_session(resp=response): result = client.send_request("POST", "https://example.com", {}) self.assertEqual(result.status_code, status) + def test_context_manager_closes_session(self): + client = ClientBase(use_connection_pooling=True) + session = FakeSession() + client._build_session = lambda: session + + with client: + client.send_request("POST", "https://example.com", {}) + self.assertFalse(session.closed) + + self.assertTrue(session.closed) + self.assertIsNone(client._current_session) + + def test_close_without_session(self): + client = ClientBase() + client.close() + self.assertIsNone(client._current_session) + + def test_close_resets_session(self): + client = ClientBase(use_connection_pooling=True) + session = FakeSession() + client._build_session = lambda: session + + client.send_request("POST", "https://example.com", {}) + self.assertIsNotNone(client._current_session) + + client.close() + self.assertTrue(session.closed) + self.assertIsNone(client._current_session) + + def test_close_resets_session_even_on_error(self): + client = ClientBase(use_connection_pooling=True) + session = FakeSession() + session.close = lambda: (_ for _ in ()).throw(RuntimeError("close failed")) + client._current_session = session + + with self.assertRaises(RuntimeError): + client.close() + + self.assertIsNone(client._current_session) + if __name__ == "__main__": unittest.main()