Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
# Changelog

## [1.2.1]

### 🛠️ Technical Improvements

- **NOSCRIPT Error Recovery for Async Operations**: The `arun_sha()` function now automatically recovers from NOSCRIPT errors (e.g., after Redis restart) by re-registering Lua scripts and retrying.
- If scripts fail to execute after re-registration, raises `PersistentNoScriptError` indicating a server-side issue.


## [1.2.0]

### 🔄 Changed
Expand Down
4 changes: 2 additions & 2 deletions rapyer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from rapyer.fields.key import KeyAnnotation
from rapyer.fields.safe_load import SafeLoadAnnotation
from rapyer.links import REDIS_SUPPORTED_LINK
from rapyer.scripts import handle_noscript_error
from rapyer.scripts import registry as scripts_registry
from rapyer.types.base import RedisType, REDIS_DUMP_FLAG_NAME, FAILED_FIELDS_KEY
from rapyer.types.convert import RedisConverter
from rapyer.typing_support import Self, Unpack
Expand Down Expand Up @@ -570,7 +570,7 @@ async def apipeline(
raise

if noscript_on_first_attempt:
await handle_noscript_error(self.Meta.redis)
await scripts_registry.handle_noscript_error(self.Meta.redis, self.Meta)
evalsha_commands = [
(args, options)
for args, options in commands_backup
Expand Down
3 changes: 3 additions & 0 deletions rapyer/scripts/constants.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
REDIS_VARIANT = "redis"
FAKEREDIS_VARIANT = "fakeredis"

REMOVE_RANGE_SCRIPT_NAME = "remove_range"
NUM_MUL_SCRIPT_NAME = "num_mul"
NUM_FLOORDIV_SCRIPT_NAME = "num_floordiv"
Expand Down
7 changes: 4 additions & 3 deletions rapyer/scripts/loader.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from functools import lru_cache
from importlib import resources

from rapyer.scripts.constants import FAKEREDIS_VARIANT, REDIS_VARIANT

VARIANTS = {
"redis": {
REDIS_VARIANT: {
"EXTRACT_ARRAY": "local arr = cjson.decode(arr_json)[1]",
"EXTRACT_VALUE": "local value = tonumber(cjson.decode(current_json)[1])",
"EXTRACT_STR": "local value = cjson.decode(current_json)[1]",
Expand All @@ -19,7 +20,7 @@
extracted = parsed
end""",
},
"fakeredis": {
FAKEREDIS_VARIANT: {
"EXTRACT_ARRAY": "local arr = cjson.decode(arr_json)",
"EXTRACT_VALUE": "local value = tonumber(cjson.decode(current_json)[1])",
"EXTRACT_STR": "local value = cjson.decode(current_json)[1]",
Expand All @@ -45,7 +46,7 @@ def _load_template(category: str, name: str) -> str:
return resources.files(package).joinpath(filename).read_text()


def load_script(category: str, name: str, variant: str = "redis") -> str:
def load_script(category: str, name: str, variant: str = REDIS_VARIANT) -> str:
template = _load_template(category, name)
replacements = VARIANTS[variant]
result = template
Expand Down
39 changes: 31 additions & 8 deletions rapyer/scripts/registry.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
from rapyer.errors import ScriptsNotInitializedError
from typing import TYPE_CHECKING

from rapyer.errors import PersistentNoScriptError, ScriptsNotInitializedError
from rapyer.scripts.constants import (
DATETIME_ADD_SCRIPT_NAME,
DICT_POP_SCRIPT_NAME,
DICT_POPITEM_SCRIPT_NAME,
FAKEREDIS_VARIANT,
NUM_FLOORDIV_SCRIPT_NAME,
NUM_MOD_SCRIPT_NAME,
NUM_MUL_SCRIPT_NAME,
NUM_POW_FLOAT_SCRIPT_NAME,
NUM_POW_SCRIPT_NAME,
NUM_TRUEDIV_SCRIPT_NAME,
REDIS_VARIANT,
REMOVE_RANGE_SCRIPT_NAME,
STR_APPEND_SCRIPT_NAME,
STR_MUL_SCRIPT_NAME,
)
from rapyer.scripts.loader import load_script
from redis.exceptions import NoScriptError

if TYPE_CHECKING:
from rapyer.config import RedisConfig

SCRIPT_REGISTRY: list[tuple[str, str, str]] = [
("list", "remove_range", REMOVE_RANGE_SCRIPT_NAME),
Expand Down Expand Up @@ -41,15 +49,15 @@ def _build_scripts(variant: str) -> dict[str, str]:


def get_scripts() -> dict[str, str]:
return _build_scripts("redis")
return _build_scripts(REDIS_VARIANT)


def get_scripts_fakeredis() -> dict[str, str]:
return _build_scripts("fakeredis")
return _build_scripts(FAKEREDIS_VARIANT)


async def register_scripts(redis_client, is_fakeredis: bool = False) -> None:
variant = "fakeredis" if is_fakeredis else "redis"
variant = FAKEREDIS_VARIANT if is_fakeredis else REDIS_VARIANT
scripts = _build_scripts(variant)
for name, script_text in scripts.items():
sha = await redis_client.script_load(script_text)
Expand All @@ -70,10 +78,25 @@ def run_sha(pipeline, script_name: str, keys: int, *args):
pipeline.evalsha(sha, keys, *args)


async def arun_sha(client, script_name: str, keys: int, *args):
async def arun_sha(
client, redis_config: "RedisConfig", script_name: str, keys: int, *args
):
sha = get_script(script_name)
try:
return await client.evalsha(sha, keys, *args)
except NoScriptError:
pass

await handle_noscript_error(client, redis_config)
sha = get_script(script_name)
return await client.evalsha(sha, keys, *args)
try:
return await client.evalsha(sha, keys, *args)
except NoScriptError as e:
raise PersistentNoScriptError(
"NOSCRIPT error persisted after re-registering scripts. "
"This indicates a server-side problem with Redis."
) from e


async def handle_noscript_error(redis_client) -> None:
await register_scripts(redis_client)
async def handle_noscript_error(redis_client, redis_config: "RedisConfig"):
await register_scripts(redis_client, is_fakeredis=redis_config.is_fake_redis)
15 changes: 13 additions & 2 deletions rapyer/types/dct.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,13 @@ async def aupdate(self, **kwargs):

async def apop(self, key, default=None):
result = await arun_sha(
self.client, DICT_POP_SCRIPT_NAME, 1, self.key, self.json_path, key
self.client,
self.Meta,
DICT_POP_SCRIPT_NAME,
1,
self.key,
self.json_path,
key,
)
super().pop(key, None)
await self.refresh_ttl_if_needed()
Expand All @@ -120,7 +126,12 @@ async def apop(self, key, default=None):

async def apopitem(self):
result = await arun_sha(
self.client, DICT_POPITEM_SCRIPT_NAME, 1, self.key, self.json_path
self.client,
self.Meta,
DICT_POPITEM_SCRIPT_NAME,
1,
self.key,
self.json_path,
)
await self.refresh_ttl_if_needed()

Expand Down
15 changes: 14 additions & 1 deletion tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import os
from dataclasses import dataclass
from typing import Generic, TypeVar
from unittest.mock import AsyncMock, patch

import pytest
import pytest_asyncio

import rapyer
from rapyer.scripts import register_scripts

Expand Down Expand Up @@ -248,6 +249,18 @@ async def saved_model_with_reduced_ttl(real_redis_client):
await model.adelete()


@pytest_asyncio.fixture
async def flush_scripts(real_redis_client):
await real_redis_client.execute_command("SCRIPT", "FLUSH")
yield


@pytest.fixture
def disable_noscript_recovery():
with patch("rapyer.scripts.registry.handle_noscript_error", new_callable=AsyncMock):
yield


@pytest_asyncio.fixture
async def saved_no_refresh_model_with_reduced_ttl(real_redis_client):
model = TTLRefreshDisabledModel(
Expand Down
23 changes: 22 additions & 1 deletion tests/integration/dct/test_redis_dict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from datetime import datetime
from unittest.mock import AsyncMock, patch

import pytest

from rapyer.base import AtomicRedisModel
from rapyer.errors import PersistentNoScriptError
from rapyer.types.dct import RedisDict
from redis.exceptions import NoScriptError
from tests.models.collection_types import (
IntDictModel,
StrDictModel,
Expand All @@ -16,6 +18,7 @@
ListDictModel,
NestedDictModel,
BaseDictMetadataModel,
ComprehensiveTestModel,
)
from tests.models.common import Status, Person

Expand Down Expand Up @@ -541,3 +544,21 @@ async def test_redis_dict__apop_empty_redis__check_no_default_sanity(

# Assert
assert result is None


@pytest.mark.asyncio
async def test_redis_dict__apop_raises_persistent_noscript_error_when_scripts_keep_failing(
disable_noscript_recovery,
):
# Arrange
model = ComprehensiveTestModel(metadata={"key1": "value1"})
await model.asave()

mock_evalsha = AsyncMock(side_effect=NoScriptError("NOSCRIPT"))

# Act & Assert
with patch.object(model.Meta.redis, "evalsha", mock_evalsha):
with pytest.raises(PersistentNoScriptError) as exc_info:
await model.metadata.apop("key1")

assert "server-side" in str(exc_info.value).lower()
57 changes: 23 additions & 34 deletions tests/integration/pipeline/test_pipeline_noscript_recovery.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from unittest.mock import AsyncMock, patch

import pytest

from rapyer.errors import PersistentNoScriptError
Expand All @@ -8,32 +6,32 @@


@pytest.mark.asyncio
async def test_pipeline_recovers_from_noscript_error_after_script_flush_sanity():
async def test_pipeline_recovers_from_noscript_error_after_script_flush_sanity(
flush_scripts,
):
# Arrange
model = ComprehensiveTestModel(
tags=["a", "b", "c", "d", "e"],
metadata={"key1": "value1"},
)
await model.asave()

# Act - flush scripts mid-pipeline to simulate Redis restart
# Act
async with model.apipeline() as redis_model:
# Multiple pipeline operations to verify all are executed
redis_model.tags.append("f") # Regular pipeline command (ARRAPPEND)
redis_model.tags.remove_range(1, 3) # Uses evalsha (Lua script)
redis_model.metadata["key2"] = "value2" # Dict setitem (JSON.SET)

# Simulate Redis restart by flushing all scripts
await model.Meta.redis.execute_command("SCRIPT", "FLUSH")
redis_model.tags.append("f")
redis_model.tags.remove_range(1, 3)
redis_model.metadata["key2"] = "value2"

# Assert - pipeline should have recovered and ALL changes applied
# Assert
final_model = await ComprehensiveTestModel.aget(model.key)
assert final_model.tags == ["a", "d", "e", "f"]
assert final_model.metadata == {"key1": "value1", "key2": "value2"}


@pytest.mark.asyncio
async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity():
async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity(
flush_scripts,
):
# Arrange
model = TTLRefreshTestModel(
name="original",
Expand All @@ -44,27 +42,17 @@ async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity(
)
await model.asave()

# Act - operations on all Redis types then flush scripts
# Act
async with model.apipeline() as redis_model:
# RedisInt operations
redis_model.age += 5

# RedisFloat operations
redis_model.score += 2.5

# RedisList operations
redis_model.tags.append("f")
redis_model.tags[0] = "new_a"
redis_model.tags.remove_range(1, 3) # Lua script (evalsha)

# RedisDict operations
redis_model.tags.remove_range(1, 3)
redis_model.settings["setting2"] = "value2"
redis_model.settings.update({"setting3": "value3"})

# Simulate Redis restart
await model.Meta.redis.execute_command("SCRIPT", "FLUSH")

# Assert - all operations on all types should succeed
# Assert
final_model = await TTLRefreshTestModel.aget(model.key)
assert final_model.age == 15
assert final_model.score == 4.0
Expand All @@ -79,16 +67,17 @@ async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity(


@pytest.mark.asyncio
async def test_pipeline_raises_persistent_noscript_error_when_scripts_keep_failing_error():
async def test_pipeline_raises_persistent_noscript_error_when_scripts_keep_failing_error(
flush_scripts,
disable_noscript_recovery,
):
# Arrange
model = ComprehensiveTestModel(tags=["a", "b", "c"])
await model.asave()
await model.Meta.redis.execute_command("SCRIPT", "FLUSH")

# Act & Assert - patch handle_noscript_error to not actually register scripts
with patch("rapyer.base.handle_noscript_error", new_callable=AsyncMock):
with pytest.raises(PersistentNoScriptError) as exc_info:
async with model.apipeline() as redis_model:
redis_model.tags.remove_range(0, 1)
# Act & Assert
with pytest.raises(PersistentNoScriptError) as exc_info:
async with model.apipeline() as redis_model:
redis_model.tags.remove_range(0, 1)

assert "server-side" in str(exc_info.value).lower()
assert "server-side" in str(exc_info.value).lower()
21 changes: 19 additions & 2 deletions tests/unit/test_scripts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from unittest.mock import AsyncMock, MagicMock

import pytest

from rapyer.errors import ScriptsNotInitializedError
from rapyer.scripts import (
run_sha,
Expand Down Expand Up @@ -52,15 +51,33 @@ async def test_handle_noscript_error_reloads_scripts_sanity(clear_script_state):
# Arrange
mock_redis = AsyncMock()
mock_redis.script_load = AsyncMock(return_value="new_sha_456")
mock_config = MagicMock()
mock_config.is_fake_redis = False

# Act
await handle_noscript_error(mock_redis)
await handle_noscript_error(mock_redis, mock_config)

# Assert
mock_redis.script_load.assert_called()
assert _REGISTERED_SCRIPT_SHAS.get(REMOVE_RANGE_SCRIPT_NAME) == "new_sha_456"


@pytest.mark.asyncio
async def test_handle_noscript_error_reloads_scripts_with_fakeredis(clear_script_state):
# Arrange
mock_redis = AsyncMock()
mock_redis.script_load = AsyncMock(return_value="fakeredis_sha_789")
mock_config = MagicMock()
mock_config.is_fake_redis = True

# Act
await handle_noscript_error(mock_redis, mock_config)

# Assert
mock_redis.script_load.assert_called()
assert _REGISTERED_SCRIPT_SHAS.get(REMOVE_RANGE_SCRIPT_NAME) == "fakeredis_sha_789"


@pytest.mark.asyncio
async def test_register_scripts_stores_shas_sanity(clear_script_state):
# Arrange
Expand Down
Loading
Loading