diff --git a/CHANGELOG.md b/CHANGELOG.md index ad3a7f83..eb14f163 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## [1.2.1] + +### 🛠️ Technical Improvements + +- **NOSCRIPT Error Recovery for Async Operations**: The `arun_sha()` function now automatically recovers from NOSCRIPT errors (e.g., after Redis restart) by re-registering Lua scripts and retrying. + - If scripts fail to execute after re-registration, raises `PersistentNoScriptError` indicating a server-side issue. + + ## [1.2.0] ### 🔄 Changed diff --git a/rapyer/base.py b/rapyer/base.py index 63eff2aa..60f2e995 100644 --- a/rapyer/base.py +++ b/rapyer/base.py @@ -36,7 +36,7 @@ from rapyer.fields.key import KeyAnnotation from rapyer.fields.safe_load import SafeLoadAnnotation from rapyer.links import REDIS_SUPPORTED_LINK -from rapyer.scripts import handle_noscript_error +from rapyer.scripts import registry as scripts_registry from rapyer.types.base import RedisType, REDIS_DUMP_FLAG_NAME, FAILED_FIELDS_KEY from rapyer.types.convert import RedisConverter from rapyer.typing_support import Self, Unpack @@ -570,7 +570,7 @@ async def apipeline( raise if noscript_on_first_attempt: - await handle_noscript_error(self.Meta.redis) + await scripts_registry.handle_noscript_error(self.Meta.redis, self.Meta) evalsha_commands = [ (args, options) for args, options in commands_backup diff --git a/rapyer/scripts/constants.py b/rapyer/scripts/constants.py index a3aa251a..a666bb41 100644 --- a/rapyer/scripts/constants.py +++ b/rapyer/scripts/constants.py @@ -1,3 +1,6 @@ +REDIS_VARIANT = "redis" +FAKEREDIS_VARIANT = "fakeredis" + REMOVE_RANGE_SCRIPT_NAME = "remove_range" NUM_MUL_SCRIPT_NAME = "num_mul" NUM_FLOORDIV_SCRIPT_NAME = "num_floordiv" diff --git a/rapyer/scripts/loader.py b/rapyer/scripts/loader.py index 72d1c06e..baa4386e 100644 --- a/rapyer/scripts/loader.py +++ b/rapyer/scripts/loader.py @@ -1,9 +1,10 @@ from functools import lru_cache from importlib import resources +from rapyer.scripts.constants import FAKEREDIS_VARIANT, REDIS_VARIANT VARIANTS = { - "redis": { + REDIS_VARIANT: { "EXTRACT_ARRAY": "local arr = cjson.decode(arr_json)[1]", "EXTRACT_VALUE": "local value = tonumber(cjson.decode(current_json)[1])", "EXTRACT_STR": "local value = cjson.decode(current_json)[1]", @@ -19,7 +20,7 @@ extracted = parsed end""", }, - "fakeredis": { + FAKEREDIS_VARIANT: { "EXTRACT_ARRAY": "local arr = cjson.decode(arr_json)", "EXTRACT_VALUE": "local value = tonumber(cjson.decode(current_json)[1])", "EXTRACT_STR": "local value = cjson.decode(current_json)[1]", @@ -45,7 +46,7 @@ def _load_template(category: str, name: str) -> str: return resources.files(package).joinpath(filename).read_text() -def load_script(category: str, name: str, variant: str = "redis") -> str: +def load_script(category: str, name: str, variant: str = REDIS_VARIANT) -> str: template = _load_template(category, name) replacements = VARIANTS[variant] result = template diff --git a/rapyer/scripts/registry.py b/rapyer/scripts/registry.py index 9923ddb5..3540dc8f 100644 --- a/rapyer/scripts/registry.py +++ b/rapyer/scripts/registry.py @@ -1,19 +1,27 @@ -from rapyer.errors import ScriptsNotInitializedError +from typing import TYPE_CHECKING + +from rapyer.errors import PersistentNoScriptError, ScriptsNotInitializedError from rapyer.scripts.constants import ( DATETIME_ADD_SCRIPT_NAME, DICT_POP_SCRIPT_NAME, DICT_POPITEM_SCRIPT_NAME, + FAKEREDIS_VARIANT, NUM_FLOORDIV_SCRIPT_NAME, NUM_MOD_SCRIPT_NAME, NUM_MUL_SCRIPT_NAME, NUM_POW_FLOAT_SCRIPT_NAME, NUM_POW_SCRIPT_NAME, NUM_TRUEDIV_SCRIPT_NAME, + REDIS_VARIANT, REMOVE_RANGE_SCRIPT_NAME, STR_APPEND_SCRIPT_NAME, STR_MUL_SCRIPT_NAME, ) from rapyer.scripts.loader import load_script +from redis.exceptions import NoScriptError + +if TYPE_CHECKING: + from rapyer.config import RedisConfig SCRIPT_REGISTRY: list[tuple[str, str, str]] = [ ("list", "remove_range", REMOVE_RANGE_SCRIPT_NAME), @@ -41,15 +49,15 @@ def _build_scripts(variant: str) -> dict[str, str]: def get_scripts() -> dict[str, str]: - return _build_scripts("redis") + return _build_scripts(REDIS_VARIANT) def get_scripts_fakeredis() -> dict[str, str]: - return _build_scripts("fakeredis") + return _build_scripts(FAKEREDIS_VARIANT) async def register_scripts(redis_client, is_fakeredis: bool = False) -> None: - variant = "fakeredis" if is_fakeredis else "redis" + variant = FAKEREDIS_VARIANT if is_fakeredis else REDIS_VARIANT scripts = _build_scripts(variant) for name, script_text in scripts.items(): sha = await redis_client.script_load(script_text) @@ -70,10 +78,25 @@ def run_sha(pipeline, script_name: str, keys: int, *args): pipeline.evalsha(sha, keys, *args) -async def arun_sha(client, script_name: str, keys: int, *args): +async def arun_sha( + client, redis_config: "RedisConfig", script_name: str, keys: int, *args +): + sha = get_script(script_name) + try: + return await client.evalsha(sha, keys, *args) + except NoScriptError: + pass + + await handle_noscript_error(client, redis_config) sha = get_script(script_name) - return await client.evalsha(sha, keys, *args) + try: + return await client.evalsha(sha, keys, *args) + except NoScriptError as e: + raise PersistentNoScriptError( + "NOSCRIPT error persisted after re-registering scripts. " + "This indicates a server-side problem with Redis." + ) from e -async def handle_noscript_error(redis_client) -> None: - await register_scripts(redis_client) +async def handle_noscript_error(redis_client, redis_config: "RedisConfig"): + await register_scripts(redis_client, is_fakeredis=redis_config.is_fake_redis) diff --git a/rapyer/types/dct.py b/rapyer/types/dct.py index b82f9537..b592e6c9 100644 --- a/rapyer/types/dct.py +++ b/rapyer/types/dct.py @@ -106,7 +106,13 @@ async def aupdate(self, **kwargs): async def apop(self, key, default=None): result = await arun_sha( - self.client, DICT_POP_SCRIPT_NAME, 1, self.key, self.json_path, key + self.client, + self.Meta, + DICT_POP_SCRIPT_NAME, + 1, + self.key, + self.json_path, + key, ) super().pop(key, None) await self.refresh_ttl_if_needed() @@ -120,7 +126,12 @@ async def apop(self, key, default=None): async def apopitem(self): result = await arun_sha( - self.client, DICT_POPITEM_SCRIPT_NAME, 1, self.key, self.json_path + self.client, + self.Meta, + DICT_POPITEM_SCRIPT_NAME, + 1, + self.key, + self.json_path, ) await self.refresh_ttl_if_needed() diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 8c10e48a..b9d7da13 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -1,9 +1,10 @@ import os from dataclasses import dataclass from typing import Generic, TypeVar +from unittest.mock import AsyncMock, patch +import pytest import pytest_asyncio - import rapyer from rapyer.scripts import register_scripts @@ -248,6 +249,18 @@ async def saved_model_with_reduced_ttl(real_redis_client): await model.adelete() +@pytest_asyncio.fixture +async def flush_scripts(real_redis_client): + await real_redis_client.execute_command("SCRIPT", "FLUSH") + yield + + +@pytest.fixture +def disable_noscript_recovery(): + with patch("rapyer.scripts.registry.handle_noscript_error", new_callable=AsyncMock): + yield + + @pytest_asyncio.fixture async def saved_no_refresh_model_with_reduced_ttl(real_redis_client): model = TTLRefreshDisabledModel( diff --git a/tests/integration/dct/test_redis_dict.py b/tests/integration/dct/test_redis_dict.py index 131af318..81d8126d 100644 --- a/tests/integration/dct/test_redis_dict.py +++ b/tests/integration/dct/test_redis_dict.py @@ -1,9 +1,11 @@ from datetime import datetime +from unittest.mock import AsyncMock, patch import pytest - from rapyer.base import AtomicRedisModel +from rapyer.errors import PersistentNoScriptError from rapyer.types.dct import RedisDict +from redis.exceptions import NoScriptError from tests.models.collection_types import ( IntDictModel, StrDictModel, @@ -16,6 +18,7 @@ ListDictModel, NestedDictModel, BaseDictMetadataModel, + ComprehensiveTestModel, ) from tests.models.common import Status, Person @@ -541,3 +544,21 @@ async def test_redis_dict__apop_empty_redis__check_no_default_sanity( # Assert assert result is None + + +@pytest.mark.asyncio +async def test_redis_dict__apop_raises_persistent_noscript_error_when_scripts_keep_failing( + disable_noscript_recovery, +): + # Arrange + model = ComprehensiveTestModel(metadata={"key1": "value1"}) + await model.asave() + + mock_evalsha = AsyncMock(side_effect=NoScriptError("NOSCRIPT")) + + # Act & Assert + with patch.object(model.Meta.redis, "evalsha", mock_evalsha): + with pytest.raises(PersistentNoScriptError) as exc_info: + await model.metadata.apop("key1") + + assert "server-side" in str(exc_info.value).lower() diff --git a/tests/integration/pipeline/test_pipeline_noscript_recovery.py b/tests/integration/pipeline/test_pipeline_noscript_recovery.py index 4ed50e59..47b803cf 100644 --- a/tests/integration/pipeline/test_pipeline_noscript_recovery.py +++ b/tests/integration/pipeline/test_pipeline_noscript_recovery.py @@ -1,5 +1,3 @@ -from unittest.mock import AsyncMock, patch - import pytest from rapyer.errors import PersistentNoScriptError @@ -8,7 +6,9 @@ @pytest.mark.asyncio -async def test_pipeline_recovers_from_noscript_error_after_script_flush_sanity(): +async def test_pipeline_recovers_from_noscript_error_after_script_flush_sanity( + flush_scripts, +): # Arrange model = ComprehensiveTestModel( tags=["a", "b", "c", "d", "e"], @@ -16,24 +16,22 @@ async def test_pipeline_recovers_from_noscript_error_after_script_flush_sanity() ) await model.asave() - # Act - flush scripts mid-pipeline to simulate Redis restart + # Act async with model.apipeline() as redis_model: - # Multiple pipeline operations to verify all are executed - redis_model.tags.append("f") # Regular pipeline command (ARRAPPEND) - redis_model.tags.remove_range(1, 3) # Uses evalsha (Lua script) - redis_model.metadata["key2"] = "value2" # Dict setitem (JSON.SET) - - # Simulate Redis restart by flushing all scripts - await model.Meta.redis.execute_command("SCRIPT", "FLUSH") + redis_model.tags.append("f") + redis_model.tags.remove_range(1, 3) + redis_model.metadata["key2"] = "value2" - # Assert - pipeline should have recovered and ALL changes applied + # Assert final_model = await ComprehensiveTestModel.aget(model.key) assert final_model.tags == ["a", "d", "e", "f"] assert final_model.metadata == {"key1": "value1", "key2": "value2"} @pytest.mark.asyncio -async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity(): +async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity( + flush_scripts, +): # Arrange model = TTLRefreshTestModel( name="original", @@ -44,27 +42,17 @@ async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity( ) await model.asave() - # Act - operations on all Redis types then flush scripts + # Act async with model.apipeline() as redis_model: - # RedisInt operations redis_model.age += 5 - - # RedisFloat operations redis_model.score += 2.5 - - # RedisList operations redis_model.tags.append("f") redis_model.tags[0] = "new_a" - redis_model.tags.remove_range(1, 3) # Lua script (evalsha) - - # RedisDict operations + redis_model.tags.remove_range(1, 3) redis_model.settings["setting2"] = "value2" redis_model.settings.update({"setting3": "value3"}) - # Simulate Redis restart - await model.Meta.redis.execute_command("SCRIPT", "FLUSH") - - # Assert - all operations on all types should succeed + # Assert final_model = await TTLRefreshTestModel.aget(model.key) assert final_model.age == 15 assert final_model.score == 4.0 @@ -79,16 +67,17 @@ async def test_pipeline_recovers_with_all_redis_types_after_script_flush_sanity( @pytest.mark.asyncio -async def test_pipeline_raises_persistent_noscript_error_when_scripts_keep_failing_error(): +async def test_pipeline_raises_persistent_noscript_error_when_scripts_keep_failing_error( + flush_scripts, + disable_noscript_recovery, +): # Arrange model = ComprehensiveTestModel(tags=["a", "b", "c"]) await model.asave() - await model.Meta.redis.execute_command("SCRIPT", "FLUSH") - # Act & Assert - patch handle_noscript_error to not actually register scripts - with patch("rapyer.base.handle_noscript_error", new_callable=AsyncMock): - with pytest.raises(PersistentNoScriptError) as exc_info: - async with model.apipeline() as redis_model: - redis_model.tags.remove_range(0, 1) + # Act & Assert + with pytest.raises(PersistentNoScriptError) as exc_info: + async with model.apipeline() as redis_model: + redis_model.tags.remove_range(0, 1) - assert "server-side" in str(exc_info.value).lower() + assert "server-side" in str(exc_info.value).lower() diff --git a/tests/unit/test_scripts.py b/tests/unit/test_scripts.py index 6acef3fa..d8fea478 100644 --- a/tests/unit/test_scripts.py +++ b/tests/unit/test_scripts.py @@ -1,7 +1,6 @@ from unittest.mock import AsyncMock, MagicMock import pytest - from rapyer.errors import ScriptsNotInitializedError from rapyer.scripts import ( run_sha, @@ -52,15 +51,33 @@ async def test_handle_noscript_error_reloads_scripts_sanity(clear_script_state): # Arrange mock_redis = AsyncMock() mock_redis.script_load = AsyncMock(return_value="new_sha_456") + mock_config = MagicMock() + mock_config.is_fake_redis = False # Act - await handle_noscript_error(mock_redis) + await handle_noscript_error(mock_redis, mock_config) # Assert mock_redis.script_load.assert_called() assert _REGISTERED_SCRIPT_SHAS.get(REMOVE_RANGE_SCRIPT_NAME) == "new_sha_456" +@pytest.mark.asyncio +async def test_handle_noscript_error_reloads_scripts_with_fakeredis(clear_script_state): + # Arrange + mock_redis = AsyncMock() + mock_redis.script_load = AsyncMock(return_value="fakeredis_sha_789") + mock_config = MagicMock() + mock_config.is_fake_redis = True + + # Act + await handle_noscript_error(mock_redis, mock_config) + + # Assert + mock_redis.script_load.assert_called() + assert _REGISTERED_SCRIPT_SHAS.get(REMOVE_RANGE_SCRIPT_NAME) == "fakeredis_sha_789" + + @pytest.mark.asyncio async def test_register_scripts_stores_shas_sanity(clear_script_state): # Arrange diff --git a/tests/unit/types/test_dict_lua_scripts_with_fakeredis.py b/tests/unit/types/test_dict_lua_scripts_with_fakeredis.py index 8b6b68c3..16acc328 100644 --- a/tests/unit/types/test_dict_lua_scripts_with_fakeredis.py +++ b/tests/unit/types/test_dict_lua_scripts_with_fakeredis.py @@ -1,4 +1,5 @@ import pytest +import pytest_asyncio from tests.models.redis_types import DirectRedisDictModel @@ -43,3 +44,44 @@ async def test_redis_dict_apopitem_with_fakeredis_sanity(setup_fake_redis): assert result == "only_value" loaded = await DirectRedisDictModel.aget(model.key) assert len(loaded.metadata) == 0 + + +@pytest_asyncio.fixture +async def flush_fakeredis_scripts(setup_fake_redis): + await DirectRedisDictModel.Meta.redis.execute_command("SCRIPT", "FLUSH") + yield + + +@pytest.mark.asyncio +async def test_redis_dict_apop_recovers_from_noscript_with_fakeredis( + flush_fakeredis_scripts, +): + # Arrange + model = DirectRedisDictModel(metadata={"key1": "value1", "key2": "value2"}) + await model.asave() + + # Act + result = await model.metadata.apop("key1") + + # Assert + assert result == "value1" + loaded = await DirectRedisDictModel.aget(model.key) + assert "key1" not in loaded.metadata + assert loaded.metadata["key2"] == "value2" + + +@pytest.mark.asyncio +async def test_redis_dict_apopitem_recovers_from_noscript_with_fakeredis( + flush_fakeredis_scripts, +): + # Arrange + model = DirectRedisDictModel(metadata={"only_key": "only_value"}) + await model.asave() + + # Act + result = await model.metadata.apopitem() + + # Assert + assert result == "only_value" + loaded = await DirectRedisDictModel.aget(model.key) + assert len(loaded.metadata) == 0