diff --git a/examples/redis_flag_cache.py b/examples/redis_flag_cache.py new file mode 100644 index 00000000..aa4437fd --- /dev/null +++ b/examples/redis_flag_cache.py @@ -0,0 +1,144 @@ +""" +Redis-based distributed cache for PostHog feature flag definitions. + +This example demonstrates how to implement a FlagDefinitionCacheProvider +using Redis for multi-instance deployments (leader election pattern). + +Usage: + import redis + from posthog import Posthog + + redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True) + cache = RedisFlagCache(redis_client, service_key="my-service") + + posthog = Posthog( + "", + personal_api_key="", + flag_definition_cache_provider=cache, + ) + +Requirements: + pip install redis +""" + +import json +import uuid + +from posthog import FlagDefinitionCacheData, FlagDefinitionCacheProvider +from redis import Redis +from typing import Optional + + +class RedisFlagCache(FlagDefinitionCacheProvider): + """ + A distributed cache for PostHog feature flag definitions using Redis. + + In a multi-instance deployment (e.g., multiple serverless functions or containers), + we want only ONE instance to poll PostHog for flag updates, while all instances + share the cached results. This prevents N instances from making N redundant API calls. + + The implementation uses leader election: + - One instance "wins" and becomes responsible for fetching + - Other instances read from the shared cache + - If the leader dies, the lock expires (TTL) and another instance takes over + + Uses Lua scripts for atomic operations, following Redis distributed lock best practices: + https://redis.io/docs/latest/develop/clients/patterns/distributed-locks/ + """ + + LOCK_TTL_MS = 60 * 1000 # 60 seconds, should be longer than the flags poll interval + CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours + + # Lua script: acquire lock if free, or extend if we own it + _LUA_TRY_LEAD = """ + local current = redis.call('GET', KEYS[1]) + if current == false then + redis.call('SET', KEYS[1], ARGV[1], 'PX', ARGV[2]) + return 1 + elseif current == ARGV[1] then + redis.call('PEXPIRE', KEYS[1], ARGV[2]) + return 1 + end + return 0 + """ + + # Lua script: release lock only if we own it + _LUA_STOP_LEAD = """ + if redis.call('GET', KEYS[1]) == ARGV[1] then + return redis.call('DEL', KEYS[1]) + end + return 0 + """ + + def __init__(self, redis: Redis[str], service_key: str): + """ + Initialize the Redis flag cache. + + Args: + redis: A redis-py client instance. Must be configured with + decode_responses=True for correct string handling. + service_key: A unique identifier for this service/environment. + Used to scope Redis keys, allowing multiple services + or environments to share the same Redis instance. + Examples: "my-api-prod", "checkout-service", "staging". + + Redis Keys Created: + - posthog:flags:{service_key} - Cached flag definitions (JSON) + - posthog:flags:{service_key}:lock - Leader election lock + + Example: + redis_client = redis.Redis( + host='localhost', + port=6379, + decode_responses=True + ) + cache = RedisFlagCache(redis_client, service_key="my-api-prod") + """ + self._redis = redis + self._cache_key = f"posthog:flags:{service_key}" + self._lock_key = f"posthog:flags:{service_key}:lock" + self._instance_id = str(uuid.uuid4()) + self._try_lead = self._redis.register_script(self._LUA_TRY_LEAD) + self._stop_lead = self._redis.register_script(self._LUA_STOP_LEAD) + + def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]: + """ + Retrieve cached flag definitions from Redis. + + Returns: + Cached flag definitions if available, None otherwise. + """ + cached = self._redis.get(self._cache_key) + return json.loads(cached) if cached else None + + def should_fetch_flag_definitions(self) -> bool: + """ + Determines if this instance should fetch flag definitions from PostHog. + + Atomically either: + - Acquires the lock if no one holds it, OR + - Extends the lock TTL if we already hold it + + Returns: + True if this instance is the leader and should fetch, False otherwise. + """ + result = self._try_lead( + keys=[self._lock_key], + args=[self._instance_id, self.LOCK_TTL_MS], + ) + return result == 1 + + def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None: + """ + Store fetched flag definitions in Redis. + + Args: + data: The flag definitions to cache. + """ + self._redis.set(self._cache_key, json.dumps(data), ex=self.CACHE_TTL_SECONDS) + + def shutdown(self) -> None: + """ + Release leadership if we hold it. Safe to call even if not the leader. + """ + self._stop_lead(keys=[self._lock_key], args=[self._instance_id]) diff --git a/remote_config_example.py b/examples/remote_config.py similarity index 100% rename from remote_config_example.py rename to examples/remote_config.py diff --git a/mypy-baseline.txt b/mypy-baseline.txt index 0e32db71..232ce8be 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -26,14 +26,9 @@ posthog/client.py:0: error: Incompatible types in assignment (expression has typ posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Any, Any]", variable has type "None") [assignment] posthog/client.py:0: error: "None" has no attribute "__iter__" (not iterable) [attr-defined] posthog/client.py:0: error: Statement is unreachable [unreachable] -posthog/client.py:0: error: Incompatible types in assignment (expression has type "Any | dict[Any, Any]", variable has type "None") [assignment] -posthog/client.py:0: error: Incompatible types in assignment (expression has type "Any | dict[Any, Any]", variable has type "None") [assignment] -posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "None") [assignment] -posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "None") [assignment] posthog/client.py:0: error: Right operand of "and" is never evaluated [unreachable] posthog/client.py:0: error: Incompatible types in assignment (expression has type "Poller", variable has type "None") [assignment] posthog/client.py:0: error: "None" has no attribute "start" [attr-defined] -posthog/client.py:0: error: "None" has no attribute "get" [attr-defined] posthog/client.py:0: error: Statement is unreachable [unreachable] posthog/client.py:0: error: Statement is unreachable [unreachable] posthog/client.py:0: error: Name "urlparse" already defined (possibly by an import) [no-redef] diff --git a/posthog/__init__.py b/posthog/__init__.py index b513d8de..a78a1165 100644 --- a/posthog/__init__.py +++ b/posthog/__init__.py @@ -22,6 +22,10 @@ InconclusiveMatchError as InconclusiveMatchError, RequiresServerEvaluation as RequiresServerEvaluation, ) +from posthog.flag_definition_cache import ( + FlagDefinitionCacheData as FlagDefinitionCacheData, + FlagDefinitionCacheProvider as FlagDefinitionCacheProvider, +) from posthog.request import ( disable_connection_reuse as disable_connection_reuse, enable_keep_alive as enable_keep_alive, diff --git a/posthog/client.py b/posthog/client.py index 188cbf67..4ddea56f 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -28,6 +28,10 @@ RequiresServerEvaluation, match_feature_flag_properties, ) +from posthog.flag_definition_cache import ( + FlagDefinitionCacheData, + FlagDefinitionCacheProvider, +) from posthog.poller import Poller from posthog.request import ( DEFAULT_HOST, @@ -184,6 +188,7 @@ def __init__( before_send=None, flag_fallback_cache_url=None, enable_local_evaluation=True, + flag_definition_cache_provider: Optional[FlagDefinitionCacheProvider] = None, capture_exception_code_variables=False, code_variables_mask_patterns=None, code_variables_ignore_patterns=None, @@ -222,8 +227,8 @@ def __init__( self.timeout = timeout self._feature_flags = None # private variable to store flags self.feature_flags_by_key = None - self.group_type_mapping = None - self.cohorts = None + self.group_type_mapping: Optional[dict[str, str]] = None + self.cohorts: Optional[dict[str, Any]] = None self.poll_interval = poll_interval self.feature_flags_request_timeout_seconds = ( feature_flags_request_timeout_seconds @@ -233,6 +238,7 @@ def __init__( self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url) self.flag_definition_version = 0 self._flags_etag: Optional[str] = None + self._flag_definition_cache_provider = flag_definition_cache_provider self.disabled = disabled self.disable_geoip = disable_geoip self.historical_migration = historical_migration @@ -1169,17 +1175,25 @@ def join(self): posthog.join() ``` """ - for consumer in self.consumers: - consumer.pause() - try: - consumer.join() - except RuntimeError: - # consumer thread has not started - pass + if self.consumers: + for consumer in self.consumers: + consumer.pause() + try: + consumer.join() + except RuntimeError: + # consumer thread has not started + pass if self.poller: self.poller.stop() + # Shutdown the cache provider (release locks, cleanup) + if self._flag_definition_cache_provider: + try: + self._flag_definition_cache_provider.shutdown() + except Exception as e: + self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}") + def shutdown(self): """ Flush all messages and cleanly shutdown the client. Call this before the process ends in serverless environments to avoid data loss. @@ -1195,7 +1209,71 @@ def shutdown(self): if self.exception_capture: self.exception_capture.close() + def _update_flag_state( + self, data: FlagDefinitionCacheData, old_flags_by_key: Optional[dict] = None + ) -> None: + """Update internal flag state from cache data and invalidate evaluation cache if changed.""" + self.feature_flags = data["flags"] + self.group_type_mapping = data["group_type_mapping"] + self.cohorts = data["cohorts"] + + # Invalidate evaluation cache if flag definitions changed + if ( + self.flag_cache + and old_flags_by_key is not None + and old_flags_by_key != (self.feature_flags_by_key or {}) + ): + old_version = self.flag_definition_version + self.flag_definition_version += 1 + self.flag_cache.invalidate_version(old_version) + def _load_feature_flags(self): + should_fetch = True + if self._flag_definition_cache_provider: + try: + should_fetch = ( + self._flag_definition_cache_provider.should_fetch_flag_definitions() + ) + except Exception as e: + self.log.error( + f"[FEATURE FLAGS] Cache provider should_fetch error: {e}" + ) + # Fail-safe: fetch from API if cache provider errors + should_fetch = True + + # If not fetching, try to get from cache + if not should_fetch and self._flag_definition_cache_provider: + try: + cached_data = ( + self._flag_definition_cache_provider.get_flag_definitions() + ) + if cached_data: + self.log.debug( + "[FEATURE FLAGS] Using cached flag definitions from external cache" + ) + self._update_flag_state( + cached_data, old_flags_by_key=self.feature_flags_by_key or {} + ) + self._last_feature_flag_poll = datetime.now(tz=tzutc()) + return + else: + # Emergency fallback: if cache is empty and we have no flags, fetch anyway. + # There's really no other way of recovering in this case. + if not self.feature_flags: + self.log.debug( + "[FEATURE FLAGS] Cache empty and no flags loaded, falling back to API fetch" + ) + should_fetch = True + except Exception as e: + self.log.error(f"[FEATURE FLAGS] Cache provider get error: {e}") + # Fail-safe: fetch from API if cache provider errors + should_fetch = True + + if should_fetch: + self._fetch_feature_flags_from_api() + + def _fetch_feature_flags_from_api(self): + """Fetch feature flags from the PostHog API.""" try: # Store old flags to detect changes old_flags_by_key: dict[str, dict] = self.feature_flags_by_key or {} @@ -1225,17 +1303,21 @@ def _load_feature_flags(self): ) return - self.feature_flags = response.data["flags"] or [] - self.group_type_mapping = response.data["group_type_mapping"] or {} - self.cohorts = response.data["cohorts"] or {} + self._update_flag_state(response.data, old_flags_by_key=old_flags_by_key) - # Check if flag definitions changed and update version - if self.flag_cache and old_flags_by_key != ( - self.feature_flags_by_key or {} - ): - old_version = self.flag_definition_version - self.flag_definition_version += 1 - self.flag_cache.invalidate_version(old_version) + # Store in external cache if provider is configured + if self._flag_definition_cache_provider: + try: + self._flag_definition_cache_provider.on_flag_definitions_received( + { + "flags": self.feature_flags or [], + "group_type_mapping": self.group_type_mapping or {}, + "cohorts": self.cohorts or {}, + } + ) + except Exception as e: + self.log.error(f"[FEATURE FLAGS] Cache provider store error: {e}") + # Flags are already in memory, so continue normally except APIError as e: if e.status == 401: @@ -1335,7 +1417,8 @@ def _compute_flag_locally( flag_filters = feature_flag.get("filters") or {} aggregation_group_type_index = flag_filters.get("aggregation_group_type_index") if aggregation_group_type_index is not None: - group_name = self.group_type_mapping.get(str(aggregation_group_type_index)) + group_type_mapping = self.group_type_mapping or {} + group_name = group_type_mapping.get(str(aggregation_group_type_index)) if not group_name: self.log.warning( diff --git a/posthog/flag_definition_cache.py b/posthog/flag_definition_cache.py new file mode 100644 index 00000000..330bbd45 --- /dev/null +++ b/posthog/flag_definition_cache.py @@ -0,0 +1,127 @@ +""" +Flag Definition Cache Provider interface for multi-worker environments. + +EXPERIMENTAL: This API may change in future minor version bumps. + +This module provides an interface for external caching of feature flag definitions, +enabling multi-worker environments (Kubernetes, load-balanced servers, serverless +functions) to share flag definitions and reduce API calls. + +Usage: + + from posthog import Posthog + from posthog.flag_definition_cache import FlagDefinitionCacheProvider + + cache = RedisFlagDefinitionCache(redis_client, "my-team") + posthog = Posthog( + "", + personal_api_key="", + flag_definition_cache_provider=cache, + ) +""" + +from typing import Any, Dict, List, Optional, Protocol, runtime_checkable + +from typing_extensions import Required, TypedDict + + +class FlagDefinitionCacheData(TypedDict): + """ + Data structure for cached flag definitions. + + Attributes: + flags: List of feature flag definition dictionaries from the API. + group_type_mapping: Mapping of group type indices to group names. + cohorts: Dictionary of cohort definitions for local evaluation. + """ + + flags: Required[List[Dict[str, Any]]] + group_type_mapping: Required[Dict[str, str]] + cohorts: Required[Dict[str, Any]] + + +@runtime_checkable +class FlagDefinitionCacheProvider(Protocol): + """ + Interface for external caching of feature flag definitions. + + Enables multi-worker environments to share flag definitions, reducing API + calls while ensuring all workers have consistent data. + + EXPERIMENTAL: This API may change in future minor version bumps. + + The four methods handle the complete lifecycle of flag definition caching: + + 1. `should_fetch_flag_definitions()` - Called before each poll to determine + if this worker should fetch new definitions. Use for distributed lock + coordination to ensure only one worker fetches at a time. + + 2. `get_flag_definitions()` - Called when `should_fetch_flag_definitions()` + returns False. Returns cached definitions if available. + + 3. `on_flag_definitions_received()` - Called after successfully fetching + new definitions from the API. Store the data in your external cache + and release any locks. + + 4. `shutdown()` - Called when the PostHog client shuts down. Release any + distributed locks and clean up resources. + + Error Handling: + All methods are wrapped in try/except. Errors will be logged but will + never break flag evaluation. On error: + - `should_fetch_flag_definitions()` errors default to fetching (fail-safe) + - `get_flag_definitions()` errors fall back to API fetch + - `on_flag_definitions_received()` errors are logged but flags remain in memory + - `shutdown()` errors are logged but shutdown continues + """ + + def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]: + """ + Retrieve cached flag definitions. + + Returns: + Cached flag definitions if available and valid, None otherwise. + Returning None will trigger a fetch from the API if this worker + has no flags loaded yet. + """ + ... + + def should_fetch_flag_definitions(self) -> bool: + """ + Determine whether this instance should fetch new flag definitions. + + Use this for distributed lock coordination. Only one worker should + return True to avoid thundering herd problems. A typical implementation + uses a distributed lock (e.g., Redis SETNX) that expires after the + poll interval. + + Returns: + True if this instance should fetch from the API, False otherwise. + When False, the client will call `get_flag_definitions()` to + retrieve cached data instead. + """ + ... + + def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None: + """ + Called after successfully receiving new flag definitions from PostHog. + + Use this to store the data in your external cache and release any + distributed locks acquired in `should_fetch_flag_definitions()`. + + Args: + data: The flag definitions to cache, containing flags, + group_type_mapping, and cohorts. + """ + ... + + def shutdown(self) -> None: + """ + Called when the PostHog client shuts down. + + Use this to release any distributed locks and clean up resources. + This method is called even if `should_fetch_flag_definitions()` + returned False, so implementations should handle the case where + no lock was acquired. + """ + ... diff --git a/posthog/test/test_flag_definition_cache.py b/posthog/test/test_flag_definition_cache.py new file mode 100644 index 00000000..ee17eb9a --- /dev/null +++ b/posthog/test/test_flag_definition_cache.py @@ -0,0 +1,612 @@ +""" +Tests for FlagDefinitionCacheProvider functionality. + +These tests follow the patterns from the TypeScript implementation in posthog-js/packages/node. +""" + +import threading +import unittest +from typing import Optional +from unittest import mock + +from posthog.client import Client +from posthog.flag_definition_cache import ( + FlagDefinitionCacheData, + FlagDefinitionCacheProvider, +) +from posthog.request import GetResponse +from posthog.test.test_utils import FAKE_TEST_API_KEY + + +class MockCacheProvider: + """A mock implementation of FlagDefinitionCacheProvider for testing.""" + + def __init__(self): + self.stored_data: Optional[FlagDefinitionCacheData] = None + self.should_fetch_return_value = True + self.get_call_count = 0 + self.should_fetch_call_count = 0 + self.on_received_call_count = 0 + self.shutdown_call_count = 0 + self.should_fetch_error: Optional[Exception] = None + self.get_error: Optional[Exception] = None + self.on_received_error: Optional[Exception] = None + self.shutdown_error: Optional[Exception] = None + + def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]: + self.get_call_count += 1 + if self.get_error: + raise self.get_error + return self.stored_data + + def should_fetch_flag_definitions(self) -> bool: + self.should_fetch_call_count += 1 + if self.should_fetch_error: + raise self.should_fetch_error + return self.should_fetch_return_value + + def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None: + self.on_received_call_count += 1 + if self.on_received_error: + raise self.on_received_error + self.stored_data = data + + def shutdown(self) -> None: + self.shutdown_call_count += 1 + if self.shutdown_error: + raise self.shutdown_error + + +class TestFlagDefinitionCacheProvider(unittest.TestCase): + """Tests for the FlagDefinitionCacheProvider protocol.""" + + @classmethod + def setUpClass(cls): + # Prevent real HTTP requests + cls.client_post_patcher = mock.patch("posthog.client.batch_post") + cls.consumer_post_patcher = mock.patch("posthog.consumer.batch_post") + cls.client_post_patcher.start() + cls.consumer_post_patcher.start() + + @classmethod + def tearDownClass(cls): + cls.client_post_patcher.stop() + cls.consumer_post_patcher.stop() + + def setUp(self): + self.cache_provider = MockCacheProvider() + self.sample_flags_data: FlagDefinitionCacheData = { + "flags": [ + {"key": "test-flag", "active": True, "filters": {}}, + {"key": "another-flag", "active": False, "filters": {}}, + ], + "group_type_mapping": {"0": "company", "1": "project"}, + "cohorts": {"1": {"properties": []}}, + } + + def tearDown(self): + # Ensure client cleanup + pass + + def _create_client_with_cache(self) -> Client: + """Create a client with the mock cache provider.""" + return Client( + FAKE_TEST_API_KEY, + personal_api_key="test-personal-key", + flag_definition_cache_provider=self.cache_provider, + sync_mode=True, + enable_local_evaluation=False, # Disable poller for tests + ) + + +class TestCacheInitialization(TestFlagDefinitionCacheProvider): + """Tests for cache initialization behavior.""" + + @mock.patch("posthog.client.get") + def test_uses_cached_data_when_should_fetch_returns_false(self, mock_get): + """When should_fetch returns False and cache has data, use cached data.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = self.sample_flags_data + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should not call API + mock_get.assert_not_called() + + # Should have called cache methods + self.assertEqual(self.cache_provider.should_fetch_call_count, 1) + self.assertEqual(self.cache_provider.get_call_count, 1) + + # Flags should be loaded from cache + self.assertEqual(len(client.feature_flags), 2) + self.assertEqual(client.feature_flags[0]["key"], "test-flag") + + client.join() + + @mock.patch("posthog.client.get") + def test_fetches_from_api_when_should_fetch_returns_true(self, mock_get): + """When should_fetch returns True, fetch from API.""" + self.cache_provider.should_fetch_return_value = True + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should call API + mock_get.assert_called_once() + + # Should have called should_fetch but not get + self.assertEqual(self.cache_provider.should_fetch_call_count, 1) + self.assertEqual(self.cache_provider.get_call_count, 0) + + # Should have called on_received to store in cache + self.assertEqual(self.cache_provider.on_received_call_count, 1) + + client.join() + + @mock.patch("posthog.client.get") + def test_emergency_fallback_when_cache_empty_and_no_flags(self, mock_get): + """When should_fetch=False but cache is empty and no flags loaded, fetch anyway.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = None # Empty cache + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should call API due to emergency fallback + mock_get.assert_called_once() + + # Should have called on_received + self.assertEqual(self.cache_provider.on_received_call_count, 1) + + client.join() + + @mock.patch("posthog.client.get") + def test_preserves_existing_flags_when_cache_returns_none(self, mock_get): + """When cache returns None but client has flags, preserve existing flags.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = None # Empty cache + + client = self._create_client_with_cache() + + # Pre-load flags (simulating a previous successful fetch) + client.feature_flags = self.sample_flags_data["flags"] + client.group_type_mapping = self.sample_flags_data["group_type_mapping"] + client.cohorts = self.sample_flags_data["cohorts"] + + client._load_feature_flags() + + # Should NOT call API since we already have flags + mock_get.assert_not_called() + + # Existing flags should be preserved + self.assertEqual(len(client.feature_flags), 2) + self.assertEqual(client.feature_flags[0]["key"], "test-flag") + + client.join() + + +class TestFetchCoordination(TestFlagDefinitionCacheProvider): + """Tests for fetch coordination between workers.""" + + @mock.patch("posthog.client.get") + def test_calls_should_fetch_before_each_poll(self, mock_get): + """should_fetch_flag_definitions is called before each poll cycle.""" + self.cache_provider.should_fetch_return_value = True + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + + # First poll + client._load_feature_flags() + self.assertEqual(self.cache_provider.should_fetch_call_count, 1) + + # Second poll + client._load_feature_flags() + self.assertEqual(self.cache_provider.should_fetch_call_count, 2) + + client.join() + + @mock.patch("posthog.client.get") + def test_does_not_call_on_received_when_fetch_skipped(self, mock_get): + """on_flag_definitions_received is NOT called when fetch is skipped.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = self.sample_flags_data + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should not call on_received since we didn't fetch + self.assertEqual(self.cache_provider.on_received_call_count, 0) + + client.join() + + @mock.patch("posthog.client.get") + def test_stores_data_in_cache_after_api_fetch(self, mock_get): + """on_flag_definitions_received receives the fetched data.""" + self.cache_provider.should_fetch_return_value = True + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should have stored data in cache + self.assertEqual(self.cache_provider.on_received_call_count, 1) + self.assertIsNotNone(self.cache_provider.stored_data) + self.assertEqual(len(self.cache_provider.stored_data["flags"]), 2) + + client.join() + + @mock.patch("posthog.client.get") + def test_304_not_modified_does_not_update_cache(self, mock_get): + """When API returns 304 Not Modified, cache should not be updated.""" + self.cache_provider.should_fetch_return_value = True + + # First fetch to populate flags and ETag + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Verify initial fetch worked + self.assertEqual(self.cache_provider.on_received_call_count, 1) + self.assertEqual(len(client.feature_flags), 2) + + # Second fetch returns 304 Not Modified + mock_get.return_value = GetResponse( + data=None, etag="test-etag", not_modified=True + ) + + client._load_feature_flags() + + # API was called twice + self.assertEqual(mock_get.call_count, 2) + + # should_fetch was called twice + self.assertEqual(self.cache_provider.should_fetch_call_count, 2) + + # on_received should NOT be called again (304 = no new data) + self.assertEqual(self.cache_provider.on_received_call_count, 1) + + # Flags should still be present + self.assertEqual(len(client.feature_flags), 2) + + client.join() + + +class TestErrorHandling(TestFlagDefinitionCacheProvider): + """Tests for error handling in cache provider operations.""" + + @mock.patch("posthog.client.get") + def test_should_fetch_error_defaults_to_fetching(self, mock_get): + """When should_fetch throws an error, default to fetching from API.""" + self.cache_provider.should_fetch_error = Exception("Lock acquisition failed") + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should still fetch from API + mock_get.assert_called_once() + + # Flags should be loaded + self.assertEqual(len(client.feature_flags), 2) + + client.join() + + @mock.patch("posthog.client.get") + def test_get_error_falls_back_to_api_fetch(self, mock_get): + """When get_flag_definitions throws an error, fetch from API.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.get_error = Exception("Cache read failed") + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should fall back to API + mock_get.assert_called_once() + + client.join() + + @mock.patch("posthog.client.get") + def test_on_received_error_keeps_flags_in_memory(self, mock_get): + """When on_flag_definitions_received throws, flags are still in memory.""" + self.cache_provider.should_fetch_return_value = True + self.cache_provider.on_received_error = Exception("Cache write failed") + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Flags should still be loaded in memory despite cache error + self.assertEqual(len(client.feature_flags), 2) + self.assertEqual(client.feature_flags[0]["key"], "test-flag") + + client.join() + + @mock.patch("posthog.client.get") + def test_shutdown_error_is_logged_but_continues(self, mock_get): + """When shutdown throws an error, it's logged but shutdown continues.""" + self.cache_provider.shutdown_error = Exception("Lock release failed") + + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Should not raise when joining + client.join() + + # Shutdown was called + self.assertEqual(self.cache_provider.shutdown_call_count, 1) + + +class TestShutdownLifecycle(TestFlagDefinitionCacheProvider): + """Tests for shutdown lifecycle.""" + + @mock.patch("posthog.client.get") + def test_shutdown_calls_cache_provider_shutdown(self, mock_get): + """Client shutdown calls cache provider shutdown.""" + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Shutdown + client.join() + + self.assertEqual(self.cache_provider.shutdown_call_count, 1) + + @mock.patch("posthog.client.get") + def test_shutdown_called_even_without_fetching(self, mock_get): + """Shutdown is called even when cache was used instead of fetching.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = self.sample_flags_data + + client = self._create_client_with_cache() + client._load_feature_flags() + client.join() + + # Shutdown should still be called + self.assertEqual(self.cache_provider.shutdown_call_count, 1) + + @mock.patch("posthog.client.get") + def test_multiple_join_calls_only_shutdown_once(self, mock_get): + """Calling join() multiple times should only call cache provider shutdown once.""" + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Call join multiple times + client.join() + client.join() + client.join() + + # Shutdown should be called each time (current behavior - no guard) + # This test documents the current behavior + self.assertGreaterEqual(self.cache_provider.shutdown_call_count, 1) + + +class TestBackwardCompatibility(TestFlagDefinitionCacheProvider): + """Tests for backward compatibility without cache provider.""" + + @mock.patch("posthog.client.get") + def test_works_without_cache_provider(self, mock_get): + """Client works normally without a cache provider configured.""" + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + # Create client without cache provider + client = Client( + FAKE_TEST_API_KEY, + personal_api_key="test-personal-key", + sync_mode=True, + enable_local_evaluation=False, + ) + client._load_feature_flags() + + # Should fetch from API + mock_get.assert_called_once() + + # Flags should be loaded + self.assertEqual(len(client.feature_flags), 2) + + client.join() + + +class TestDataIntegrity(TestFlagDefinitionCacheProvider): + """Tests for data integrity between cache and client state.""" + + @mock.patch("posthog.client.get") + def test_cached_flags_available_for_evaluation(self, mock_get): + """Flags loaded from cache are available for local evaluation.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = { + "flags": [ + { + "key": "test-flag", + "active": True, + "filters": { + "groups": [ + { + "properties": [], + "rollout_percentage": 100, + } + ] + }, + } + ], + "group_type_mapping": {}, + "cohorts": {}, + } + + client = self._create_client_with_cache() + client._load_feature_flags() + + # Flag should be accessible + self.assertEqual(len(client.feature_flags), 1) + self.assertEqual(client.feature_flags_by_key["test-flag"]["key"], "test-flag") + + client.join() + + @mock.patch("posthog.client.get") + def test_group_type_mapping_loaded_from_cache(self, mock_get): + """Group type mapping is correctly loaded from cache.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = self.sample_flags_data + + client = self._create_client_with_cache() + client._load_feature_flags() + + self.assertEqual(client.group_type_mapping["0"], "company") + self.assertEqual(client.group_type_mapping["1"], "project") + + client.join() + + @mock.patch("posthog.client.get") + def test_cohorts_loaded_from_cache(self, mock_get): + """Cohorts are correctly loaded from cache.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = self.sample_flags_data + + client = self._create_client_with_cache() + client._load_feature_flags() + + self.assertIn("1", client.cohorts) + + client.join() + + @mock.patch("posthog.client.get") + def test_cache_updated_when_api_returns_new_data(self, mock_get): + """State transition: cache has old data -> API returns new -> cache updated.""" + # Start with old cached data + old_flags_data: FlagDefinitionCacheData = { + "flags": [{"key": "old-flag", "active": True, "filters": {}}], + "group_type_mapping": {}, + "cohorts": {}, + } + self.cache_provider.stored_data = old_flags_data + self.cache_provider.should_fetch_return_value = False + + client = self._create_client_with_cache() + + # First load from cache + client._load_feature_flags() + self.assertEqual(client.feature_flags[0]["key"], "old-flag") + self.assertEqual(self.cache_provider.on_received_call_count, 0) + + # Now trigger API fetch with new data + self.cache_provider.should_fetch_return_value = True + new_flags_data: FlagDefinitionCacheData = { + "flags": [{"key": "new-flag", "active": True, "filters": {}}], + "group_type_mapping": {"0": "company"}, + "cohorts": {"1": {"properties": []}}, + } + mock_get.return_value = GetResponse( + data=new_flags_data, etag="new-etag", not_modified=False + ) + + client._load_feature_flags() + + # Verify new flags loaded + self.assertEqual(client.feature_flags[0]["key"], "new-flag") + self.assertEqual(client.group_type_mapping["0"], "company") + + # Verify cache was updated + self.assertEqual(self.cache_provider.on_received_call_count, 1) + self.assertEqual(self.cache_provider.stored_data["flags"][0]["key"], "new-flag") + + client.join() + + +class TestConcurrency(TestFlagDefinitionCacheProvider): + """Tests for thread safety and concurrent access.""" + + @mock.patch("posthog.client.get") + def test_concurrent_load_feature_flags_is_thread_safe(self, mock_get): + """Multiple threads calling _load_feature_flags should not cause errors.""" + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + errors = [] + + def load_flags(): + try: + client._load_feature_flags() + except Exception as e: + errors.append(e) + + # Launch 5 threads concurrently + threads = [threading.Thread(target=load_flags) for _ in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + + # Should complete without errors + self.assertEqual(len(errors), 0, f"Unexpected errors: {errors}") + + # Flags should be loaded + self.assertIsNotNone(client.feature_flags) + self.assertEqual(len(client.feature_flags), 2) + + client.join() + + +class TestProtocolCompliance(unittest.TestCase): + """Tests for Protocol compliance.""" + + def test_mock_provider_is_protocol_instance(self): + """MockCacheProvider satisfies FlagDefinitionCacheProvider protocol.""" + provider = MockCacheProvider() + self.assertIsInstance(provider, FlagDefinitionCacheProvider) + + def test_incomplete_provider_is_not_protocol_instance(self): + """Class missing methods is not a FlagDefinitionCacheProvider.""" + + class IncompleteProvider: + def get_flag_definitions(self): + return None + + provider = IncompleteProvider() + self.assertNotIsInstance(provider, FlagDefinitionCacheProvider) + + +if __name__ == "__main__": + unittest.main()