Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions examples/redis_flag_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
Redis-based distributed cache for PostHog feature flag definitions.

This example demonstrates how to implement a FlagDefinitionCacheProvider
using Redis for multi-instance deployments (leader election pattern).

Usage:
import redis
from posthog import Posthog

redis_client = redis.Redis(host='localhost', port=6379, decode_responses=True)
cache = RedisFlagCache(redis_client, service_key="my-service")

posthog = Posthog(
"<project_api_key>",
personal_api_key="<personal_api_key>",
flag_definition_cache_provider=cache,
)

Requirements:
pip install redis
"""

import json
import uuid

from posthog import FlagDefinitionCacheData, FlagDefinitionCacheProvider
from redis import Redis
from typing import Optional


class RedisFlagCache(FlagDefinitionCacheProvider):
"""
A distributed cache for PostHog feature flag definitions using Redis.

In a multi-instance deployment (e.g., multiple serverless functions or containers),
we want only ONE instance to poll PostHog for flag updates, while all instances
share the cached results. This prevents N instances from making N redundant API calls.

The implementation uses leader election:
- One instance "wins" and becomes responsible for fetching
- Other instances read from the shared cache
- If the leader dies, the lock expires (TTL) and another instance takes over

Uses Lua scripts for atomic operations, following Redis distributed lock best practices:
https://redis.io/docs/latest/develop/clients/patterns/distributed-locks/
"""

LOCK_TTL_MS = 60 * 1000 # 60 seconds, should be longer than the flags poll interval
CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours

# Lua script: acquire lock if free, or extend if we own it
_LUA_TRY_LEAD = """
local current = redis.call('GET', KEYS[1])
if current == false then
redis.call('SET', KEYS[1], ARGV[1], 'PX', ARGV[2])
return 1
elseif current == ARGV[1] then
redis.call('PEXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
"""

# Lua script: release lock only if we own it
_LUA_STOP_LEAD = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
"""

def __init__(self, redis: Redis[str], service_key: str):
"""
Initialize the Redis flag cache.

Args:
redis: A redis-py client instance. Must be configured with
decode_responses=True for correct string handling.
service_key: A unique identifier for this service/environment.
Used to scope Redis keys, allowing multiple services
or environments to share the same Redis instance.
Examples: "my-api-prod", "checkout-service", "staging".

Redis Keys Created:
- posthog:flags:{service_key} - Cached flag definitions (JSON)
- posthog:flags:{service_key}:lock - Leader election lock

Example:
redis_client = redis.Redis(
host='localhost',
port=6379,
decode_responses=True
)
cache = RedisFlagCache(redis_client, service_key="my-api-prod")
"""
self._redis = redis
self._cache_key = f"posthog:flags:{service_key}"
self._lock_key = f"posthog:flags:{service_key}:lock"
self._instance_id = str(uuid.uuid4())
self._try_lead = self._redis.register_script(self._LUA_TRY_LEAD)
self._stop_lead = self._redis.register_script(self._LUA_STOP_LEAD)

def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]:
"""
Retrieve cached flag definitions from Redis.

Returns:
Cached flag definitions if available, None otherwise.
"""
cached = self._redis.get(self._cache_key)
return json.loads(cached) if cached else None

def should_fetch_flag_definitions(self) -> bool:
"""
Determines if this instance should fetch flag definitions from PostHog.

Atomically either:
- Acquires the lock if no one holds it, OR
- Extends the lock TTL if we already hold it

Returns:
True if this instance is the leader and should fetch, False otherwise.
"""
result = self._try_lead(
keys=[self._lock_key],
args=[self._instance_id, self.LOCK_TTL_MS],
)
return result == 1

def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None:
"""
Store fetched flag definitions in Redis.

Args:
data: The flag definitions to cache.
"""
self._redis.set(self._cache_key, json.dumps(data), ex=self.CACHE_TTL_SECONDS)

def shutdown(self) -> None:
"""
Release leadership if we hold it. Safe to call even if not the leader.
"""
self._stop_lead(keys=[self._lock_key], args=[self._instance_id])
File renamed without changes.
5 changes: 0 additions & 5 deletions mypy-baseline.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,9 @@ posthog/client.py:0: error: Incompatible types in assignment (expression has typ
posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Any, Any]", variable has type "None") [assignment]
posthog/client.py:0: error: "None" has no attribute "__iter__" (not iterable) [attr-defined]
posthog/client.py:0: error: Statement is unreachable [unreachable]
posthog/client.py:0: error: Incompatible types in assignment (expression has type "Any | dict[Any, Any]", variable has type "None") [assignment]
posthog/client.py:0: error: Incompatible types in assignment (expression has type "Any | dict[Any, Any]", variable has type "None") [assignment]
posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "None") [assignment]
posthog/client.py:0: error: Incompatible types in assignment (expression has type "dict[Never, Never]", variable has type "None") [assignment]
posthog/client.py:0: error: Right operand of "and" is never evaluated [unreachable]
posthog/client.py:0: error: Incompatible types in assignment (expression has type "Poller", variable has type "None") [assignment]
posthog/client.py:0: error: "None" has no attribute "start" [attr-defined]
posthog/client.py:0: error: "None" has no attribute "get" [attr-defined]
posthog/client.py:0: error: Statement is unreachable [unreachable]
posthog/client.py:0: error: Statement is unreachable [unreachable]
posthog/client.py:0: error: Name "urlparse" already defined (possibly by an import) [no-redef]
Expand Down
4 changes: 4 additions & 0 deletions posthog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@
InconclusiveMatchError as InconclusiveMatchError,
RequiresServerEvaluation as RequiresServerEvaluation,
)
from posthog.flag_definition_cache import (
FlagDefinitionCacheData as FlagDefinitionCacheData,
FlagDefinitionCacheProvider as FlagDefinitionCacheProvider,
)
from posthog.request import (
disable_connection_reuse as disable_connection_reuse,
enable_keep_alive as enable_keep_alive,
Expand Down
123 changes: 103 additions & 20 deletions posthog/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
RequiresServerEvaluation,
match_feature_flag_properties,
)
from posthog.flag_definition_cache import (
FlagDefinitionCacheData,
FlagDefinitionCacheProvider,
)
from posthog.poller import Poller
from posthog.request import (
DEFAULT_HOST,
Expand Down Expand Up @@ -184,6 +188,7 @@ def __init__(
before_send=None,
flag_fallback_cache_url=None,
enable_local_evaluation=True,
flag_definition_cache_provider: Optional[FlagDefinitionCacheProvider] = None,
capture_exception_code_variables=False,
code_variables_mask_patterns=None,
code_variables_ignore_patterns=None,
Expand Down Expand Up @@ -222,8 +227,8 @@ def __init__(
self.timeout = timeout
self._feature_flags = None # private variable to store flags
self.feature_flags_by_key = None
self.group_type_mapping = None
self.cohorts = None
self.group_type_mapping: Optional[dict[str, str]] = None
self.cohorts: Optional[dict[str, Any]] = None
self.poll_interval = poll_interval
self.feature_flags_request_timeout_seconds = (
feature_flags_request_timeout_seconds
Expand All @@ -233,6 +238,7 @@ def __init__(
self.flag_cache = self._initialize_flag_cache(flag_fallback_cache_url)
self.flag_definition_version = 0
self._flags_etag: Optional[str] = None
self._flag_definition_cache_provider = flag_definition_cache_provider
self.disabled = disabled
self.disable_geoip = disable_geoip
self.historical_migration = historical_migration
Expand Down Expand Up @@ -1169,17 +1175,25 @@ def join(self):
posthog.join()
```
"""
for consumer in self.consumers:
consumer.pause()
try:
consumer.join()
except RuntimeError:
# consumer thread has not started
pass
if self.consumers:
for consumer in self.consumers:
consumer.pause()
try:
consumer.join()
except RuntimeError:
# consumer thread has not started
pass

if self.poller:
self.poller.stop()

# Shutdown the cache provider (release locks, cleanup)
if self._flag_definition_cache_provider:
try:
self._flag_definition_cache_provider.shutdown()
except Exception as e:
self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}")

def shutdown(self):
"""
Flush all messages and cleanly shutdown the client. Call this before the process ends in serverless environments to avoid data loss.
Expand All @@ -1195,7 +1209,71 @@ def shutdown(self):
if self.exception_capture:
self.exception_capture.close()

def _update_flag_state(
self, data: FlagDefinitionCacheData, old_flags_by_key: Optional[dict] = None
) -> None:
"""Update internal flag state from cache data and invalidate evaluation cache if changed."""
self.feature_flags = data["flags"]
self.group_type_mapping = data["group_type_mapping"]
self.cohorts = data["cohorts"]

# Invalidate evaluation cache if flag definitions changed
if (
self.flag_cache
and old_flags_by_key is not None
and old_flags_by_key != (self.feature_flags_by_key or {})
):
old_version = self.flag_definition_version
self.flag_definition_version += 1
self.flag_cache.invalidate_version(old_version)

def _load_feature_flags(self):
should_fetch = True
if self._flag_definition_cache_provider:
try:
should_fetch = (
self._flag_definition_cache_provider.should_fetch_flag_definitions()
)
except Exception as e:
self.log.error(
f"[FEATURE FLAGS] Cache provider should_fetch error: {e}"
)
# Fail-safe: fetch from API if cache provider errors
should_fetch = True

# If not fetching, try to get from cache
if not should_fetch and self._flag_definition_cache_provider:
try:
cached_data = (
self._flag_definition_cache_provider.get_flag_definitions()
)
if cached_data:
self.log.debug(
"[FEATURE FLAGS] Using cached flag definitions from external cache"
)
self._update_flag_state(
cached_data, old_flags_by_key=self.feature_flags_by_key or {}
)
self._last_feature_flag_poll = datetime.now(tz=tzutc())
return
else:
# Emergency fallback: if cache is empty and we have no flags, fetch anyway.
# There's really no other way of recovering in this case.
if not self.feature_flags:
self.log.debug(
"[FEATURE FLAGS] Cache empty and no flags loaded, falling back to API fetch"
)
should_fetch = True
except Exception as e:
self.log.error(f"[FEATURE FLAGS] Cache provider get error: {e}")
# Fail-safe: fetch from API if cache provider errors
should_fetch = True

if should_fetch:
self._fetch_feature_flags_from_api()

def _fetch_feature_flags_from_api(self):
"""Fetch feature flags from the PostHog API."""
try:
# Store old flags to detect changes
old_flags_by_key: dict[str, dict] = self.feature_flags_by_key or {}
Expand Down Expand Up @@ -1225,17 +1303,21 @@ def _load_feature_flags(self):
)
return

self.feature_flags = response.data["flags"] or []
self.group_type_mapping = response.data["group_type_mapping"] or {}
self.cohorts = response.data["cohorts"] or {}
self._update_flag_state(response.data, old_flags_by_key=old_flags_by_key)

# Check if flag definitions changed and update version
if self.flag_cache and old_flags_by_key != (
self.feature_flags_by_key or {}
):
old_version = self.flag_definition_version
self.flag_definition_version += 1
self.flag_cache.invalidate_version(old_version)
# Store in external cache if provider is configured
if self._flag_definition_cache_provider:
try:
self._flag_definition_cache_provider.on_flag_definitions_received(
{
"flags": self.feature_flags or [],
"group_type_mapping": self.group_type_mapping or {},
"cohorts": self.cohorts or {},
}
)
except Exception as e:
self.log.error(f"[FEATURE FLAGS] Cache provider store error: {e}")
# Flags are already in memory, so continue normally

except APIError as e:
if e.status == 401:
Expand Down Expand Up @@ -1335,7 +1417,8 @@ def _compute_flag_locally(
flag_filters = feature_flag.get("filters") or {}
aggregation_group_type_index = flag_filters.get("aggregation_group_type_index")
if aggregation_group_type_index is not None:
group_name = self.group_type_mapping.get(str(aggregation_group_type_index))
group_type_mapping = self.group_type_mapping or {}
group_name = group_type_mapping.get(str(aggregation_group_type_index))

if not group_name:
self.log.warning(
Expand Down
Loading
Loading