Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion juniper_data/storage/cached.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@

import numpy as np

from juniper_data.api.observability import set_datasets_cached
from juniper_data.core.models import DatasetMeta
from juniper_data.storage.constants import DEFAULT_LIST_LIMIT, DEFAULT_LIST_OFFSET

logger = logging.getLogger(__name__)
from .base import DatasetStore

# Probe limit used when sampling the cache backend for the
# ``juniper_data_datasets_cached`` gauge. Mirrors the limit used by
# :meth:`CachedDatasetStore.warm_cache` so the gauge reflects the same
# population that warm_cache would touch. Cache backends are expected
# to be in-memory (Redis / InMemoryDatasetStore) so a SCAN over 10k
# keys is cheap relative to a dataset save/load.
_CACHE_COUNT_PROBE_LIMIT: int = 10_000


class CachedDatasetStore(DatasetStore):
"""Composable caching wrapper for dataset storage.
Expand Down Expand Up @@ -42,6 +51,22 @@ def __init__(
self._cache = cache
self._write_through = write_through

def _emit_cached_count(self) -> None:
"""Update the ``juniper_data_datasets_cached`` Prometheus gauge.

Probes the cache backend for its current dataset population and
publishes the count via :func:`juniper_data.api.observability.set_datasets_cached`.
Failures (cache backend unavailable, metric registry not yet
initialised, etc.) are swallowed so observability never breaks
the storage path -- mirrors the ``contextlib.suppress(Exception)``
discipline used everywhere else in this class.
"""
try:
count = len(self._cache.list_datasets(limit=_CACHE_COUNT_PROBE_LIMIT))
set_datasets_cached(count)
except Exception:
logger.debug("Failed to update juniper_data_datasets_cached gauge", exc_info=True)

def save(
self,
dataset_id: str,
Expand All @@ -60,6 +85,7 @@ def save(
if self._write_through:
with contextlib.suppress(Exception):
self._cache.save(dataset_id, meta, arrays)
self._emit_cached_count()

def get_meta(self, dataset_id: str) -> DatasetMeta | None:
"""Get metadata, checking cache first.
Expand Down Expand Up @@ -92,6 +118,7 @@ def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
artifact = self._primary.get_artifact_bytes(dataset_id)

if artifact is not None:
populated = False
with contextlib.suppress(Exception):
meta = self._primary.get_meta(dataset_id)
if meta is not None:
Expand All @@ -100,6 +127,9 @@ def get_artifact_bytes(self, dataset_id: str) -> bytes | None:
with np.load(io.BytesIO(artifact)) as npz:
arrays = {k: npz[k] for k in npz.files}
self._cache.save(dataset_id, meta, arrays)
populated = True
if populated:
self._emit_cached_count()
return artifact

def exists(self, dataset_id: str) -> bool:
Expand All @@ -125,8 +155,12 @@ def delete(self, dataset_id: str) -> bool:
Returns:
True if the dataset was deleted from primary, False otherwise.
"""
cache_touched = False
with contextlib.suppress(Exception):
self._cache.delete(dataset_id)
cache_touched = True
if cache_touched:
self._emit_cached_count()
return self._primary.delete(dataset_id)

def list_datasets(self, limit: int = DEFAULT_LIST_LIMIT, offset: int = DEFAULT_LIST_OFFSET) -> list[str]:
Expand Down Expand Up @@ -176,9 +210,11 @@ def invalidate_cache(self, dataset_id: str) -> bool:
True if entry was removed from cache, False otherwise.
"""
try:
return self._cache.delete(dataset_id)
result = self._cache.delete(dataset_id)
except Exception:
return False
self._emit_cached_count()
return result

def warm_cache(self, dataset_ids: list[str] | None = None) -> int:
"""Populate cache from primary store.
Expand Down Expand Up @@ -209,4 +245,6 @@ def warm_cache(self, dataset_ids: list[str] | None = None) -> int:
logger.warning("Failed to cache dataset %s", dataset_id, exc_info=True)
continue

if cached_count > 0:
self._emit_cached_count()
return cached_count
158 changes: 158 additions & 0 deletions juniper_data/tests/unit/test_cached_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pytest

import juniper_data.api.observability as obs
from juniper_data.core.models import DatasetMeta
from juniper_data.storage import CachedDatasetStore, InMemoryDatasetStore

Expand Down Expand Up @@ -421,3 +422,160 @@ def test_exists_suppresses_cache_error(

assert cached.exists("test-1") is True
assert cached.exists("nonexistent") is False


@pytest.fixture
def _reset_dataset_metrics():
"""Reset the lazy-init dataset metrics + REGISTRY between tests.

Mirrors the autouse fixture in ``test_observability.py``: re-using
``CachedDatasetStore`` across tests would otherwise trip
``ValueError: Duplicated timeseries in CollectorRegistry`` when the
gauge is re-registered against the global ``prometheus_client``
REGISTRY.
"""
pytest.importorskip("prometheus_client")
from prometheus_client import REGISTRY

obs._dataset_metrics = None
yield
collectors = list(getattr(REGISTRY, "_collector_to_names", {}).keys())
for collector in collectors:
try:
REGISTRY.unregister(collector)
except KeyError:
Comment thread
github-advanced-security[bot] marked this conversation as resolved.
Fixed
# Collector may already be absent; teardown is best-effort.
continue
obs._dataset_metrics = None


def _read_datasets_cached_gauge() -> float:
"""Return the current value of the ``juniper_data_datasets_cached`` gauge."""
return obs._ensure_dataset_metrics()["datasets_cached"]._value.get()


@pytest.mark.unit
class TestDatasetsCachedGauge:
"""Wire-up tests for the ``juniper_data_datasets_cached`` Prometheus gauge.

Closes the production-caller gap surfaced by juniper-ml#223
(post-METRICS-MON state report §15): the gauge was defined and
helper-tested but had no production updater. ``CachedDatasetStore``
is the canonical cache layer, so its mutation paths now publish the
cache cardinality after every change.
"""

def test_save_emits_cache_count(
self,
_reset_dataset_metrics,
primary_store: InMemoryDatasetStore,
cache_store: InMemoryDatasetStore,
sample_meta: DatasetMeta,
sample_arrays: dict[str, np.ndarray],
) -> None:
"""save() with write_through should publish the cache cardinality."""
cached = CachedDatasetStore(primary_store, cache_store, write_through=True)

cached.save("test-1", sample_meta, sample_arrays)
assert _read_datasets_cached_gauge() == 1

cached.save("test-2", sample_meta, sample_arrays)
assert _read_datasets_cached_gauge() == 2

cached.save("test-3", sample_meta, sample_arrays)
assert _read_datasets_cached_gauge() == 3

def test_delete_emits_decremented_cache_count(
self,
_reset_dataset_metrics,
primary_store: InMemoryDatasetStore,
cache_store: InMemoryDatasetStore,
sample_meta: DatasetMeta,
sample_arrays: dict[str, np.ndarray],
) -> None:
"""delete() should publish the post-eviction cache cardinality."""
cached = CachedDatasetStore(primary_store, cache_store, write_through=True)

cached.save("test-1", sample_meta, sample_arrays)
cached.save("test-2", sample_meta, sample_arrays)
cached.save("test-3", sample_meta, sample_arrays)
assert _read_datasets_cached_gauge() == 3

cached.delete("test-2")
assert _read_datasets_cached_gauge() == 2

cached.delete("test-1")
assert _read_datasets_cached_gauge() == 1

def test_invalidate_cache_emits_decremented_cache_count(
self,
_reset_dataset_metrics,
primary_store: InMemoryDatasetStore,
cache_store: InMemoryDatasetStore,
sample_meta: DatasetMeta,
sample_arrays: dict[str, np.ndarray],
) -> None:
"""invalidate_cache() should publish the post-eviction cache cardinality."""
cached = CachedDatasetStore(primary_store, cache_store, write_through=True)
cached.save("test-1", sample_meta, sample_arrays)
cached.save("test-2", sample_meta, sample_arrays)
assert _read_datasets_cached_gauge() == 2

cached.invalidate_cache("test-1")

assert _read_datasets_cached_gauge() == 1
assert primary_store.exists("test-1") # primary untouched

def test_warm_cache_emits_populated_count(
self,
_reset_dataset_metrics,
primary_store: InMemoryDatasetStore,
cache_store: InMemoryDatasetStore,
sample_meta: DatasetMeta,
sample_arrays: dict[str, np.ndarray],
) -> None:
"""warm_cache() should publish the post-warm cache cardinality."""
cached = CachedDatasetStore(primary_store, cache_store, write_through=False)
primary_store.save("test-1", sample_meta, sample_arrays)
primary_store.save("test-2", sample_meta, sample_arrays)
# No write-through, so the cache is empty before warming.
assert _read_datasets_cached_gauge() == 0

count = cached.warm_cache()

assert count == 2
assert _read_datasets_cached_gauge() == 2

def test_get_artifact_bytes_read_through_emits_cache_count(
self,
_reset_dataset_metrics,
primary_store: InMemoryDatasetStore,
cache_store: InMemoryDatasetStore,
sample_meta: DatasetMeta,
sample_arrays: dict[str, np.ndarray],
) -> None:
"""Read-through cache population in get_artifact_bytes should publish."""
cached = CachedDatasetStore(primary_store, cache_store, write_through=False)
primary_store.save("test-1", sample_meta, sample_arrays)
assert _read_datasets_cached_gauge() == 0

artifact = cached.get_artifact_bytes("test-1")

assert artifact is not None
assert _read_datasets_cached_gauge() == 1

def test_save_without_write_through_does_not_emit(
self,
_reset_dataset_metrics,
primary_store: InMemoryDatasetStore,
cache_store: InMemoryDatasetStore,
sample_meta: DatasetMeta,
sample_arrays: dict[str, np.ndarray],
) -> None:
"""save() with write_through=False does not touch the cache, so no emit."""
cached = CachedDatasetStore(primary_store, cache_store, write_through=False)

cached.save("test-1", sample_meta, sample_arrays)

# Cache is empty; gauge stays at 0 (default for a fresh Gauge).
assert _read_datasets_cached_gauge() == 0
Loading