Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/basic.py → examples/mari_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,16 @@
def on_event(event: EdgeEvent, event_data: MariNode | Frame):
if event == EdgeEvent.NODE_JOINED:
assert isinstance(event_data, MariNode)
# print(f"Node {event_data} joined")
print(f"Node {event_data} joined")
# print("#", end="", flush=True)
elif event == EdgeEvent.NODE_LEFT:
assert isinstance(event_data, MariNode)
# print(f"Node {event_data} left")
print(f"Node {event_data} left")
# print("0", end="", flush=True)
elif event == EdgeEvent.NODE_DATA:
assert isinstance(event_data, Frame)
# print(f"Got frame from 0x{event_data.header.source:016x}: {event_data.payload.hex()}, rssi: {event_data.stats.rssi_dbm}")
# print(".", end="", flush=True)
print(".", end="", flush=True)


@click.command()
Expand All @@ -47,7 +47,7 @@ def main(port: str | None):
# # print(f"Sending frame to 0x{node.address:016x}")
# mari.send_frame(dst=node.address, payload=b"A" * 3)
time.sleep(0.3)
tui.render(mari)
# tui.render(mari)

except KeyboardInterrupt:
print("\nInterrupted by user")
Expand Down
89 changes: 89 additions & 0 deletions marilib/ela_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from dataclasses import dataclass
import lakers
import logging
from marilib.model import MariNode
from marilib.mari_protocol import Frame
import requests
import random

V = bytes.fromhex("72cc4761dbd4c78f758931aa589d348d1ef874a7e303ede2f140dcf3e6aa4aac")
CRED_V = bytes.fromhex(
"a2026b6578616d706c652e65647508a101a501020241322001215820bbc34960526ea4d32e940cad2a234148ddc21791a12afbcbac93622046dd44f02258204519e257236b2a0ce2023f0931f1f386ca7afda64fcde0108c224c51eabf6072"
)
CBOR_TRUE = bytes.fromhex("f5")

CRED_REQUEST_PATH = ".well-known/lake-authz/cred-request"

@dataclass
class ELAHandler:
edhoc_responder: lakers.EdhocResponder = lakers.EdhocResponder(V, CRED_V)
authz_authenticator: lakers.AuthzAutenticator = lakers.AuthzAutenticator()
loc_w: str | None = None
c_i: bytes | None = None
c_r: bytes | None = None
state: str = "init"
node: MariNode | None = None
received_voucher: bytes | None = None

def handle_join_request(self, frame: Frame):
res = self.handle_message_1(frame)
if not res:
print("Failed to handle message 1")
return False, None
res, message_2 = self.prepare_message_2()
if not res:
print("Failed to prepare message 2")
return False, None
message_2 = self.c_i + message_2
return True, message_2

def handle_message_1(self, frame: Frame):
# handle message 1
if not frame.payload.startswith(CBOR_TRUE):
self.state = "error"
return False
message_1 = frame.payload[1:]
c_i, ead_1 = self.edhoc_responder.process_message_1(message_1)
self.c_i = c_i
print(
f"edhoc_message_1: {message_1.hex(' ').upper()} ead_1: {ead_1.value().hex(' ').upper() if ead_1 else None}"
)

# request voucher
loc_w, voucher_request = self.authz_authenticator.process_ead_1(ead_1, message_1)
voucher_request_url = f"{loc_w}/.well-known/lake-authz/voucher-request"
print(
f"Requesting voucher: {voucher_request_url} {voucher_request.hex(' ').upper()}"
)
response = requests.post(voucher_request_url, data=voucher_request)
if response.status_code == 200:
print(
f"Got an ok voucher response: {response.content.hex(' ').upper()}"
)
self.received_voucher = response.content
self.state = "received_voucher"
else:
print(
f"Error requesting voucher: {response.status_code}"
)

return True

def prepare_message_2(self):
if self.state != "received_voucher":
return False, None
ead_2 = self.authz_authenticator.prepare_ead_2(self.received_voucher)
print(f"Prepared ead_2: {ead_2.value().hex(' ').upper()}")
self.c_r = [random.randint(0, 23)] # already cbor-encoded as single-byte integer
message_2 = self.edhoc_responder.prepare_message_2(lakers.CredentialTransfer.ByValue, self.c_r, ead_2)
print(f"Prepared message_2: {message_2.hex(' ').upper()}")
return True, message_2

@staticmethod
def fetch_credential_remotely(loc_w: str, id_cred_i: bytes) -> bytes:
url = f"{loc_w}/{CRED_REQUEST_PATH}"
res = requests.post(url, data=id_cred_i)
if res.status_code == 200:
return res.content
else:
raise Exception(f"Error fetching credential {id_cred_i} at {loc_w}")
37 changes: 27 additions & 10 deletions marilib/marilib.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@
from datetime import datetime
from typing import Callable

from marilib.mari_protocol import MARI_BROADCAST_ADDRESS, Frame, Header
from marilib.mari_protocol import MARI_BROADCAST_ADDRESS, Frame, PacketType, Header
from marilib.model import EdgeEvent, GatewayInfo, MariGateway, MariNode
from marilib.protocol import ProtocolPayloadParserException
from marilib.serial_adapter import SerialAdapter
from marilib.serial_uart import get_default_port
from marilib.ela_handler import ELAHandler


USE_ELA = True


@dataclass
Expand All @@ -18,6 +22,7 @@ class MariLib:
serial_interface: SerialAdapter | None = None
started_ts: datetime = field(default_factory=datetime.now)
last_received_serial_data: datetime = field(default_factory=datetime.now)
pending_edhoc_sessions: dict[int, ELAHandler] = field(default_factory=dict)

def __post_init__(self):
if self.port is None:
Expand All @@ -39,16 +44,21 @@ def on_data_received(self, data: bytes):
# print(bytes(data).hex())

if event_type == EdgeEvent.NODE_JOINED:
address = int.from_bytes(data[1:9], "little")
# print(f"Event: {EdgeEvent.NODE_JOINED.name} {address}")
node = self.gateway.add_node(address)
self.cb_application(EdgeEvent.NODE_JOINED, node)
frame_bytes = data[1:]
frame = Frame().from_bytes(frame_bytes)

elif event_type == EdgeEvent.NODE_LEFT:
address = int.from_bytes(data[1:9], "little")
# print(f"Event: {EdgeEvent.NODE_LEFT.name} {address}")
if node := self.gateway.remove_node(address):
self.cb_application(EdgeEvent.NODE_LEFT, node)
if USE_ELA:
print(f"Node trying to join: {frame}")
ela_handler = ELAHandler()
res, join_response_payload = ela_handler.handle_join_request(frame)
if res:
self.send_join_response(frame.header.destination, join_response_payload)
else:
print(f"Node joined: {frame.header.destination}")
address = frame.header.destination
# print(f"Event: {EdgeEvent.NODE_JOINED.name} {address}")
node = self.gateway.add_node(address)
self.cb_application(EdgeEvent.NODE_JOINED, node)

elif event_type == EdgeEvent.NODE_KEEP_ALIVE:
address = int.from_bytes(data[1:9], "little")
Expand Down Expand Up @@ -88,3 +98,10 @@ def send_frame(self, dst: int, payload: bytes):
for node in self.gateway.nodes:
node.register_sent_frame(mari_frame)
self.gateway.register_sent_frame(mari_frame)

def send_join_response(self, dst: int, payload: bytes):
assert self.serial_interface is not None
mari_frame = Frame(Header(destination=dst, type_=PacketType.JOIN_RESPONSE), payload=payload)
uart_frame_type = b"\x01"
uart_frame = uart_frame_type + mari_frame.to_bytes()
self.serial_interface.send_data(uart_frame)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ dependencies = [
"rich == 14.0.0",
"structlog == 24.4.0",
"tqdm == 4.66.5",
"requests == 2.32.3",
]
description = "MariLib is a Python library for interacting with the Mari network."
readme = "README.md"
Expand Down
Loading