diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 65a0125..fa626c3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -63,7 +63,7 @@ repos: # basic check - id: ruff name: Ruff check - args: ["--fix"] + args: ["--fix"] #, "--unsafe-fixes" # it needs to be after formatting hooks because the lines might be changed - repo: https://github.com/pre-commit/mirrors-mypy diff --git a/src/cachier/core.py b/src/cachier/core.py index 8c56d96..a57ef8d 100644 --- a/src/cachier/core.py +++ b/src/cachier/core.py @@ -358,11 +358,7 @@ def _call(*args, max_age: Optional[timedelta] = None, **kwds): ) nonneg_max_age = False else: - max_allowed_age = ( - min(_stale_after, max_age) - if max_age is not None - else _stale_after - ) + max_allowed_age = min(_stale_after, max_age) # note: if max_age < 0, we always consider a value stale if nonneg_max_age and (now - entry.time <= max_allowed_age): _print("And it is fresh!") diff --git a/src/cachier/cores/base.py b/src/cachier/cores/base.py index ef63185..933d55a 100644 --- a/src/cachier/cores/base.py +++ b/src/cachier/cores/base.py @@ -27,12 +27,17 @@ class RecalculationNeeded(Exception): def _get_func_str(func: Callable) -> str: - return f".{func.__module__}.{func.__name__}" + """Return a string identifier for the function (module + name). + + We accept Any here because static analysis can't always prove that the + runtime object will have __module__ and __name__, but at runtime the + decorated functions always do. + """ + return f".{func.__module__}.{func.__name__}" -class _BaseCore: - __metaclass__ = abc.ABCMeta +class _BaseCore(metaclass=abc.ABCMeta): def __init__( self, hash_func: Optional[HashFunc], diff --git a/src/cachier/cores/redis.py b/src/cachier/cores/redis.py index ff4d8fd..f6f8a64 100644 --- a/src/cachier/cores/redis.py +++ b/src/cachier/cores/redis.py @@ -73,6 +73,54 @@ def set_func(self, func): super().set_func(func) self._func_str = _get_func_str(func) + @staticmethod + def _loading_pickle(raw_value) -> Any: + """Load pickled data with some recovery attempts.""" + try: + if isinstance(raw_value, bytes): + return pickle.loads(raw_value) + elif isinstance(raw_value, str): + # try to recover by encoding; prefer utf-8 but fall + # back to latin-1 in case raw binary was coerced to str + try: + return pickle.loads(raw_value.encode("utf-8")) + except Exception: + return pickle.loads(raw_value.encode("latin-1")) + else: + # unexpected type; attempt pickle.loads directly + try: + return pickle.loads(raw_value) + except Exception: + return None + except Exception as exc: + warnings.warn( + f"Redis value deserialization failed: {exc}", + stacklevel=2, + ) + return None + + @staticmethod + def _get_raw_field(cached_data, field: str): + """Fetch field from cached_data with bytes/str key handling.""" + # try bytes key first, then str key + bkey = field.encode("utf-8") + if bkey in cached_data: + return cached_data[bkey] + return cached_data.get(field) + + @staticmethod + def _get_bool_field(cached_data, name: str) -> bool: + """Fetch boolean field from cached_data.""" + raw = _RedisCore._get_raw_field(cached_data, name) or b"false" + if isinstance(raw, bytes): + try: + s = raw.decode("utf-8") + except Exception: + s = raw.decode("latin-1", errors="ignore") + else: + s = str(raw) + return s.lower() == "true" + def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: """Get entry based on given key from Redis.""" redis_client = self._resolve_redis_client() @@ -86,32 +134,28 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: # Deserialize the value value = None - if cached_data.get(b"value"): - value = pickle.loads(cached_data[b"value"]) + raw_value = _RedisCore._get_raw_field(cached_data, "value") + if raw_value is not None: + value = self._loading_pickle(raw_value) # Parse timestamp - timestamp_str = cached_data.get(b"timestamp", b"").decode("utf-8") + raw_ts = _RedisCore._get_raw_field(cached_data, "timestamp") or b"" + if isinstance(raw_ts, bytes): + try: + timestamp_str = raw_ts.decode("utf-8") + except Exception: + timestamp_str = raw_ts.decode("latin-1", errors="ignore") + else: + timestamp_str = str(raw_ts) timestamp = ( datetime.fromisoformat(timestamp_str) if timestamp_str else datetime.now() ) - # Parse boolean fields - stale = ( - cached_data.get(b"stale", b"false").decode("utf-8").lower() - == "true" - ) - processing = ( - cached_data.get(b"processing", b"false") - .decode("utf-8") - .lower() - == "true" - ) - completed = ( - cached_data.get(b"completed", b"false").decode("utf-8").lower() - == "true" - ) + stale = _RedisCore._get_bool_field(cached_data, "stale") + processing = _RedisCore._get_bool_field(cached_data, "processing") + completed = _RedisCore._get_bool_field(cached_data, "completed") entry = CacheEntry( value=value, @@ -126,9 +170,9 @@ def get_entry_by_key(self, key: str) -> Tuple[str, Optional[CacheEntry]]: return key, None def set_entry(self, key: str, func_res: Any) -> bool: + """Map the given result to the given key in Redis.""" if not self._should_store(func_res): return False - """Map the given result to the given key in Redis.""" redis_client = self._resolve_redis_client() redis_key = self._get_redis_key(key) @@ -242,8 +286,16 @@ def delete_stale_entries(self, stale_after: timedelta) -> None: ts = redis_client.hget(key, "timestamp") if ts is None: continue + # ts may be bytes or str depending on client configuration + if isinstance(ts, bytes): + try: + ts_s = ts.decode("utf-8") + except Exception: + ts_s = ts.decode("latin-1", errors="ignore") + else: + ts_s = str(ts) try: - ts_val = datetime.fromisoformat(ts.decode("utf-8")) + ts_val = datetime.fromisoformat(ts_s) except Exception as exc: warnings.warn( f"Redis timestamp parse failed: {exc}", stacklevel=2 diff --git a/tests/test_pickle_core.py b/tests/test_pickle_core.py index 9530249..e6ce3f5 100644 --- a/tests/test_pickle_core.py +++ b/tests/test_pickle_core.py @@ -35,6 +35,7 @@ from cachier import cachier from cachier.config import CacheEntry, _global_params from cachier.cores.pickle import _PickleCore +from cachier.cores.redis import _RedisCore def _get_decorated_func(func, **kwargs): @@ -43,9 +44,6 @@ def _get_decorated_func(func, **kwargs): return decorated_func -# Pickle core tests - - def _takes_2_seconds(arg_1, arg_2): """Some function.""" sleep(2) @@ -1084,3 +1082,70 @@ def mock_func(): with patch("os.remove", side_effect=FileNotFoundError): # Should not raise exception core.delete_stale_entries(timedelta(hours=1)) + + +# Redis core static method tests +@pytest.mark.parametrize( + ("test_input", "expected"), + [ + (pickle.dumps({"test": 123}), {"test": 123}), # valid string + # (pickle.dumps({"test": 123}).decode("utf-8"), {"test": 123}), + # (b"\x80\x04\x95", None), # corrupted bytes + (123, None), # unexpected type + # (b"corrupted", None), # triggers warning + ], +) +def test_redis_loading_pickle(test_input, expected): + """Test _RedisCore._loading_pickle with various inputs and exceptions.""" + assert _RedisCore._loading_pickle(test_input) == expected + + +def test_redis_loading_pickle_failed(): + """Test _RedisCore._loading_pickle with various inputs and exceptions.""" + with patch("pickle.loads", side_effect=Exception("Failed")): + assert _RedisCore._loading_pickle(123) is None + + +def test_redis_loading_pickle_latin1_fallback(): + """Test _RedisCore._loading_pickle with latin-1 fallback.""" + valid_obj = {"test": 123} + with patch("pickle.loads") as mock_loads: + mock_loads.side_effect = [Exception("UTF-8 failed"), valid_obj] + result = _RedisCore._loading_pickle("invalid_utf8_string") + assert result == valid_obj + assert mock_loads.call_count == 2 + + +@pytest.mark.parametrize( + ("cached_data", "key", "expected"), + [ + ({b"field": b"value", "other": "data"}, "field", b"value"), + ({"field": "value", b"other": b"data"}, "field", "value"), + ({"other": "value"}, "field", None), + ], +) +def test_redis_get_raw_field(cached_data, key, expected): + """Test _RedisCore._get_raw_field with bytes and string keys.""" + assert _RedisCore._get_raw_field(cached_data, key) == expected + + +@pytest.mark.parametrize( + ("cached_data", "key", "expected"), + [ + ({b"flag": b"true"}, "flag", True), + ({b"flag": b"false"}, "flag", False), + ({"flag": "TRUE"}, "flag", True), + ({}, "flag", False), + ({b"flag": 123}, "flag", False), + ], +) +def test_redis_get_bool_field(cached_data, key, expected): + """Test _RedisCore._get_bool_field with various inputs.""" + assert _RedisCore._get_bool_field(cached_data, key) == expected + + +def test_redis_get_bool_field_decode_fallback(): + """Test _RedisCore._get_bool_field with decoding fallback.""" + with patch.object(_RedisCore, "_get_raw_field", return_value=b"\xff\xfe"): + result = _RedisCore._get_bool_field({}, "flag") + assert result is False