Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions src/sentry/utils/session_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any
from uuid import uuid4

import sentry_sdk
Expand All @@ -7,6 +8,8 @@
from sentry.utils.json import dumps, loads

EXPIRATION_TTL = 10 * 60
# Allow a few retries to account for replica lag immediately after writing the session.
_PERSISTENCE_RECHECK_ATTEMPTS = 3


class RedisSessionStore:
Expand Down Expand Up @@ -55,6 +58,8 @@ def __init__(self, request, prefix, ttl=EXPIRATION_TTL):
self.request = request
self.prefix = prefix
self.ttl = ttl
self._state_cache: dict[str, Any] | None = None
self._pending_writes = 0

@property
def _client(self):
Expand Down Expand Up @@ -83,6 +88,8 @@ def regenerate(self, initial_state=None):

value = dumps(initial_state)
self._client.setex(redis_key, self.ttl, value)
self._state_cache = loads(value)
self._pending_writes = _PERSISTENCE_RECHECK_ATTEMPTS

def clear(self):
if not self.redis_key:
Expand All @@ -93,6 +100,8 @@ def clear(self):
session = self.request.session
del session[self.session_key]
self.mark_session()
self._state_cache = None
self._pending_writes = 0

def is_valid(self):
return bool(self.redis_key and self.get_state() is not None)
Expand All @@ -103,13 +112,20 @@ def get_state(self):

state_json = self._client.get(self.redis_key)
if not state_json:
if self._pending_writes > 0 and self._state_cache is not None:
self._pending_writes -= 1
return self._state_cache
return None

try:
return loads(state_json)
state = loads(state_json)
except Exception as e:
sentry_sdk.capture_exception(e)
return None
return None

self._state_cache = state
self._pending_writes = 0
return state


def redis_property(key: str):
Expand All @@ -130,6 +146,8 @@ def setter(store: "RedisSessionStore", value):
return

state[key] = value
store._state_cache = state
store._pending_writes = _PERSISTENCE_RECHECK_ATTEMPTS
store._client.setex(store.redis_key, store.ttl, dumps(state))

return property(getter, setter)
55 changes: 55 additions & 0 deletions tests/sentry/utils/test_session_store.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,35 @@
from unittest import TestCase
from unittest.mock import patch

from django.test import Client, RequestFactory

from sentry.utils.session_store import RedisSessionStore, redis_property


class FakeRedisClient:
def __init__(self) -> None:
self._data: dict[str, str] = {}
self.read_failures = 0

def setex(self, key, ttl, value):
if key is None:
return
self._data[key] = value

def get(self, key):
if key is None:
return None
if self.read_failures > 0:
self.read_failures -= 1
return None
return self._data.get(key)

def delete(self, key):
if key is None:
return
self._data.pop(key, None)


class RedisSessionStoreTestCase(TestCase):
class TestRedisSessionStore(RedisSessionStore):
some_value = redis_property("some_value")
Expand Down Expand Up @@ -65,3 +90,33 @@ def test_malformed_state(self) -> None:

assert self.store.is_valid() is False
assert self.store.get_state() is None

def test_write_survives_transient_read_failure(self) -> None:
fake_client = FakeRedisClient()
with patch("sentry.utils.session_store.redis.redis_clusters.get", return_value=fake_client):
store = self.TestRedisSessionStore(self.request, "test-store")
store.regenerate()

# Simulate a laggy replica not returning the freshly written state.
fake_client.read_failures = 1
store.some_value = "test_value"

store2 = self.TestRedisSessionStore(self.request, "test-store")
assert store2.some_value == "test_value"

def test_cache_eventually_invalidates_when_missing(self) -> None:
fake_client = FakeRedisClient()
with patch("sentry.utils.session_store.redis.redis_clusters.get", return_value=fake_client):
store = self.TestRedisSessionStore(self.request, "test-store")
store.regenerate()

fake_client.delete(store.redis_key)

state = None
for _ in range(10):
state = store.get_state()
if state is None:
break

assert state is None
assert store.is_valid() is False
Loading