From 372489e768c921dcd2222319bcbe79b903ec0cb9 Mon Sep 17 00:00:00 2001 From: Ali Al-Alak Date: Sun, 6 Jul 2025 18:07:37 +0200 Subject: [PATCH] Enabling global hooks to be registered. Hooks are registered similar to routing with the @before_message and @after_message decorators. The hooks are added to _message_hooks property of the ChargePoint class and executed either before or after routing the ocpp call. --- ocpp/charge_point.py | 31 ++++++++++++-- ocpp/routing.py | 86 ++++++++++++++++++++++++++++++++++++++ tests/test_charge_point.py | 82 ++++++++++++++++++++++++++++++++++++ tests/test_routing.py | 57 ++++++++++++++++++++++++- 4 files changed, 252 insertions(+), 4 deletions(-) diff --git a/ocpp/charge_point.py b/ocpp/charge_point.py index 0d6d11a66..79f19e631 100644 --- a/ocpp/charge_point.py +++ b/ocpp/charge_point.py @@ -9,7 +9,7 @@ from ocpp.exceptions import NotImplementedError, NotSupportedError, OCPPError from ocpp.messages import Call, MessageType, unpack, validate_payload -from ocpp.routing import create_route_map +from ocpp.routing import create_route_map, discover_message_hooks LOGGER = logging.getLogger("ocpp") @@ -227,6 +227,9 @@ def __init__(self, id, connection, response_timeout=30, logger=LOGGER): # if exists. self.route_map = create_route_map(self) + # A dictionary that holds global message hooks + self._message_hooks = discover_message_hooks(self) + self._call_lock = asyncio.Lock() # A queue used to pass CallResults and CallErrors from @@ -241,6 +244,23 @@ def __init__(self, id, connection, response_timeout=30, logger=LOGGER): # The logger used to log messages self.logger = logger + async def _execute_hooks(self, hook_type, *args, **kwargs): + """ + Execute all hooks of a given type with error handling. + + Args: + hook_type: Type of hook to execute ('before_message', 'after_message', etc.) + *args: Arguments to pass to the hook functions + **kwargs: Keyword arguments to pass to the hook functions + """ + for hook in self._message_hooks.get(hook_type, []): + try: + result = hook(*args, **kwargs) + if inspect.isawaitable(result): + await result + except Exception as e: + self.logger.exception(f"Error in {hook_type} hook {hook.__name__}: {e}") + async def start(self): while True: message = await self._connection.recv() @@ -256,12 +276,14 @@ async def route_message(self, raw_msg): If the message is of type CallResult or CallError the message is passed to the call() function via the response_queue. """ + # Execute before_message hooks + await self._execute_hooks("before_message", raw_msg) + try: msg = unpack(raw_msg) except OCPPError as e: self.logger.exception( - "Unable to parse message: '%s', it doesn't seem " - "to be valid OCPP: %s", + "Unable to parse message: '%s', it doesn't seem to be valid OCPP: %s", raw_msg, e, ) @@ -278,6 +300,9 @@ async def route_message(self, raw_msg): elif msg.message_type_id in [MessageType.CallResult, MessageType.CallError]: self._response_queue.put_nowait(msg) + # Execute after_message hooks + await self._execute_hooks("after_message", raw_msg, msg) + async def _handle_call(self, msg): """ Execute all hooks installed for based on the Action of the message. diff --git a/ocpp/routing.py b/ocpp/routing.py index bf4b12f2b..31352989d 100644 --- a/ocpp/routing.py +++ b/ocpp/routing.py @@ -1,6 +1,7 @@ import functools routables = [] +global_hooks = [] def on(action, *, skip_schema_validation=False): @@ -83,6 +84,65 @@ def inner(*args, **kwargs): return decorator +def before_message(func): + """ + Function decorator to mark function as a global hook that runs before + any message processing. The wrapped function may be async or sync. + + The hook function will receive the raw message as its first argument. + It's recommended you use `**kwargs` in your definition to ignore any + extra arguments that may be added in the future. + + The hook function should not return anything. + + It can be used like so: + + ``` + class MyChargePoint(cp): + @before_message + async def log_incoming_message(self, raw_msg, **kwargs): + await self.db.save_message(raw_msg) + ``` + """ + + @functools.wraps(func) + def inner(*args, **kwargs): + return func(*args, **kwargs) + + inner._before_message = True + if func.__name__ not in global_hooks: + global_hooks.append(func.__name__) + return inner + + +def after_message(func): + """ + Function decorator to mark function as a global hook that runs after + any message processing. The wrapped function may be async or sync. + + The hook function will receive the raw message as its first argument + and the parsed message as its second argument. + + It can be used like so: + + ``` + class MyChargePoint(cp): + @after_message + async def log_processed_message(self, raw_msg, parsed_msg, **kwargs): + await self.db.update_message_status(parsed_msg.unique_id, 'processed') + ``` + """ + + @functools.wraps(func) + def inner(*args, **kwargs): + return func(*args, **kwargs) + + inner._after_message = True + if func.__name__ not in global_hooks: + global_hooks.append(func.__name__) + return inner + + def create_route_map(obj): """ Iterates of all attributes of the class looking for attributes which @@ -137,3 +197,29 @@ def after_boot_notification(self, *args, **kwargs): continue return routes + + +def discover_message_hooks(obj): + """ + Discovers and organizes global message hooks from decorated methods. + + Returns a dictionary with hook types as keys and lists of handlers as values. + """ + hooks = { + "before_message": [], + "after_message": [], + } + + for attr_name in global_hooks: + try: + attr = getattr(obj, attr_name) + + if hasattr(attr, "_before_message"): + hooks["before_message"].append(attr) + if hasattr(attr, "_after_message"): + hooks["after_message"].append(attr) + + except AttributeError: + continue + + return hooks diff --git a/tests/test_charge_point.py b/tests/test_charge_point.py index 25a7c8e0d..a1899b504 100644 --- a/tests/test_charge_point.py +++ b/tests/test_charge_point.py @@ -1,8 +1,10 @@ from dataclasses import asdict +from unittest.mock import AsyncMock, Mock import pytest from ocpp.charge_point import ( + ChargePoint, camel_to_snake_case, remove_nones, serialize_as_dict, @@ -476,3 +478,83 @@ def after_boot_notification(self, *args, **kwargs): assert ChargerA.after_boot_notification_call_count == 1 assert ChargerB.on_boot_notification_call_count == 1 assert ChargerB.after_boot_notification_call_count == 1 + + +@pytest.mark.asyncio +async def test_execute_hooks_no_hooks(): + """Test _execute_hooks when no hooks exist for the hook type""" + # Arrange + mock_connection = AsyncMock() + mock_logger = Mock() + + charge_point = ChargePoint("test_id", mock_connection, logger=mock_logger) + charge_point._message_hooks = {} + + # Act + await charge_point._execute_hooks("before_message", "some_arg") + + # Assert + mock_logger.exception.assert_not_called() + + +@pytest.mark.asyncio +async def test_execute_hooks_sync(): + """Test _execute_hooks actually calls a registered hook""" + # Arrange + mock_connection = AsyncMock() + + charge_point = ChargePoint("test_id", mock_connection) + + # Create a mock hook + mock_hook = Mock() + charge_point._message_hooks = {"before_message": [mock_hook]} + + # Act + await charge_point._execute_hooks("before_message", "test_arg", kwarg1="test_value") + + # Assert + mock_hook.assert_called_once_with("test_arg", kwarg1="test_value") + + +@pytest.mark.asyncio +async def test_execute_hooks_async(): + """Test _execute_hooks calls and awaits an async hook""" + # Arrange + mock_connection = AsyncMock() + + charge_point = ChargePoint("test_id", mock_connection) + + # Create an async mock hook + mock_async_hook = AsyncMock() + charge_point._message_hooks = {"after_message": [mock_async_hook]} + + # Act + await charge_point._execute_hooks("after_message", "test_message") + + # Assert + mock_async_hook.assert_called_once_with("test_message") + mock_async_hook.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_execute_hooks_calls_multiple(): + """Test _execute_hooks calls all hooks of the same type""" + # Arrange + mock_connection = AsyncMock() + + charge_point = ChargePoint("test_id", mock_connection) + + # Create multiple mock hooks + hook1 = Mock() + hook2 = Mock() + hook3 = Mock() + + charge_point._message_hooks = {"before_message": [hook1, hook2, hook3]} + + # Act + await charge_point._execute_hooks("before_message", "msg_data", user_id="123") + + # Assert + hook1.assert_called_once_with("msg_data", user_id="123") + hook2.assert_called_once_with("msg_data", user_id="123") + hook3.assert_called_once_with("msg_data", user_id="123") diff --git a/tests/test_routing.py b/tests/test_routing.py index b443a3c1f..ffbcee626 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -1,4 +1,11 @@ -from ocpp.routing import after, create_route_map, on +from ocpp.routing import ( + after, + after_message, + before_message, + create_route_map, + discover_message_hooks, + on, +) from ocpp.v16.enums import Action @@ -39,3 +46,51 @@ def undecorated(self): "_skip_schema_validation": False, }, } + + +def test_discover_message_hooks(): + """ + This test validates that message hooks is created correctly and holds all + functions decorated with the @before_message and @after_message decorators. + + """ + + class ChargePoint: + @before_message + def before_message_hook(self): + pass + + @after_message + def after_message_hook(self): + pass + + def undecorated(self): + pass + + cp = ChargePoint() + hooks = discover_message_hooks(cp) + + assert hooks == { + "before_message": [cp.before_message_hook], + "after_message": [cp.after_message_hook], + } + + +def test_discover_message_hooks_empty(): + """Test that discover_message_hooks works with no hooks.""" + + class ChargePoint: + @on(Action.heartbeat) + def on_heartbeat(self): + pass + + def undecorated(self): + pass + + cp = ChargePoint() + hooks = discover_message_hooks(cp) + + assert hooks == { + "before_message": [], + "after_message": [], + }