diff --git a/examples/basic.py b/examples/mari_edge.py similarity index 91% rename from examples/basic.py rename to examples/mari_edge.py index 549bbde..2c5d659 100644 --- a/examples/basic.py +++ b/examples/mari_edge.py @@ -14,16 +14,16 @@ def on_event(event: EdgeEvent, event_data: MariNode | Frame): if event == EdgeEvent.NODE_JOINED: assert isinstance(event_data, MariNode) - # print(f"Node {event_data} joined") + print(f"Node {event_data} joined") # print("#", end="", flush=True) elif event == EdgeEvent.NODE_LEFT: assert isinstance(event_data, MariNode) - # print(f"Node {event_data} left") + print(f"Node {event_data} left") # print("0", end="", flush=True) elif event == EdgeEvent.NODE_DATA: assert isinstance(event_data, Frame) # print(f"Got frame from 0x{event_data.header.source:016x}: {event_data.payload.hex()}, rssi: {event_data.stats.rssi_dbm}") - # print(".", end="", flush=True) + print(".", end="", flush=True) @click.command() @@ -47,7 +47,7 @@ def main(port: str | None): # # print(f"Sending frame to 0x{node.address:016x}") # mari.send_frame(dst=node.address, payload=b"A" * 3) time.sleep(0.3) - tui.render(mari) + # tui.render(mari) except KeyboardInterrupt: print("\nInterrupted by user") diff --git a/marilib/ela_handler.py b/marilib/ela_handler.py new file mode 100644 index 0000000..b0a04b2 --- /dev/null +++ b/marilib/ela_handler.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass +import lakers +import logging +from marilib.model import MariNode +from marilib.mari_protocol import Frame +import requests +import random + +V = bytes.fromhex("72cc4761dbd4c78f758931aa589d348d1ef874a7e303ede2f140dcf3e6aa4aac") +CRED_V = bytes.fromhex( + "a2026b6578616d706c652e65647508a101a501020241322001215820bbc34960526ea4d32e940cad2a234148ddc21791a12afbcbac93622046dd44f02258204519e257236b2a0ce2023f0931f1f386ca7afda64fcde0108c224c51eabf6072" +) +CBOR_TRUE = bytes.fromhex("f5") + +CRED_REQUEST_PATH = ".well-known/lake-authz/cred-request" + +@dataclass +class ELAHandler: + edhoc_responder: lakers.EdhocResponder = lakers.EdhocResponder(V, CRED_V) + authz_authenticator: lakers.AuthzAutenticator = lakers.AuthzAutenticator() + loc_w: str | None = None + c_i: bytes | None = None + c_r: bytes | None = None + state: str = "init" + node: MariNode | None = None + received_voucher: bytes | None = None + + def handle_join_request(self, frame: Frame): + res = self.handle_message_1(frame) + if not res: + print("Failed to handle message 1") + return False, None + res, message_2 = self.prepare_message_2() + if not res: + print("Failed to prepare message 2") + return False, None + message_2 = self.c_i + message_2 + return True, message_2 + + def handle_message_1(self, frame: Frame): + # handle message 1 + if not frame.payload.startswith(CBOR_TRUE): + self.state = "error" + return False + message_1 = frame.payload[1:] + c_i, ead_1 = self.edhoc_responder.process_message_1(message_1) + self.c_i = c_i + print( + f"edhoc_message_1: {message_1.hex(' ').upper()} ead_1: {ead_1.value().hex(' ').upper() if ead_1 else None}" + ) + + # request voucher + loc_w, voucher_request = self.authz_authenticator.process_ead_1(ead_1, message_1) + voucher_request_url = f"{loc_w}/.well-known/lake-authz/voucher-request" + print( + f"Requesting voucher: {voucher_request_url} {voucher_request.hex(' ').upper()}" + ) + response = requests.post(voucher_request_url, data=voucher_request) + if response.status_code == 200: + print( + f"Got an ok voucher response: {response.content.hex(' ').upper()}" + ) + self.received_voucher = response.content + self.state = "received_voucher" + else: + print( + f"Error requesting voucher: {response.status_code}" + ) + + return True + + def prepare_message_2(self): + if self.state != "received_voucher": + return False, None + ead_2 = self.authz_authenticator.prepare_ead_2(self.received_voucher) + print(f"Prepared ead_2: {ead_2.value().hex(' ').upper()}") + self.c_r = [random.randint(0, 23)] # already cbor-encoded as single-byte integer + message_2 = self.edhoc_responder.prepare_message_2(lakers.CredentialTransfer.ByValue, self.c_r, ead_2) + print(f"Prepared message_2: {message_2.hex(' ').upper()}") + return True, message_2 + + @staticmethod + def fetch_credential_remotely(loc_w: str, id_cred_i: bytes) -> bytes: + url = f"{loc_w}/{CRED_REQUEST_PATH}" + res = requests.post(url, data=id_cred_i) + if res.status_code == 200: + return res.content + else: + raise Exception(f"Error fetching credential {id_cred_i} at {loc_w}") diff --git a/marilib/marilib.py b/marilib/marilib.py index a837753..570593b 100644 --- a/marilib/marilib.py +++ b/marilib/marilib.py @@ -2,11 +2,15 @@ from datetime import datetime from typing import Callable -from marilib.mari_protocol import MARI_BROADCAST_ADDRESS, Frame, Header +from marilib.mari_protocol import MARI_BROADCAST_ADDRESS, Frame, PacketType, Header from marilib.model import EdgeEvent, GatewayInfo, MariGateway, MariNode from marilib.protocol import ProtocolPayloadParserException from marilib.serial_adapter import SerialAdapter from marilib.serial_uart import get_default_port +from marilib.ela_handler import ELAHandler + + +USE_ELA = True @dataclass @@ -18,6 +22,7 @@ class MariLib: serial_interface: SerialAdapter | None = None started_ts: datetime = field(default_factory=datetime.now) last_received_serial_data: datetime = field(default_factory=datetime.now) + pending_edhoc_sessions: dict[int, ELAHandler] = field(default_factory=dict) def __post_init__(self): if self.port is None: @@ -39,16 +44,21 @@ def on_data_received(self, data: bytes): # print(bytes(data).hex()) if event_type == EdgeEvent.NODE_JOINED: - address = int.from_bytes(data[1:9], "little") - # print(f"Event: {EdgeEvent.NODE_JOINED.name} {address}") - node = self.gateway.add_node(address) - self.cb_application(EdgeEvent.NODE_JOINED, node) + frame_bytes = data[1:] + frame = Frame().from_bytes(frame_bytes) - elif event_type == EdgeEvent.NODE_LEFT: - address = int.from_bytes(data[1:9], "little") - # print(f"Event: {EdgeEvent.NODE_LEFT.name} {address}") - if node := self.gateway.remove_node(address): - self.cb_application(EdgeEvent.NODE_LEFT, node) + if USE_ELA: + print(f"Node trying to join: {frame}") + ela_handler = ELAHandler() + res, join_response_payload = ela_handler.handle_join_request(frame) + if res: + self.send_join_response(frame.header.destination, join_response_payload) + else: + print(f"Node joined: {frame.header.destination}") + address = frame.header.destination + # print(f"Event: {EdgeEvent.NODE_JOINED.name} {address}") + node = self.gateway.add_node(address) + self.cb_application(EdgeEvent.NODE_JOINED, node) elif event_type == EdgeEvent.NODE_KEEP_ALIVE: address = int.from_bytes(data[1:9], "little") @@ -88,3 +98,10 @@ def send_frame(self, dst: int, payload: bytes): for node in self.gateway.nodes: node.register_sent_frame(mari_frame) self.gateway.register_sent_frame(mari_frame) + + def send_join_response(self, dst: int, payload: bytes): + assert self.serial_interface is not None + mari_frame = Frame(Header(destination=dst, type_=PacketType.JOIN_RESPONSE), payload=payload) + uart_frame_type = b"\x01" + uart_frame = uart_frame_type + mari_frame.to_bytes() + self.serial_interface.send_data(uart_frame) diff --git a/pyproject.toml b/pyproject.toml index 17cb100..5d6c45c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "rich == 14.0.0", "structlog == 24.4.0", "tqdm == 4.66.5", + "requests == 2.32.3", ] description = "MariLib is a Python library for interacting with the Mari network." readme = "README.md"