Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ repos:
# basic check
- id: ruff
name: Ruff check
args: ["--fix"]
args: ["--fix"] #, "--unsafe-fixes"

# it needs to be after formatting hooks because the lines might be changed
- repo: https://github.com/pre-commit/mirrors-mypy
Expand Down
6 changes: 1 addition & 5 deletions src/cachier/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,11 +358,7 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds):
)
nonneg_max_age = False
else:
max_allowed_age = (
min(_stale_after, max_age)
if max_age is not None
else _stale_after
)
max_allowed_age = min(_stale_after, max_age)
# note: if max_age < 0, we always consider a value stale
if nonneg_max_age and (now - entry.time <= max_allowed_age):
_print("And it is fresh!")
Expand Down
11 changes: 8 additions & 3 deletions src/cachier/cores/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,17 @@ class RecalculationNeeded(Exception):


def _get_func_str(func: Callable) -> str:
return f".{func.__module__}.{func.__name__}"
"""Return a string identifier for the function (module + name).

We accept Any here because static analysis can't always prove that the
runtime object will have __module__ and __name__, but at runtime the
decorated functions always do.

"""
return f".{func.__module__}.{func.__name__}"

class _BaseCore:
__metaclass__ = abc.ABCMeta

class _BaseCore(metaclass=abc.ABCMeta):
def __init__(
self,
hash_func: Optional[HashFunc],
Expand Down
92 changes: 72 additions & 20 deletions src/cachier/cores/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,54 @@ def set_func(self, func):
super().set_func(func)
self._func_str = _get_func_str(func)

@staticmethod
def _loading_pickle(raw_value) -> Any:
"""Load pickled data with some recovery attempts."""
try:
if isinstance(raw_value, bytes):
return pickle.loads(raw_value)
elif isinstance(raw_value, str):
# try to recover by encoding; prefer utf-8 but fall
# back to latin-1 in case raw binary was coerced to str
try:
return pickle.loads(raw_value.encode("utf-8"))
except Exception:
return pickle.loads(raw_value.encode("latin-1"))
else:
# unexpected type; attempt pickle.loads directly
try:
return pickle.loads(raw_value)
except Exception:
return None
except Exception as exc:
warnings.warn(
f"Redis value deserialization failed: {exc}",
stacklevel=2,
)
return None

@staticmethod
def _get_raw_field(cached_data, field: str):
"""Fetch field from cached_data with bytes/str key handling."""
# try bytes key first, then str key
bkey = field.encode("utf-8")
if bkey in cached_data:
return cached_data[bkey]
return cached_data.get(field)

@staticmethod
def _get_bool_field(cached_data, name: str) -> bool:
"""Fetch boolean field from cached_data."""
raw = _RedisCore._get_raw_field(cached_data, name) or b"false"
if isinstance(raw, bytes):
try:
s = raw.decode("utf-8")
except Exception:
s = raw.decode("latin-1", errors="ignore")
else:
s = str(raw)
return s.lower() == "true"

def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
"""Get entry based on given key from Redis."""
redis_client = self._resolve_redis_client()
Expand All @@ -86,32 +134,28 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:

# Deserialize the value
value = None
if cached_data.get(b"value"):
value = pickle.loads(cached_data[b"value"])
raw_value = _RedisCore._get_raw_field(cached_data, "value")
if raw_value is not None:
value = self._loading_pickle(raw_value)

# Parse timestamp
timestamp_str = cached_data.get(b"timestamp", b"").decode("utf-8")
raw_ts = _RedisCore._get_raw_field(cached_data, "timestamp") or b""
if isinstance(raw_ts, bytes):
try:
timestamp_str = raw_ts.decode("utf-8")
except Exception:
timestamp_str = raw_ts.decode("latin-1", errors="ignore")
else:
timestamp_str = str(raw_ts)
timestamp = (
datetime.fromisoformat(timestamp_str)
if timestamp_str
else datetime.now()
)

# Parse boolean fields
stale = (
cached_data.get(b"stale", b"false").decode("utf-8").lower()
== "true"
)
processing = (
cached_data.get(b"processing", b"false")
.decode("utf-8")
.lower()
== "true"
)
completed = (
cached_data.get(b"completed", b"false").decode("utf-8").lower()
== "true"
)
stale = _RedisCore._get_bool_field(cached_data, "stale")
processing = _RedisCore._get_bool_field(cached_data, "processing")
completed = _RedisCore._get_bool_field(cached_data, "completed")

entry = CacheEntry(
value=value,
Expand All @@ -126,9 +170,9 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]:
return key, None

def set_entry(self, key: str, func_res: Any) -> bool:
"""Map the given result to the given key in Redis."""
if not self._should_store(func_res):
return False
"""Map the given result to the given key in Redis."""
redis_client = self._resolve_redis_client()
redis_key = self._get_redis_key(key)

Expand Down Expand Up @@ -242,8 +286,16 @@ def delete_stale_entries(self, stale_after: timedelta) -> None:
ts = redis_client.hget(key, "timestamp")
if ts is None:
continue
# ts may be bytes or str depending on client configuration
if isinstance(ts, bytes):
try:
ts_s = ts.decode("utf-8")
except Exception:
ts_s = ts.decode("latin-1", errors="ignore")
else:
ts_s = str(ts)
try:
ts_val = datetime.fromisoformat(ts.decode("utf-8"))
ts_val = datetime.fromisoformat(ts_s)
except Exception as exc:
warnings.warn(
f"Redis timestamp parse failed: {exc}", stacklevel=2
Expand Down
71 changes: 68 additions & 3 deletions tests/test_pickle_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from cachier import cachier
from cachier.config import CacheEntry, _global_params
from cachier.cores.pickle import _PickleCore
from cachier.cores.redis import _RedisCore


def _get_decorated_func(func, **kwargs):
Expand All @@ -43,9 +44,6 @@ def _get_decorated_func(func, **kwargs):
return decorated_func


# Pickle core tests


def _takes_2_seconds(arg_1, arg_2):
"""Some function."""
sleep(2)
Expand Down Expand Up @@ -1084,3 +1082,70 @@ def mock_func():
with patch("os.remove", side_effect=FileNotFoundError):
# Should not raise exception
core.delete_stale_entries(timedelta(hours=1))


# Redis core static method tests
@pytest.mark.parametrize(
("test_input", "expected"),
[
(pickle.dumps({"test": 123}), {"test": 123}), # valid string
# (pickle.dumps({"test": 123}).decode("utf-8"), {"test": 123}),
# (b"\x80\x04\x95", None), # corrupted bytes
(123, None), # unexpected type
# (b"corrupted", None), # triggers warning
],
)
def test_redis_loading_pickle(test_input, expected):
"""Test _RedisCore._loading_pickle with various inputs and exceptions."""
assert _RedisCore._loading_pickle(test_input) == expected


def test_redis_loading_pickle_failed():
"""Test _RedisCore._loading_pickle with various inputs and exceptions."""
with patch("pickle.loads", side_effect=Exception("Failed")):
assert _RedisCore._loading_pickle(123) is None


def test_redis_loading_pickle_latin1_fallback():
"""Test _RedisCore._loading_pickle with latin-1 fallback."""
valid_obj = {"test": 123}
with patch("pickle.loads") as mock_loads:
mock_loads.side_effect = [Exception("UTF-8 failed"), valid_obj]
result = _RedisCore._loading_pickle("invalid_utf8_string")
assert result == valid_obj
assert mock_loads.call_count == 2


@pytest.mark.parametrize(
("cached_data", "key", "expected"),
[
({b"field": b"value", "other": "data"}, "field", b"value"),
({"field": "value", b"other": b"data"}, "field", "value"),
({"other": "value"}, "field", None),
],
)
def test_redis_get_raw_field(cached_data, key, expected):
"""Test _RedisCore._get_raw_field with bytes and string keys."""
assert _RedisCore._get_raw_field(cached_data, key) == expected


@pytest.mark.parametrize(
("cached_data", "key", "expected"),
[
({b"flag": b"true"}, "flag", True),
({b"flag": b"false"}, "flag", False),
({"flag": "TRUE"}, "flag", True),
({}, "flag", False),
({b"flag": 123}, "flag", False),
],
)
def test_redis_get_bool_field(cached_data, key, expected):
"""Test _RedisCore._get_bool_field with various inputs."""
assert _RedisCore._get_bool_field(cached_data, key) == expected


def test_redis_get_bool_field_decode_fallback():
"""Test _RedisCore._get_bool_field with decoding fallback."""
with patch.object(_RedisCore, "_get_raw_field", return_value=b"\xff\xfe"):
result = _RedisCore._get_bool_field({}, "flag")
assert result is False