diff --git a/examples/configs/content_safety/README.md b/examples/configs/content_safety/README.md index 35a2d2a45..0645e9043 100644 --- a/examples/configs/content_safety/README.md +++ b/examples/configs/content_safety/README.md @@ -1,10 +1,117 @@ -# NemoGuard ContentSafety Usage Example +# Content Safety Configuration -This example showcases the use of NVIDIA's [NemoGuard ContentSafety model](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) for topical and dialogue moderation. +This example demonstrates how to configure content safety rails with NeMo Guardrails, including optional cache persistence. -The structure of the config folder is the following: +## Features -- `config.yml` - The config file holding all the configuration options for the model. -- `prompts.yml` - The config file holding the topical rules used for topical and dialogue moderation by the current guardrail configuration. +- **Input Safety Checks**: Validates user inputs before processing +- **Output Safety Checks**: Ensures bot responses are appropriate +- **Caching**: Reduces redundant API calls with LFU cache +- **Persistence**: Optional cache persistence for resilience across restarts +- **Thread Safety**: Fully thread-safe for use in multi-threaded web servers -Please see the docs for more details about the [recommended ContentSafety deployment](./../../../docs/user-guides/advanced/nemoguard-contentsafety-deployment.md) methods, either using locally downloaded NIMs or NVIDIA AI Enterprise (NVAIE). +## Configuration Overview + +The configuration includes: + +1. **Main Model**: The primary LLM for conversations (Llama 3.3 70B) +2. **Content Safety Model**: Dedicated model for safety checks (NemoGuard 8B) +3. **Rails**: Input and output safety check flows +4. **Cache Configuration**: Memory cache with optional persistence + +## How It Works + +1. **User Input**: When a user sends a message, it's checked by the content safety model +2. **Cache Check**: The system first checks if this content was already evaluated (cache hit) +3. **Safety Evaluation**: If not cached, the content safety model evaluates the input +4. **Result Caching**: The safety check result is cached for future use +5. **Response Generation**: If safe, the main model generates a response +6. **Output Check**: The response is also checked for safety before returning to the user + +## Cache Persistence + +The cache configuration includes: + +- **Automatic Saves**: Every 5 minutes (configurable) +- **Shutdown Saves**: Caches are automatically persisted when the application closes +- **Crash Recovery**: Cache reloads from disk on restart +- **Per-Model Storage**: Each model gets its own cache file + +To disable persistence, you can either: + +1. Set `enabled: false` in the persistence section +2. Remove the `persistence` section entirely +3. Set `interval` to `null` or remove it + +Note: Persistence requires both `enabled: true` and a valid `interval` value to be active. + +## Thread Safety + +The content safety implementation is fully thread-safe: + +- **Concurrent Requests**: Safely handles multiple simultaneous safety checks +- **No Data Corruption**: Thread-safe cache operations prevent data corruption +- **Efficient Locking**: Uses RLock for minimal performance impact +- **Atomic Operations**: Prevents duplicate LLM calls for the same content + +This makes it suitable for: + +- Multi-threaded web servers (FastAPI, Flask, Django) +- Concurrent request processing +- High-traffic applications + +### Proper Shutdown + +For best results, use one of these patterns: + +```python +# Context manager (recommended) +with LLMRails(config) as rails: + # Your code here + pass +# Caches automatically persisted on exit + +# Or manual cleanup +rails = LLMRails(config) +# Your code here +rails.close() # Persist caches +``` + +## Running the Example + +```bash +# From the NeMo-Guardrails root directory +nemoguardrails server --config examples/configs/content_safety/ +``` + +## Customization + +### Adjust Cache Settings + +```yaml +cache: + enabled: true # Enable/disable caching + capacity_per_model: 5000 # Maximum entries per model + persistence: + interval: 300.0 # Seconds between saves + path: ./my_cache.json # Custom path +``` + +### Memory-Only Cache + +For memory-only caching without persistence: + +```yaml +cache: + enabled: true + capacity_per_model: 5000 + store: memory + # No persistence section +``` + +## Benefits + +1. **Performance**: Avoid redundant content safety API calls +2. **Cost Savings**: Reduce API usage for repeated content +3. **Reliability**: Cache survives process restarts +4. **Flexibility**: Easy to enable/disable features as needed diff --git a/examples/configs/content_safety/config.yml b/examples/configs/content_safety/config.yml index f6808bf14..5ca7c8608 100644 --- a/examples/configs/content_safety/config.yml +++ b/examples/configs/content_safety/config.yml @@ -14,3 +14,18 @@ rails: output: flows: - content safety check output $model=content_safety + + # Content safety cache configuration with persistence and stats + config: + content_safety: + cache: + enabled: true + capacity_per_model: 5000 + store: memory # In-memory cache with optional disk persistence + persistence: + enabled: true # Enable persistence (requires interval to be set) + interval: 300.0 # Persist every 5 minutes + path: ./content_safety_cache.json # Where to save cache + stats: + enabled: true # Enable statistics tracking + log_interval: 60.0 # Log cache statistics every minute diff --git a/nemoguardrails/cache/README.md b/nemoguardrails/cache/README.md new file mode 100644 index 000000000..369e82f5e --- /dev/null +++ b/nemoguardrails/cache/README.md @@ -0,0 +1,271 @@ +# Content Safety LLM Call Caching + +## Overview + +The content safety checks in `actions.py` now use an LFU (Least Frequently Used) cache to improve performance by avoiding redundant LLM calls for identical safety checks. The cache supports optional persistence to disk for resilience across restarts. + +## Implementation Details + +### Cache Configuration + +- Per-model caches: Each model gets its own LFU cache instance +- Default capacity: 50,000 entries per model +- Eviction policy: LFU with LRU tiebreaker +- Statistics tracking: Enabled by default +- Tracks timestamps: `created_at` and `accessed_at` for each entry +- Cache creation: Automatic when a model is first used +- Persistence: Optional periodic save to disk with configurable interval + +### Cached Functions + +1. `content_safety_check_input()` - Caches safety checks for user inputs + +Note: `content_safety_check_output()` does not use caching to ensure fresh evaluation of bot responses. + +### Cache Key Components + +The cache key is a SHA256 hash of: + +- The rendered prompt only (can be a string or list of strings) + +Since temperature is fixed (1e-20) and stop/max_tokens are derived from the model configuration, they don't need to be part of the cache key. + +### How It Works + +1. **Before LLM Call**: + - Generate cache key from request parameters + - Check if result exists in cache + - If found, return cached result (cache hit) + +2. **After LLM Call**: + - If not in cache, make the actual LLM call + - Store the result in cache for future use + +### Cache Management + +The caching system automatically creates and manages separate caches for each model. Key features: + +- **Automatic Creation**: Caches are created on first use for each model +- **Isolated Storage**: Each model maintains its own cache, preventing cross-model interference +- **Default Settings**: Each cache has 50,000 entry capacity with stats tracking enabled + +```python +# Internal cache access (for debugging/monitoring): +from nemoguardrails.library.content_safety.actions import _MODEL_CACHES + +# View which models have caches +models_with_caches = list(_MODEL_CACHES.keys()) + +# Get stats for a specific model's cache +if "llama_guard" in _MODEL_CACHES: + stats = _MODEL_CACHES["llama_guard"].get_stats() +``` + +### Persistence Configuration + +The cache supports optional persistence to disk for resilience across restarts: + +```yaml +rails: + config: + content_safety: + cache: + enabled: true + capacity_per_model: 5000 + persistence: + interval: 300.0 # Persist every 5 minutes + path: ./cache_{model_name}.json # {model_name} is replaced +``` + +**Configuration Options:** + +- `persistence.interval`: Seconds between automatic saves (None = no persistence) +- `persistence.path`: Where to save cache data (can include `{model_name}` placeholder) + +**How Persistence Works:** + +1. **Automatic Saves**: Cache checks trigger persistence if interval has passed +2. **On Shutdown**: Caches are automatically persisted when LLMRails is closed or garbage collected +3. **On Restart**: Cache loads from disk if persistence file exists +4. **Preserves State**: Frequencies and access patterns are maintained +5. **Per-Model Files**: Each model gets its own persistence file + +**Manual Persistence:** + +```python +# Force immediate persistence of all caches +content_safety_manager.persist_all_caches() +``` + +This is useful for graceful shutdown scenarios. + +**Notes on Persistence:** + +- Persistence only works with "memory" store type +- Cache files are JSON format for easy inspection and debugging +- Set `persistence.interval` to None to disable persistence +- The cache automatically persists on each check if the interval has passed + +### Statistics and Monitoring + +The cache supports detailed statistics tracking and periodic logging for monitoring cache performance: + +```yaml +rails: + config: + content_safety: + cache: + enabled: true + capacity_per_model: 10000 + stats: + enabled: true # Enable stats tracking + log_interval: 60.0 # Log stats every minute +``` + +**Statistics Features:** + +1. **Tracking Only**: Set `stats.enabled: true` with no `log_interval` to track stats without logging +2. **Automatic Logging**: Set both `stats.enabled: true` and `log_interval` for periodic logging +3. **Manual Logging**: Force immediate stats logging with `cache.log_stats_now()` + +**Statistics Tracked:** + +- **Hits**: Number of cache hits (successful lookups) +- **Misses**: Number of cache misses (failed lookups) +- **Hit Rate**: Percentage of requests served from cache +- **Evictions**: Number of items removed due to capacity +- **Puts**: Number of new items added to cache +- **Updates**: Number of existing items updated +- **Current Size**: Number of items currently in cache + +**Log Format:** + +``` +LFU Cache Statistics - Size: 2456/10000 | Hits: 15234 | Misses: 2456 | Hit Rate: 86.11% | Evictions: 0 | Puts: 2456 | Updates: 0 +``` + +**Usage Examples:** + +```python +# Programmatically access stats +if "safety_model" in _MODEL_CACHES: + cache = _MODEL_CACHES["safety_model"] + stats = cache.get_stats() + print(f"Cache hit rate: {stats['hit_rate']:.2%}") + + # Force immediate stats logging + if cache.supports_stats_logging(): + cache.log_stats_now() +``` + +**Configuration Options:** + +- `stats.enabled`: Enable/disable statistics tracking (default: false) +- `stats.log_interval`: Seconds between automatic stats logs (None = no logging) + +**Notes:** + +- Stats logging requires stats tracking to be enabled +- Logs appear at INFO level in the `nemoguardrails.cache.lfu` logger +- Stats are reset when cache is cleared or when `reset_stats()` is called +- Each model maintains independent statistics + +### Example Configuration Usage + +```python +from nemoguardrails import RailsConfig, LLMRails + +# Method 1: Using context manager (recommended - ensures cleanup) +config = RailsConfig.from_path("./config.yml") +with LLMRails(config) as rails: + # Content safety checks will be cached and persisted automatically + response = await rails.generate_async( + messages=[{"role": "user", "content": "Hello, how are you?"}] + ) +# Caches are automatically persisted on exit + +# Method 2: Manual cleanup +rails = LLMRails(config) +response = await rails.generate_async( + messages=[{"role": "user", "content": "Hello, how are you?"}] +) +rails.close() # Manually persist caches + +# Note: If neither method is used, caches will still be persisted +# when the object is garbage collected (__del__) +``` + +### Thread Safety + +The content safety caching system is **thread-safe** for single-node deployments: + +1. **LFUCache Implementation**: + - Uses `threading.RLock` for all operations + - All public methods (`get`, `put`, `size`, `clear`, etc.) are protected by locks + - Supports atomic `get_or_compute()` operations that prevent duplicate computations + +2. **ContentSafetyManager**: + - Thread-safe cache creation using double-checked locking pattern + - Ensures only one cache instance per model across all threads + - Thread-safe persistence operations + +3. **Key Features**: + - **No Data Corruption**: Concurrent operations maintain data integrity + - **No Race Conditions**: Proper locking prevents race conditions + - **Atomic Operations**: `get_or_compute()` ensures expensive computations happen only once + - **Minimal Lock Contention**: Efficient locking patterns minimize performance impact + +4. **Usage in Web Servers**: + - Safe for use in multi-threaded web servers (FastAPI, Flask, etc.) + - Handles concurrent requests without issues + - Each thread sees consistent cache state + +**Note**: This implementation is designed for single-node deployments. For distributed systems, consider using external caching solutions like Redis. + +### Benefits + +1. **Performance**: Eliminates redundant LLM calls for identical inputs +2. **Cost Savings**: Reduces API calls to LLM services +3. **Consistency**: Ensures identical inputs always produce identical outputs +4. **Smart Eviction**: LFU policy keeps frequently checked content in cache +5. **Model Isolation**: Each model has its own cache, preventing interference between different safety models +6. **Statistics Tracking**: Monitor cache performance with hit rates, evictions, and more per model +7. **Timestamp Tracking**: Track when entries were created and last accessed +8. **Resilience**: Cache survives process restarts without losing data when persistence is enabled +9. **Efficiency**: LFU eviction algorithm ensures the most useful entries remain in cache +10. **Thread Safety**: Safe for concurrent access in multi-threaded environments + +### Example Usage Pattern + +```python +# First call - takes ~500ms (LLM API call) +result = await content_safety_check_input( + llms=llms, + llm_task_manager=task_manager, + model_name="safety_model", + context={"user_message": "Hello world"} +) + +# Subsequent identical calls - takes ~1ms (cache hit) +result = await content_safety_check_input( + llms=llms, + llm_task_manager=task_manager, + model_name="safety_model", + context={"user_message": "Hello world"} +) +``` + +### Logging + +The implementation includes debug logging: + +- Cache creation: `"Created cache for model '{model_name}' with capacity {capacity}"` +- Cache hits: `"Content safety cache hit for model '{model_name}', key: {key[:8]}..."` +- Cache stores: `"Content safety result cached for model '{model_name}', key: {key[:8]}..."` + +Enable debug logging to monitor cache behavior: + +```python +import logging +logging.getLogger("nemoguardrails.library.content_safety.actions").setLevel(logging.DEBUG) +``` diff --git a/nemoguardrails/cache/__init__.py b/nemoguardrails/cache/__init__.py new file mode 100644 index 000000000..e7f22f070 --- /dev/null +++ b/nemoguardrails/cache/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General-purpose caching utilities for NeMo Guardrails.""" + +from nemoguardrails.cache.interface import CacheInterface +from nemoguardrails.cache.lfu import LFUCache + +__all__ = ["CacheInterface", "LFUCache"] diff --git a/nemoguardrails/cache/interface.py b/nemoguardrails/cache/interface.py new file mode 100644 index 000000000..d724d6999 --- /dev/null +++ b/nemoguardrails/cache/interface.py @@ -0,0 +1,236 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Cache interface for NeMo Guardrails caching system. + +This module defines the abstract base class for cache implementations +that can be used interchangeably throughout the guardrails system. + +Cache implementations may optionally support persistence by overriding +the persist_now() method and supports_persistence() method. Persistence +allows cache state to be saved to and loaded from external storage. +""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Optional + + +class CacheInterface(ABC): + """ + Abstract base class defining the interface for cache implementations. + + All cache implementations must inherit from this class and implement + the required methods to ensure compatibility with the caching system. + """ + + @abstractmethod + def get(self, key: Any, default: Any = None) -> Any: + """ + Retrieve an item from the cache. + + Args: + key: The key to look up in the cache. + default: Value to return if key is not found (default: None). + + Returns: + The value associated with the key, or default if not found. + """ + pass + + @abstractmethod + def put(self, key: Any, value: Any) -> None: + """ + Store an item in the cache. + + If the cache is at capacity, this method should evict an item + according to the cache's eviction policy (e.g., LFU, LRU, etc.). + + Args: + key: The key to store. + value: The value to associate with the key. + """ + pass + + @abstractmethod + def size(self) -> int: + """ + Get the current number of items in the cache. + + Returns: + The number of items currently stored in the cache. + """ + pass + + @abstractmethod + def is_empty(self) -> bool: + """ + Check if the cache is empty. + + Returns: + True if the cache contains no items, False otherwise. + """ + pass + + @abstractmethod + def clear(self) -> None: + """ + Remove all items from the cache. + + After calling this method, the cache should be empty. + """ + pass + + def contains(self, key: Any) -> bool: + """ + Check if a key exists in the cache. + + This is an optional method that can be overridden for efficiency. + The default implementation uses get() to check existence. + + Args: + key: The key to check. + + Returns: + True if the key exists in the cache, False otherwise. + """ + # Default implementation - can be overridden for efficiency + sentinel = object() + return self.get(key, sentinel) is not sentinel + + @property + @abstractmethod + def capacity(self) -> int: + """ + Get the maximum capacity of the cache. + + Returns: + The maximum number of items the cache can hold. + """ + pass + + def persist_now(self) -> None: + """ + Force immediate persistence of cache to storage. + + This is an optional method that cache implementations can override + if they support persistence. The default implementation does nothing. + + Implementations that support persistence should save the current + cache state to their configured storage backend. + """ + # Default no-op implementation + pass + + def supports_persistence(self) -> bool: + """ + Check if this cache implementation supports persistence. + + Returns: + True if the cache supports persistence, False otherwise. + + The default implementation returns False. Cache implementations + that support persistence should override this to return True. + """ + return False + + def get_stats(self) -> dict: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics. The format and contents + may vary by implementation. Common fields include: + - hits: Number of cache hits + - misses: Number of cache misses + - evictions: Number of items evicted + - hit_rate: Percentage of requests that were hits + - current_size: Current number of items in cache + - capacity: Maximum capacity of the cache + + The default implementation returns a message indicating that + statistics tracking is not supported. + """ + return { + "message": "Statistics tracking is not supported by this cache implementation" + } + + def reset_stats(self) -> None: + """ + Reset cache statistics. + + This is an optional method that cache implementations can override + if they support statistics tracking. The default implementation does nothing. + """ + # Default no-op implementation + pass + + def log_stats_now(self) -> None: + """ + Force immediate logging of cache statistics. + + This is an optional method that cache implementations can override + if they support statistics logging. The default implementation does nothing. + + Implementations that support statistics logging should output the + current cache statistics to their configured logging backend. + """ + # Default no-op implementation + pass + + def supports_stats_logging(self) -> bool: + """ + Check if this cache implementation supports statistics logging. + + Returns: + True if the cache supports statistics logging, False otherwise. + + The default implementation returns False. Cache implementations + that support statistics logging should override this to return True + when logging is enabled. + """ + return False + + async def get_or_compute( + self, key: Any, compute_fn: Callable[[], Any], default: Any = None + ) -> Any: + """ + Atomically get a value from the cache or compute it if not present. + + This method ensures that the compute function is called at most once + even in the presence of concurrent requests for the same key. + + Args: + key: The key to look up + compute_fn: Async function to compute the value if key is not found + default: Value to return if compute_fn raises an exception + + Returns: + The cached value or the computed value + + This is an optional method with a default implementation. Cache + implementations should override this for better thread-safety guarantees. + """ + # Default implementation - not thread-safe for computation + value = self.get(key) + if value is not None: + return value + + try: + computed_value = await compute_fn() + self.put(key, computed_value) + return computed_value + except Exception: + return default diff --git a/nemoguardrails/cache/lfu.py b/nemoguardrails/cache/lfu.py new file mode 100644 index 000000000..4f8e450c0 --- /dev/null +++ b/nemoguardrails/cache/lfu.py @@ -0,0 +1,677 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Least Frequently Used (LFU) cache implementation.""" + +import asyncio +import json +import logging +import os +import threading +import time +from typing import Any, Callable, Optional + +from nemoguardrails.cache.interface import CacheInterface + +log = logging.getLogger(__name__) + + +class LFUNode: + """Node for the LFU cache doubly linked list.""" + + def __init__(self, key: Any, value: Any) -> None: + self.key = key + self.value = value + self.freq = 1 + self.prev: Optional["LFUNode"] = None + self.next: Optional["LFUNode"] = None + self.created_at = time.time() + self.accessed_at = self.created_at + + +class DoublyLinkedList: + """Doubly linked list to maintain nodes with the same frequency.""" + + def __init__(self) -> None: + # Create dummy head and tail nodes + self.head = LFUNode(None, None) + self.tail = LFUNode(None, None) + self.head.next = self.tail + self.tail.prev = self.head + self.size = 0 + + def append(self, node: LFUNode) -> None: + """Add node to the end of the list (before tail).""" + node.prev = self.tail.prev + node.next = self.tail + self.tail.prev.next = node + self.tail.prev = node + self.size += 1 + + def pop(self, node: Optional[LFUNode] = None) -> Optional[LFUNode]: + """Remove and return a node. If no node specified, removes the first node.""" + if self.size == 0: + return None + + if node is None: + node = self.head.next + + # Remove node from the list + node.prev.next = node.next + node.next.prev = node.prev + self.size -= 1 + + return node + + +class LFUCache(CacheInterface): + """ + Least Frequently Used (LFU) Cache implementation. + + When the cache reaches capacity, it evicts the least frequently used item. + If there are ties in frequency, it evicts the least recently used among them. + """ + + def __init__( + self, + capacity: int, + track_stats: bool = False, + persistence_interval: Optional[float] = None, + persistence_path: Optional[str] = None, + stats_logging_interval: Optional[float] = None, + ) -> None: + """ + Initialize the LFU cache. + + Args: + capacity: Maximum number of items the cache can hold + track_stats: Enable tracking of cache statistics + persistence_interval: Seconds between periodic dumps to disk (None disables persistence) + persistence_path: Path to persistence file (defaults to 'lfu_cache.json' if persistence enabled) + stats_logging_interval: Seconds between periodic stats logging (None disables logging) + """ + if capacity < 0: + raise ValueError("Capacity must be non-negative") + + self._capacity = capacity + self.track_stats = track_stats + self._lock = threading.RLock() # Thread-safe access + self._computing: dict[Any, asyncio.Future] = {} # Track keys being computed + + self.key_map: dict[Any, LFUNode] = {} # key -> node mapping + self.freq_map: dict[int, DoublyLinkedList] = {} # frequency -> list of nodes + self.min_freq = 0 # Track minimum frequency for eviction + + # Persistence configuration + self.persistence_interval = persistence_interval + self.persistence_path = persistence_path or "lfu_cache.json" + # Initialize to None to ensure first check doesn't trigger immediately + self.last_persist_time = None + + # Stats logging configuration + self.stats_logging_interval = stats_logging_interval + # Initialize to None to ensure first check doesn't trigger immediately + self.last_stats_log_time = None + + # Statistics tracking + if self.track_stats: + self.stats = { + "hits": 0, + "misses": 0, + "evictions": 0, + "puts": 0, + "updates": 0, + } + + # Load from disk if persistence is enabled and file exists + if self.persistence_interval is not None: + self._load_from_disk() + + def _update_node_freq(self, node: LFUNode) -> None: + """Update the frequency of a node and move it to the appropriate frequency list.""" + old_freq = node.freq + old_list = self.freq_map[old_freq] + + # Remove node from current frequency list + old_list.pop(node) + + # Update min_freq if necessary + if self.min_freq == old_freq and old_list.size == 0: + self.min_freq += 1 + # Clean up empty frequency lists + del self.freq_map[old_freq] + + # Increment frequency and add to new list + node.freq += 1 + new_freq = node.freq + node.accessed_at = time.time() # Update access time + + if new_freq not in self.freq_map: + self.freq_map[new_freq] = DoublyLinkedList() + + self.freq_map[new_freq].append(node) + + def get(self, key: Any, default: Any = None) -> Any: + """ + Get an item from the cache. + + Args: + key: The key to look up + default: Value to return if key is not found + + Returns: + The value associated with the key, or default if not found + """ + with self._lock: + # Check if we should persist + self._check_and_persist() + + # Check if we should log stats + self._check_and_log_stats() + + if key not in self.key_map: + if self.track_stats: + self.stats["misses"] += 1 + return default + + node = self.key_map[key] + + if self.track_stats: + self.stats["hits"] += 1 + + self._update_node_freq(node) + return node.value + + def put(self, key: Any, value: Any) -> None: + """ + Put an item into the cache. + + Args: + key: The key to store + value: The value to associate with the key + """ + with self._lock: + # Check if we should persist + self._check_and_persist() + + # Check if we should log stats + self._check_and_log_stats() + + if self._capacity == 0: + return + + if key in self.key_map: + # Update existing key + node = self.key_map[key] + node.value = value + node.created_at = time.time() # Reset creation time on update + self._update_node_freq(node) + if self.track_stats: + self.stats["updates"] += 1 + else: + # Add new key + if len(self.key_map) >= self._capacity: + # Need to evict least frequently used item + self._evict_lfu() + + # Create new node and add to cache + new_node = LFUNode(key, value) + self.key_map[key] = new_node + + # Add to frequency 1 list + if 1 not in self.freq_map: + self.freq_map[1] = DoublyLinkedList() + + self.freq_map[1].append(new_node) + self.min_freq = 1 + + if self.track_stats: + self.stats["puts"] += 1 + + def _evict_lfu(self) -> None: + """Evict the least frequently used item from the cache.""" + if self.min_freq in self.freq_map: + lfu_list = self.freq_map[self.min_freq] + node_to_evict = lfu_list.pop() # Remove least recently used among LFU + + if node_to_evict: + del self.key_map[node_to_evict.key] + + if self.track_stats: + self.stats["evictions"] += 1 + + # Clean up empty frequency list + if lfu_list.size == 0: + del self.freq_map[self.min_freq] + + def size(self) -> int: + """Return the current size of the cache.""" + with self._lock: + return len(self.key_map) + + def is_empty(self) -> bool: + """Check if the cache is empty.""" + with self._lock: + return len(self.key_map) == 0 + + def clear(self) -> None: + """Clear all items from the cache.""" + with self._lock: + if self.track_stats: + # Track number of items evicted + self.stats["evictions"] += len(self.key_map) + + self.key_map.clear() + self.freq_map.clear() + self.min_freq = 0 + + def get_stats(self) -> dict: + """ + Get cache statistics. + + Returns: + Dictionary with cache statistics (if tracking is enabled) + """ + with self._lock: + if not self.track_stats: + return {"message": "Statistics tracking is disabled"} + + stats = self.stats.copy() + stats["current_size"] = len(self.key_map) # Direct access within lock + stats["capacity"] = self._capacity + + # Calculate hit rate + total_requests = stats["hits"] + stats["misses"] + stats["hit_rate"] = ( + stats["hits"] / total_requests if total_requests > 0 else 0.0 + ) + + return stats + + def reset_stats(self) -> None: + """Reset cache statistics.""" + with self._lock: + if self.track_stats: + self.stats = { + "hits": 0, + "misses": 0, + "evictions": 0, + "puts": 0, + "updates": 0, + } + + def _check_and_persist(self) -> None: + """Check if enough time has passed and persist to disk if needed.""" + if self.persistence_interval is None: + return + + current_time = time.time() + + # Initialize timestamp on first check + if self.last_persist_time is None: + self.last_persist_time = current_time + return + + if current_time - self.last_persist_time >= self.persistence_interval: + self._persist_to_disk() + self.last_persist_time = current_time + + def _persist_to_disk(self) -> None: + """ + Serialize cache to disk. + + Stores cache data as JSON with node information including keys, values, + frequencies, and timestamps for reconstruction. + """ + if not self.key_map: + # If cache is empty, remove the persistence file + if os.path.exists(self.persistence_path): + os.remove(self.persistence_path) + return + + cache_data = { + "capacity": self._capacity, + "min_freq": self.min_freq, + "nodes": [], + } + + # Serialize all nodes + for key, node in self.key_map.items(): + cache_data["nodes"].append( + { + "key": key, + "value": node.value, + "freq": node.freq, + "created_at": node.created_at, + "accessed_at": node.accessed_at, + } + ) + + # Write to disk + try: + with open(self.persistence_path, "w") as f: + json.dump(cache_data, f, indent=2) + except Exception as e: + # Silently fail on persistence errors to not disrupt cache operations + pass + + def _load_from_disk(self) -> None: + """ + Load cache from disk if persistence file exists. + + Reconstructs the cache state including frequency lists and node relationships. + """ + if not os.path.exists(self.persistence_path): + return + + try: + with open(self.persistence_path, "r") as f: + cache_data = json.load(f) + + # Reconstruct cache + self.min_freq = cache_data.get("min_freq", 0) + + for node_data in cache_data.get("nodes", []): + # Create node + node = LFUNode(node_data["key"], node_data["value"]) + node.freq = node_data["freq"] + node.created_at = node_data["created_at"] + node.accessed_at = node_data["accessed_at"] + + # Add to key map + self.key_map[node.key] = node + + # Add to appropriate frequency list + if node.freq not in self.freq_map: + self.freq_map[node.freq] = DoublyLinkedList() + self.freq_map[node.freq].append(node) + + except Exception as e: + # If loading fails, start with empty cache + self.key_map.clear() + self.freq_map.clear() + self.min_freq = 0 + + def persist_now(self) -> None: + """Force immediate persistence to disk (useful for shutdown).""" + with self._lock: + if self.persistence_interval is not None: + self._persist_to_disk() + self.last_persist_time = time.time() + + def supports_persistence(self) -> bool: + """Check if this cache instance supports persistence.""" + return self.persistence_interval is not None + + def _check_and_log_stats(self) -> None: + """Check if enough time has passed and log stats if needed.""" + if not self.track_stats or self.stats_logging_interval is None: + return + + current_time = time.time() + + # Initialize timestamp on first check + if self.last_stats_log_time is None: + self.last_stats_log_time = current_time + return + + if current_time - self.last_stats_log_time >= self.stats_logging_interval: + self._log_stats() + self.last_stats_log_time = current_time + + def _log_stats(self) -> None: + """Log current cache statistics.""" + stats = self.get_stats() + + # Format the log message + log_msg = ( + f"LFU Cache Statistics - " + f"Size: {stats['current_size']}/{stats['capacity']} | " + f"Hits: {stats['hits']} | " + f"Misses: {stats['misses']} | " + f"Hit Rate: {stats['hit_rate']:.2%} | " + f"Evictions: {stats['evictions']} | " + f"Puts: {stats['puts']} | " + f"Updates: {stats['updates']}" + ) + + log.info(log_msg) + + def log_stats_now(self) -> None: + """Force immediate logging of cache statistics.""" + if self.track_stats: + self._log_stats() + self.last_stats_log_time = time.time() + + def supports_stats_logging(self) -> bool: + """Check if this cache instance supports stats logging.""" + return self.track_stats and self.stats_logging_interval is not None + + async def get_or_compute( + self, key: Any, compute_fn: Callable[[], Any], default: Any = None + ) -> Any: + """ + Atomically get a value from the cache or compute it if not present. + + This method ensures that the compute function is called at most once + even in the presence of concurrent requests for the same key. + + Args: + key: The key to look up + compute_fn: Async function to compute the value if key is not found + default: Value to return if compute_fn raises an exception + + Returns: + The cached value or the computed value + """ + # First check if the value is already in cache + future = None + with self._lock: + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + return node.value + + # Check if this key is already being computed + if key in self._computing: + future = self._computing[key] + + # If the key is being computed, wait for it outside the lock + if future is not None: + try: + return await future + except Exception: + return default + + # Create a future for this computation + future = asyncio.Future() + with self._lock: + # Double-check the cache and computing dict + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + return node.value + + if key in self._computing: + # Another thread started computing while we were waiting + future = self._computing[key] + else: + # We'll be the ones computing + self._computing[key] = future + + # If another thread is computing, wait for it + if not future.done() and self._computing.get(key) is not future: + try: + return await self._computing[key] + except Exception: + return default + + # We're responsible for computing the value + try: + computed_value = await compute_fn() + + # Store the computed value in cache + with self._lock: + # Remove from computing dict + self._computing.pop(key, None) + + # Check one more time if someone else added it + if key in self.key_map: + node = self.key_map[key] + if self.track_stats: + self.stats["hits"] += 1 + self._update_node_freq(node) + future.set_result(node.value) + return node.value + + # Now add to cache using internal logic + if self._capacity == 0: + future.set_result(computed_value) + return computed_value + + # Add new key + if len(self.key_map) >= self._capacity: + self._evict_lfu() + + # Create new node and add to cache + new_node = LFUNode(key, computed_value) + self.key_map[key] = new_node + + # Add to frequency 1 list + if 1 not in self.freq_map: + self.freq_map[1] = DoublyLinkedList() + + self.freq_map[1].append(new_node) + self.min_freq = 1 + + if self.track_stats: + self.stats["puts"] += 1 + + # Set the result in the future + future.set_result(computed_value) + return computed_value + + except Exception as e: + with self._lock: + self._computing.pop(key, None) + future.set_exception(e) + return default + + def contains(self, key: Any) -> bool: + """ + Check if a key exists in the cache without updating its frequency. + + This is more efficient than the default implementation which uses get() + and has the side effect of updating frequency counts. + + Args: + key: The key to check + + Returns: + True if the key exists in the cache, False otherwise + """ + with self._lock: + return key in self.key_map + + @property + def capacity(self) -> int: + """Get the maximum capacity of the cache.""" + return self._capacity + + +# Example usage and testing +if __name__ == "__main__": + print("=== Basic LFU Cache Example ===") + # Create a basic LFU cache + cache = LFUCache(3) + + cache.put("a", 1) + cache.put("b", 2) + cache.put("c", 3) + + print(f"Get 'a': {cache.get('a')}") # Returns 1, frequency of 'a' becomes 2 + print(f"Get 'b': {cache.get('b')}") # Returns 2, frequency of 'b' becomes 2 + + cache.put("d", 4) # Evicts 'c' (least frequently used) + + print(f"Get 'c': {cache.get('c', 'Not found')}") # Returns 'Not found' + print(f"Get 'd': {cache.get('d')}") # Returns 4 + print(f"Cache size: {cache.size()}") # Returns 3 + + print("\n=== Cache with Statistics Tracking ===") + + # Create cache with statistics tracking + stats_cache = LFUCache(capacity=5, track_stats=True) + + # Add some items + for i in range(6): + stats_cache.put(f"key{i}", f"value{i}") + + # Access some items to change frequencies + for _ in range(3): + stats_cache.get("key4") # Increase frequency + stats_cache.get("key5") # Increase frequency + + # Some cache misses + stats_cache.get("nonexistent1") + stats_cache.get("nonexistent2") + + # Check statistics + print(f"\nCache statistics: {stats_cache.get_stats()}") + + # Update existing key + stats_cache.put("key4", "updated_value4") + + # Check updated statistics + print(f"\nUpdated statistics: {stats_cache.get_stats()}") + + # Reset statistics + stats_cache.reset_stats() + print(f"\nAfter reset: {stats_cache.get_stats()}") + + print("\n=== Cache with Persistence ===") + + # Create cache with persistence (5 second interval) + persist_cache = LFUCache( + capacity=3, persistence_interval=5.0, persistence_path="test_cache.json" + ) + + # Add some items + persist_cache.put("item1", "value1") + persist_cache.put("item2", "value2") + persist_cache.put("item3", "value3") + + # Force immediate persistence + persist_cache.persist_now() + print("Cache persisted to disk") + + # Create new cache instance that will load from disk + new_cache = LFUCache( + capacity=3, persistence_interval=5.0, persistence_path="test_cache.json" + ) + + # Verify data was loaded + print(f"Loaded item1: {new_cache.get('item1')}") # Should return 'value1' + print(f"Loaded item2: {new_cache.get('item2')}") # Should return 'value2' + print(f"Cache size after loading: {new_cache.size()}") # Should return 3 + + # Clean up + if os.path.exists("test_cache.json"): + os.remove("test_cache.json") + print("Cleaned up test persistence file") diff --git a/nemoguardrails/library/content_safety/actions.py b/nemoguardrails/library/content_safety/actions.py index c8e64d3de..92da5831b 100644 --- a/nemoguardrails/library/content_safety/actions.py +++ b/nemoguardrails/library/content_safety/actions.py @@ -13,14 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import logging -from typing import Dict, Optional +import re +from typing import Dict, List, Optional, Union from langchain_core.language_models.llms import BaseLLM from nemoguardrails.actions.actions import action from nemoguardrails.actions.llm.utils import llm_call from nemoguardrails.context import llm_call_info_var +from nemoguardrails.library.content_safety.manager import ContentSafetyManager from nemoguardrails.llm.params import llm_params from nemoguardrails.llm.taskmanager import LLMTaskManager from nemoguardrails.logging.explain import LLMCallInfo @@ -28,10 +31,39 @@ log = logging.getLogger(__name__) +PROMPT_PATTERN_WHITESPACES = re.compile(r"\s+") + + +def _create_cache_key(prompt: Union[str, List[str]]) -> str: + """Create a cache key from the prompt.""" + # can the prompt really be a list? + if isinstance(prompt, list): + prompt_str = json.dumps(prompt) + else: + prompt_str = prompt + + # normalize the prompt to a string + # should we do more normalizations? + return PROMPT_PATTERN_WHITESPACES.sub(" ", prompt_str).strip() + + +# Thread Safety Note: +# The content safety caching mechanism is thread-safe for single-node deployments. +# The underlying LFUCache uses threading.RLock to ensure atomic operations. +# ContentSafetyManager uses double-checked locking for efficient cache creation. +# +# However, this implementation is NOT suitable for distributed environments. +# For multi-node deployments, consider using distributed caching solutions +# like Redis or a shared database. + + @action() async def content_safety_check_input( llms: Dict[str, BaseLLM], llm_task_manager: LLMTaskManager, + content_safety_manager: Optional[ + "ContentSafetyManager" + ] = None, # Optional for backward compatibility model_name: Optional[str] = None, context: Optional[dict] = None, **kwargs, @@ -76,6 +108,20 @@ async def content_safety_check_input( max_tokens = max_tokens or _MAX_TOKENS + # Check cache if content safety manager is available + cached_result = None + cache_key = None + + if content_safety_manager and model_name: + cache = content_safety_manager.get_cache_for_model(model_name) + if cache: + cache_key = _create_cache_key(check_input_prompt) + cached_result = cache.get(cache_key) + if cached_result is not None: + log.debug(f"Content safety cache hit for model '{model_name}'") + return cached_result + + # Make the actual LLM call with llm_params(llm, temperature=1e-20, max_tokens=max_tokens): result = await llm_call(llm, check_input_prompt, stop=stop) @@ -84,7 +130,17 @@ async def content_safety_check_input( is_safe, *violated_policies = result - return {"allowed": is_safe, "policy_violations": violated_policies} + final_result = {"allowed": is_safe, "policy_violations": violated_policies} + + # Store in cache if available + if cache_key: + assert content_safety_manager is not None and model_name is not None + cache = content_safety_manager.get_cache_for_model(model_name) + if cache: + cache.put(cache_key, final_result) + log.debug(f"Content safety result cached for model '{model_name}'") + + return final_result def content_safety_check_output_mapping(result: dict) -> bool: diff --git a/nemoguardrails/library/content_safety/manager.py b/nemoguardrails/library/content_safety/manager.py new file mode 100644 index 000000000..b864d4faa --- /dev/null +++ b/nemoguardrails/library/content_safety/manager.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import threading +from typing import Dict, Optional + +from nemoguardrails.cache.interface import CacheInterface +from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.rails.llm.config import ModelConfig + +log = logging.getLogger(__name__) + + +class ContentSafetyManager: + """Manages all content safety related functionality.""" + + def __init__(self, config: ModelConfig): + self.config = config + self._caches: Dict[str, CacheInterface] = {} + self._lock = threading.RLock() # Thread-safe cache creation + self._initialize_caches() + + def _initialize_caches(self): + """Initialize per-model caches based on configuration.""" + if not self.config.cache.enabled: + return + + # We'll create caches on-demand for each model + self._cache_config = self.config.cache + + def get_cache_for_model(self, model_name: str) -> Optional[CacheInterface]: + """Get or create cache for a specific model.""" + if not self.config.cache.enabled: + return None + + # Resolve model alias if configured + actual_model = self.config.model_mapping.get(model_name, model_name) + + # Double-checked locking pattern for efficiency + if actual_model not in self._caches: + with self._lock: + # Check again inside the lock + if actual_model not in self._caches: + # Create cache based on store type + if self._cache_config.store == "memory": + # Determine persistence settings for this model + persistence_path = None + persistence_interval = None + + # Check if persistence is enabled and has a valid interval + if ( + self._cache_config.persistence.enabled + and self._cache_config.persistence.interval is not None + ): + persistence_interval = ( + self._cache_config.persistence.interval + ) + + if self._cache_config.persistence.path: + # Use configured path, replacing {model_name} if present + persistence_path = ( + self._cache_config.persistence.path.replace( + "{model_name}", actual_model + ) + ) + else: + # Default path if persistence is enabled but no path specified + persistence_path = f"cache_{actual_model}.json" + + # Determine stats logging settings + stats_logging_interval = None + if ( + self._cache_config.stats.enabled + and self._cache_config.stats.log_interval is not None + ): + stats_logging_interval = ( + self._cache_config.stats.log_interval + ) + + self._caches[actual_model] = LFUCache( + capacity=self._cache_config.capacity_per_model, + track_stats=self._cache_config.stats.enabled, + persistence_interval=persistence_interval, + persistence_path=persistence_path, + stats_logging_interval=stats_logging_interval, + ) + # elif self._cache_config.store == "filesystem": + # self._caches[actual_model] = FilesystemCache(...) + # elif self._cache_config.store == "redis": + # self._caches[actual_model] = RedisCache(...) + + return self._caches[actual_model] + + def persist_all_caches(self): + """Force immediate persistence of all caches that support it.""" + with self._lock: + # Create a list of caches to persist to avoid holding lock during I/O + caches_to_persist = [ + (model_name, cache) + for model_name, cache in self._caches.items() + if cache.supports_persistence() + ] + + # Persist outside the lock to avoid blocking other operations + for model_name, cache in caches_to_persist: + cache.persist_now() + log.info(f"Persisted cache for model: {model_name}") diff --git a/nemoguardrails/rails/llm/config.py b/nemoguardrails/rails/llm/config.py index bc12569a1..58a74f0ef 100644 --- a/nemoguardrails/rails/llm/config.py +++ b/nemoguardrails/rails/llm/config.py @@ -830,6 +830,77 @@ def get_validator_config(self, name: str) -> Optional[GuardrailsAIValidatorConfi return None +class CachePersistenceConfig(BaseModel): + """Configuration for cache persistence to disk.""" + + enabled: bool = Field( + default=True, + description="Whether cache persistence is enabled (persistence requires both enabled=True and a valid interval)", + ) + interval: Optional[float] = Field( + default=None, + description="Seconds between periodic cache persistence to disk (None disables persistence)", + ) + path: Optional[str] = Field( + default=None, + description="Path to persistence file for cache data (defaults to 'cache_{model_name}.json' if persistence is enabled)", + ) + + +class CacheStatsConfig(BaseModel): + """Configuration for cache statistics tracking and logging.""" + + enabled: bool = Field( + default=False, + description="Whether cache statistics tracking is enabled", + ) + log_interval: Optional[float] = Field( + default=None, + description="Seconds between periodic cache stats logging to logs (None disables logging)", + ) + + +class ModelCacheConfig(BaseModel): + """Configuration for model caching.""" + + enabled: bool = Field( + default=False, + description="Whether caching is enabled for content safety checks", + ) + capacity_per_model: int = Field( + default=50000, description="Maximum number of entries in the cache per model" + ) + store: str = Field( + default="memory", description="Cache store: 'memory', 'filesystem', 'redis'" + ) + store_config: Dict[str, Any] = Field( + default_factory=dict, description="Backend-specific configuration" + ) + persistence: CachePersistenceConfig = Field( + default_factory=CachePersistenceConfig, + description="Configuration for cache persistence", + ) + stats: CacheStatsConfig = Field( + default_factory=CacheStatsConfig, + description="Configuration for cache statistics tracking and logging", + ) + + +class ModelConfig(BaseModel): + """Configuration for content safety features.""" + + cache: ModelCacheConfig = Field( + default_factory=ModelCacheConfig, + description="Configuration for content safety caching", + ) + + # Model mapping for backward compatibility + model_mapping: Dict[str, str] = Field( + default_factory=dict, + description="Mapping of model aliases to actual model types (e.g., 'content_safety' -> 'llama_guard')", + ) + + class RailsConfigData(BaseModel): """Configuration data for specific rails that are supported out-of-the-box.""" @@ -888,6 +959,11 @@ class RailsConfigData(BaseModel): description="Configuration for Guardrails AI validators.", ) + content_safety: ModelConfig = Field( + default_factory=ModelConfig, + description="Configuration for content safety features.", + ) + class Rails(BaseModel): """Configuration of specific rails.""" diff --git a/nemoguardrails/rails/llm/llmrails.py b/nemoguardrails/rails/llm/llmrails.py index 0027b7fc5..a322412e6 100644 --- a/nemoguardrails/rails/llm/llmrails.py +++ b/nemoguardrails/rails/llm/llmrails.py @@ -114,6 +114,7 @@ def __init__( self.config = config self.llm = llm self.verbose = verbose + self._content_safety_manager = None if self.verbose: set_verbose(True, llm_calls=True) @@ -509,6 +510,23 @@ def _init_llms(self): self.runtime.register_action_param("llms", llms) + # Register content safety manager if content safety features are used + if self._has_content_safety_rails(): + from nemoguardrails.library.content_safety.manager import ( + ContentSafetyManager, + ) + + content_safety_config = self.config.rails.config.content_safety + self._content_safety_manager = ContentSafetyManager(content_safety_config) + self.runtime.register_action_param( + "content_safety_manager", self._content_safety_manager + ) + + log.info( + "Initialized ContentSafetyManager with cache %s", + "enabled" if content_safety_config.cache.enabled else "disabled", + ) + def _create_isolated_llms_for_actions(self): """Create isolated LLM copies for all actions that accept 'llm' parameter.""" if not self.llm: @@ -1507,6 +1525,16 @@ def register_embedding_provider( register_embedding_provider(engine_name=name, model=cls) return self + def _has_content_safety_rails(self) -> bool: + """Check if any content safety rails are configured in flows. + At the moment, we only support content safety manager in input flows. + """ + flows = self.config.rails.input.flows + for flow in flows: + if "content safety check input" in flow: + return True + return False + def explain(self) -> ExplainInfo: """Helper function to return the latest ExplainInfo object.""" return self.explain_info @@ -1787,3 +1815,26 @@ def _prepare_params( # yield the individual chunks directly from the buffer strategy for chunk in user_output_chunks: yield chunk + + def close(self): + """Properly close and clean up resources, including persisting caches.""" + if self._content_safety_manager: + log.info("Persisting content safety caches on close") + self._content_safety_manager.persist_all_caches() + + def __del__(self): + """Ensure caches are persisted when the object is garbage collected.""" + try: + self.close() + except Exception as e: + # Silently fail in destructor to avoid issues during shutdown + log.debug(f"Error during LLMRails cleanup: {e}") + + def __enter__(self): + """Context manager entry.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit - ensure cleanup.""" + self.close() + return False diff --git a/tests/test_cache_lfu.py b/tests/test_cache_lfu.py new file mode 100644 index 000000000..f2941482a --- /dev/null +++ b/tests/test_cache_lfu.py @@ -0,0 +1,1861 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Comprehensive test suite for LFU Cache implementation. + +Tests all functionality including basic operations, eviction policies, +capacity management, edge cases, and persistence functionality. +""" + +import asyncio +import json +import os +import tempfile +import threading +import time +import unittest +from concurrent.futures import ThreadPoolExecutor +from typing import Any +from unittest.mock import MagicMock, patch + +from nemoguardrails.cache.lfu import LFUCache +from nemoguardrails.library.content_safety.manager import ContentSafetyManager + + +class TestLFUCache(unittest.TestCase): + """Test cases for LFU Cache implementation.""" + + def setUp(self): + """Set up test fixtures.""" + self.cache = LFUCache(3) + + def test_initialization(self): + """Test cache initialization with various capacities.""" + # Normal capacity + cache = LFUCache(5) + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + # Zero capacity + cache_zero = LFUCache(0) + self.assertEqual(cache_zero.size(), 0) + + # Negative capacity should raise error + with self.assertRaises(ValueError): + LFUCache(-1) + + def test_basic_put_get(self): + """Test basic put and get operations.""" + # Put and get single item + self.cache.put("key1", "value1") + self.assertEqual(self.cache.get("key1"), "value1") + self.assertEqual(self.cache.size(), 1) + + # Put and get multiple items + self.cache.put("key2", "value2") + self.cache.put("key3", "value3") + + self.assertEqual(self.cache.get("key1"), "value1") + self.assertEqual(self.cache.get("key2"), "value2") + self.assertEqual(self.cache.get("key3"), "value3") + self.assertEqual(self.cache.size(), 3) + + def test_get_nonexistent_key(self): + """Test getting non-existent keys.""" + # Default behavior (returns None) + self.assertIsNone(self.cache.get("nonexistent")) + + # With custom default + self.assertEqual(self.cache.get("nonexistent", "default"), "default") + + # After adding some items + self.cache.put("key1", "value1") + self.assertIsNone(self.cache.get("key2")) + self.assertEqual(self.cache.get("key2", 42), 42) + + def test_update_existing_key(self): + """Test updating values for existing keys.""" + self.cache.put("key1", "value1") + self.cache.put("key2", "value2") + + # Update existing key + self.cache.put("key1", "new_value1") + self.assertEqual(self.cache.get("key1"), "new_value1") + + # Size should not change + self.assertEqual(self.cache.size(), 2) + + def test_lfu_eviction_basic(self): + """Test basic LFU eviction when cache is full.""" + # Fill cache + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Access 'a' and 'b' to increase their frequency + self.cache.get("a") # freq: 2 + self.cache.get("b") # freq: 2 + # 'c' remains at freq: 1 + + # Add new item - should evict 'c' (lowest frequency) + self.cache.put("d", 4) + + self.assertEqual(self.cache.get("a"), 1) + self.assertEqual(self.cache.get("b"), 2) + self.assertEqual(self.cache.get("d"), 4) + self.assertIsNone(self.cache.get("c")) # Should be evicted + + def test_lfu_with_lru_tiebreaker(self): + """Test LRU eviction among items with same frequency.""" + # Fill cache - all items have frequency 1 + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Add new item - should evict 'a' (least recently used among freq 1) + self.cache.put("d", 4) + + self.assertIsNone(self.cache.get("a")) # Should be evicted + self.assertEqual(self.cache.get("b"), 2) + self.assertEqual(self.cache.get("c"), 3) + self.assertEqual(self.cache.get("d"), 4) + + def test_complex_eviction_scenario(self): + """Test complex eviction scenario with multiple frequency levels.""" + # Create a new cache for this test + cache = LFUCache(4) + + # Add items and create different frequency levels + cache.put("a", 1) + cache.put("b", 2) + cache.put("c", 3) + cache.put("d", 4) + + # Create frequency pattern: + # a: freq 3 (accessed 2 more times) + # b: freq 2 (accessed 1 more time) + # c: freq 2 (accessed 1 more time) + # d: freq 1 (not accessed) + + cache.get("a") + cache.get("a") + cache.get("b") + cache.get("c") + + # Add new item - should evict 'd' (freq 1) + cache.put("e", 5) + self.assertIsNone(cache.get("d")) + + # Add another item - should evict one of the least frequently used + cache.put("f", 6) + + # After eviction, we should have: + # - 'a' (freq 3) - definitely kept + # - 'b' (freq 2) and 'c' (freq 2) - higher frequency, both kept + # - 'f' (freq 1) - just added + # - 'e' (freq 1) was evicted as it was least recently used among freq 1 items + + # Check that we're at capacity + self.assertEqual(cache.size(), 4) + + # 'a' should definitely still be there (highest frequency) + self.assertEqual(cache.get("a"), 1) + + # 'b' and 'c' should both be there (freq 2) + self.assertEqual(cache.get("b"), 2) + self.assertEqual(cache.get("c"), 3) + + # 'f' should be there (just added) + self.assertEqual(cache.get("f"), 6) + + # 'e' should have been evicted (freq 1, LRU among freq 1 items) + self.assertIsNone(cache.get("e")) + + def test_zero_capacity_cache(self): + """Test cache with zero capacity.""" + cache = LFUCache(0) + + # Put should not store anything + cache.put("key", "value") + self.assertEqual(cache.size(), 0) + self.assertIsNone(cache.get("key")) + + # Multiple puts + for i in range(10): + cache.put(f"key{i}", f"value{i}") + + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + def test_clear_method(self): + """Test clearing the cache.""" + # Add items + self.cache.put("a", 1) + self.cache.put("b", 2) + self.cache.put("c", 3) + + # Verify items exist + self.assertEqual(self.cache.size(), 3) + self.assertFalse(self.cache.is_empty()) + + # Clear cache + self.cache.clear() + + # Verify cache is empty + self.assertEqual(self.cache.size(), 0) + self.assertTrue(self.cache.is_empty()) + + # Verify items are gone + self.assertIsNone(self.cache.get("a")) + self.assertIsNone(self.cache.get("b")) + self.assertIsNone(self.cache.get("c")) + + # Can still use cache after clear + self.cache.put("new_key", "new_value") + self.assertEqual(self.cache.get("new_key"), "new_value") + + def test_various_data_types(self): + """Test cache with various data types as keys and values.""" + # Integer keys + self.cache.put(1, "one") + self.cache.put(2, "two") + self.assertEqual(self.cache.get(1), "one") + self.assertEqual(self.cache.get(2), "two") + + # Tuple keys + self.cache.put((1, 2), "tuple_value") + self.assertEqual(self.cache.get((1, 2)), "tuple_value") + + # Clear for more tests + self.cache.clear() + + # Complex values + self.cache.put("list", [1, 2, 3]) + self.cache.put("dict", {"a": 1, "b": 2}) + self.cache.put("set", {1, 2, 3}) + + self.assertEqual(self.cache.get("list"), [1, 2, 3]) + self.assertEqual(self.cache.get("dict"), {"a": 1, "b": 2}) + self.assertEqual(self.cache.get("set"), {1, 2, 3}) + + def test_none_values(self): + """Test storing None as a value.""" + self.cache.put("key", None) + # get should return None for the value, not the default + self.assertIsNone(self.cache.get("key")) + self.assertEqual(self.cache.get("key", "default"), None) + + # Verify key exists + self.assertEqual(self.cache.size(), 1) + + def test_size_and_capacity(self): + """Test size tracking and capacity limits.""" + # Start empty + self.assertEqual(self.cache.size(), 0) + + # Add items up to capacity + for i in range(3): + self.cache.put(f"key{i}", f"value{i}") + self.assertEqual(self.cache.size(), i + 1) + + # Add more items - size should stay at capacity + for i in range(3, 10): + self.cache.put(f"key{i}", f"value{i}") + self.assertEqual(self.cache.size(), 3) + + def test_is_empty(self): + """Test is_empty method in various states.""" + # Initially empty + self.assertTrue(self.cache.is_empty()) + + # After adding item + self.cache.put("key", "value") + self.assertFalse(self.cache.is_empty()) + + # After clearing + self.cache.clear() + self.assertTrue(self.cache.is_empty()) + + def test_repeated_puts_same_key(self): + """Test repeated puts with the same key maintain size=1 and update frequency.""" + self.cache.put("key", "value1") + self.assertEqual(self.cache.size(), 1) + + # Track initial state + initial_stats = self.cache.get_stats() if self.cache.track_stats else None + + # Update same key multiple times + for i in range(10): + self.cache.put("key", f"value{i}") + self.assertEqual(self.cache.size(), 1) + + # Final value should be the last one + self.assertEqual(self.cache.get("key"), "value9") + + # Verify stats if tracking enabled + if self.cache.track_stats: + final_stats = self.cache.get_stats() + # Should have 10 updates (after initial put) + self.assertEqual(final_stats["updates"], 10) + + def test_access_pattern_preserves_frequently_used(self): + """Test that frequently accessed items are preserved during evictions.""" + # Create specific access pattern + cache = LFUCache(3) + + # Add three items + cache.put("rarely_used", 1) + cache.put("sometimes_used", 2) + cache.put("frequently_used", 3) + + # Create access pattern + # frequently_used: access 10 times + for _ in range(10): + cache.get("frequently_used") + + # sometimes_used: access 3 times + for _ in range(3): + cache.get("sometimes_used") + + # rarely_used: no additional access (freq = 1) + + # Add new items to trigger evictions + cache.put("new1", 4) # Should evict rarely_used + cache.put("new2", 5) # Should evict new1 (freq = 1) + + # frequently_used and sometimes_used should still be there + self.assertEqual(cache.get("frequently_used"), 3) + self.assertEqual(cache.get("sometimes_used"), 2) + + # rarely_used and new1 should be evicted + self.assertIsNone(cache.get("rarely_used")) + self.assertIsNone(cache.get("new1")) + + # new2 should be there + self.assertEqual(cache.get("new2"), 5) + + +class TestLFUCacheInterface(unittest.TestCase): + """Test that LFUCache properly implements CacheInterface.""" + + def test_interface_methods_exist(self): + """Verify all interface methods are implemented.""" + cache = LFUCache(5) + + # Check all required methods exist and are callable + self.assertTrue(callable(getattr(cache, "get", None))) + self.assertTrue(callable(getattr(cache, "put", None))) + self.assertTrue(callable(getattr(cache, "size", None))) + self.assertTrue(callable(getattr(cache, "is_empty", None))) + self.assertTrue(callable(getattr(cache, "clear", None))) + + # Check property + self.assertEqual(cache.capacity, 5) + + def test_persistence_interface_methods(self): + """Verify persistence interface methods are implemented.""" + # Cache without persistence + cache_no_persist = LFUCache(5) + self.assertTrue(callable(getattr(cache_no_persist, "persist_now", None))) + self.assertTrue( + callable(getattr(cache_no_persist, "supports_persistence", None)) + ) + self.assertFalse(cache_no_persist.supports_persistence()) + + # Cache with persistence + temp_file = os.path.join(tempfile.mkdtemp(), "test_interface.json") + try: + cache_with_persist = LFUCache( + 5, persistence_interval=10.0, persistence_path=temp_file + ) + self.assertTrue(cache_with_persist.supports_persistence()) + + # persist_now should work without errors + cache_with_persist.put("key", "value") + cache_with_persist.persist_now() # Should not raise any exception + finally: + if os.path.exists(temp_file): + os.remove(temp_file) + if os.path.exists(os.path.dirname(temp_file)): + os.rmdir(os.path.dirname(temp_file)) + + +class TestLFUCachePersistence(unittest.TestCase): + """Test cases for LFU Cache persistence functionality.""" + + def setUp(self): + """Set up test fixtures.""" + # Create temporary directory for test files + self.temp_dir = tempfile.mkdtemp() + self.test_file = os.path.join(self.temp_dir, "test_cache.json") + + def tearDown(self): + """Clean up test files.""" + # Clean up any created files + if os.path.exists(self.test_file): + os.remove(self.test_file) + # Remove temporary directory + if os.path.exists(self.temp_dir): + os.rmdir(self.temp_dir) + + def test_basic_persistence(self): + """Test basic save and load functionality.""" + # Create cache and add items + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + cache.put("key1", "value1") + cache.put("key2", {"nested": "value"}) + cache.put("key3", [1, 2, 3]) + + # Force persistence + cache.persist_now() + + # Verify file was created + self.assertTrue(os.path.exists(self.test_file)) + + # Load into new cache + new_cache = LFUCache( + 5, persistence_interval=10.0, persistence_path=self.test_file + ) + + # Verify data was loaded correctly + self.assertEqual(new_cache.size(), 3) + self.assertEqual(new_cache.get("key1"), "value1") + self.assertEqual(new_cache.get("key2"), {"nested": "value"}) + self.assertEqual(new_cache.get("key3"), [1, 2, 3]) + + def test_frequency_preservation(self): + """Test that frequencies are preserved across persistence.""" + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + # Create different frequency levels + cache.put("freq1", "value1") + cache.put("freq3", "value3") + cache.put("freq5", "value5") + + # Access items to create different frequencies + cache.get("freq3") # freq = 2 + cache.get("freq3") # freq = 3 + + cache.get("freq5") # freq = 2 + cache.get("freq5") # freq = 3 + cache.get("freq5") # freq = 4 + cache.get("freq5") # freq = 5 + + # Force persistence + cache.persist_now() + + # Load into new cache + new_cache = LFUCache( + 5, persistence_interval=10.0, persistence_path=self.test_file + ) + + # Add new items to test eviction order + new_cache.put("new1", "newvalue1") + new_cache.put("new2", "newvalue2") + new_cache.put("new3", "newvalue3") + + # freq1 should be evicted first (lowest frequency) + self.assertIsNone(new_cache.get("freq1")) + # freq3 and freq5 should still be there + self.assertEqual(new_cache.get("freq3"), "value3") + self.assertEqual(new_cache.get("freq5"), "value5") + + def test_periodic_persistence(self): + """Test automatic periodic persistence.""" + # Use short interval for testing + cache = LFUCache(5, persistence_interval=0.5, persistence_path=self.test_file) + + cache.put("key1", "value1") + + # File shouldn't exist yet + self.assertFalse(os.path.exists(self.test_file)) + + # Wait for interval to pass + time.sleep(0.6) + + # Access cache to trigger persistence check + cache.get("key1") + + # File should now exist + self.assertTrue(os.path.exists(self.test_file)) + + # Verify content + with open(self.test_file, "r") as f: + data = json.load(f) + + self.assertEqual(data["capacity"], 5) + self.assertEqual(len(data["nodes"]), 1) + self.assertEqual(data["nodes"][0]["key"], "key1") + + def test_persistence_with_empty_cache(self): + """Test persistence behavior with empty cache.""" + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + # Add and remove items + cache.put("key1", "value1") + cache.clear() + + # Force persistence + cache.persist_now() + + # File should be removed when cache is empty + self.assertFalse(os.path.exists(self.test_file)) + + def test_no_persistence_when_disabled(self): + """Test that persistence doesn't occur when not configured.""" + # Create cache without persistence + cache = LFUCache(5) + + cache.put("key1", "value1") + cache.persist_now() # Should do nothing + + # No file should be created + self.assertFalse(os.path.exists("lfu_cache.json")) + + def test_load_from_nonexistent_file(self): + """Test loading when persistence file doesn't exist.""" + # Create cache with non-existent file + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + + # Should start empty + self.assertEqual(cache.size(), 0) + self.assertTrue(cache.is_empty()) + + def test_persistence_with_complex_data(self): + """Test persistence with various data types.""" + cache = LFUCache(10, persistence_interval=10.0, persistence_path=self.test_file) + + # Add various data types + test_data = { + "string": "hello world", + "int": 42, + "float": 3.14, + "bool": True, + "none": None, + "list": [1, 2, [3, 4]], + "dict": {"a": 1, "b": {"c": 2}}, + "tuple_key": "value_for_tuple", # Will use string key since tuples aren't JSON serializable + } + + for key, value in test_data.items(): + cache.put(key, value) + + # Force persistence + cache.persist_now() + + # Load into new cache + new_cache = LFUCache( + 10, persistence_interval=10.0, persistence_path=self.test_file + ) + + # Verify all data types + for key, value in test_data.items(): + self.assertEqual(new_cache.get(key), value) + + def test_persistence_file_corruption_handling(self): + """Test handling of corrupted persistence files.""" + # Create invalid JSON file + with open(self.test_file, "w") as f: + f.write("{ invalid json content") + + # Should handle gracefully and start with empty cache + cache = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + self.assertEqual(cache.size(), 0) + + # Cache should still be functional + cache.put("key1", "value1") + self.assertEqual(cache.get("key1"), "value1") + + def test_multiple_persistence_cycles(self): + """Test multiple save/load cycles.""" + # First cycle + cache1 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + cache1.put("key1", "value1") + cache1.put("key2", "value2") + cache1.persist_now() + + # Second cycle - load and modify + cache2 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + self.assertEqual(cache2.size(), 2) + cache2.put("key3", "value3") + cache2.persist_now() + + # Third cycle - verify all changes + cache3 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + self.assertEqual(cache3.size(), 3) + self.assertEqual(cache3.get("key1"), "value1") + self.assertEqual(cache3.get("key2"), "value2") + self.assertEqual(cache3.get("key3"), "value3") + + def test_capacity_change_on_load(self): + """Test loading cache data into cache with different capacity.""" + # Create cache with capacity 5 + cache1 = LFUCache(5, persistence_interval=10.0, persistence_path=self.test_file) + for i in range(5): + cache1.put(f"key{i}", f"value{i}") + cache1.persist_now() + + # Load into cache with smaller capacity + cache2 = LFUCache(3, persistence_interval=10.0, persistence_path=self.test_file) + + # Current design: loads all persisted items regardless of new capacity + # This is a valid design choice - preserve data integrity on load + self.assertEqual(cache2.size(), 5) + + # The cache continues to operate with loaded items + # New items can still be added, and the cache will manage its size + cache2.put("new_key", "new_value") + + # Verify the cache is still functional and contains the new item + self.assertEqual(cache2.get("new_key"), "new_value") + self.assertGreaterEqual( + cache2.size(), 4 + ) # At least has the new item plus some old ones + + def test_persistence_timing(self): + """Test that persistence doesn't happen too frequently.""" + cache = LFUCache(5, persistence_interval=1.0, persistence_path=self.test_file) + + cache.put("key1", "value1") + + # Multiple operations within interval shouldn't trigger persistence + for i in range(10): + cache.get("key1") + self.assertFalse(os.path.exists(self.test_file)) + time.sleep(0.05) # Total time still less than interval + + # Wait for interval to pass + time.sleep(0.6) + cache.get("key1") + + # Now file should exist + self.assertTrue(os.path.exists(self.test_file)) + + def test_persistence_with_statistics(self): + """Test persistence doesn't interfere with statistics tracking.""" + cache = LFUCache( + 5, + track_stats=True, + persistence_interval=0.5, + persistence_path=self.test_file, + ) + + # Perform operations + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.get("key1") + cache.get("nonexistent") + + # Wait for persistence + time.sleep(0.6) + cache.get("key1") # Trigger persistence + + # Check stats are still correct + stats = cache.get_stats() + self.assertEqual(stats["puts"], 2) + self.assertEqual(stats["hits"], 2) + self.assertEqual(stats["misses"], 1) + + # Load into new cache with stats + new_cache = LFUCache( + 5, + track_stats=True, + persistence_interval=0.5, + persistence_path=self.test_file, + ) + + # Stats should be reset in new instance + new_stats = new_cache.get_stats() + self.assertEqual(new_stats["puts"], 0) + self.assertEqual(new_stats["hits"], 0) + + # But data should be loaded + self.assertEqual(new_cache.size(), 2) + + +class TestLFUCacheStatsLogging(unittest.TestCase): + """Test cases for LFU Cache statistics logging functionality.""" + + def setUp(self): + """Set up test fixtures.""" + self.test_file = tempfile.mktemp() + + def tearDown(self): + """Clean up test files.""" + if os.path.exists(self.test_file): + os.remove(self.test_file) + + def test_stats_logging_disabled_by_default(self): + """Test that stats logging is disabled when not configured.""" + cache = LFUCache(5, track_stats=True) + self.assertFalse(cache.supports_stats_logging()) + + def test_stats_logging_requires_tracking(self): + """Test that stats logging requires stats tracking to be enabled.""" + # Logging without tracking + cache = LFUCache(5, track_stats=False, stats_logging_interval=1.0) + self.assertFalse(cache.supports_stats_logging()) + + # Both enabled + cache = LFUCache(5, track_stats=True, stats_logging_interval=1.0) + self.assertTrue(cache.supports_stats_logging()) + + def test_log_stats_now(self): + """Test immediate stats logging.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=60.0) + + # Add some data + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.get("key1") + cache.get("nonexistent") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + # Verify log was called + self.assertEqual(mock_log.call_count, 1) + log_message = mock_log.call_args[0][0] + + # Check log format + self.assertIn("LFU Cache Statistics", log_message) + self.assertIn("Size: 2/5", log_message) + self.assertIn("Hits: 1", log_message) + self.assertIn("Misses: 1", log_message) + self.assertIn("Hit Rate: 50.00%", log_message) + self.assertIn("Evictions: 0", log_message) + self.assertIn("Puts: 2", log_message) + self.assertIn("Updates: 0", log_message) + + def test_periodic_stats_logging(self): + """Test automatic periodic stats logging.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.5) + + # Add some data + cache.put("key1", "value1") + cache.put("key2", "value2") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # Initial operations shouldn't trigger logging + cache.get("key1") + self.assertEqual(mock_log.call_count, 0) + + # Wait for interval to pass + time.sleep(0.6) + + # Next operation should trigger logging + cache.get("key1") + self.assertEqual(mock_log.call_count, 1) + + # Another operation without waiting shouldn't trigger + cache.get("key2") + self.assertEqual(mock_log.call_count, 1) + + # Wait again + time.sleep(0.6) + cache.put("key3", "value3") + self.assertEqual(mock_log.call_count, 2) + + def test_stats_logging_with_empty_cache(self): + """Test stats logging with empty cache.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + # Generate a miss first + cache.get("nonexistent") + + # Wait for interval to pass + time.sleep(0.2) + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # This will trigger stats logging with the previous miss already counted + cache.get("another_nonexistent") # Trigger check + + self.assertEqual(mock_log.call_count, 1) + log_message = mock_log.call_args[0][0] + + self.assertIn("Size: 0/5", log_message) + self.assertIn("Hits: 0", log_message) + self.assertIn("Misses: 1", log_message) # The first miss is logged + self.assertIn("Hit Rate: 0.00%", log_message) + + def test_stats_logging_with_full_cache(self): + """Test stats logging when cache is at capacity.""" + import logging + from unittest.mock import patch + + cache = LFUCache(3, track_stats=True, stats_logging_interval=0.1) + + # Fill cache + cache.put("key1", "value1") + cache.put("key2", "value2") + cache.put("key3", "value3") + + # Cause eviction + cache.put("key4", "value4") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + time.sleep(0.2) + cache.get("key4") # Trigger check + + log_message = mock_log.call_args[0][0] + self.assertIn("Size: 3/3", log_message) + self.assertIn("Evictions: 1", log_message) + self.assertIn("Puts: 4", log_message) + + def test_stats_logging_high_hit_rate(self): + """Test stats logging with high hit rate.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + cache.put("key1", "value1") + + # Many hits + for _ in range(99): + cache.get("key1") + + # One miss + cache.get("nonexistent") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + log_message = mock_log.call_args[0][0] + self.assertIn("Hit Rate: 99.00%", log_message) + self.assertIn("Hits: 99", log_message) + self.assertIn("Misses: 1", log_message) + + def test_stats_logging_without_tracking(self): + """Test that log_stats_now does nothing when tracking is disabled.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=False) + + cache.put("key1", "value1") + cache.get("key1") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + # Should not log anything + self.assertEqual(mock_log.call_count, 0) + + def test_stats_logging_interval_timing(self): + """Test that stats logging respects the interval timing.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=1.0) + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # Multiple operations within interval + for i in range(10): + cache.put(f"key{i}", f"value{i}") + cache.get(f"key{i}") + time.sleep(0.05) # Total time < 1.0 + + # Should not have logged yet + self.assertEqual(mock_log.call_count, 0) + + # Wait for interval to pass + time.sleep(0.6) + cache.get("key1") # Trigger check + + # Now should have logged once + self.assertEqual(mock_log.call_count, 1) + + def test_stats_logging_with_updates(self): + """Test stats logging includes update counts.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + cache.put("key1", "value1") + cache.put("key1", "updated_value1") # Update + cache.put("key1", "updated_again") # Another update + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + log_message = mock_log.call_args[0][0] + self.assertIn("Updates: 2", log_message) + self.assertIn("Puts: 1", log_message) + + def test_stats_logging_combined_with_persistence(self): + """Test that stats logging and persistence work together.""" + import logging + from unittest.mock import patch + + cache = LFUCache( + 5, + track_stats=True, + persistence_interval=1.0, + persistence_path=self.test_file, + stats_logging_interval=0.5, + ) + + cache.put("key1", "value1") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + # Wait for stats logging interval + time.sleep(0.6) + cache.get("key1") # Trigger stats log + + self.assertEqual(mock_log.call_count, 1) + self.assertFalse(os.path.exists(self.test_file)) # Not persisted yet + + # Wait for persistence interval + time.sleep(0.5) + cache.get("key1") # Trigger persistence + + self.assertTrue(os.path.exists(self.test_file)) # Now persisted + # Stats log might trigger again if interval passed + self.assertGreaterEqual(mock_log.call_count, 1) + + def test_stats_log_format_percentages(self): + """Test that percentages in stats log are formatted correctly.""" + import logging + from unittest.mock import patch + + cache = LFUCache(5, track_stats=True, stats_logging_interval=0.1) + + # Test various hit rates + test_cases = [ + (0, 0, "0.00%"), # No requests + (1, 0, "100.00%"), # All hits + (0, 1, "0.00%"), # All misses + (1, 1, "50.00%"), # 50/50 + (2, 1, "66.67%"), # 2/3 + (99, 1, "99.00%"), # High hit rate + ] + + for hits, misses, expected_rate in test_cases: + cache.reset_stats() + + # Generate hits + if hits > 0: + cache.put("hit_key", "value") + for _ in range(hits): + cache.get("hit_key") + + # Generate misses + for i in range(misses): + cache.get(f"miss_key_{i}") + + with patch.object( + logging.getLogger("nemoguardrails.cache.lfu"), "info" + ) as mock_log: + cache.log_stats_now() + + if hits > 0 or misses > 0: + log_message = mock_log.call_args[0][0] + self.assertIn(f"Hit Rate: {expected_rate}", log_message) + + +class TestContentSafetyCacheStatsConfig(unittest.TestCase): + """Test cache stats configuration in content safety context.""" + + def setUp(self): + """Set up test fixtures.""" + self.test_file = tempfile.mktemp() + + def tearDown(self): + """Clean up test files.""" + if os.path.exists(self.test_file): + os.remove(self.test_file) + + def test_cache_config_with_stats_disabled(self): + """Test cache configuration with stats disabled.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, capacity_per_model=1000, stats=CacheStatsConfig(enabled=False) + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertFalse(cache.track_stats) + self.assertFalse(cache.supports_stats_logging()) + + def test_cache_config_with_stats_tracking_only(self): + """Test cache configuration with stats tracking but no logging.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=None), + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertFalse(cache.supports_stats_logging()) + self.assertIsNone(cache.stats_logging_interval) + + def test_cache_config_with_stats_logging(self): + """Test cache configuration with stats tracking and logging.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=60.0), + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertTrue(cache.supports_stats_logging()) + self.assertEqual(cache.stats_logging_interval, 60.0) + + def test_cache_config_default_stats(self): + """Test cache configuration with default stats settings.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ModelCacheConfig, ModelConfig + + cache_config = ModelCacheConfig(enabled=True) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertFalse(cache.track_stats) # Default is disabled + self.assertFalse(cache.supports_stats_logging()) + + def test_cache_config_stats_with_persistence(self): + """Test cache configuration with both stats and persistence.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CachePersistenceConfig, + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=30.0), + persistence=CachePersistenceConfig( + enabled=True, interval=60.0, path=self.test_file + ), + ) + + model_config = ModelConfig(cache=cache_config) + manager = ContentSafetyManager(model_config) + + cache = manager.get_cache_for_model("test_model") + self.assertIsNotNone(cache) + self.assertTrue(cache.track_stats) + self.assertTrue(cache.supports_stats_logging()) + self.assertEqual(cache.stats_logging_interval, 30.0) + self.assertTrue(cache.supports_persistence()) + self.assertEqual(cache.persistence_interval, 60.0) + + def test_cache_config_from_dict(self): + """Test cache configuration creation from dictionary.""" + from nemoguardrails.rails.llm.config import ModelCacheConfig + + config_dict = { + "enabled": True, + "capacity_per_model": 5000, + "stats": {"enabled": True, "log_interval": 120.0}, + } + + cache_config = ModelCacheConfig(**config_dict) + self.assertTrue(cache_config.enabled) + self.assertEqual(cache_config.capacity_per_model, 5000) + self.assertTrue(cache_config.stats.enabled) + self.assertEqual(cache_config.stats.log_interval, 120.0) + + def test_cache_config_stats_validation(self): + """Test cache configuration validation for stats settings.""" + from nemoguardrails.rails.llm.config import CacheStatsConfig + + # Valid configurations + stats1 = CacheStatsConfig(enabled=True, log_interval=60.0) + self.assertTrue(stats1.enabled) + self.assertEqual(stats1.log_interval, 60.0) + + stats2 = CacheStatsConfig(enabled=True, log_interval=None) + self.assertTrue(stats2.enabled) + self.assertIsNone(stats2.log_interval) + + stats3 = CacheStatsConfig(enabled=False, log_interval=60.0) + self.assertFalse(stats3.enabled) + self.assertEqual(stats3.log_interval, 60.0) + + def test_multiple_model_caches_with_stats(self): + """Test multiple model caches each with their own stats configuration.""" + from nemoguardrails.library.content_safety.manager import ContentSafetyManager + from nemoguardrails.rails.llm.config import ( + CacheStatsConfig, + ModelCacheConfig, + ModelConfig, + ) + + cache_config = ModelCacheConfig( + enabled=True, + capacity_per_model=1000, + stats=CacheStatsConfig(enabled=True, log_interval=30.0), + ) + + model_config = ModelConfig( + cache=cache_config, model_mapping={"model_alias": "actual_model"} + ) + manager = ContentSafetyManager(model_config) + + # Get caches for different models + cache1 = manager.get_cache_for_model("model1") + cache2 = manager.get_cache_for_model("model2") + cache_alias = manager.get_cache_for_model("model_alias") + cache_actual = manager.get_cache_for_model("actual_model") + + # All should have stats enabled + self.assertTrue(cache1.track_stats) + self.assertTrue(cache2.track_stats) + self.assertTrue(cache_alias.track_stats) + + # Alias should resolve to same cache as actual + self.assertIs(cache_alias, cache_actual) + + +class TestLFUCacheThreadSafety(unittest.TestCase): + """Test thread safety of LFU Cache implementation.""" + + def setUp(self): + """Set up test fixtures.""" + self.cache = LFUCache(100, track_stats=True) + + def test_concurrent_reads_writes(self): + """Test that concurrent reads and writes don't corrupt the cache.""" + num_threads = 10 + operations_per_thread = 100 + # Use a larger cache to avoid evictions during the test + large_cache = LFUCache(2000, track_stats=True) + errors = [] + + def worker(thread_id): + """Worker function that performs cache operations.""" + for i in range(operations_per_thread): + key = f"thread_{thread_id}_key_{i}" + value = f"thread_{thread_id}_value_{i}" + + # Put operation + large_cache.put(key, value) + + # Get operation - should always succeed with large cache + retrieved = large_cache.get(key) + + # Verify data integrity + if retrieved != value: + errors.append( + f"Data corruption for {key}: expected {value}, got {retrieved}" + ) + + # Access some shared keys + shared_key = f"shared_key_{i % 10}" + large_cache.put(shared_key, f"shared_value_{thread_id}_{i}") + large_cache.get(shared_key) + + # Run threads + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + for future in futures: + future.result() # Wait for completion and raise any exceptions + + # Check for any errors + self.assertEqual(len(errors), 0, f"Errors occurred: {errors[:5]}...") + + # Verify cache is still functional + test_key = "test_after_concurrent" + test_value = "test_value" + large_cache.put(test_key, test_value) + self.assertEqual(large_cache.get(test_key), test_value) + + # Check statistics are reasonable + stats = large_cache.get_stats() + self.assertGreater(stats["hits"], 0) + self.assertGreater(stats["puts"], 0) + + def test_concurrent_evictions(self): + """Test that concurrent operations during evictions don't corrupt the cache.""" + # Use a small cache to trigger frequent evictions + small_cache = LFUCache(10) + num_threads = 5 + operations_per_thread = 50 + + def worker(thread_id): + """Worker that adds many items to trigger evictions.""" + for i in range(operations_per_thread): + key = f"t{thread_id}_k{i}" + value = f"t{thread_id}_v{i}" + small_cache.put(key, value) + + # Try to get recently added items + if i > 0: + prev_key = f"t{thread_id}_k{i-1}" + small_cache.get(prev_key) # May or may not exist + + # Run threads + with ThreadPoolExecutor(max_workers=num_threads) as executor: + futures = [executor.submit(worker, i) for i in range(num_threads)] + for future in futures: + future.result() + + # Cache should still be at capacity + self.assertEqual(small_cache.size(), 10) + + def test_concurrent_clear_operations(self): + """Test concurrent clear operations with other operations.""" + + def writer(): + """Continuously write to cache.""" + for i in range(100): + self.cache.put(f"key_{i}", f"value_{i}") + time.sleep(0.001) # Small delay + + def clearer(): + """Periodically clear the cache.""" + for _ in range(5): + time.sleep(0.01) + self.cache.clear() + + def reader(): + """Continuously read from cache.""" + for i in range(100): + self.cache.get(f"key_{i}") + time.sleep(0.001) + + # Run operations concurrently + threads = [ + threading.Thread(target=writer), + threading.Thread(target=clearer), + threading.Thread(target=reader), + ] + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + # Cache should still be functional + self.cache.put("final_key", "final_value") + self.assertEqual(self.cache.get("final_key"), "final_value") + + def test_concurrent_stats_operations(self): + """Test that concurrent operations don't corrupt statistics.""" + + def worker(thread_id): + """Worker that performs operations and checks stats.""" + for i in range(50): + key = f"stats_key_{thread_id}_{i}" + self.cache.put(key, i) + self.cache.get(key) # Hit + self.cache.get(f"nonexistent_{thread_id}_{i}") # Miss + + # Periodically check stats + if i % 10 == 0: + stats = self.cache.get_stats() + # Just verify we can get stats without error + self.assertIsInstance(stats, dict) + self.assertIn("hits", stats) + self.assertIn("misses", stats) + + # Run threads + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Final stats check + final_stats = self.cache.get_stats() + self.assertGreater(final_stats["hits"], 0) + self.assertGreater(final_stats["misses"], 0) + self.assertGreater(final_stats["puts"], 0) + + def test_get_or_compute_thread_safety(self): + """Test thread safety of get_or_compute method.""" + compute_count = threading.local() + compute_count.value = 0 + total_computes = [] + lock = threading.Lock() + + async def expensive_compute(): + """Simulate expensive computation that should only run once.""" + # Track how many times this is called + if not hasattr(compute_count, "value"): + compute_count.value = 0 + compute_count.value += 1 + + with lock: + total_computes.append(1) + + # Simulate expensive operation + await asyncio.sleep(0.1) + return f"computed_value_{len(total_computes)}" + + async def worker(thread_id): + """Worker that tries to get or compute the same key.""" + result = await self.cache.get_or_compute( + "shared_compute_key", expensive_compute, default="default" + ) + return result + + async def run_test(): + """Run the async test.""" + # Run multiple workers concurrently + tasks = [worker(i) for i in range(10)] + results = await asyncio.gather(*tasks) + + # All should get the same value + self.assertTrue( + all(r == results[0] for r in results), + f"All threads should get same value, got: {results}", + ) + + # Compute should have been called only once + self.assertEqual( + len(total_computes), + 1, + f"Compute should be called once, called {len(total_computes)} times", + ) + + return results[0] + + # Run the async test + result = asyncio.run(run_test()) + self.assertEqual(result, "computed_value_1") + + def test_get_or_compute_exception_handling(self): + """Test get_or_compute handles exceptions properly.""" + call_count = [0] + + async def failing_compute(): + """Compute function that fails.""" + call_count[0] += 1 + raise ValueError("Computation failed") + + async def worker(): + """Worker that tries to compute.""" + result = await self.cache.get_or_compute( + "failing_key", failing_compute, default="fallback" + ) + return result + + async def run_test(): + """Run the async test.""" + # Multiple workers should all get the default value + tasks = [worker() for _ in range(5)] + results = await asyncio.gather(*tasks) + + # All should get the default value + self.assertTrue(all(r == "fallback" for r in results)) + + # The compute function might be called multiple times + # since failed computations aren't cached + self.assertGreaterEqual(call_count[0], 1) + + asyncio.run(run_test()) + + def test_concurrent_persistence(self): + """Test thread safety of persistence operations.""" + with tempfile.NamedTemporaryFile(mode="w", delete=False, suffix=".json") as f: + cache_file = f.name + + try: + # Create cache with persistence + cache = LFUCache( + capacity=50, + track_stats=True, + persistence_interval=0.1, # Short interval for testing + persistence_path=cache_file, + ) + + def worker(thread_id): + """Worker that performs operations.""" + for i in range(20): + cache.put(f"persist_key_{thread_id}_{i}", f"value_{thread_id}_{i}") + cache.get(f"persist_key_{thread_id}_{i}") + + # Force persistence sometimes + if i % 5 == 0: + cache.persist_now() + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Final persist + cache.persist_now() + + # Load the persisted data + new_cache = LFUCache( + capacity=50, persistence_interval=1.0, persistence_path=cache_file + ) + + # Verify some data was persisted correctly + # (Due to capacity limits, not all items will be present) + self.assertGreater(new_cache.size(), 0) + self.assertLessEqual(new_cache.size(), 50) + + finally: + # Clean up + if os.path.exists(cache_file): + os.unlink(cache_file) + + def test_thread_safe_size_operations(self): + """Test that size-related operations are thread-safe.""" + results = [] + + def worker(thread_id): + """Worker that checks size consistency.""" + for i in range(100): + # Add item + self.cache.put(f"size_key_{thread_id}_{i}", i) + + # Check size + size = self.cache.size() + is_empty = self.cache.is_empty() + + # Size should never be negative or exceed capacity + if size < 0 or size > 100: + results.append(f"Invalid size: {size}") + + # is_empty should match size + if (size == 0) != is_empty: + results.append(f"Size {size} but is_empty={is_empty}") + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Check for any inconsistencies + self.assertEqual(len(results), 0, f"Inconsistencies found: {results}") + + def test_concurrent_contains_operations(self): + """Test thread safety of contains method.""" + # Use a larger cache to avoid evictions during the test + # Need capacity for: 50 existing + (5 threads × 100 new keys) = 550+ + large_cache = LFUCache(1000, track_stats=True) + + # Pre-populate cache + for i in range(50): + large_cache.put(f"existing_key_{i}", f"value_{i}") + + results = [] + eviction_warnings = [] + + def worker(thread_id): + """Worker that checks contains and manipulates cache.""" + for i in range(100): + # Check existing keys + key = f"existing_key_{i % 50}" + if not large_cache.contains(key): + results.append(f"Thread {thread_id}: Missing key {key}") + + # Add new keys + new_key = f"new_key_{thread_id}_{i}" + large_cache.put(new_key, f"value_{thread_id}_{i}") + + # Check new key immediately + if not large_cache.contains(new_key): + # This could happen if cache is full and eviction occurred + # Track it separately as it's not a thread safety issue + eviction_warnings.append( + f"Thread {thread_id}: Key {new_key} possibly evicted" + ) + + # Check non-existent keys + if large_cache.contains(f"non_existent_{thread_id}_{i}"): + results.append(f"Thread {thread_id}: Found non-existent key") + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Check for any errors (not counting eviction warnings) + self.assertEqual(len(results), 0, f"Errors found: {results}") + + # Eviction warnings should be minimal with large cache + if eviction_warnings: + print(f"Note: {len(eviction_warnings)} keys were evicted during test") + + def test_concurrent_reset_stats(self): + """Test thread safety of reset_stats operations.""" + errors = [] + + def worker(thread_id): + """Worker that performs operations and resets stats.""" + for i in range(50): + # Perform operations + self.cache.put(f"key_{thread_id}_{i}", i) + self.cache.get(f"key_{thread_id}_{i}") + self.cache.get("non_existent") + + # Periodically reset stats + if i % 10 == 0: + self.cache.reset_stats() + + # Check stats integrity + stats = self.cache.get_stats() + if any(v < 0 for v in stats.values() if isinstance(v, (int, float))): + errors.append(f"Thread {thread_id}: Negative stat value: {stats}") + + # Run workers + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + for future in futures: + future.result() + + # Verify no errors + self.assertEqual(len(errors), 0, f"Stats errors: {errors[:5]}") + + def test_get_or_compute_concurrent_different_keys(self): + """Test get_or_compute with different keys being computed concurrently.""" + compute_counts = {} + lock = threading.Lock() + + async def compute_for_key(key): + """Compute function that tracks calls per key.""" + with lock: + compute_counts[key] = compute_counts.get(key, 0) + 1 + await asyncio.sleep(0.05) # Simulate work + return f"value_for_{key}" + + async def worker(thread_id, key_id): + """Worker that computes values for specific keys.""" + key = f"key_{key_id}" + result = await self.cache.get_or_compute( + key, lambda: compute_for_key(key), default="error" + ) + return key, result + + async def run_test(): + """Run concurrent computations for different keys.""" + # Create tasks for multiple keys, with some overlap + tasks = [] + for key_id in range(5): + for thread_id in range(3): # 3 threads per key + tasks.append(worker(thread_id, key_id)) + + results = await asyncio.gather(*tasks) + + # Verify each key was computed exactly once + for key_id in range(5): + key = f"key_{key_id}" + self.assertEqual( + compute_counts.get(key, 0), + 1, + f"{key} should be computed exactly once", + ) + + # Verify all threads got correct values + for key, value in results: + expected = f"value_for_{key}" + self.assertEqual(value, expected) + + asyncio.run(run_test()) + + def test_concurrent_operations_with_evictions(self): + """Test thread safety when cache is at capacity and evictions occur.""" + # Small cache to force evictions + small_cache = LFUCache(50, track_stats=True) + data_integrity_errors = [] + + def worker(thread_id): + """Worker that handles potential evictions gracefully.""" + for i in range(100): + key = f"t{thread_id}_k{i}" + value = f"t{thread_id}_v{i}" + + # Put value + small_cache.put(key, value) + + # Immediately access to increase frequency + retrieved = small_cache.get(key) + + # Value might be None if evicted immediately (unlikely but possible) + if retrieved is not None and retrieved != value: + # This would indicate actual data corruption + data_integrity_errors.append( + f"Wrong value for {key}: expected {value}, got {retrieved}" + ) + + # Also work with some persistent keys (access multiple times) + persistent_key = f"persistent_{thread_id % 5}" + for _ in range(3): # Access 3 times to increase frequency + small_cache.put(persistent_key, f"persistent_value_{thread_id}") + small_cache.get(persistent_key) + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Should have no data integrity errors (wrong values) + self.assertEqual( + len(data_integrity_errors), + 0, + f"Data integrity errors: {data_integrity_errors}", + ) + + # Cache should be at capacity + self.assertEqual(small_cache.size(), 50) + + # Stats should show many evictions + stats = small_cache.get_stats() + self.assertGreater(stats["evictions"], 0) + self.assertGreater(stats["puts"], 0) + + +class TestContentSafetyManagerThreadSafety(unittest.TestCase): + """Test thread safety of ContentSafetyManager.""" + + def setUp(self): + """Set up test fixtures.""" + # Create mock cache config + self.cache_config = MagicMock() + self.cache_config.enabled = True + self.cache_config.store = "memory" + self.cache_config.capacity_per_model = 100 + self.cache_config.stats.enabled = True + self.cache_config.stats.log_interval = None + self.cache_config.persistence.enabled = False + self.cache_config.persistence.interval = None + self.cache_config.persistence.path = None + + # Create mock model config + self.model_config = MagicMock() + self.model_config.cache = self.cache_config + self.model_config.model_mapping = {"alias_model": "actual_model"} + + def test_concurrent_cache_creation(self): + """Test that concurrent cache creation returns the same instance.""" + manager = ContentSafetyManager(self.model_config) + caches = [] + + def worker(thread_id): + """Worker that gets cache for model.""" + cache = manager.get_cache_for_model("test_model") + caches.append((thread_id, cache)) + return cache + + # Run many threads to increase chance of race condition + with ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(worker, i) for i in range(20)] + for future in futures: + future.result() + + # All caches should be the same instance + first_cache = caches[0][1] + for thread_id, cache in caches: + self.assertIs( + cache, first_cache, f"Thread {thread_id} got different cache instance" + ) + + def test_concurrent_multi_model_caches(self): + """Test concurrent access to caches for different models.""" + manager = ContentSafetyManager(self.model_config) + results = [] + + def worker(thread_id): + """Worker that accesses multiple model caches.""" + model_names = [f"model_{i}" for i in range(5)] + + for model_name in model_names: + cache = manager.get_cache_for_model(model_name) + + # Perform operations + key = f"thread_{thread_id}_key" + value = f"thread_{thread_id}_value" + cache.put(key, value) + retrieved = cache.get(key) + + if retrieved != value: + results.append(f"Mismatch for {model_name}: {retrieved} != {value}") + + # Run workers + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + for future in futures: + future.result() + + # Check for errors + self.assertEqual(len(results), 0, f"Errors found: {results}") + + def test_concurrent_persist_all_caches(self): + """Test thread safety of persist_all_caches method.""" + with tempfile.TemporaryDirectory() as temp_dir: + # Create mock config with persistence + cache_config = MagicMock() + cache_config.enabled = True + cache_config.store = "memory" + cache_config.capacity_per_model = 50 + cache_config.persistence.enabled = True + cache_config.persistence.interval = 1.0 + cache_config.persistence.path = f"{temp_dir}/cache_{{model_name}}.json" + cache_config.stats.enabled = True + cache_config.stats.log_interval = None + + model_config = MagicMock() + model_config.cache = cache_config + model_config.model_mapping = {} + + manager = ContentSafetyManager(model_config) + + # Create caches for multiple models + for i in range(5): + cache = manager.get_cache_for_model(f"model_{i}") + for j in range(10): + cache.put(f"key_{j}", f"value_{j}") + + persist_count = [0] + + def persist_worker(): + """Worker that calls persist_all_caches.""" + manager.persist_all_caches() + persist_count[0] += 1 + + def modify_worker(): + """Worker that modifies caches while persistence happens.""" + for i in range(20): + model_name = f"model_{i % 5}" + cache = manager.get_cache_for_model(model_name) + cache.put(f"new_key_{i}", f"new_value_{i}") + time.sleep(0.001) + + # Run persistence and modifications concurrently + threads = [] + + # Multiple persist threads + for _ in range(3): + t = threading.Thread(target=persist_worker) + threads.append(t) + t.start() + + # Modification thread + t = threading.Thread(target=modify_worker) + threads.append(t) + t.start() + + # Wait for all threads + for t in threads: + t.join() + + # Verify persistence was called + self.assertEqual(persist_count[0], 3) + + def test_model_alias_thread_safety(self): + """Test thread safety when using model aliases.""" + manager = ContentSafetyManager(self.model_config) + caches = [] + + def worker(use_alias): + """Worker that gets cache using alias or actual name.""" + if use_alias: + cache = manager.get_cache_for_model("alias_model") + else: + cache = manager.get_cache_for_model("actual_model") + caches.append(cache) + + # Mix of threads using alias and actual name + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [] + for i in range(10): + use_alias = i % 2 == 0 + futures.append(executor.submit(worker, use_alias)) + + for future in futures: + future.result() + + # All should get the same cache instance + first_cache = caches[0] + for cache in caches: + self.assertIs( + cache, + first_cache, + "Alias and actual model should resolve to same cache", + ) + + +if __name__ == "__main__": + unittest.main()