Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions tests/utils/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""
Unit tests for helper utilities in memori.utils.helpers.
Each test documents the expected behavior of key helper classes so future
refactors retain backwards compatibility.
"""

import sys
import types

if "memori.config.pool_config" not in sys.modules:
pool_module = types.ModuleType("memori.config.pool_config")

class PoolConfig:
DEFAULT_POOL_SIZE = 2
DEFAULT_MAX_OVERFLOW = 3
DEFAULT_POOL_TIMEOUT = 30
DEFAULT_POOL_RECYCLE = 3600
DEFAULT_POOL_PRE_PING = True

pool_module.PoolConfig = PoolConfig
pool_module.pool_config = PoolConfig()
sys.modules["memori.config.pool_config"] = pool_module

import asyncio
import os
from datetime import datetime, timedelta

import pytest

from memori.utils.helpers import (
AsyncUtils,
DateTimeUtils,
FileUtils,
JsonUtils,
RetryUtils,
StringUtils,
)


def test_string_utils_generate_id_and_prefix():
"""StringUtils should create unique identifiers and preserve prefixes."""
generated = StringUtils.generate_id("mem-")
other = StringUtils.generate_id("mem-")

assert generated.startswith("mem-")
assert generated != other # should be unique


def test_string_utils_truncate_and_sanitize_filename():
"""Verify truncate_text respects suffix and filenames are sanitized."""
truncated = StringUtils.truncate_text("abcdef", max_length=4, suffix="?")
assert truncated == "abc?"

sanitized = StringUtils.sanitize_filename("my:/invalid*file?.txt")
assert sanitized == "my__invalid_file_.txt"


def test_string_utils_hash_and_keyword_extraction():
"""Ensure deterministic hashing and basic keyword extraction."""
hashed = StringUtils.hash_text("memori")
assert hashed == StringUtils.hash_text("memori")

keywords = StringUtils.extract_keywords(
"The Memori memory layer connects SQL databases effortlessly", max_keywords=3
)
assert len(keywords) == 3
assert all(word not in {"the", "and"} for word in keywords)


def test_datetime_utils_basic_operations():
"""DateTimeUtils helpers should format/parse and handle offsets."""
now = DateTimeUtils.now()
formatted = DateTimeUtils.format_datetime(now)
parsed = DateTimeUtils.parse_datetime(formatted)

assert isinstance(parsed, datetime)
assert parsed.strftime("%Y-%m-%d %H:%M:%S") == formatted

future = DateTimeUtils.add_days(now, 2)
past = DateTimeUtils.subtract_days(now, 2)

assert future - now == timedelta(days=2)
assert now - past == timedelta(days=2)

old_time = datetime.now() - timedelta(hours=5)
assert DateTimeUtils.is_expired(old_time, expiry_hours=1)
assert "minute" in DateTimeUtils.time_ago_string(
datetime.now() - timedelta(minutes=3)
)


def test_json_utils_safe_operations():
"""JsonUtils should safely merge/load/dump even with invalid inputs."""
data = {"nested": {"value": 1}}
merged = JsonUtils.merge_dicts(data, {"nested": {"extra": 2}, "new": True})
assert merged["nested"]["value"] == 1
assert merged["nested"]["extra"] == 2
assert merged["new"] is True

assert JsonUtils.safe_loads('{"valid": true}', default={}) == {"valid": True}
assert JsonUtils.safe_loads("not json", default={"fallback": 1}) == {"fallback": 1}

dumped = JsonUtils.safe_dumps({"a": 1})
assert '"a": 1' in dumped


def test_file_utils_roundtrip(tmp_path):
"""Validate file helpers read/write/size-check and detect recency."""
file_path = tmp_path / "memori" / "test.txt"
FileUtils.ensure_directory(file_path.parent)

assert FileUtils.safe_write_text(file_path, "hello")
assert FileUtils.safe_read_text(file_path) == "hello"
assert FileUtils.get_file_size(file_path) > 0
assert FileUtils.is_file_recent(file_path)

# Make file look old to ensure is_file_recent can return False
old_timestamp = (datetime.now() - timedelta(days=3)).timestamp()
os.utime(file_path, (old_timestamp, old_timestamp))
assert not FileUtils.is_file_recent(file_path, hours=24)


def test_retry_utils_retries_until_success():
"""Retry decorator should retry until success within max attempts."""
attempts = {"count": 0}

@RetryUtils.retry_on_exception(
max_attempts=3, delay=0, backoff=1, exceptions=(ValueError,)
)
def flaky():
attempts["count"] += 1
if attempts["count"] < 3:
raise ValueError("boom")
return "ok"

assert flaky() == "ok"
assert attempts["count"] == 3


def test_retry_utils_raises_on_failure():
"""Retry decorator should surface final exception when retries exhausted."""

@RetryUtils.retry_on_exception(
max_attempts=2, delay=0, backoff=1, exceptions=(RuntimeError,)
)
def always_fail():
raise RuntimeError("nope")

with pytest.raises(RuntimeError):
always_fail()


@pytest.mark.asyncio
async def test_async_utils_gather_with_concurrency():
"""Async gather helper should respect concurrency limits and order."""

async def echo(value, delay=0):
await asyncio.sleep(delay)
return value

tasks = [echo(i, delay=0.01) for i in range(5)]
results = await AsyncUtils.gather_with_concurrency(2, *tasks)
assert results == list(range(5))
77 changes: 77 additions & 0 deletions tests/utils/test_log_sanitizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Unit tests for memori.utils.log_sanitizer.
This suite documents expected redaction behavior so sensitive data never leaks
through logging regressions.
"""

import sys
import types

if "memori.config.pool_config" not in sys.modules:
pool_module = types.ModuleType("memori.config.pool_config")

class PoolConfig:
DEFAULT_POOL_SIZE = 2
DEFAULT_MAX_OVERFLOW = 3
DEFAULT_POOL_TIMEOUT = 30
DEFAULT_POOL_RECYCLE = 3600
DEFAULT_POOL_PRE_PING = True

pool_module.PoolConfig = PoolConfig
pool_module.pool_config = PoolConfig()
sys.modules["memori.config.pool_config"] = pool_module

from memori.utils.log_sanitizer import (
LogSanitizer,
SanitizedLogger,
sanitize_dict_for_logging,
sanitize_for_logging,
)


def test_log_sanitizer_replaces_sensitive_tokens():
"""Sensitive tokens (emails, cards, api keys) should be redacted."""
raw = (
"Contact me at user@example.com, token=abcd1234secret, "
"and card 1234-5678-9012-3456"
)
sanitized = sanitize_for_logging(raw, max_length=200)

assert "[EMAIL_REDACTED]" in sanitized
assert "[CARD_REDACTED]" in sanitized
assert "token=[REDACTED]" in sanitized
assert "user@example.com" not in sanitized


def test_sanitize_dict_handles_multiple_values():
"""Dictionary sanitizer should sanitize string values and stringify others."""
payload = {"email": "someone@example.com", "count": 5}
sanitized = sanitize_dict_for_logging(payload, max_length=50)

assert sanitized["email"] == "[EMAIL_REDACTED]"
assert sanitized["count"] == "5"


def test_log_sanitizer_truncates_long_text():
"""Sanitizer should truncate when max_length is provided."""
text = "a" * 50
sanitized = LogSanitizer.sanitize(text, max_length=10, truncate_suffix="...")
assert sanitized.startswith("a" * 10)
assert sanitized.endswith("...")


def test_sanitized_logger_sanitizes_messages():
"""SanitizedLogger should emit redacted output before logging."""
records = []

class DummyLogger:
def info(self, message, *args, **kwargs):
records.append(message)

logger = SanitizedLogger(logger_instance=DummyLogger(), max_length=200)
logger.info("My email is admin@example.com and password is secret")

assert len(records) == 1
assert "[EMAIL_REDACTED]" in records[0]
assert "password" in records[0]
116 changes: 116 additions & 0 deletions tests/utils/test_rate_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Unit tests for memori.utils.rate_limiter covering window/quota helpers.
"""

import sys
import types

if "memori.config.pool_config" not in sys.modules:
pool_module = types.ModuleType("memori.config.pool_config")

class PoolConfig:
DEFAULT_POOL_SIZE = 2
DEFAULT_MAX_OVERFLOW = 3
DEFAULT_POOL_TIMEOUT = 30
DEFAULT_POOL_RECYCLE = 3600
DEFAULT_POOL_PRE_PING = True

pool_module.PoolConfig = PoolConfig
pool_module.pool_config = PoolConfig()
sys.modules["memori.config.pool_config"] = pool_module

import pytest

from memori.utils import rate_limiter as rl
from memori.utils.rate_limiter import QuotaExceeded, RateLimiter, RateLimitExceeded


def test_rate_limiter_enforces_limit_and_resets(monkeypatch):
"""Rate limiter should enforce limits per window and reset after expiry."""
fake_time = {"value": 0.0}

def _fake_time():
return fake_time["value"]

monkeypatch.setattr(rl.time, "time", _fake_time)

limiter = RateLimiter()
allowed, _ = limiter.check_rate_limit("user", "op", limit=1, window_seconds=10)
assert allowed

allowed, error = limiter.check_rate_limit("user", "op", limit=1, window_seconds=10)
assert not allowed
assert "Rate limit exceeded" in error

fake_time["value"] = 65
allowed, _ = limiter.check_rate_limit("user", "op", limit=1, window_seconds=10)
assert allowed # window reset


def test_rate_limiter_storage_and_memory_quotas():
"""Storage and memory quotas should enforce limits and track increments."""
limiter = RateLimiter()

ok, _ = limiter.check_storage_quota("user", additional_bytes=50, limit_bytes=100)
assert ok
limiter.increment_quota("user", "storage_bytes", amount=50)

allowed, message = limiter.check_storage_quota(
"user", additional_bytes=60, limit_bytes=80
)
assert not allowed
assert "Storage quota exceeded" in message

limiter.increment_quota("user", "memory_count", amount=5)
allowed, message = limiter.check_memory_count_quota("user", limit=5)
assert not allowed
assert "Memory count quota exceeded" in message


def test_rate_limiter_api_calls_reset(monkeypatch):
"""Daily API quota should reset when the last reset timestamp is stale."""
limiter = RateLimiter()

quota = limiter._quotas["user"]
quota.api_calls_today = 1_000
quota.last_reset = quota.last_reset - rl.timedelta(days=2)

allowed, _ = limiter.check_api_call_quota("user", limit=1_000)
assert allowed
assert quota.api_calls_today == 0


def test_rate_limit_decorator_raises(monkeypatch):
"""rate_limited decorator should raise when exceeding allowed invocations."""
local_limiter = RateLimiter()
monkeypatch.setattr(rl, "_global_limiter", local_limiter)

class Dummy:
user_id = "user"

@rl.rate_limited("op", limit=1, window_seconds=60)
def action(self):
return "done"

instance = Dummy()
assert instance.action() == "done"

with pytest.raises(RateLimitExceeded):
instance.action()


def test_storage_quota_decorator(monkeypatch):
"""storage_quota decorator should raise QuotaExceeded when payload is heavy."""
local_limiter = RateLimiter()
monkeypatch.setattr(rl, "_global_limiter", local_limiter)

class Dummy:
user_id = "user"

@rl.storage_quota(limit_bytes=10)
def save(self, user_input="", ai_output=""):
return len(user_input) + len(ai_output)

d = Dummy()
with pytest.raises(QuotaExceeded):
d.save(user_input="a" * 20, ai_output="b")