Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions ocpp/charge_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from ocpp.exceptions import NotImplementedError, NotSupportedError, OCPPError
from ocpp.messages import Call, MessageType, unpack, validate_payload
from ocpp.routing import create_route_map
from ocpp.routing import create_route_map, discover_message_hooks

LOGGER = logging.getLogger("ocpp")

Expand Down Expand Up @@ -227,6 +227,9 @@ def __init__(self, id, connection, response_timeout=30, logger=LOGGER):
# if exists.
self.route_map = create_route_map(self)

# A dictionary that holds global message hooks
self._message_hooks = discover_message_hooks(self)

self._call_lock = asyncio.Lock()

# A queue used to pass CallResults and CallErrors from
Expand All @@ -241,6 +244,23 @@ def __init__(self, id, connection, response_timeout=30, logger=LOGGER):
# The logger used to log messages
self.logger = logger

async def _execute_hooks(self, hook_type, *args, **kwargs):
"""
Execute all hooks of a given type with error handling.

Args:
hook_type: Type of hook to execute ('before_message', 'after_message', etc.)
*args: Arguments to pass to the hook functions
**kwargs: Keyword arguments to pass to the hook functions
"""
for hook in self._message_hooks.get(hook_type, []):
try:
result = hook(*args, **kwargs)
if inspect.isawaitable(result):
await result
except Exception as e:
self.logger.exception(f"Error in {hook_type} hook {hook.__name__}: {e}")

async def start(self):
while True:
message = await self._connection.recv()
Expand All @@ -256,12 +276,14 @@ async def route_message(self, raw_msg):
If the message is of type CallResult or CallError the message is passed
to the call() function via the response_queue.
"""
# Execute before_message hooks
await self._execute_hooks("before_message", raw_msg)

try:
msg = unpack(raw_msg)
except OCPPError as e:
self.logger.exception(
"Unable to parse message: '%s', it doesn't seem "
"to be valid OCPP: %s",
"Unable to parse message: '%s', it doesn't seem to be valid OCPP: %s",
raw_msg,
e,
)
Expand All @@ -278,6 +300,9 @@ async def route_message(self, raw_msg):
elif msg.message_type_id in [MessageType.CallResult, MessageType.CallError]:
self._response_queue.put_nowait(msg)

# Execute after_message hooks
await self._execute_hooks("after_message", raw_msg, msg)

async def _handle_call(self, msg):
"""
Execute all hooks installed for based on the Action of the message.
Expand Down
86 changes: 86 additions & 0 deletions ocpp/routing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

routables = []
global_hooks = []


def on(action, *, skip_schema_validation=False):
Expand Down Expand Up @@ -83,6 +84,65 @@ def inner(*args, **kwargs):
return decorator


def before_message(func):
"""
Function decorator to mark function as a global hook that runs before
any message processing. The wrapped function may be async or sync.

The hook function will receive the raw message as its first argument.
It's recommended you use `**kwargs` in your definition to ignore any
extra arguments that may be added in the future.

The hook function should not return anything.

It can be used like so:

```
class MyChargePoint(cp):
@before_message
async def log_incoming_message(self, raw_msg, **kwargs):
await self.db.save_message(raw_msg)
```
"""

@functools.wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs)

inner._before_message = True
if func.__name__ not in global_hooks:
global_hooks.append(func.__name__)
return inner


def after_message(func):
"""
Function decorator to mark function as a global hook that runs after
any message processing. The wrapped function may be async or sync.

The hook function will receive the raw message as its first argument
and the parsed message as its second argument.

It can be used like so:

```
class MyChargePoint(cp):
@after_message
async def log_processed_message(self, raw_msg, parsed_msg, **kwargs):
await self.db.update_message_status(parsed_msg.unique_id, 'processed')
```
"""

@functools.wraps(func)
def inner(*args, **kwargs):
return func(*args, **kwargs)

inner._after_message = True
if func.__name__ not in global_hooks:
global_hooks.append(func.__name__)
return inner


def create_route_map(obj):
"""
Iterates of all attributes of the class looking for attributes which
Expand Down Expand Up @@ -137,3 +197,29 @@ def after_boot_notification(self, *args, **kwargs):
continue

return routes


def discover_message_hooks(obj):
"""
Discovers and organizes global message hooks from decorated methods.

Returns a dictionary with hook types as keys and lists of handlers as values.
"""
hooks = {
"before_message": [],
"after_message": [],
}

for attr_name in global_hooks:
try:
attr = getattr(obj, attr_name)

if hasattr(attr, "_before_message"):
hooks["before_message"].append(attr)
if hasattr(attr, "_after_message"):
hooks["after_message"].append(attr)

except AttributeError:
continue

return hooks
82 changes: 82 additions & 0 deletions tests/test_charge_point.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from dataclasses import asdict
from unittest.mock import AsyncMock, Mock

import pytest

from ocpp.charge_point import (
ChargePoint,
camel_to_snake_case,
remove_nones,
serialize_as_dict,
Expand Down Expand Up @@ -476,3 +478,83 @@ def after_boot_notification(self, *args, **kwargs):
assert ChargerA.after_boot_notification_call_count == 1
assert ChargerB.on_boot_notification_call_count == 1
assert ChargerB.after_boot_notification_call_count == 1


@pytest.mark.asyncio
async def test_execute_hooks_no_hooks():
"""Test _execute_hooks when no hooks exist for the hook type"""
# Arrange
mock_connection = AsyncMock()
mock_logger = Mock()

charge_point = ChargePoint("test_id", mock_connection, logger=mock_logger)
charge_point._message_hooks = {}

# Act
await charge_point._execute_hooks("before_message", "some_arg")

# Assert
mock_logger.exception.assert_not_called()


@pytest.mark.asyncio
async def test_execute_hooks_sync():
"""Test _execute_hooks actually calls a registered hook"""
# Arrange
mock_connection = AsyncMock()

charge_point = ChargePoint("test_id", mock_connection)

# Create a mock hook
mock_hook = Mock()
charge_point._message_hooks = {"before_message": [mock_hook]}

# Act
await charge_point._execute_hooks("before_message", "test_arg", kwarg1="test_value")

# Assert
mock_hook.assert_called_once_with("test_arg", kwarg1="test_value")


@pytest.mark.asyncio
async def test_execute_hooks_async():
"""Test _execute_hooks calls and awaits an async hook"""
# Arrange
mock_connection = AsyncMock()

charge_point = ChargePoint("test_id", mock_connection)

# Create an async mock hook
mock_async_hook = AsyncMock()
charge_point._message_hooks = {"after_message": [mock_async_hook]}

# Act
await charge_point._execute_hooks("after_message", "test_message")

# Assert
mock_async_hook.assert_called_once_with("test_message")
mock_async_hook.assert_awaited_once()


@pytest.mark.asyncio
async def test_execute_hooks_calls_multiple():
"""Test _execute_hooks calls all hooks of the same type"""
# Arrange
mock_connection = AsyncMock()

charge_point = ChargePoint("test_id", mock_connection)

# Create multiple mock hooks
hook1 = Mock()
hook2 = Mock()
hook3 = Mock()

charge_point._message_hooks = {"before_message": [hook1, hook2, hook3]}

# Act
await charge_point._execute_hooks("before_message", "msg_data", user_id="123")

# Assert
hook1.assert_called_once_with("msg_data", user_id="123")
hook2.assert_called_once_with("msg_data", user_id="123")
hook3.assert_called_once_with("msg_data", user_id="123")
57 changes: 56 additions & 1 deletion tests/test_routing.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from ocpp.routing import after, create_route_map, on
from ocpp.routing import (
after,
after_message,
before_message,
create_route_map,
discover_message_hooks,
on,
)
from ocpp.v16.enums import Action


Expand Down Expand Up @@ -39,3 +46,51 @@ def undecorated(self):
"_skip_schema_validation": False,
},
}


def test_discover_message_hooks():
"""
This test validates that message hooks is created correctly and holds all
functions decorated with the @before_message and @after_message decorators.

"""

class ChargePoint:
@before_message
def before_message_hook(self):
pass

@after_message
def after_message_hook(self):
pass

def undecorated(self):
pass

cp = ChargePoint()
hooks = discover_message_hooks(cp)

assert hooks == {
"before_message": [cp.before_message_hook],
"after_message": [cp.after_message_hook],
}


def test_discover_message_hooks_empty():
"""Test that discover_message_hooks works with no hooks."""

class ChargePoint:
@on(Action.heartbeat)
def on_heartbeat(self):
pass

def undecorated(self):
pass

cp = ChargePoint()
hooks = discover_message_hooks(cp)

assert hooks == {
"before_message": [],
"after_message": [],
}