From 336c850e6129c50abcf97c128c7f3a2a04de3fdd Mon Sep 17 00:00:00 2001 From: pauldambra Date: Wed, 7 Jan 2026 21:47:43 +0000 Subject: [PATCH 1/3] chore: add a test to describe upload behaviour when there are errors --- posthog/test/test_consumer.py | 37 +++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/posthog/test/test_consumer.py b/posthog/test/test_consumer.py index 606cdaad..b3f9c557 100644 --- a/posthog/test/test_consumer.py +++ b/posthog/test/test_consumer.py @@ -194,3 +194,40 @@ def mock_post_fn(_, data, **kwargs): q.put(track) q.join() self.assertEqual(mock_post.call_count, 2) + + def test_upload_exception_calls_on_error_and_does_not_raise(self): + on_error_called = [] + + def on_error(e, batch): + on_error_called.append((e, batch)) + + q = Queue() + consumer = Consumer(q, TEST_API_KEY, on_error=on_error) + track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"} + q.put(track) + + with mock.patch.object( + consumer, "request", side_effect=Exception("request failed") + ): + result = consumer.upload() + + self.assertFalse(result) + self.assertEqual(len(on_error_called), 1) + self.assertIsInstance(on_error_called[0][0], Exception) + self.assertEqual(str(on_error_called[0][0]), "request failed") + + def test_upload_exception_in_on_error_does_not_raise(self): + def on_error(e, batch): + raise Exception("on_error failed") + + q = Queue() + consumer = Consumer(q, TEST_API_KEY, on_error=on_error) + track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"} + q.put(track) + + with mock.patch.object( + consumer, "request", side_effect=Exception("request failed") + ): + result = consumer.upload() + + self.assertFalse(result) From 4e12dcffd51de376930c767e56e36fc39ad79aab Mon Sep 17 00:00:00 2001 From: pauldambra Date: Wed, 7 Jan 2026 22:04:22 +0000 Subject: [PATCH 2/3] refactor the test file --- posthog/test/test_consumer.py | 145 +++++++++++++--------------------- 1 file changed, 54 insertions(+), 91 deletions(-) diff --git a/posthog/test/test_consumer.py b/posthog/test/test_consumer.py index b3f9c557..1f5fc045 100644 --- a/posthog/test/test_consumer.py +++ b/posthog/test/test_consumer.py @@ -3,6 +3,7 @@ import unittest import mock +from parameterized import parameterized try: from queue import Queue @@ -14,6 +15,10 @@ from posthog.test.test_utils import TEST_API_KEY +def _track_event(event_name="python event"): + return {"type": "track", "event": event_name, "distinct_id": "distinct_id"} + + class TestConsumer(unittest.TestCase): def test_next(self): q = Queue() @@ -43,8 +48,7 @@ def test_dropping_oversize_msg(self): def test_upload(self): q = Queue() consumer = Consumer(q, TEST_API_KEY) - track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"} - q.put(track) + q.put(_track_event()) success = consumer.upload() self.assertTrue(success) @@ -57,13 +61,8 @@ def test_flush_interval(self): consumer = Consumer(q, TEST_API_KEY, flush_at=10, flush_interval=flush_interval) with mock.patch("posthog.consumer.batch_post") as mock_post: consumer.start() - for i in range(0, 3): - track = { - "type": "track", - "event": "python event %d" % i, - "distinct_id": "distinct_id", - } - q.put(track) + for i in range(3): + q.put(_track_event("python event %d" % i)) time.sleep(flush_interval * 1.1) self.assertEqual(mock_post.call_count, 3) @@ -78,81 +77,51 @@ def test_multiple_uploads_per_interval(self): ) with mock.patch("posthog.consumer.batch_post") as mock_post: consumer.start() - for i in range(0, flush_at * 2): - track = { - "type": "track", - "event": "python event %d" % i, - "distinct_id": "distinct_id", - } - q.put(track) + for i in range(flush_at * 2): + q.put(_track_event("python event %d" % i)) time.sleep(flush_interval * 1.1) self.assertEqual(mock_post.call_count, 2) def test_request(self): consumer = Consumer(None, TEST_API_KEY) - track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"} - consumer.request([track]) + consumer.request([_track_event()]) - def _test_request_retry(self, consumer, expected_exception, exception_count): - def mock_post(*args, **kwargs): - mock_post.call_count += 1 - if mock_post.call_count <= exception_count: - raise expected_exception + def _run_retry_test(self, exception, exception_count, retries=10): + call_count = [0] - mock_post.call_count = 0 + def mock_post(*args, **kwargs): + call_count[0] += 1 + if call_count[0] <= exception_count: + raise exception + consumer = Consumer(None, TEST_API_KEY, retries=retries) with mock.patch( "posthog.consumer.batch_post", mock.Mock(side_effect=mock_post) ): - track = { - "type": "track", - "event": "python event", - "distinct_id": "distinct_id", - } - # request() should succeed if the number of exceptions raised is - # less than the retries paramater. - if exception_count <= consumer.retries: - consumer.request([track]) + if exception_count <= retries: + consumer.request([_track_event()]) else: - # if exceptions are raised more times than the retries - # parameter, we expect the exception to be returned to - # the caller. - try: - consumer.request([track]) - except type(expected_exception) as exc: - self.assertEqual(exc, expected_exception) - else: - self.fail( - "request() should raise an exception if still failing after %d retries" - % consumer.retries - ) - - def test_request_retry(self): - # we should retry on general errors - consumer = Consumer(None, TEST_API_KEY) - self._test_request_retry(consumer, Exception("generic exception"), 2) - - # we should retry on server errors - consumer = Consumer(None, TEST_API_KEY) - self._test_request_retry(consumer, APIError(500, "Internal Server Error"), 2) - - # we should retry on HTTP 429 errors - consumer = Consumer(None, TEST_API_KEY) - self._test_request_retry(consumer, APIError(429, "Too Many Requests"), 2) - - # we should NOT retry on other client errors - consumer = Consumer(None, TEST_API_KEY) - api_error = APIError(400, "Client Errors") - try: - self._test_request_retry(consumer, api_error, 1) - except APIError: - pass - else: - self.fail("request() should not retry on client errors") - - # test for number of exceptions raise > retries value - consumer = Consumer(None, TEST_API_KEY, retries=3) - self._test_request_retry(consumer, APIError(500, "Internal Server Error"), 3) + with self.assertRaises(type(exception)): + consumer.request([_track_event()]) + + @parameterized.expand( + [ + ("general_errors", Exception("generic exception"), 2), + ("server_errors", APIError(500, "Internal Server Error"), 2), + ("rate_limit_errors", APIError(429, "Too Many Requests"), 2), + ] + ) + def test_request_retries_on_retriable_errors( + self, _name, exception, exception_count + ): + self._run_retry_test(exception, exception_count) + + def test_request_does_not_retry_client_errors(self): + with self.assertRaises(APIError): + self._run_retry_test(APIError(400, "Client Errors"), 1) + + def test_request_fails_when_exceptions_exceed_retries(self): + self._run_retry_test(APIError(500, "Internal Server Error"), 4, retries=3) def test_pause(self): consumer = Consumer(None, TEST_API_KEY) @@ -195,15 +164,25 @@ def mock_post_fn(_, data, **kwargs): q.join() self.assertEqual(mock_post.call_count, 2) - def test_upload_exception_calls_on_error_and_does_not_raise(self): + @parameterized.expand( + [ + ("on_error_succeeds", False), + ("on_error_raises", True), + ] + ) + def test_upload_exception_calls_on_error_and_does_not_raise( + self, _name, on_error_raises + ): on_error_called = [] def on_error(e, batch): on_error_called.append((e, batch)) + if on_error_raises: + raise Exception("on_error failed") q = Queue() consumer = Consumer(q, TEST_API_KEY, on_error=on_error) - track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"} + track = _track_event() q.put(track) with mock.patch.object( @@ -213,21 +192,5 @@ def on_error(e, batch): self.assertFalse(result) self.assertEqual(len(on_error_called), 1) - self.assertIsInstance(on_error_called[0][0], Exception) self.assertEqual(str(on_error_called[0][0]), "request failed") - - def test_upload_exception_in_on_error_does_not_raise(self): - def on_error(e, batch): - raise Exception("on_error failed") - - q = Queue() - consumer = Consumer(q, TEST_API_KEY, on_error=on_error) - track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"} - q.put(track) - - with mock.patch.object( - consumer, "request", side_effect=Exception("request failed") - ): - result = consumer.upload() - - self.assertFalse(result) + self.assertEqual(on_error_called[0][1], [track]) From 748013c94b37f2b9e2a6d5edf90410bb3c640f2a Mon Sep 17 00:00:00 2001 From: pauldambra Date: Wed, 7 Jan 2026 22:12:48 +0000 Subject: [PATCH 3/3] add typehints to the test file --- posthog/test/test_consumer.py | 45 +++++++++++++++++++---------------- 1 file changed, 24 insertions(+), 21 deletions(-) diff --git a/posthog/test/test_consumer.py b/posthog/test/test_consumer.py index 1f5fc045..28d51221 100644 --- a/posthog/test/test_consumer.py +++ b/posthog/test/test_consumer.py @@ -1,6 +1,7 @@ import json import time import unittest +from typing import Any import mock from parameterized import parameterized @@ -15,19 +16,19 @@ from posthog.test.test_utils import TEST_API_KEY -def _track_event(event_name="python event"): +def _track_event(event_name: str = "python event") -> dict[str, str]: return {"type": "track", "event": event_name, "distinct_id": "distinct_id"} class TestConsumer(unittest.TestCase): - def test_next(self): + def test_next(self) -> None: q = Queue() consumer = Consumer(q, "") q.put(1) next = consumer.next() self.assertEqual(next, [1]) - def test_next_limit(self): + def test_next_limit(self) -> None: q = Queue() flush_at = 50 consumer = Consumer(q, "", flush_at) @@ -36,7 +37,7 @@ def test_next_limit(self): next = consumer.next() self.assertEqual(next, list(range(flush_at))) - def test_dropping_oversize_msg(self): + def test_dropping_oversize_msg(self) -> None: q = Queue() consumer = Consumer(q, "") oversize_msg = {"m": "x" * MAX_MSG_SIZE} @@ -45,14 +46,14 @@ def test_dropping_oversize_msg(self): self.assertEqual(next, []) self.assertTrue(q.empty()) - def test_upload(self): + def test_upload(self) -> None: q = Queue() consumer = Consumer(q, TEST_API_KEY) q.put(_track_event()) success = consumer.upload() self.assertTrue(success) - def test_flush_interval(self): + def test_flush_interval(self) -> None: # Put _n_ items in the queue, pausing a little bit more than # _flush_interval_ after each one. # The consumer should upload _n_ times. @@ -66,7 +67,7 @@ def test_flush_interval(self): time.sleep(flush_interval * 1.1) self.assertEqual(mock_post.call_count, 3) - def test_multiple_uploads_per_interval(self): + def test_multiple_uploads_per_interval(self) -> None: # Put _flush_at*2_ items in the queue at once, then pause for # _flush_interval_. The consumer should upload 2 times. q = Queue() @@ -82,14 +83,16 @@ def test_multiple_uploads_per_interval(self): time.sleep(flush_interval * 1.1) self.assertEqual(mock_post.call_count, 2) - def test_request(self): + def test_request(self) -> None: consumer = Consumer(None, TEST_API_KEY) consumer.request([_track_event()]) - def _run_retry_test(self, exception, exception_count, retries=10): + def _run_retry_test( + self, exception: Exception, exception_count: int, retries: int = 10 + ) -> None: call_count = [0] - def mock_post(*args, **kwargs): + def mock_post(*args: Any, **kwargs: Any) -> None: call_count[0] += 1 if call_count[0] <= exception_count: raise exception @@ -112,23 +115,23 @@ def mock_post(*args, **kwargs): ] ) def test_request_retries_on_retriable_errors( - self, _name, exception, exception_count - ): + self, _name: str, exception: Exception, exception_count: int + ) -> None: self._run_retry_test(exception, exception_count) - def test_request_does_not_retry_client_errors(self): + def test_request_does_not_retry_client_errors(self) -> None: with self.assertRaises(APIError): self._run_retry_test(APIError(400, "Client Errors"), 1) - def test_request_fails_when_exceptions_exceed_retries(self): + def test_request_fails_when_exceptions_exceed_retries(self) -> None: self._run_retry_test(APIError(500, "Internal Server Error"), 4, retries=3) - def test_pause(self): + def test_pause(self) -> None: consumer = Consumer(None, TEST_API_KEY) consumer.pause() self.assertFalse(consumer.running) - def test_max_batch_size(self): + def test_max_batch_size(self) -> None: q = Queue() consumer = Consumer(q, TEST_API_KEY, flush_at=100000, flush_interval=3) properties = {} @@ -144,7 +147,7 @@ def test_max_batch_size(self): # Let's capture 8MB of data to trigger two batches n_msgs = int(8_000_000 / msg_size) - def mock_post_fn(_, data, **kwargs): + def mock_post_fn(_: str, data: str, **kwargs: Any) -> mock.Mock: res = mock.Mock() res.status_code = 200 request_size = len(data.encode()) @@ -171,11 +174,11 @@ def mock_post_fn(_, data, **kwargs): ] ) def test_upload_exception_calls_on_error_and_does_not_raise( - self, _name, on_error_raises - ): - on_error_called = [] + self, _name: str, on_error_raises: bool + ) -> None: + on_error_called: list[tuple[Exception, list[dict[str, str]]]] = [] - def on_error(e, batch): + def on_error(e: Exception, batch: list[dict[str, str]]) -> None: on_error_called.append((e, batch)) if on_error_raises: raise Exception("on_error failed")