From 57566a9eedaee9170f05cf7ec99f0ef7832e9fd3 Mon Sep 17 00:00:00 2001 From: happybear-21 Date: Mon, 17 Nov 2025 16:29:09 +0530 Subject: [PATCH] added: tests memori/utils, helpers, log_sanitizer, rate_limiter --- tests/utils/test_helpers.py | 164 ++++++++++++++++++++++++++++++ tests/utils/test_log_sanitizer.py | 77 ++++++++++++++ tests/utils/test_rate_limiter.py | 116 +++++++++++++++++++++ 3 files changed, 357 insertions(+) create mode 100644 tests/utils/test_helpers.py create mode 100644 tests/utils/test_log_sanitizer.py create mode 100644 tests/utils/test_rate_limiter.py diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py new file mode 100644 index 00000000..d4459bd9 --- /dev/null +++ b/tests/utils/test_helpers.py @@ -0,0 +1,164 @@ +""" +Unit tests for helper utilities in memori.utils.helpers. + +Each test documents the expected behavior of key helper classes so future +refactors retain backwards compatibility. +""" + +import sys +import types + +if "memori.config.pool_config" not in sys.modules: + pool_module = types.ModuleType("memori.config.pool_config") + + class PoolConfig: + DEFAULT_POOL_SIZE = 2 + DEFAULT_MAX_OVERFLOW = 3 + DEFAULT_POOL_TIMEOUT = 30 + DEFAULT_POOL_RECYCLE = 3600 + DEFAULT_POOL_PRE_PING = True + + pool_module.PoolConfig = PoolConfig + pool_module.pool_config = PoolConfig() + sys.modules["memori.config.pool_config"] = pool_module + +import asyncio +import os +from datetime import datetime, timedelta + +import pytest + +from memori.utils.helpers import ( + AsyncUtils, + DateTimeUtils, + FileUtils, + JsonUtils, + RetryUtils, + StringUtils, +) + + +def test_string_utils_generate_id_and_prefix(): + """StringUtils should create unique identifiers and preserve prefixes.""" + generated = StringUtils.generate_id("mem-") + other = StringUtils.generate_id("mem-") + + assert generated.startswith("mem-") + assert generated != other # should be unique + + +def test_string_utils_truncate_and_sanitize_filename(): + """Verify truncate_text respects suffix and filenames are sanitized.""" + truncated = StringUtils.truncate_text("abcdef", max_length=4, suffix="?") + assert truncated == "abc?" + + sanitized = StringUtils.sanitize_filename("my:/invalid*file?.txt") + assert sanitized == "my__invalid_file_.txt" + + +def test_string_utils_hash_and_keyword_extraction(): + """Ensure deterministic hashing and basic keyword extraction.""" + hashed = StringUtils.hash_text("memori") + assert hashed == StringUtils.hash_text("memori") + + keywords = StringUtils.extract_keywords( + "The Memori memory layer connects SQL databases effortlessly", max_keywords=3 + ) + assert len(keywords) == 3 + assert all(word not in {"the", "and"} for word in keywords) + + +def test_datetime_utils_basic_operations(): + """DateTimeUtils helpers should format/parse and handle offsets.""" + now = DateTimeUtils.now() + formatted = DateTimeUtils.format_datetime(now) + parsed = DateTimeUtils.parse_datetime(formatted) + + assert isinstance(parsed, datetime) + assert parsed.strftime("%Y-%m-%d %H:%M:%S") == formatted + + future = DateTimeUtils.add_days(now, 2) + past = DateTimeUtils.subtract_days(now, 2) + + assert future - now == timedelta(days=2) + assert now - past == timedelta(days=2) + + old_time = datetime.now() - timedelta(hours=5) + assert DateTimeUtils.is_expired(old_time, expiry_hours=1) + assert "minute" in DateTimeUtils.time_ago_string( + datetime.now() - timedelta(minutes=3) + ) + + +def test_json_utils_safe_operations(): + """JsonUtils should safely merge/load/dump even with invalid inputs.""" + data = {"nested": {"value": 1}} + merged = JsonUtils.merge_dicts(data, {"nested": {"extra": 2}, "new": True}) + assert merged["nested"]["value"] == 1 + assert merged["nested"]["extra"] == 2 + assert merged["new"] is True + + assert JsonUtils.safe_loads('{"valid": true}', default={}) == {"valid": True} + assert JsonUtils.safe_loads("not json", default={"fallback": 1}) == {"fallback": 1} + + dumped = JsonUtils.safe_dumps({"a": 1}) + assert '"a": 1' in dumped + + +def test_file_utils_roundtrip(tmp_path): + """Validate file helpers read/write/size-check and detect recency.""" + file_path = tmp_path / "memori" / "test.txt" + FileUtils.ensure_directory(file_path.parent) + + assert FileUtils.safe_write_text(file_path, "hello") + assert FileUtils.safe_read_text(file_path) == "hello" + assert FileUtils.get_file_size(file_path) > 0 + assert FileUtils.is_file_recent(file_path) + + # Make file look old to ensure is_file_recent can return False + old_timestamp = (datetime.now() - timedelta(days=3)).timestamp() + os.utime(file_path, (old_timestamp, old_timestamp)) + assert not FileUtils.is_file_recent(file_path, hours=24) + + +def test_retry_utils_retries_until_success(): + """Retry decorator should retry until success within max attempts.""" + attempts = {"count": 0} + + @RetryUtils.retry_on_exception( + max_attempts=3, delay=0, backoff=1, exceptions=(ValueError,) + ) + def flaky(): + attempts["count"] += 1 + if attempts["count"] < 3: + raise ValueError("boom") + return "ok" + + assert flaky() == "ok" + assert attempts["count"] == 3 + + +def test_retry_utils_raises_on_failure(): + """Retry decorator should surface final exception when retries exhausted.""" + + @RetryUtils.retry_on_exception( + max_attempts=2, delay=0, backoff=1, exceptions=(RuntimeError,) + ) + def always_fail(): + raise RuntimeError("nope") + + with pytest.raises(RuntimeError): + always_fail() + + +@pytest.mark.asyncio +async def test_async_utils_gather_with_concurrency(): + """Async gather helper should respect concurrency limits and order.""" + + async def echo(value, delay=0): + await asyncio.sleep(delay) + return value + + tasks = [echo(i, delay=0.01) for i in range(5)] + results = await AsyncUtils.gather_with_concurrency(2, *tasks) + assert results == list(range(5)) diff --git a/tests/utils/test_log_sanitizer.py b/tests/utils/test_log_sanitizer.py new file mode 100644 index 00000000..fc3c5341 --- /dev/null +++ b/tests/utils/test_log_sanitizer.py @@ -0,0 +1,77 @@ +""" +Unit tests for memori.utils.log_sanitizer. + +This suite documents expected redaction behavior so sensitive data never leaks +through logging regressions. +""" + +import sys +import types + +if "memori.config.pool_config" not in sys.modules: + pool_module = types.ModuleType("memori.config.pool_config") + + class PoolConfig: + DEFAULT_POOL_SIZE = 2 + DEFAULT_MAX_OVERFLOW = 3 + DEFAULT_POOL_TIMEOUT = 30 + DEFAULT_POOL_RECYCLE = 3600 + DEFAULT_POOL_PRE_PING = True + + pool_module.PoolConfig = PoolConfig + pool_module.pool_config = PoolConfig() + sys.modules["memori.config.pool_config"] = pool_module + +from memori.utils.log_sanitizer import ( + LogSanitizer, + SanitizedLogger, + sanitize_dict_for_logging, + sanitize_for_logging, +) + + +def test_log_sanitizer_replaces_sensitive_tokens(): + """Sensitive tokens (emails, cards, api keys) should be redacted.""" + raw = ( + "Contact me at user@example.com, token=abcd1234secret, " + "and card 1234-5678-9012-3456" + ) + sanitized = sanitize_for_logging(raw, max_length=200) + + assert "[EMAIL_REDACTED]" in sanitized + assert "[CARD_REDACTED]" in sanitized + assert "token=[REDACTED]" in sanitized + assert "user@example.com" not in sanitized + + +def test_sanitize_dict_handles_multiple_values(): + """Dictionary sanitizer should sanitize string values and stringify others.""" + payload = {"email": "someone@example.com", "count": 5} + sanitized = sanitize_dict_for_logging(payload, max_length=50) + + assert sanitized["email"] == "[EMAIL_REDACTED]" + assert sanitized["count"] == "5" + + +def test_log_sanitizer_truncates_long_text(): + """Sanitizer should truncate when max_length is provided.""" + text = "a" * 50 + sanitized = LogSanitizer.sanitize(text, max_length=10, truncate_suffix="...") + assert sanitized.startswith("a" * 10) + assert sanitized.endswith("...") + + +def test_sanitized_logger_sanitizes_messages(): + """SanitizedLogger should emit redacted output before logging.""" + records = [] + + class DummyLogger: + def info(self, message, *args, **kwargs): + records.append(message) + + logger = SanitizedLogger(logger_instance=DummyLogger(), max_length=200) + logger.info("My email is admin@example.com and password is secret") + + assert len(records) == 1 + assert "[EMAIL_REDACTED]" in records[0] + assert "password" in records[0] diff --git a/tests/utils/test_rate_limiter.py b/tests/utils/test_rate_limiter.py new file mode 100644 index 00000000..f90a521f --- /dev/null +++ b/tests/utils/test_rate_limiter.py @@ -0,0 +1,116 @@ +""" +Unit tests for memori.utils.rate_limiter covering window/quota helpers. +""" + +import sys +import types + +if "memori.config.pool_config" not in sys.modules: + pool_module = types.ModuleType("memori.config.pool_config") + + class PoolConfig: + DEFAULT_POOL_SIZE = 2 + DEFAULT_MAX_OVERFLOW = 3 + DEFAULT_POOL_TIMEOUT = 30 + DEFAULT_POOL_RECYCLE = 3600 + DEFAULT_POOL_PRE_PING = True + + pool_module.PoolConfig = PoolConfig + pool_module.pool_config = PoolConfig() + sys.modules["memori.config.pool_config"] = pool_module + +import pytest + +from memori.utils import rate_limiter as rl +from memori.utils.rate_limiter import QuotaExceeded, RateLimiter, RateLimitExceeded + + +def test_rate_limiter_enforces_limit_and_resets(monkeypatch): + """Rate limiter should enforce limits per window and reset after expiry.""" + fake_time = {"value": 0.0} + + def _fake_time(): + return fake_time["value"] + + monkeypatch.setattr(rl.time, "time", _fake_time) + + limiter = RateLimiter() + allowed, _ = limiter.check_rate_limit("user", "op", limit=1, window_seconds=10) + assert allowed + + allowed, error = limiter.check_rate_limit("user", "op", limit=1, window_seconds=10) + assert not allowed + assert "Rate limit exceeded" in error + + fake_time["value"] = 65 + allowed, _ = limiter.check_rate_limit("user", "op", limit=1, window_seconds=10) + assert allowed # window reset + + +def test_rate_limiter_storage_and_memory_quotas(): + """Storage and memory quotas should enforce limits and track increments.""" + limiter = RateLimiter() + + ok, _ = limiter.check_storage_quota("user", additional_bytes=50, limit_bytes=100) + assert ok + limiter.increment_quota("user", "storage_bytes", amount=50) + + allowed, message = limiter.check_storage_quota( + "user", additional_bytes=60, limit_bytes=80 + ) + assert not allowed + assert "Storage quota exceeded" in message + + limiter.increment_quota("user", "memory_count", amount=5) + allowed, message = limiter.check_memory_count_quota("user", limit=5) + assert not allowed + assert "Memory count quota exceeded" in message + + +def test_rate_limiter_api_calls_reset(monkeypatch): + """Daily API quota should reset when the last reset timestamp is stale.""" + limiter = RateLimiter() + + quota = limiter._quotas["user"] + quota.api_calls_today = 1_000 + quota.last_reset = quota.last_reset - rl.timedelta(days=2) + + allowed, _ = limiter.check_api_call_quota("user", limit=1_000) + assert allowed + assert quota.api_calls_today == 0 + + +def test_rate_limit_decorator_raises(monkeypatch): + """rate_limited decorator should raise when exceeding allowed invocations.""" + local_limiter = RateLimiter() + monkeypatch.setattr(rl, "_global_limiter", local_limiter) + + class Dummy: + user_id = "user" + + @rl.rate_limited("op", limit=1, window_seconds=60) + def action(self): + return "done" + + instance = Dummy() + assert instance.action() == "done" + + with pytest.raises(RateLimitExceeded): + instance.action() + + +def test_storage_quota_decorator(monkeypatch): + """storage_quota decorator should raise QuotaExceeded when payload is heavy.""" + local_limiter = RateLimiter() + monkeypatch.setattr(rl, "_global_limiter", local_limiter) + + class Dummy: + user_id = "user" + + @rl.storage_quota(limit_bytes=10) + def save(self, user_input="", ai_output=""): + return len(user_input) + len(ai_output) + + d = Dummy() + with pytest.raises(QuotaExceeded): + d.save(user_input="a" * 20, ai_output="b")