diff --git a/customerio/api.py b/customerio/api.py index 538bf8a..5eb637c 100644 --- a/customerio/api.py +++ b/customerio/api.py @@ -3,7 +3,6 @@ """ import base64 -import json from .client_base import ClientBase, CustomerIOException from .regions import Region, Regions @@ -95,31 +94,31 @@ def send_email(self, request): if isinstance(request, SendEmailRequest): request = request._to_dict() resp = self.send_request("POST", self.url + "/v1/send/email", request) - return json.loads(resp) + return resp.json() def send_push(self, request): if isinstance(request, SendPushRequest): request = request._to_dict() resp = self.send_request("POST", self.url + "/v1/send/push", request) - return json.loads(resp) + return resp.json() def send_sms(self, request): if isinstance(request, SendSMSRequest): request = request._to_dict() resp = self.send_request("POST", self.url + "/v1/send/sms", request) - return json.loads(resp) + return resp.json() def send_inbox_message(self, request): if isinstance(request, SendInboxMessageRequest): request = request._to_dict() resp = self.send_request("POST", self.url + "/v1/send/inbox_message", request) - return json.loads(resp) + return resp.json() def send_in_app(self, request): if isinstance(request, SendInAppRequest): request = request._to_dict() resp = self.send_request("POST", self.url + "/v1/send/in_app", request) - return json.loads(resp) + return resp.json() def _build_session(self): session = super()._build_session() diff --git a/customerio/client_base.py b/customerio/client_base.py index df5dd03..3370665 100644 --- a/customerio/client_base.py +++ b/customerio/client_base.py @@ -54,7 +54,7 @@ def send_request(self, method, url, data): result_status = response.status_code if result_status < 200 or result_status >= 300: raise CustomerIOException(f"{result_status}: {url} {data} {response.text}") - return response.text + return response except CustomerIOException: raise diff --git a/customerio/track.py b/customerio/track.py index 3e3ffa9..c183484 100644 --- a/customerio/track.py +++ b/customerio/track.py @@ -89,7 +89,7 @@ def identify(self, id, **kwargs): if not id: raise CustomerIOException("id cannot be blank in identify") url = self.get_customer_query_string(id) - self.send_request("PUT", url, kwargs) + return self.send_request("PUT", url, kwargs) def track(self, customer_id, name, **data): """Track an event for a given customer_id.""" @@ -100,7 +100,7 @@ def track(self, customer_id, name, **data): "name": name, "data": self._sanitize(data), } - self.send_request("POST", url, post_data) + return self.send_request("POST", url, post_data) def track_anonymous(self, anonymous_id, name, **data): """Track an event for a given anonymous_id.""" @@ -112,7 +112,7 @@ def track_anonymous(self, anonymous_id, name, **data): if anonymous_id: post_data["anonymous_id"] = anonymous_id - self.send_request("POST", url, post_data) + return self.send_request("POST", url, post_data) def pageview(self, customer_id, page, **data): """Track a pageview for a given customer_id.""" @@ -124,7 +124,7 @@ def pageview(self, customer_id, page, **data): "name": page, "data": self._sanitize(data), } - self.send_request("POST", url, post_data) + return self.send_request("POST", url, post_data) def backfill(self, customer_id, name, timestamp, **data): """Backfill an event (track with timestamp) for a given customer_id.""" @@ -147,7 +147,7 @@ def backfill(self, customer_id, name, timestamp, **data): "timestamp": timestamp, } - self.send_request("POST", url, post_data) + return self.send_request("POST", url, post_data) def delete(self, customer_id): """Delete a customer profile.""" @@ -155,7 +155,7 @@ def delete(self, customer_id): raise CustomerIOException("customer_id cannot be blank in delete") url = self.get_customer_query_string(customer_id) - self.send_request("DELETE", url, {}) + return self.send_request("DELETE", url, {}) def add_device(self, customer_id, device_id, platform, **data): """Add a device to a customer profile.""" @@ -176,7 +176,7 @@ def add_device(self, customer_id, device_id, platform, **data): ) payload = {"device": data} url = self.get_device_query_string(customer_id) - self.send_request("PUT", url, payload) + return self.send_request("PUT", url, payload) def delete_device(self, customer_id, device_id): """Delete a device from a customer profile.""" @@ -188,13 +188,13 @@ def delete_device(self, customer_id, device_id): url = self.get_device_query_string(customer_id) delete_url = f"{url}/{self._url_encode(device_id)}" - self.send_request("DELETE", delete_url, {}) + return self.send_request("DELETE", delete_url, {}) def suppress(self, customer_id): if not customer_id: raise CustomerIOException("customer_id cannot be blank in suppress") - self.send_request( + return self.send_request( "POST", f"{self.base_url}/customers/{self._url_encode(customer_id)}/suppress", {}, @@ -204,7 +204,7 @@ def unsuppress(self, customer_id): if not customer_id: raise CustomerIOException("customer_id cannot be blank in unsuppress") - self.send_request( + return self.send_request( "POST", f"{self.base_url}/customers/{self._url_encode(customer_id)}/unsuppress", {}, @@ -232,7 +232,7 @@ def merge_customers(self, primary_id_type, primary_id, secondary_id_type, second "primary": {primary_id_type: primary_id}, "secondary": {secondary_id_type: secondary_id}, } - self.send_request("POST", url, post_data) + return self.send_request("POST", url, post_data) def batch(self, operations): """Send multiple operations in a single request. @@ -247,7 +247,7 @@ def batch(self, operations): url = f"https://{self.host}/api/v2/batch" else: url = f"https://{self.host}:{self.port}/api/v2/batch" - self.send_request("POST", url, {"batch": operations}) + return self.send_request("POST", url, {"batch": operations}) def _build_session(self): session = super()._build_session() diff --git a/tests/test_client_base.py b/tests/test_client_base.py index fc67484..efaf624 100644 --- a/tests/test_client_base.py +++ b/tests/test_client_base.py @@ -43,8 +43,10 @@ def build_session(): client._build_session = build_session - self.assertEqual(client.send_request("POST", "https://example.com", {}), "ok") - self.assertEqual(client.send_request("POST", "https://example.com", {}), "ok") + resp1 = client.send_request("POST", "https://example.com", {}) + resp2 = client.send_request("POST", "https://example.com", {}) + self.assertEqual(resp1.status_code, 200) + self.assertEqual(resp2.status_code, 200) self.assertEqual(len(sessions), 1) self.assertFalse(sessions[0].closed) @@ -76,7 +78,8 @@ def send_request(): thread.join() self.assertEqual(errors, []) - self.assertEqual(sorted(responses), ["ok", "ok"]) + self.assertEqual(len(responses), 2) + self.assertTrue(all(r.status_code == 200 for r in responses)) self.assertEqual(len(sessions), 2) self.assertTrue(all(session.closed for session in sessions)) self.assertTrue(all(session.request_count == 1 for session in sessions)) @@ -129,7 +132,7 @@ def build_session(resp=response): client._build_session = build_session client._current_session = None result = client.send_request("POST", "https://example.com", {}) - self.assertEqual(result, "ok") + self.assertEqual(result.status_code, status) if __name__ == "__main__":