diff --git a/customerio/client_base.py b/customerio/client_base.py index 6e95bad..df5dd03 100644 --- a/customerio/client_base.py +++ b/customerio/client_base.py @@ -52,17 +52,18 @@ def send_request(self, method, url, data): ) result_status = response.status_code - if result_status != 200: + if result_status < 200 or result_status >= 300: raise CustomerIOException(f"{result_status}: {url} {data} {response.text}") return response.text + except CustomerIOException: + raise except Exception as e: - # Raise exception alerting user that the system might be - # experiencing an outage and refer them to system status page. - message = f"""Failed to receive valid response after {self.retries} retries. -Check system status at http://status.customer.io. -Last caught exception -- {type(e)}: {e} - """ + message = ( + f"Failed to receive valid response after {self.retries} retries.\n" + f"Check system status at http://status.customer.io.\n" + f"Last caught exception -- {type(e)}: {e}" + ) raise CustomerIOException(message) from e def _sanitize(self, data): @@ -93,9 +94,12 @@ def _build_session(self): session = Session() session.headers["User-Agent"] = f"Customer.io Python Client/{ClientVersion}" - session.mount( - "https://", - HTTPAdapter(max_retries=Retry(total=self.retries, backoff_factor=self.backoff_factor)), + retry = Retry( + total=self.retries, + backoff_factor=self.backoff_factor, + allowed_methods=None, + status_forcelist=[500, 502, 503, 504], ) + session.mount("https://", HTTPAdapter(max_retries=retry)) return session diff --git a/tests/test_client_base.py b/tests/test_client_base.py index c6cef74..fc67484 100644 --- a/tests/test_client_base.py +++ b/tests/test_client_base.py @@ -1,7 +1,7 @@ import threading import unittest -from customerio.client_base import ClientBase +from customerio.client_base import ClientBase, CustomerIOException class FakeResponse: @@ -82,6 +82,55 @@ def send_request(): self.assertTrue(all(session.request_count == 1 for session in sessions)) self.assertIsNone(client._current_session) + def test_retry_config_allows_post(self): + client = ClientBase(retries=5, backoff_factor=0.1) + session = client._build_session() + adapter = session.get_adapter("https://example.com") + retry = adapter.max_retries + + self.assertEqual(retry.total, 5) + self.assertEqual(retry.backoff_factor, 0.1) + self.assertIsNone(retry.allowed_methods) + self.assertEqual(set(retry.status_forcelist), {500, 502, 503, 504}) + + def test_non_200_raises_without_retry_wrapper(self): + client = ClientBase() + + error_response = FakeResponse() + error_response.status_code = 400 + error_response.text = "bad request" + + def build_session(): + session = FakeSession() + session.request = lambda *a, **kw: error_response + return session + + client._build_session = build_session + + with self.assertRaises(CustomerIOException) as ctx: + client.send_request("POST", "https://example.com", {}) + + self.assertIn("400", str(ctx.exception)) + self.assertNotIn("retries", str(ctx.exception)) + + def test_2xx_status_codes_accepted(self): + client = ClientBase() + + for status in [200, 201, 202, 204]: + response = FakeResponse() + response.status_code = status + response.text = "ok" + + def build_session(resp=response): + session = FakeSession() + session.request = lambda *a, **kw: resp + return session + + client._build_session = build_session + client._current_session = None + result = client.send_request("POST", "https://example.com", {}) + self.assertEqual(result, "ok") + if __name__ == "__main__": unittest.main()