From ef30e189487d279e9b0f67cb82dc72ac8881bf45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Ryannel?= Date: Thu, 30 Mar 2023 18:05:05 +0200 Subject: [PATCH 1/5] add ws client/server --- src/olink/ws/client.py | 38 ++++++++++++++++++ src/olink/ws/emitter.py | 14 +++++++ src/olink/ws/server.py | 85 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 137 insertions(+) create mode 100644 src/olink/ws/client.py create mode 100644 src/olink/ws/emitter.py create mode 100644 src/olink/ws/server.py diff --git a/src/olink/ws/client.py b/src/olink/ws/client.py new file mode 100644 index 0000000..fe872f1 --- /dev/null +++ b/src/olink/ws/client.py @@ -0,0 +1,38 @@ +import asyncio +import websockets as ws +from .emitter import Emitter + + +class Client: + send_queue = asyncio.Queue() + recv_queue = asyncio.Queue() + node = None + def __init__(self, node): + self.node = node + + def send(self, msg): + self.send_queue.put_nowait(msg) + + async def handle_send(self): + async for msg in self.send_queue: + data = self.serializer.serialize(msg) + await self.conn.send(data) + + async def handle_recv(self): + async for msg in self.recv_queue: + self.emitter.emit(msg.object, msg) + + async def recv(self): + async for data in self.conn: + msg = self.serializer.deserialize(data) + self.recv_queue.put_nowait(msg) + + async def connect(self, addr: str): + # connect to server + async for conn in ws.connect(addr): + self.conn = conn + # start send and recv tasks + await asyncio.gather(self.handle_send(), self.handle_recv(), self.recv()) + # wait for all queues to be empty + await self.send_queue.join() + await self.recv_queue.join() \ No newline at end of file diff --git a/src/olink/ws/emitter.py b/src/olink/ws/emitter.py new file mode 100644 index 0000000..6265ed5 --- /dev/null +++ b/src/olink/ws/emitter.py @@ -0,0 +1,14 @@ + +class Emitter: + def __init__(self): + self._callbacks = {} + + def on(self, event, callback): + self._callbacks[event] = callback + + def emit(self, event, *args): + self._callbacks[event](*args) + + def off(self, event): + self._callbacks.pop(event) + diff --git a/src/olink/ws/server.py b/src/olink/ws/server.py new file mode 100644 index 0000000..608a613 --- /dev/null +++ b/src/olink/ws/server.py @@ -0,0 +1,85 @@ +import websockets as ws +from .emitter import Emitter +from typing import Any +import asyncio +from .client import Client +from ..remotenode import IObjectSource, RemoteNode + +class SourceAdapter(IObjectSource): + node: RemoteNode = None + object_id: str = None + def __init__(self, objectId: str, impl) -> None: + self.object_id = objectId + self.impl = impl + RemoteNode.register_source(self) + + def olink_object_name() -> str: + return self.objectId + + def olink_invoke(self, name: str, args: list[Any]) -> Any: + path = Name.path_from_name(name) + func = getattr(self.impl, path) + try: + result = func(**args) + except Exception as e: + print('error: %s' % e) + result = None + return result + + def olink_set_property(self, name: str, value: Any): + # set property value on implementation + path = Name.path_from_name(name) + setattr(self, self.impl, value) + + def olink_linked(self, name: str, node: "RemoteNode"): + # called when the source is linked to a client node + self.node = node + + def olink_collect_properties(self) -> object: + # collect properties from implementation to send back to client node initially + return {k: getattr(self.impl, k) for k in ['count']} + + +class RemotePipe: + send_queue = asyncio.Queue() + recv_queue = asyncio.Queue() + node = RemoteNode() + def __init__(self, conn: ws.ClientConnection): + self.conn = conn + self.node.on_write(self._send) + + def _send(self, data): + self.send_queue.put_nowait(data) + + async def handle_send(self): + async for data in self.send_queue: + await self.conn.send(data) + + async def handle_recv(self): + async for data in self.recv_queue: + self.node.handle_message(data) + + async def recv(self): + async for data in self.conn: + self.recv_queue.put_nowait(data) + +class Server: + pipes = [] + def handle_connection(self, pipe: ws.WebSocketServerProtocol, path: str): + pipe = RemotePipe(pipe, self.serializer) + self.pipes.append(pipe) + + async def serve(self, host: str, port: int): + async with ws.serve(self.handle_connection, host, port): + await asyncio.Future() + + + + +def run_server(host: str, port: int): + server = Server() + asyncio.run(server.serve(host, port)) + + +if __name__ == "__main__": + run_server("localhost", 8152) From 38d51583f89fdcf3afbe73ec9c727a781615720c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Ryannel?= Date: Fri, 31 Mar 2023 12:50:49 +0200 Subject: [PATCH 2/5] rework olink core - better isolation of remote and client - better base types for source and sinks - cleaner structure --- setup.cfg | 2 +- src/olink/client/__init__.py | 4 + src/olink/{clientnode.py => client/node.py} | 109 +---------- src/olink/client/registry.py | 80 ++++++++ src/olink/client/sink.py | 40 ++++ src/olink/client/types.py | 23 +++ src/olink/core/__init__.py | 4 + src/olink/core/hook.py | 15 ++ src/olink/core/protocol.py | 4 +- src/olink/core/types.py | 4 +- src/olink/mocks/mocksink.py | 2 +- src/olink/mocks/mocksource.py | 4 +- src/olink/remote/__init__.py | 6 + src/olink/remote/adapter.py | 38 ++++ src/olink/remote/node.py | 78 ++++++++ src/olink/remote/registry.py | 91 +++++++++ src/olink/remote/types.py | 24 +++ src/olink/remotenode.py | 193 -------------------- src/olink/ws/__init__.py | 2 + src/olink/ws/client.py | 7 +- src/olink/ws/emitter.py | 14 -- src/olink/ws/server.py | 37 +--- 22 files changed, 421 insertions(+), 360 deletions(-) create mode 100644 src/olink/client/__init__.py rename src/olink/{clientnode.py => client/node.py} (56%) create mode 100644 src/olink/client/registry.py create mode 100644 src/olink/client/sink.py create mode 100644 src/olink/client/types.py create mode 100644 src/olink/core/hook.py create mode 100644 src/olink/remote/__init__.py create mode 100644 src/olink/remote/adapter.py create mode 100644 src/olink/remote/node.py create mode 100644 src/olink/remote/registry.py create mode 100644 src/olink/remote/types.py delete mode 100644 src/olink/remotenode.py create mode 100644 src/olink/ws/__init__.py delete mode 100644 src/olink/ws/emitter.py diff --git a/setup.cfg b/setup.cfg index d30f298..f3473b4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] -name = olink-core +name = olink version = 0.0.1 author = AppiGear author_email = info@apigear.io diff --git a/src/olink/client/__init__.py b/src/olink/client/__init__.py new file mode 100644 index 0000000..58d15bf --- /dev/null +++ b/src/olink/client/__init__.py @@ -0,0 +1,4 @@ +from .node import ClientNode, InvokeReplyArg, InvokeReplyFunc +from .registry import ClientRegistry, get_client_registry +from .types import IObjectSink +from .sink import AbstractSink \ No newline at end of file diff --git a/src/olink/clientnode.py b/src/olink/client/node.py similarity index 56% rename from src/olink/clientnode.py rename to src/olink/client/node.py index 9ff6c79..974f69d 100644 --- a/src/olink/clientnode.py +++ b/src/olink/client/node.py @@ -1,9 +1,8 @@ from re import A -from typing import Any, Callable, Optional, Protocol as ProtocolType -from olink.core.types import Base, LogLevel, MsgType, Name -from olink.core.node import BaseNode -from olink.core.protocol import Protocol - +from typing import Any, Callable, Optional +from olink.core import LogLevel, MsgType, BaseNode, Protocol +from .types import IObjectSink +from .registry import ClientRegistry, get_client_registry class InvokeReplyArg: def __init__(self, name: str, value: Any): @@ -16,106 +15,6 @@ def __init__(self, name: str, value: Any): InvokeReplyFunc = Callable[[InvokeReplyArg], None] -class IObjectSink(ProtocolType): - # interface for object sinks - def olink_object_name() -> str: - # return object name - raise NotImplementedError() - - def olink_on_signal(self, name: str, args: list[Any]) -> None: - # called on signal message - raise NotImplementedError() - - def olink_on_property_changed(self, name: str, value: Any) -> None: - # called on property changed message - raise NotImplementedError() - - def olink_on_init(self, name: str, props: object, node: "ClientNode"): - # called on init message - raise NotImplementedError() - - def olink_on_release(self) -> None: - # called when sink is released - raise NotImplementedError() - - -class SinkToClientEntry: - # entry in the client registry - sink: IObjectSink = None - node: "ClientNode" = None - - def __init__(self, sink=None): - self.sink = sink - self.node = None - - -class ClientRegistry(Base): - # client side registry to link sinks to nodes - entries: dict[str, SinkToClientEntry] = {} - - def remove_node(self, node: "ClientNode"): - # remove node from all sinks - for entry in self.entries.values(): - if entry.node is node: - entry.node = None - - def add_node_to_sink(self, name: str, node: "ClientNode"): - # add not to named sink - self._entry(name).node = node - - def remove_node_from_sink(self, name: str, node: "ClientNode"): - # remove node from named sink - resource = Name.resource_from_name(name) - if resource in self.entries: - if self.entries[resource].node is node: - self.entries[resource].node = None - else: - self.emit_log( - LogLevel.DEBUG, f"unlink node failed, not the same node: {resource}") - - def register_sink(self, sink: IObjectSink) -> "ClientNode": - # register sink using object name - name = sink.olink_object_name() - entry = self._entry(name) - entry.sink = sink - return entry.node - - def unregister_sink(self, sink: IObjectSink): - # unregister sink using object name - name = sink.olink_object_name() - self._remove_entry(name) - - def get_sink(self, name: str) -> Optional[IObjectSink]: - # get sink using name - return self._entry(name).sink - - def get_node(self, name: str) -> Optional["ClientNode"]: - # get node using name - return self._entry(name).node - - def _entry(self, name: str) -> SinkToClientEntry: - # get an entry by name - resource = Name.resource_from_name(name) - if not resource in self.entries: - self.emit_log(LogLevel.DEBUG, f"add new resource: {resource}") - self.entries[resource] = SinkToClientEntry() - return self.entries[resource] - - def _remove_entry(self, name: str) -> None: - # remove an entry by name - resource = Name.resource_from_name(name) - del self.entries[resource] - - -# global client registry -_registry = ClientRegistry() - - -def get_client_registry() -> ClientRegistry: - # get global client registry - return _registry - - class ClientNode(BaseNode): # client side node invokes_pending: dict[int, InvokeReplyFunc] = {} diff --git a/src/olink/client/registry.py b/src/olink/client/registry.py new file mode 100644 index 0000000..210e522 --- /dev/null +++ b/src/olink/client/registry.py @@ -0,0 +1,80 @@ +from .types import IObjectSink +from olink.core import Name, Base, LogLevel +from typing import Optional + +class SinkToClientEntry: + # entry in the client registry + sink: IObjectSink = None + node: "ClientNode" = None + + def __init__(self, sink=None): + self.sink = sink + self.node = None + + +class ClientRegistry(Base): + # client side registry to link sinks to nodes + entries: dict[str, SinkToClientEntry] = {} + + def remove_node(self, node: "ClientNode"): + # remove node from all sinks + for entry in self.entries.values(): + if entry.node is node: + entry.node = None + + def add_node_to_sink(self, name: str, node: "ClientNode"): + # add not to named sink + self._entry(name).node = node + + def remove_node_from_sink(self, name: str, node: "ClientNode"): + # remove node from named sink + resource = Name.resource_from_name(name) + if resource in self.entries: + if self.entries[resource].node is node: + self.entries[resource].node = None + else: + self.emit_log( + LogLevel.DEBUG, f"unlink node failed, not the same node: {resource}") + + def register_sink(self, sink: IObjectSink) -> "ClientNode": + # register sink using object name + name = sink.olink_object_name() + entry = self._entry(name) + entry.sink = sink + return entry.node + + def unregister_sink(self, sink: IObjectSink): + # unregister sink using object name + name = sink.olink_object_name() + self._remove_entry(name) + + def get_sink(self, name: str) -> Optional[IObjectSink]: + # get sink using name + return self._entry(name).sink + + def get_node(self, name: str) -> Optional["ClientNode"]: + # get node using name + return self._entry(name).node + + def _entry(self, name: str) -> SinkToClientEntry: + # get an entry by name + resource = Name.resource_from_name(name) + if not resource in self.entries: + self.emit_log(LogLevel.DEBUG, f"add new resource: {resource}") + self.entries[resource] = SinkToClientEntry() + return self.entries[resource] + + def _remove_entry(self, name: str) -> None: + # remove an entry by name + resource = Name.resource_from_name(name) + del self.entries[resource] + + +# global client registry +_registry = ClientRegistry() + + +def get_client_registry() -> ClientRegistry: + # get global client registry + return _registry + diff --git a/src/olink/client/sink.py b/src/olink/client/sink.py new file mode 100644 index 0000000..bdbdd1e --- /dev/null +++ b/src/olink/client/sink.py @@ -0,0 +1,40 @@ +import asyncio +from typing import Any +from olink.core import Name +from .node import IObjectSink, ClientNode +from olink.core.hook import EventHook + + +class AbstractSink(IObjectSink): + on_property_changed = EventHook() + object_id: str = None + + client = None + + def __init__(self, object_id: str): + self.object_id = object_id + self.client = ClientNode.register_sink(self) + + async def _invoke(self, name, args): + future = asyncio.get_running_loop().create_future() + def func(args): + return future.set_result(args.value) + self.client.invoke_remote(f'{self.object_id}/{name}', args, func) + return await asyncio.wait_for(future, 500) + + def olink_object_name(self): + return self.object_id + + def olink_on_init(self, name: str, props: object, node: ClientNode): + for k in props: + setattr(self, k, props[k]) + + def olink_on_property_changed(self, name: str, value: Any) -> None: + path = Name.path_from_name(name) + setattr(self, name, value) + self.on_property_changed.fire(path, value) + + def olink_on_signal(self, name: str, args: list[Any]): + path = Name.path_from_name(name) + hook = getattr(self, f'on_{path}') + hook.fire(*args) diff --git a/src/olink/client/types.py b/src/olink/client/types.py new file mode 100644 index 0000000..fb151c6 --- /dev/null +++ b/src/olink/client/types.py @@ -0,0 +1,23 @@ +from typing import Any, Protocol as ProtocolType + +class IObjectSink(ProtocolType): + # interface for object sinks + def olink_object_name() -> str: + # return object name + raise NotImplementedError() + + def olink_on_signal(self, name: str, args: list[Any]) -> None: + # called on signal message + raise NotImplementedError() + + def olink_on_property_changed(self, name: str, value: Any) -> None: + # called on property changed message + raise NotImplementedError() + + def olink_on_init(self, name: str, props: object, node: "ClientNode"): + # called on init message + raise NotImplementedError() + + def olink_on_release(self) -> None: + # called when sink is released + raise NotImplementedError() \ No newline at end of file diff --git a/src/olink/core/__init__.py b/src/olink/core/__init__.py index e69de29..325d9f4 100644 --- a/src/olink/core/__init__.py +++ b/src/olink/core/__init__.py @@ -0,0 +1,4 @@ +from .hook import EventHook +from .types import Name, Base, LogLevel, MsgType +from .protocol import IProtocolListener, Protocol +from .node import BaseNode diff --git a/src/olink/core/hook.py b/src/olink/core/hook.py new file mode 100644 index 0000000..becad12 --- /dev/null +++ b/src/olink/core/hook.py @@ -0,0 +1,15 @@ +class EventHook(object): + def __init__(self): + self.__handlers = [] + + def __iadd__(self, handler): + self.__handlers.append(handler) + return self + + def __isub__(self, handler): + self.__handlers.remove(handler) + return self + + def fire(self, *args, **kwargs): + for handler in self.__handlers: + handler(*args, **kwargs) \ No newline at end of file diff --git a/src/olink/core/protocol.py b/src/olink/core/protocol.py index 38a5da6..781a1e3 100644 --- a/src/olink/core/protocol.py +++ b/src/olink/core/protocol.py @@ -1,9 +1,9 @@ from typing import Any -from typing import Protocol as ProptocolType +from typing import Protocol as ProtocolType from .types import Base, LogLevel, MsgType -class IProtocolListener(ProptocolType): +class IProtocolListener(ProtocolType): # interface for protocol listeners def handle_link(self, name: str) -> None: # called when a link is created diff --git a/src/olink/core/types.py b/src/olink/core/types.py index 1b11b48..62edb44 100644 --- a/src/olink/core/types.py +++ b/src/olink/core/types.py @@ -1,6 +1,6 @@ from enum import IntEnum from typing import Any, Callable -from typing import Protocol as ProptocolType +from typing import Protocol as ProtocolType import json @@ -74,7 +74,7 @@ class LogLevel: WriteLogFunc = Callable[[LogLevel, str], None] -class ILogger(ProptocolType): +class ILogger(ProtocolType): def log(level: LogLevel, msg: str) -> None: raise NotImplementedError() diff --git a/src/olink/mocks/mocksink.py b/src/olink/mocks/mocksink.py index cc11322..87affd9 100644 --- a/src/olink/mocks/mocksink.py +++ b/src/olink/mocks/mocksink.py @@ -1,6 +1,6 @@ from olink.core.types import Name from typing import Any, Optional -from olink.clientnode import ClientNode, IObjectSink, InvokeReplyArg +from olink.client import ClientNode, IObjectSink, InvokeReplyArg class MockSink(IObjectSink): name: str diff --git a/src/olink/mocks/mocksource.py b/src/olink/mocks/mocksource.py index 4d65094..4699b6f 100644 --- a/src/olink/mocks/mocksource.py +++ b/src/olink/mocks/mocksource.py @@ -1,6 +1,6 @@ from olink.core.types import Name from typing import Any -from olink.remotenode import IObjectSource, RemoteNode +from olink.remote import IObjectSource, RemoteNode class MockSource(IObjectSource): name: str @@ -21,7 +21,7 @@ def olink_object_name(self) -> str: return self.name def olink_invoke(self, name: str, args: list[Any]): - self.events.append({ 'type': 'invole', 'name': name, 'args': args }) + self.events.append({ 'type': 'invoke', 'name': name, 'args': args }) return name def olink_set_property(self, name: str, value: Any): diff --git a/src/olink/remote/__init__.py b/src/olink/remote/__init__.py new file mode 100644 index 0000000..5bbb01b --- /dev/null +++ b/src/olink/remote/__init__.py @@ -0,0 +1,6 @@ +from .registry import RemoteRegistry, get_remote_registry +from .node import RemoteNode +from .types import IObjectSource +from .adapter import SourceAdapter + + diff --git a/src/olink/remote/adapter.py b/src/olink/remote/adapter.py new file mode 100644 index 0000000..d542dd7 --- /dev/null +++ b/src/olink/remote/adapter.py @@ -0,0 +1,38 @@ +from typing import Any +from olink.core.types import Name +from .node import RemoteNode +from .types import IObjectSource + +class SourceAdapter(IObjectSource): + node: RemoteNode = None + object_id: str = None + def __init__(self, objectId: str, impl) -> None: + self.object_id = objectId + self.impl = impl + RemoteNode.register_source(self) + + def olink_object_name(self) -> str: + return self.object_id + + def olink_invoke(self, name: str, args: list[Any]) -> Any: + path = Name.path_from_name(name) + func = getattr(self.impl, path) + try: + result = func(**args) + except Exception as e: + print('error: %s' % e) + result = None + return result + + def olink_set_property(self, name: str, value: Any): + # set property value on implementation + path = Name.path_from_name(name) + setattr(self, self.impl, value) + + def olink_linked(self, name: str, node: "RemoteNode"): + # called when the source is linked to a client node + self.node = node + + def olink_collect_properties(self) -> object: + # collect properties from implementation to send back to client node initially + return {k: getattr(self.impl, k) for k in ['count']} diff --git a/src/olink/remote/node.py b/src/olink/remote/node.py new file mode 100644 index 0000000..fda9374 --- /dev/null +++ b/src/olink/remote/node.py @@ -0,0 +1,78 @@ +from olink.core.protocol import Protocol +from olink.core.node import BaseNode +from typing import Any +from .registry import RemoteRegistry, get_remote_registry +from .types import IObjectSource + +class RemoteNode(BaseNode): + # a remote node is a node that is linked to a remote source + def __init__(self): + # initialise node and attaches this node to registry + super().__init__() + + def detach(self): + # detach this node from registry + self.registry().remove_node(self) + + def handle_link(self, name: str) -> None: + # handle link message from client node + # sends init message to client node + source = RemoteNode.get_source(name) + if source: + self.registry().add_node_to_source(name, self) + source.olink_linked(name, self) + props = source.olink_collect_properties() + self.emit_write(Protocol.init_message(name, props)) + + def handle_unlink(self, name: str): + # unlinks names source from registry + source = self.get_source(name) + if source: + self.registry().remove_node_from_source(name, self) + + def handle_set_property(self, name: str, value: Any): + # handle set property message from client node + # calls set property on source + source = self.get_source(name) + if source: + source.olink_set_property(name, value) + + def handle_invoke(self, id: int, name: str, args: list[Any]) -> None: + # handle invoke message from client node + # calls invoke on source + # returns invoke reply message to client node + source = self.get_source(name) + if source: + value = source.olink_invoke(name, args) + self.emit_write(Protocol.invoke_reply_message(id, name, value)) + + def registry(self) -> RemoteRegistry: + # returns global registry + return get_remote_registry() + + @staticmethod + def get_source(name) -> IObjectSource: + # get object source from registry + return get_remote_registry().get_source(name) + + @staticmethod + def register_source(source: IObjectSource): + # add object source to registry + return get_remote_registry().add_source(source) + + @staticmethod + def unregister_source(source: IObjectSource): + # remove object source from registry + return get_remote_registry().remove_source(source) + + @staticmethod + def notify_property_change(name: str, value: Any) -> None: + # notify property change to all named client nodes + for node in get_remote_registry().get_nodes(name): + node.emit_write(Protocol.property_change_message(name, value)) + + @staticmethod + def notify_signal(name: str, args: list[Any]): + # notify signal to all named client nodes + for node in get_remote_registry().get_nodes(name): + node.emit_write(Protocol.signal_message(name, args)) diff --git a/src/olink/remote/registry.py b/src/olink/remote/registry.py new file mode 100644 index 0000000..0313b62 --- /dev/null +++ b/src/olink/remote/registry.py @@ -0,0 +1,91 @@ +from olink.core.types import Name, Base, LogLevel + +from .types import IObjectSource + +class SourceToNodeEntry: + # entry in the remote registry + source: IObjectSource = None + nodes = set() # type: set["RemoteNode"] + + def __init__(self, source=None): + self.source = source + +class RemoteRegistry(Base): + # registry of remote sources + # links sources to nodes + entries: dict[str, SourceToNodeEntry] = {} + + def add_source(self, source: IObjectSource): + # add a source to registry by object name + name = source.olink_object_name() + self.emit_log(LogLevel.DEBUG, + f"RemoteRegistry.add_object_source: {name}") + self._entry(name).source = source + + def remove_source(self, source: IObjectSource): + # remove the given source from the registry + name = source.olink_object_name() + self._remove_entry(name) + + def get_source(self, name: str): + # return the source for the given name + return self._entry(name).source + + def get_nodes(self, name: str): + # return nodes attached to the named source + return self._entry(name).nodes + + def remove_node(self, node: "RemoteNode"): + # remove the given node from the registry + self.emit_log(LogLevel.DEBUG, "RemoteRegistry.detach_remote_node") + for entry in self.entries.values(): + if node in entry.nodes: + entry.nodes.remove(node) + + def add_node_to_source(self, name: str, node: "RemoteNode"): + # add a node to the named source + self._entry(name).nodes.add(node) + + def remove_node_from_source(self, name: str, node: "RemoteNode"): + # remove the given node from the named source + self._entry(name).nodes.remove(node) + + def _entry(self, name: str) -> SourceToNodeEntry: + # returns the entry for the given resource part of the name + resource = Name.resource_from_name(name) + if not resource in self.entries: + self.emit_log(LogLevel.DEBUG, f"add new resource: {resource}") + self.entries[resource] = SourceToNodeEntry() + return self.entries[resource] + + def _remove_entry(self, name: str) -> None: + # remove entry from registry + resource = Name.resource_from_name(name) + if resource in self.entries: + del self.entries[resource] + else: + self.emit_log( + LogLevel.DEBUG, f'remove resource failed, resource not exists: {resource}') + + def _has_entry(self, name: str) -> SourceToNodeEntry: + # checks if the registry has an entry for the given name + resource = Name.resource_from_name(name) + return resource in self.entries + + def init_entry(self, name: str): + # init a new entry for the given name + resource = Name.resource_from_name(name) + if resource in self.entries: + self.entries[resource] = SourceToNodeEntry() + + def clear(self): + self.entries = {} + + +_registry = RemoteRegistry() + + +def get_remote_registry() -> RemoteRegistry: + # returns the remote registry + return _registry + diff --git a/src/olink/remote/types.py b/src/olink/remote/types.py new file mode 100644 index 0000000..8d1ba75 --- /dev/null +++ b/src/olink/remote/types.py @@ -0,0 +1,24 @@ +from typing import Any, Protocol as ProtocolType + +class IObjectSource(ProtocolType): + # interface for object sources + def olink_object_name() -> str: + # returns the object name + raise NotImplementedError() + + def olink_invoke(self, name: str, args: list[Any]): + # called on incoming invoke message + # returns resulting value + raise NotImplementedError() + + def olink_set_property(self, name: str, value: Any): + # called on incoming set property message + raise NotImplementedError() + + def olink_linked(self, name: str, node: "RemoteNode"): + # called when a remote node is linked to this node + raise NotImplementedError() + + def olink_collect_properties(self) -> object: + # returns a dictionary of all properties + raise NotImplementedError() \ No newline at end of file diff --git a/src/olink/remotenode.py b/src/olink/remotenode.py deleted file mode 100644 index b475d45..0000000 --- a/src/olink/remotenode.py +++ /dev/null @@ -1,193 +0,0 @@ -from olink.core.protocol import Protocol -from olink.core.types import Base, LogLevel, Name -from olink.core.node import BaseNode - -from typing import Any, Protocol as ProtocolType - - -class IObjectSource(ProtocolType): - # interface for object sources - def olink_object_name() -> str: - # returns the object name - raise NotImplementedError() - - def olink_invoke(self, name: str, args: list[Any]): - # called on incoming invoke message - # returns resulting value - raise NotImplementedError() - - def olink_set_property(self, name: str, value: Any): - # called on incoming set property message - raise NotImplementedError() - - def olink_linked(self, name: str, node: "RemoteNode"): - # called when a remote node is linked to this node - raise NotImplementedError() - - def olink_collect_properties(self) -> object: - # returns a dictionary of all properties - raise NotImplementedError() - - -class SourceToNodeEntry: - # entry in the remote registry - source: IObjectSource = None - nodes: set["RemoteNode"] = set() - - def __init__(self, source=None): - self.source = source - self.nodes = set() - - -class RemoteRegistry(Base): - # registry of remote sources - # links sources to nodes - entries: dict[str, SourceToNodeEntry] = {} - - def add_source(self, source: IObjectSource): - # add a source to registry by object name - name = source.olink_object_name() - self.emit_log(LogLevel.DEBUG, - f"RemoteRegistry.add_object_source: {name}") - self._entry(name).source = source - - def remove_source(self, source: IObjectSource): - # remove the given source from the registry - name = source.olink_object_name() - self._remove_entry(name) - - def get_source(self, name: str): - # return the source for the given name - return self._entry(name).source - - def get_nodes(self, name: str): - # return nodes attached to the named source - return self._entry(name).nodes - - def remove_node(self, node: "RemoteNode"): - # remove the given node from the registry - self.emit_log(LogLevel.DEBUG, "RemoteRegistry.detach_remote_node") - for entry in self.entries.values(): - if node in entry.nodes: - entry.nodes.remove(node) - - def add_node_to_source(self, name: str, node: "RemoteNode"): - # add a node to the named source - self._entry(name).nodes.add(node) - - def remove_node_from_source(self, name: str, node: "RemoteNode"): - # remove the given node from the named source - self._entry(name).nodes.remove(node) - - def _entry(self, name: str) -> SourceToNodeEntry: - # returns the entry for the given resource part of the name - resource = Name.resource_from_name(name) - if not resource in self.entries: - self.emit_log(LogLevel.DEBUG, f"add new resource: {resource}") - self.entries[resource] = SourceToNodeEntry() - return self.entries[resource] - - def _remove_entry(self, name: str) -> None: - # remove entry from registry - resource = Name.resource_from_name(name) - if resource in self.entries: - del self.entries[resource] - else: - self.emit_log( - LogLevel.DEBUG, f'remove resource failed, resource not exists: {resource}') - - def _has_entry(self, name: str) -> SourceToNodeEntry: - # checks if the registry has an entry for the given name - resource = Name.resource_from_name(name) - return resource in self.entries - - def init_entry(self, name: str): - # init a new entry for the given name - resource = Name.resource_from_name(name) - if resource in self.entries: - self.entries[resource] = SourceToNodeEntry() - - def clear(self): - self.entries = {} - - -_registry = RemoteRegistry() - - -def get_remote_registry() -> RemoteRegistry: - # returns the remote registry - return _registry - - -class RemoteNode(BaseNode): - # a remote node is a node that is linked to a remote source - def __init__(self): - # initialise node and attaches this node to registry - super().__init__() - - def detach(self): - # detach this node from registry - self.registry().remove_node(self) - - def handle_link(self, name: str) -> None: - # handle link message from client node - # sends init message to client node - source = RemoteNode.get_source(name) - if source: - self.registry().add_node_to_source(name, self) - source.olink_linked(name, self) - props = source.olink_collect_properties() - self.emit_write(Protocol.init_message(name, props)) - - def handle_unlink(self, name: str): - # unlinks names source from registry - source = self.get_source(name) - if source: - self.registry().remove_node_from_source(name, self) - - def handle_set_property(self, name: str, value: Any): - # handle set property message from client node - # calls set property on source - source = self.get_source(name) - if source: - source.olink_set_property(name, value) - - def handle_invoke(self, id: int, name: str, args: list[Any]) -> None: - # handle invoke message from client node - # calls invoke on source - # returns invoke reply message to client node - source = self.get_source(name) - if source: - value = source.olink_invoke(name, args) - self.emit_write(Protocol.invoke_reply_message(id, name, value)) - - def registry(self) -> RemoteRegistry: - # returns global registry - return get_remote_registry() - - @staticmethod - def get_source(name) -> IObjectSource: - # get object source from registry - return get_remote_registry().get_source(name) - - @staticmethod - def register_source(source: IObjectSource): - # add object source to registry - return get_remote_registry().add_source(source) - - @staticmethod - def unregister_source(source: IObjectSource): - # remove object source from registry - return get_remote_registry().remove_source(source) - - @staticmethod - def notify_property_change(name: str, value: Any) -> None: - # notify property change to all named client nodes - for node in get_remote_registry().get_nodes(name): - node.emit_write(Protocol.property_change_message(name, value)) - - @staticmethod - def notify_signal(name: str, args: list[Any]): - # notify signal to all named client nodes - for node in get_remote_registry().get_nodes(name): - node.emit_write(Protocol.signal_message(name, args)) diff --git a/src/olink/ws/__init__.py b/src/olink/ws/__init__.py new file mode 100644 index 0000000..fcafd5d --- /dev/null +++ b/src/olink/ws/__init__.py @@ -0,0 +1,2 @@ +from .client import Connection +from .server import Server, run_server \ No newline at end of file diff --git a/src/olink/ws/client.py b/src/olink/ws/client.py index fe872f1..7772c3c 100644 --- a/src/olink/ws/client.py +++ b/src/olink/ws/client.py @@ -1,13 +1,12 @@ import asyncio import websockets as ws -from .emitter import Emitter +from olink.client import ClientNode - -class Client: +class Connection: send_queue = asyncio.Queue() recv_queue = asyncio.Queue() node = None - def __init__(self, node): + def __init__(self, node=ClientNode()): self.node = node def send(self, msg): diff --git a/src/olink/ws/emitter.py b/src/olink/ws/emitter.py deleted file mode 100644 index 6265ed5..0000000 --- a/src/olink/ws/emitter.py +++ /dev/null @@ -1,14 +0,0 @@ - -class Emitter: - def __init__(self): - self._callbacks = {} - - def on(self, event, callback): - self._callbacks[event] = callback - - def emit(self, event, *args): - self._callbacks[event](*args) - - def off(self, event): - self._callbacks.pop(event) - diff --git a/src/olink/ws/server.py b/src/olink/ws/server.py index 608a613..e81b453 100644 --- a/src/olink/ws/server.py +++ b/src/olink/ws/server.py @@ -1,43 +1,8 @@ import websockets as ws -from .emitter import Emitter from typing import Any import asyncio -from .client import Client -from ..remotenode import IObjectSource, RemoteNode +from olink.remote import RemoteNode -class SourceAdapter(IObjectSource): - node: RemoteNode = None - object_id: str = None - def __init__(self, objectId: str, impl) -> None: - self.object_id = objectId - self.impl = impl - RemoteNode.register_source(self) - - def olink_object_name() -> str: - return self.objectId - - def olink_invoke(self, name: str, args: list[Any]) -> Any: - path = Name.path_from_name(name) - func = getattr(self.impl, path) - try: - result = func(**args) - except Exception as e: - print('error: %s' % e) - result = None - return result - - def olink_set_property(self, name: str, value: Any): - # set property value on implementation - path = Name.path_from_name(name) - setattr(self, self.impl, value) - - def olink_linked(self, name: str, node: "RemoteNode"): - # called when the source is linked to a client node - self.node = node - - def olink_collect_properties(self) -> object: - # collect properties from implementation to send back to client node initially - return {k: getattr(self.impl, k) for k in ['count']} class RemotePipe: From 7c8f7bdffa819ce2bbf86626555e38034137c2da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Ryannel?= Date: Wed, 12 Apr 2023 17:32:00 +0200 Subject: [PATCH 3/5] tests are working again many changes to to wrong usage of class vs instance vars of python objects. --- .python-version | 1 + .vscode/launch.json | 16 +++++ demo_server.py | 33 ++++----- examples/server.py | 23 +++--- pytest.ini | 6 ++ src/olink/client/__init__.py | 12 ++-- src/olink/client/node.py | 82 +++++++++------------ src/olink/client/registry.py | 39 ++++------ src/olink/client/sink.py | 26 ++++--- src/olink/core/__init__.py | 8 +-- src/olink/core/hook.py | 10 +-- src/olink/core/node.py | 32 +++++---- src/olink/core/protocol.py | 29 ++++---- src/olink/core/types.py | 52 +++++++------- src/olink/mocks/mocksink.py | 28 ++++---- src/olink/mocks/mocksource.py | 42 ++++++----- src/olink/remote/__init__.py | 10 ++- src/olink/remote/adapter.py | 36 ++++++---- src/olink/remote/node.py | 47 ++++--------- src/olink/remote/registry.py | 76 +++++++++++--------- src/olink/ws/client.py | 41 ++++++----- src/olink/ws/server.py | 55 +++++++++------ src/olink/ws/test_client.py | 25 +++++++ tests/test_clientnode.py | 80 ++++++++++++++------- tests/test_comms.py | 129 ++++++++++++++++++++++++---------- tests/test_protocol.py | 6 +- tests/test_remotenode.py | 54 +++++++------- 27 files changed, 569 insertions(+), 429 deletions(-) create mode 100644 .python-version create mode 100644 .vscode/launch.json create mode 100644 pytest.ini create mode 100644 src/olink/ws/test_client.py diff --git a/.python-version b/.python-version new file mode 100644 index 0000000..afad818 --- /dev/null +++ b/.python-version @@ -0,0 +1 @@ +3.11.0 diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..306f58e --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": true + } + ] +} \ No newline at end of file diff --git a/demo_server.py b/demo_server.py index 76f9d03..4aef3bf 100644 --- a/demo_server.py +++ b/demo_server.py @@ -11,26 +11,28 @@ class Counter: - count = 0 + count: int _node: RemoteNode + def __init__(self): + self.count = 0 + def increment(self): self.count += 1 # notify all registered clients - RemoteNode.notify_property_change('demo.Counter/count', self.count) + if self._node: + self._node.notify_property_changed("demo.Counter/count", self.count) class CounterAdapter(IObjectSource): - node: RemoteNode = None + _node: RemoteNode def __init__(self, impl): self.impl = impl - # need to register this source with the registry - RemoteNode.register_source(self) def olink_object_name(self): # name this source is registered under - return 'demo.Counter' + return "demo.Counter" def olink_invoke(self, name: str, args: list[Any]) -> Any: # called on incoming invoke message @@ -48,7 +50,7 @@ def olink_linked(self, name: str, node: "RemoteNode"): self.impl._node = node def olink_collect_properties(self) -> object: - return {k: getattr(self.impl, k) for k in ['count']} + return {k: getattr(self.impl, k) for k in ["count"]} counter = Counter() @@ -61,26 +63,27 @@ class RemoteEndpoint(WebSocketEndpoint): queue = Queue() async def sender(self, ws): - print('start sender') + print("start sender") while True: - print('001') + print("001") msg = await self.queue.get() - print('send', msg) + print("send", msg) await ws.send_text(msg) self.queue.task_done() async def on_connect(self, ws: WebSocket): - print('on_connect') + print("on_connect") asyncio.create_task(self.sender(ws)) def writer(msg: str): - print('writer', msg) + print("writer", msg) self.queue.put_nowait(msg) + self.node.on_write(writer) await super().on_connect(ws) async def on_receive(self, ws: WebSocket, data: Any) -> None: - print('on_receive', data) + print("on_receive", data) self.node.handle_message(data) async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: @@ -89,8 +92,6 @@ async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: await self.queue.join() -routes = [ - WebSocketRoute("/ws", RemoteEndpoint) -] +routes = [WebSocketRoute("/ws", RemoteEndpoint)] app = Starlette(routes=routes) diff --git a/examples/server.py b/examples/server.py index 19a082c..842d9ec 100644 --- a/examples/server.py +++ b/examples/server.py @@ -16,7 +16,7 @@ class CounterService: def increment(self): self.count += 1 - self._node.notify_property_change('demo.Counter/count', self.count) + self._node.notify_property_changed("demo.Counter/count", self.count) class CounterWebsocketAdapter(IObjectSource): @@ -30,7 +30,7 @@ def __init__(self, impl): def olink_object_name(self): # return service name - return 'demo.Counter' + return "demo.Counter" def olink_invoke(self, name: str, args: list[Any]) -> Any: # handle the remote call from client node @@ -42,7 +42,7 @@ def olink_invoke(self, name: str, args: list[Any]) -> Any: result = func(**args) except Exception as e: # need to have proper exception handling here - print('error: %s' % e) + print("error: %s" % e) result = None # results will be send back to calling client node return result @@ -58,7 +58,7 @@ def olink_linked(self, name: str, node: "RemoteNode"): def olink_collect_properties(self) -> object: # collect properties from implementation to send back to client node initially - return {k: getattr(self.impl, k) for k in ['count']} + return {k: getattr(self.impl, k) for k in ["count"]} # create the service implementation @@ -77,23 +77,24 @@ class RemoteEndpoint(WebSocketEndpoint): async def sender(self, ws): # sender coroutine, messages from queue are send to client - print('start sender') + print("start sender") while True: msg = await self.queue.get() - print('send', msg) + print("send", msg) await ws.send_text(msg) self.queue.task_done() async def on_connect(self, ws: WebSocket): # handle a socket connection - print('on_connect') + print("on_connect") # register a sender to the connection asyncio.create_task(self.sender(ws)) # a writer function to queue messages def writer(msg: str): - print('write to queue:', msg) + print("write to queue:", msg) self.queue.put_nowait(msg) + # register the writer function to the node self.node.on_write(writer) # call the super connection handler @@ -101,7 +102,7 @@ def writer(msg: str): async def on_receive(self, ws: WebSocket, data: Any) -> None: # handle a message from a client socket - print('on_receive', data) + print("on_receive", data) self.node.handle_message(data) async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: @@ -114,9 +115,7 @@ async def on_disconnect(self, websocket: WebSocket, close_code: int) -> None: # see https://www.starlette.io/routing/ -routes = [ - WebSocketRoute("/ws", RemoteEndpoint) -] +routes = [WebSocketRoute("/ws", RemoteEndpoint)] # call with `uvicorn server:app --port 8282` diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..7a447b5 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,6 @@ +[pytest] +minversion = 6.0 +addopts = -ra -q +testpaths = + tests +pythonpath = src diff --git a/src/olink/client/__init__.py b/src/olink/client/__init__.py index 58d15bf..24b25f3 100644 --- a/src/olink/client/__init__.py +++ b/src/olink/client/__init__.py @@ -1,4 +1,8 @@ -from .node import ClientNode, InvokeReplyArg, InvokeReplyFunc -from .registry import ClientRegistry, get_client_registry -from .types import IObjectSink -from .sink import AbstractSink \ No newline at end of file +from .node import ( + ClientNode as ClientNode, + InvokeReplyArg as InvokeReplyArg, + InvokeReplyFunc as InvokeReplyFunc, +) +from .registry import ClientRegistry as ClientRegistry +from .types import IObjectSink as IObjectSink +from .sink import AbstractSink as AbstractSink diff --git a/src/olink/client/node.py b/src/olink/client/node.py index 974f69d..6e36cd6 100644 --- a/src/olink/client/node.py +++ b/src/olink/client/node.py @@ -2,79 +2,65 @@ from typing import Any, Callable, Optional from olink.core import LogLevel, MsgType, BaseNode, Protocol from .types import IObjectSink -from .registry import ClientRegistry, get_client_registry +from .registry import ClientRegistry + class InvokeReplyArg: def __init__(self, name: str, value: Any): self.name = name self.value = value - name: str - value: Any InvokeReplyFunc = Callable[[InvokeReplyArg], None] class ClientNode(BaseNode): - # client side node - invokes_pending: dict[int, InvokeReplyFunc] = {} - requestId = 0 + def __init__(self, registry: ClientRegistry): + super().__init__() + self._invokes_pending: dict[int, InvokeReplyFunc] = {} + self._requestId: int = 0 + self._registry = registry def registry(self) -> ClientRegistry: - return get_client_registry() + return self._registry def detach(self) -> None: self.registry().remove_node(self) def next_request_id(self) -> int: - self.requestId += 1 - return self.requestId + self._requestId += 1 + return self._requestId - def invoke_remote(self, name: str, args: list[Any], func: Optional[InvokeReplyFunc]) -> None: - self.emit_log(LogLevel.DEBUG, - f"ClientNode.invoke_remote: {name} {args}") + def invoke_remote( + self, name: str, args: list[Any], func: Optional[InvokeReplyFunc] + ) -> None: + self.emit_log(LogLevel.DEBUG, f"ClientNode.invoke_remote: {name} {args}") request_id = self.next_request_id() if func: - self.invokes_pending[request_id] = func + self._invokes_pending[request_id] = func self.emit_write(Protocol.invoke_message(request_id, name, args)) def set_remote_property(self, name: str, value: Any) -> None: - # send remote propertymessage - self.emit_log(LogLevel.DEBUG, - f"ClientNode.set_remote_property: {name} {value}") + # send remote property message + self.emit_log(LogLevel.DEBUG, f"ClientNode.set_remote_property: {name} {value}") self.emit_write(Protocol.set_property_message(name, value)) def link_node(self, name: str): # register this node to sink - self.registry().add_node_to_sink(name, self) + self.registry().add_node(name, self) def unlink_node(self, name: str) -> None: # unregister this node from sink self.registry().remove_node_from_sink(name, self) - @staticmethod - def register_sink(sink: IObjectSink) -> Optional["ClientNode"]: - # register sink to registry - return get_client_registry().register_sink(sink) - - @staticmethod - def unregister_sink(sink: IObjectSink) -> None: - # unregister sink from registry - return get_client_registry().unregister_sink(sink) - - @staticmethod - def get_sink(name: str) -> Optional[IObjectSink]: - # get sink from registry - return get_client_registry().get_sink(name) - def link_remote(self, name: str): # register this node from sink and send a link message self.emit_log(LogLevel.DEBUG, f"ClientNode.linkRemote: {name}") - self.registry().add_node_to_sink(name, self) + self.registry().add_node(name, self) self.emit_write(Protocol.link_message(name)) def unlink_remote(self, name: str): - # unlink this node froom sink and send an unlink message + # unlink this node from sink and send an unlink message self.emit_log(LogLevel.DEBUG, f"ClientNode.unlink_remote: {name}") self.emit_write(Protocol.unlink_message(name)) self.registry().remove_node_from_sink(name, self) @@ -82,40 +68,40 @@ def unlink_remote(self, name: str): def handle_init(self, name: str, props: object): # handle init message from source self.emit_log(LogLevel.DEBUG, f"ClientNode.handle_init: {name}") - sink = self.get_sink(name) + sink = self.registry().get_sink(name) if sink: sink.olink_on_init(name, props, self) def handle_property_change(self, name: str, value: Any) -> None: # handle property change message from source - self.emit_log(LogLevel.DEBUG, - f"ClientNode.handle_property_change: {name}") - sink = self.get_sink(name) + self.emit_log(LogLevel.DEBUG, f"ClientNode.handle_property_change: {name}") + sink = self.registry().get_sink(name) if sink: sink.olink_on_property_changed(name, value) def handle_invoke_reply(self, id: int, name: str, value: Any) -> None: # handle invoke reply message from source - self.emit_log(LogLevel.DEBUG, - f"ClientNode.handle_invoke_reply: {id} {name} {value}") - if id in self.invokes_pending: - func = self.invokes_pending[id] + self.emit_log( + LogLevel.DEBUG, f"ClientNode.handle_invoke_reply: {id} {name} {value}" + ) + if id in self._invokes_pending: + func = self._invokes_pending[id] if func: arg = InvokeReplyArg(name, value) func(arg) - del self.invokes_pending[id] + del self._invokes_pending[id] else: self.emit_log(LogLevel.DEBUG, f"no pending invoke: {id} {name}") def handle_signal(self, name: str, args: list[Any]) -> None: # handle signal message from source - self.emit_log(LogLevel.DEBUG, - f"ClientNode.handle_signal: {name} {args}") - sink = self.get_sink(name) + self.emit_log(LogLevel.DEBUG, f"ClientNode.handle_signal: {name} {args}") + sink = self.registry().get_sink(name) if sink: sink.olink_on_signal(name, args) def handle_error(self, msgType: MsgType, id: int, error: str): # handle error message from source - self.emit_log(LogLevel.DEBUG, - f"ClientNode.handle_error: {msgType} {id} {error}") + self.emit_log( + LogLevel.DEBUG, f"ClientNode.handle_error: {msgType} {id} {error}" + ) diff --git a/src/olink/client/registry.py b/src/olink/client/registry.py index 210e522..ad34849 100644 --- a/src/olink/client/registry.py +++ b/src/olink/client/registry.py @@ -1,20 +1,21 @@ from .types import IObjectSink from olink.core import Name, Base, LogLevel -from typing import Optional +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from .node import ClientNode -class SinkToClientEntry: - # entry in the client registry - sink: IObjectSink = None - node: "ClientNode" = None +class SinkToClientEntry: def __init__(self, sink=None): - self.sink = sink - self.node = None + self.sink: IObjectSink = sink + self.node: "ClientNode" = None class ClientRegistry(Base): - # client side registry to link sinks to nodes - entries: dict[str, SinkToClientEntry] = {} + def __init__(self) -> None: + super().__init__() + self.entries: dict[str, SinkToClientEntry] = {} def remove_node(self, node: "ClientNode"): # remove node from all sinks @@ -22,7 +23,7 @@ def remove_node(self, node: "ClientNode"): if entry.node is node: entry.node = None - def add_node_to_sink(self, name: str, node: "ClientNode"): + def add_node(self, name: str, node: "ClientNode"): # add not to named sink self._entry(name).node = node @@ -34,16 +35,16 @@ def remove_node_from_sink(self, name: str, node: "ClientNode"): self.entries[resource].node = None else: self.emit_log( - LogLevel.DEBUG, f"unlink node failed, not the same node: {resource}") + LogLevel.DEBUG, f"unlink node failed, not the same node: {resource}" + ) - def register_sink(self, sink: IObjectSink) -> "ClientNode": + def add_sink(self, sink: IObjectSink) -> "ClientNode": # register sink using object name name = sink.olink_object_name() entry = self._entry(name) entry.sink = sink - return entry.node - def unregister_sink(self, sink: IObjectSink): + def remove_sink(self, sink: IObjectSink): # unregister sink using object name name = sink.olink_object_name() self._remove_entry(name) @@ -68,13 +69,3 @@ def _remove_entry(self, name: str) -> None: # remove an entry by name resource = Name.resource_from_name(name) del self.entries[resource] - - -# global client registry -_registry = ClientRegistry() - - -def get_client_registry() -> ClientRegistry: - # get global client registry - return _registry - diff --git a/src/olink/client/sink.py b/src/olink/client/sink.py index bdbdd1e..cdd3075 100644 --- a/src/olink/client/sink.py +++ b/src/olink/client/sink.py @@ -6,26 +6,25 @@ class AbstractSink(IObjectSink): - on_property_changed = EventHook() - object_id: str = None - - client = None - def __init__(self, object_id: str): - self.object_id = object_id - self.client = ClientNode.register_sink(self) + self._object_id = object_id + self.on_property_changed = EventHook() + self._node: ClientNode = None async def _invoke(self, name, args): future = asyncio.get_running_loop().create_future() + def func(args): return future.set_result(args.value) - self.client.invoke_remote(f'{self.object_id}/{name}', args, func) + + self.get_node().invoke_remote(f"{self._object_id}/{name}", args, func) return await asyncio.wait_for(future, 500) def olink_object_name(self): - return self.object_id + return self._object_id - def olink_on_init(self, name: str, props: object, node: ClientNode): + def olink_on_init(self, name: str, props: object, node: ClientNode): + self._node = node for k in props: setattr(self, k, props[k]) @@ -36,5 +35,10 @@ def olink_on_property_changed(self, name: str, value: Any) -> None: def olink_on_signal(self, name: str, args: list[Any]): path = Name.path_from_name(name) - hook = getattr(self, f'on_{path}') + hook = getattr(self, f"on_{path}") hook.fire(*args) + + def get_node(self) -> ClientNode: + if self._node is None: + raise Exception("Sink not linked to node") + return self._node diff --git a/src/olink/core/__init__.py b/src/olink/core/__init__.py index 325d9f4..c253fa5 100644 --- a/src/olink/core/__init__.py +++ b/src/olink/core/__init__.py @@ -1,4 +1,4 @@ -from .hook import EventHook -from .types import Name, Base, LogLevel, MsgType -from .protocol import IProtocolListener, Protocol -from .node import BaseNode +from .hook import EventHook as EventHook +from .types import Name as Name, Base as Base, LogLevel as LogLevel, MsgType as MsgType +from .protocol import IProtocolListener as IProtocolListener, Protocol as Protocol +from .node import BaseNode as BaseNode diff --git a/src/olink/core/hook.py b/src/olink/core/hook.py index becad12..a702554 100644 --- a/src/olink/core/hook.py +++ b/src/olink/core/hook.py @@ -1,15 +1,15 @@ class EventHook(object): def __init__(self): - self.__handlers = [] + self._handlers: list[callable] = [] def __iadd__(self, handler): - self.__handlers.append(handler) + self._handlers.append(handler) return self def __isub__(self, handler): - self.__handlers.remove(handler) + self._handlers.remove(handler) return self def fire(self, *args, **kwargs): - for handler in self.__handlers: - handler(*args, **kwargs) \ No newline at end of file + for handler in self._handlers: + handler(*args, **kwargs) diff --git a/src/olink/core/node.py b/src/olink/core/node.py index 63658a4..477b0a4 100644 --- a/src/olink/core/node.py +++ b/src/olink/core/node.py @@ -1,32 +1,34 @@ from typing import Any from olink.core.protocol import IProtocolListener, Protocol -from olink.core.types import Base, LogLevel, MessageConverter, MessageFormat, WriteMessageFunc +from olink.core.types import ( + Base, + LogLevel, + MessageConverter, + MessageFormat, + WriteMessageFunc, +) class BaseNode(Base, IProtocolListener): - # base node class - write_func: WriteMessageFunc = None - converter: MessageConverter = None - protocol: Protocol = None - def __init__(self): - super() - self.protocol = Protocol(self) - self.converter = MessageConverter(MessageFormat.JSON) + super().__init__() + self._protocol = Protocol(self) + self._converter = MessageConverter(MessageFormat.JSON) + self._write_func: WriteMessageFunc = None def on_write(self, func: WriteMessageFunc) -> None: # set the write function - self.write_func = func + self._write_func = func def emit_write(self, msg: list[Any]) -> None: # emit a message using the write function - if self.write_func: - data = self.converter.to_string(msg) - self.write_func(data) + if self._write_func: + data = self._converter.to_string(msg) + self._write_func(data) else: self.emit_log(LogLevel.DEBUG, f"write not set on protocol: {msg}") def handle_message(self, data: str) -> None: # handle a message and pass is on to the protocol - msg = self.converter.from_string(data) - self.protocol.handle_message(msg) + msg = self._converter.from_string(data) + self._protocol.handle_message(msg) diff --git a/src/olink/core/protocol.py b/src/olink/core/protocol.py index 781a1e3..f7c2a04 100644 --- a/src/olink/core/protocol.py +++ b/src/olink/core/protocol.py @@ -43,11 +43,9 @@ def handle_error(self, msgType: int, id: int, error: str) -> None: class Protocol(Base): - listener: IProtocolListener = None - def __init__(self, listener: IProtocolListener): super() - self.listener = listener + self._listener = listener @staticmethod def link_message(name: str) -> list[Any]: @@ -69,7 +67,7 @@ def set_property_message(name: str, value: Any) -> list[Any]: return [MsgType.SET_PROPERTY, name, value] @staticmethod - def property_change_message(name: str, value: Any) -> list[Any]: + def property_changed_message(name: str, value: Any) -> list[Any]: """signal property change to the client linked to the remote objects""" return [MsgType.PROPERTY_CHANGE, name, value] @@ -92,39 +90,38 @@ def error_message(msgType: MsgType, id: int, error: str) -> list[Any]: return [MsgType.ERROR, msgType, id, error] def handle_message(self, msg: list[Any]) -> bool: - if not self.listener: + if not self._listener: self.emit_log(LogLevel.DEBUG, "no listener installed") return False msgType = msg[0] if msgType == MsgType.LINK: _, name = msg - self.listener.handle_link(name) + self._listener.handle_link(name) elif msgType == MsgType.INIT: _, name, props = msg - self.listener.handle_init(name, props) + self._listener.handle_init(name, props) elif msgType == MsgType.UNLINK: _, name = msg - self.listener.handle_unlink(name) + self._listener.handle_unlink(name) elif msgType == MsgType.SET_PROPERTY: _, name, value = msg - self.listener.handle_set_property(name, value) + self._listener.handle_set_property(name, value) elif msgType == MsgType.PROPERTY_CHANGE: _, name, value = msg - self.listener.handle_property_change(name, value) + self._listener.handle_property_change(name, value) elif msgType == MsgType.INVOKE: _, id, name, args = msg - self.listener.handle_invoke(id, name, args) + self._listener.handle_invoke(id, name, args) elif msgType == MsgType.INVOKE_REPLY: _, id, name, value = msg - self.listener.handle_invoke_reply(id, name, value) + self._listener.handle_invoke_reply(id, name, value) elif msgType == MsgType.SIGNAL: _, name, args = msg - self.listener.handle_signal(name, args) + self._listener.handle_signal(name, args) elif msgType == MsgType.ERROR: _, msgType, id, error = msg - self.listener.handle_error(msgType, id, error) + self._listener.handle_error(msgType, id, error) else: - self.emit_log(LogLevel.DEBUG, - f"not supported message type: {msgType}") + self.emit_log(LogLevel.DEBUG, f"not supported message type: {msgType}") return False return True diff --git a/src/olink/core/types.py b/src/olink/core/types.py index 62edb44..f433c7f 100644 --- a/src/olink/core/types.py +++ b/src/olink/core/types.py @@ -5,22 +5,22 @@ class MsgType(IntEnum): - LINK = 10, - INIT = 11, - UNLINK = 12, - SET_PROPERTY = 20, - PROPERTY_CHANGE = 21, - INVOKE = 30, - INVOKE_REPLY = 31, - SIGNAL = 40, - ERROR = 90, + LINK = (10,) + INIT = (11,) + UNLINK = (12,) + SET_PROPERTY = (20,) + PROPERTY_CHANGE = (21,) + INVOKE = (30,) + INVOKE_REPLY = (31,) + SIGNAL = (40,) + ERROR = (90,) class MessageFormat(IntEnum): - JSON = 1, - BSON = 2, - MSGPACK = 3, - CBOR = 4, + JSON = (1,) + BSON = (2,) + MSGPACK = (3,) + CBOR = (4,) class Name: @@ -29,30 +29,27 @@ class Name: @staticmethod def resource_from_name(name: str) -> str: # return the resource name from a name - return name.split('/')[0] + return name.split("/")[0] @staticmethod def path_from_name(name: str) -> str: # return the path from a name - return name.split('/')[-1] + return name.split("/")[-1] @staticmethod def has_path(name: str) -> bool: # return true if name has a path - return '/' in name + return "/" in name @staticmethod def create_name(resource: str, path: str) -> str: # create a name from a resource and a path - return f'{resource}/{path}' + return f"{resource}/{path}" class MessageConverter: - # convert a message from/to a string - format: MessageFormat = MessageFormat.JSON - - def __init__(self, format: MessageFormat): - self.format = format + def __init__(self, format: MessageFormat = MessageFormat.JSON): + self._format = format def from_string(self, message: str) -> list[Any]: return json.loads(message) @@ -65,10 +62,10 @@ def to_string(self, data: list[Any]) -> str: class LogLevel: - DEBUG = 1, - INFO = 2, - WARNING = 3, - ERROR = 4, + DEBUG = (1,) + INFO = (2,) + WARNING = (3,) + ERROR = (4,) WriteLogFunc = Callable[[LogLevel, str], None] @@ -80,7 +77,8 @@ def log(level: LogLevel, msg: str) -> None: class Base: - log_func: WriteLogFunc = None + def __init__(self) -> None: + self.log_func: WriteLogFunc = None def on_log(self, func: WriteLogFunc): self.log_func = func diff --git a/src/olink/mocks/mocksink.py b/src/olink/mocks/mocksink.py index 87affd9..7e2e6f4 100644 --- a/src/olink/mocks/mocksink.py +++ b/src/olink/mocks/mocksink.py @@ -2,45 +2,45 @@ from typing import Any, Optional from olink.client import ClientNode, IObjectSink, InvokeReplyArg + class MockSink(IObjectSink): - name: str - events: list[Any] = [] - node: Optional[ClientNode] = None - properties: dict[str, Any] = {} def __init__(self, name: str): self.name = name - self.node = ClientNode.register_sink(self) + self.events: list[Any] = [] + self.node: Optional[ClientNode] = None + self.properties: dict[str, Any] = {} def invoke(self, name: str, args: list[Any]): if self.node: + def func(arg: InvokeReplyArg): - self.events.append({'type': 'invoke-reply', 'name': arg.name, 'value': arg.value}) + self.events.append( + {"type": "invoke-reply", "name": arg.name, "value": arg.value} + ) + self.node.invoke_remote(name, args, func) - + def olink_object_name(self) -> str: return self.name def olink_on_signal(self, name: str, args: list[Any]) -> None: - self.events.append({'type': 'signal', 'name': name, 'args': args}) + self.events.append({"type": "signal", "name": name, "args": args}) def olink_on_property_changed(self, name: str, value: Any) -> None: path = Name.path_from_name(name) - self.events.append({ 'type': 'property_change', 'name': name, 'value': value}) + self.events.append({"type": "property_change", "name": name, "value": value}) self.properties[path] = value def olink_on_init(self, name: str, props: object, node: "ClientNode"): - self.events.append({'type': 'init', 'name': name, 'props': props}) + self.events.append({"type": "init", "name": name, "props": props}) self.node = node self.properties = props def olink_on_release(self) -> None: - self.events.append({'type': 'release'}) + self.events.append({"type": "release"}) self.node = None def clear(self): self.events = [] self.properties = {} self.node = None - - - diff --git a/src/olink/mocks/mocksource.py b/src/olink/mocks/mocksource.py index 4699b6f..25c22db 100644 --- a/src/olink/mocks/mocksource.py +++ b/src/olink/mocks/mocksource.py @@ -2,49 +2,47 @@ from typing import Any from olink.remote import IObjectSource, RemoteNode -class MockSource(IObjectSource): - name: str - events: list[Any] = [] - properties: dict[str, Any] = {} - node: RemoteNode = None +class MockSource(IObjectSource): def __init__(self, name: str): self.name = name + self.events: list[Any] = [] + self.properties: dict[str, Any] = {} + self.node: RemoteNode = None def set_property(self, name: str, value: Any): - RemoteNode.notify_property_change(name, value) + path = Name.path_from_name(name) + if self.properties.get(path) != value: + self.properties[path] = value + self.get_node().notify_property_changed(name, value) def notify_signal(self, name: str, args: list[Any]): - RemoteNode.notify_signal(name, args) + self.get_node().notify_signal(name, args) def olink_object_name(self) -> str: return self.name def olink_invoke(self, name: str, args: list[Any]): - self.events.append({ 'type': 'invoke', 'name': name, 'args': args }) + self.events.append({"type": "invoke", "name": name, "args": args}) return name def olink_set_property(self, name: str, value: Any): - path = Name.path_from_name(name) - self.events.append({'type': 'set_property', 'name': name, 'value': value}) - if not path in self.properties: - # assign new value - self.properties[path] = value - RemoteNode.notify_property_change(name, value) - else: - # update existing value - if not self.properties[path] == value: - self.properties[path] = value - RemoteNode.notify_property_change(name, value) - + self.events.append({"type": "set_property", "name": name, "value": value}) + self.set_property(name, value) + def olink_linked(self, name: str, node: RemoteNode): - self.events.append({'type': 'linked', 'name': name}) self.node = node + self.events.append({"type": "linked", "name": name}) def olink_collect_properties(self) -> object: return self.properties + def get_node(self) -> RemoteNode: + if not self.node: + raise Exception("Node not set") + return self.node + def clear(self): self.events = [] self.properties = {} - self.node = None \ No newline at end of file + self.node = None diff --git a/src/olink/remote/__init__.py b/src/olink/remote/__init__.py index 5bbb01b..5ca01a0 100644 --- a/src/olink/remote/__init__.py +++ b/src/olink/remote/__init__.py @@ -1,6 +1,4 @@ -from .registry import RemoteRegistry, get_remote_registry -from .node import RemoteNode -from .types import IObjectSource -from .adapter import SourceAdapter - - +from .registry import RemoteRegistry as RemoteRegistry +from .node import RemoteNode as RemoteNode +from .types import IObjectSource as IObjectSource +from .adapter import SourceAdapter as SourceAdapter diff --git a/src/olink/remote/adapter.py b/src/olink/remote/adapter.py index d542dd7..82777e9 100644 --- a/src/olink/remote/adapter.py +++ b/src/olink/remote/adapter.py @@ -2,37 +2,49 @@ from olink.core.types import Name from .node import RemoteNode from .types import IObjectSource +import logging + class SourceAdapter(IObjectSource): - node: RemoteNode = None - object_id: str = None def __init__(self, objectId: str, impl) -> None: - self.object_id = objectId + self._object_id = objectId self.impl = impl - RemoteNode.register_source(self) + self.impl._change += self.on_change + self.impl._emit += self.on_emit + self._node: RemoteNode = None + + def on_emit(self, path: str, args: list[Any]): + name = Name.create_name(self._object_id, path) + RemoteNode.notify_signal(name, args) + + def on_change(self, name: str, value: Any): + name = Name.create_name(self._object_id, name) + RemoteNode.notify_property_changed(name, value) def olink_object_name(self) -> str: - return self.object_id - + return self._object_id + def olink_invoke(self, name: str, args: list[Any]) -> Any: path = Name.path_from_name(name) func = getattr(self.impl, path) try: - result = func(**args) + result = func(*args) except Exception as e: - print('error: %s' % e) + logging.exception(e) + print("error: %s" % e) result = None + raise e return result - + def olink_set_property(self, name: str, value: Any): # set property value on implementation path = Name.path_from_name(name) - setattr(self, self.impl, value) + setattr(self.impl, path, value) def olink_linked(self, name: str, node: "RemoteNode"): # called when the source is linked to a client node - self.node = node + self._node = node def olink_collect_properties(self) -> object: # collect properties from implementation to send back to client node initially - return {k: getattr(self.impl, k) for k in ['count']} + return {k: getattr(self.impl, k) for k in ["count"]} diff --git a/src/olink/remote/node.py b/src/olink/remote/node.py index fda9374..98da84f 100644 --- a/src/olink/remote/node.py +++ b/src/olink/remote/node.py @@ -1,14 +1,14 @@ -from olink.core.protocol import Protocol -from olink.core.node import BaseNode +from olink.core import BaseNode, Name, Protocol from typing import Any -from .registry import RemoteRegistry, get_remote_registry +from .registry import RemoteRegistry from .types import IObjectSource + class RemoteNode(BaseNode): - # a remote node is a node that is linked to a remote source - def __init__(self): + def __init__(self, registry: RemoteRegistry): # initialise node and attaches this node to registry super().__init__() + self._registry = registry def detach(self): # detach this node from registry @@ -17,9 +17,9 @@ def detach(self): def handle_link(self, name: str) -> None: # handle link message from client node # sends init message to client node - source = RemoteNode.get_source(name) + source = self.get_source(name) if source: - self.registry().add_node_to_source(name, self) + self.registry().add_node(name, self) source.olink_linked(name, self) props = source.olink_collect_properties() self.emit_write(Protocol.init_message(name, props)) @@ -47,32 +47,13 @@ def handle_invoke(self, id: int, name: str, args: list[Any]) -> None: self.emit_write(Protocol.invoke_reply_message(id, name, value)) def registry(self) -> RemoteRegistry: - # returns global registry - return get_remote_registry() - - @staticmethod - def get_source(name) -> IObjectSource: - # get object source from registry - return get_remote_registry().get_source(name) - - @staticmethod - def register_source(source: IObjectSource): - # add object source to registry - return get_remote_registry().add_source(source) + return self._registry - @staticmethod - def unregister_source(source: IObjectSource): - # remove object source from registry - return get_remote_registry().remove_source(source) + def get_source(self, name: str) -> IObjectSource: + return self.registry().get_source(name) - @staticmethod - def notify_property_change(name: str, value: Any) -> None: - # notify property change to all named client nodes - for node in get_remote_registry().get_nodes(name): - node.emit_write(Protocol.property_change_message(name, value)) + def notify_property_changed(self, name: str, value: Any) -> None: + self.registry().notify_property_changed(name, value) - @staticmethod - def notify_signal(name: str, args: list[Any]): - # notify signal to all named client nodes - for node in get_remote_registry().get_nodes(name): - node.emit_write(Protocol.signal_message(name, args)) + def notify_signal(self, name: str, args: list[Any]) -> None: + self.registry().notify_signal(name, args) diff --git a/src/olink/remote/registry.py b/src/olink/remote/registry.py index 0313b62..17b67f6 100644 --- a/src/olink/remote/registry.py +++ b/src/olink/remote/registry.py @@ -1,25 +1,26 @@ -from olink.core.types import Name, Base, LogLevel - +from olink.core import Name, Base, LogLevel, Protocol +from typing import Any, TYPE_CHECKING from .types import IObjectSource -class SourceToNodeEntry: - # entry in the remote registry - source: IObjectSource = None - nodes = set() # type: set["RemoteNode"] +if TYPE_CHECKING: + from .node import RemoteNode + +class SourceToNodeEntry: def __init__(self, source=None): - self.source = source + self.source: IObjectSource = source + self.nodes: set["RemoteNode"] = set() + class RemoteRegistry(Base): - # registry of remote sources - # links sources to nodes - entries: dict[str, SourceToNodeEntry] = {} + def __init__(self) -> None: + super().__init__() + self._entries: dict[str, SourceToNodeEntry] = {} def add_source(self, source: IObjectSource): - # add a source to registry by object name + # register a new source in the registry name = source.olink_object_name() - self.emit_log(LogLevel.DEBUG, - f"RemoteRegistry.add_object_source: {name}") + self.emit_log(LogLevel.DEBUG, f"RemoteRegistry.add_object_source: {name}") self._entry(name).source = source def remove_source(self, source: IObjectSource): @@ -38,13 +39,15 @@ def get_nodes(self, name: str): def remove_node(self, node: "RemoteNode"): # remove the given node from the registry self.emit_log(LogLevel.DEBUG, "RemoteRegistry.detach_remote_node") - for entry in self.entries.values(): + for entry in self._entries.values(): if node in entry.nodes: entry.nodes.remove(node) - def add_node_to_source(self, name: str, node: "RemoteNode"): + def add_node(self, name: str, node: "RemoteNode"): # add a node to the named source - self._entry(name).nodes.add(node) + entry = self._entry(name) + if not node in entry.nodes: + entry.nodes.add(node) def remove_node_from_source(self, name: str, node: "RemoteNode"): # remove the given node from the named source @@ -53,39 +56,46 @@ def remove_node_from_source(self, name: str, node: "RemoteNode"): def _entry(self, name: str) -> SourceToNodeEntry: # returns the entry for the given resource part of the name resource = Name.resource_from_name(name) - if not resource in self.entries: + if not resource in self._entries: self.emit_log(LogLevel.DEBUG, f"add new resource: {resource}") - self.entries[resource] = SourceToNodeEntry() - return self.entries[resource] + self._entries[resource] = SourceToNodeEntry() + return self._entries[resource] def _remove_entry(self, name: str) -> None: # remove entry from registry resource = Name.resource_from_name(name) - if resource in self.entries: - del self.entries[resource] + if resource in self._entries: + del self._entries[resource] else: self.emit_log( - LogLevel.DEBUG, f'remove resource failed, resource not exists: {resource}') + LogLevel.DEBUG, + f"remove resource failed, resource not exists: {resource}", + ) def _has_entry(self, name: str) -> SourceToNodeEntry: # checks if the registry has an entry for the given name resource = Name.resource_from_name(name) - return resource in self.entries + return resource in self._entries def init_entry(self, name: str): # init a new entry for the given name resource = Name.resource_from_name(name) - if resource in self.entries: - self.entries[resource] = SourceToNodeEntry() + if resource in self._entries: + self._entries[resource] = SourceToNodeEntry() def clear(self): - self.entries = {} + self._entries.clear() + def notify_property_changed(self, name: str, value: Any): + # notify property change to all named client nodes + resource = Name.resource_from_name(name) + for node in self.get_nodes(resource): + msg = Protocol.property_changed_message(name, value) + node.emit_write(msg) -_registry = RemoteRegistry() - - -def get_remote_registry() -> RemoteRegistry: - # returns the remote registry - return _registry - + def notify_signal(self, name: str, args: tuple): + # notify signal to all named client nodes + resource = Name.resource_from_name(name) + for node in self.get_nodes(resource): + msg = Protocol.signal_message(name, args) + node.emit_write(msg) diff --git a/src/olink/ws/client.py b/src/olink/ws/client.py index 7772c3c..f92be59 100644 --- a/src/olink/ws/client.py +++ b/src/olink/ws/client.py @@ -2,36 +2,45 @@ import websockets as ws from olink.client import ClientNode + class Connection: - send_queue = asyncio.Queue() - recv_queue = asyncio.Queue() - node = None def __init__(self, node=ClientNode()): - self.node = node + self._node = node + self._send_queue = asyncio.Queue() + self._recv_queue = asyncio.Queue() + self._conn: ws.WebSocketClientProtocol = None def send(self, msg): - self.send_queue.put_nowait(msg) + self._send_queue.put_nowait(msg) async def handle_send(self): - async for msg in self.send_queue: + while self._conn is not None: + msg = await self._send_queue.get() data = self.serializer.serialize(msg) - await self.conn.send(data) + await self._conn.send(data) async def handle_recv(self): - async for msg in self.recv_queue: + while self._conn is not None: + msg = await self._recv_queue.get() self.emitter.emit(msg.object, msg) async def recv(self): - async for data in self.conn: + async for data in self._conn: msg = self.serializer.deserialize(data) - self.recv_queue.put_nowait(msg) + await self._recv_queue.put(msg) async def connect(self, addr: str): # connect to server async for conn in ws.connect(addr): - self.conn = conn - # start send and recv tasks - await asyncio.gather(self.handle_send(), self.handle_recv(), self.recv()) - # wait for all queues to be empty - await self.send_queue.join() - await self.recv_queue.join() \ No newline at end of file + try: + self._conn = conn + # start send and recv tasks + await asyncio.gather( + self.handle_send(), self.handle_recv(), self.recv() + ) + except ws.ConnectionClosed: + continue + + def cancel(self): + self._conn.close() + self._conn = None diff --git a/src/olink/ws/server.py b/src/olink/ws/server.py index e81b453..290d3bd 100644 --- a/src/olink/ws/server.py +++ b/src/olink/ws/server.py @@ -2,49 +2,64 @@ from typing import Any import asyncio from olink.remote import RemoteNode - +import logging class RemotePipe: - send_queue = asyncio.Queue() - recv_queue = asyncio.Queue() - node = RemoteNode() - def __init__(self, conn: ws.ClientConnection): - self.conn = conn - self.node.on_write(self._send) + def __init__(self, conn: ws.WebSocketServerProtocol): + self._conn = conn + self._node = RemoteNode() + self._node.on_write(self._send) + self._send_queue = asyncio.Queue() + self._recv_queue = asyncio.Queue() def _send(self, data): - self.send_queue.put_nowait(data) + self._send_queue.put_nowait(data) async def handle_send(self): - async for data in self.send_queue: - await self.conn.send(data) + while True: + data = await self._send_queue.get() + await self._conn.send(data) async def handle_recv(self): - async for data in self.recv_queue: - self.node.handle_message(data) + while True: + data = await self._recv_queue.get() + self._node.handle_message(data) async def recv(self): - async for data in self.conn: - self.recv_queue.put_nowait(data) + while True: + data = await self._conn.recv() + self._recv_queue.put_nowait(data) + + async def run(self): + await asyncio.gather( + self.handle_send(), + self.handle_recv(), + self.recv(), + ) + class Server: - pipes = [] - def handle_connection(self, pipe: ws.WebSocketServerProtocol, path: str): - pipe = RemotePipe(pipe, self.serializer) + def __init__(self) -> None: + self.pipes: list[RemotePipe] = [] + + async def handle_connection(self, socket: ws.WebSocketServerProtocol): + logging.info("New connection %s", socket) + pipe = RemotePipe(socket) self.pipes.append(pipe) + await pipe.run() async def serve(self, host: str, port: int): async with ws.serve(self.handle_connection, host, port): - await asyncio.Future() - - + await asyncio.Future() # run forever def run_server(host: str, port: int): + logging.info("Starting server on %s:%d", host, port) server = Server() asyncio.run(server.serve(host, port)) if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) run_server("localhost", 8152) diff --git a/src/olink/ws/test_client.py b/src/olink/ws/test_client.py new file mode 100644 index 0000000..8f8e0ec --- /dev/null +++ b/src/olink/ws/test_client.py @@ -0,0 +1,25 @@ +import pytest +import websockets as ws +import asyncio + +test_host = "localhost" +test_port = 8152 + + +async def setup_server( + done: asyncio.Future, queue: asyncio.Queue, host: str, port: int +): + async def handler(socket: ws.WebSocketServerProtocol): + async for msg in socket: + await queue.put(msg) + done.set_result(True) + + await ws.serve(handler, host, port) + await done + await ws.close() + + +@pytest.mark.asyncio +async def test_client_link(): + queue = asyncio.Queue() + await setup_server(queue, test_host, test_port) diff --git a/tests/test_clientnode.py b/tests/test_clientnode.py index 87cae39..46feb57 100644 --- a/tests/test_clientnode.py +++ b/tests/test_clientnode.py @@ -1,37 +1,67 @@ -from olink.clientnode import ClientNode +from olink.client import ClientNode, ClientRegistry from olink.mocks.mocksink import MockSink +import pytest -name = 'demo.Counter' -sink = MockSink(name) -client = ClientNode() -r = client.registry() +name = "demo.Counter" -def test_add_sink(): - ClientNode.register_sink(sink) - assert(r.get_sink(name) == sink) - assert(r.get_node(name) == None) +@pytest.fixture +def node(): # type: () -> ClientNode + sink = MockSink(name) + registry = ClientRegistry() + registry.add_sink(sink) + node = ClientNode(registry) + registry.add_node(name, node) + assert registry.get_sink(name) == sink + assert registry.get_node(name) == node + return node -def test_remove_sink(): - ClientNode.unregister_sink(sink) - assert(r.get_sink(name) == None) +@pytest.fixture +def registry(): # type: () -> ClientRegistry + name = "demo.Counter" + sink = MockSink(name) + registry = ClientRegistry() + registry.add_sink(sink) + return registry -def test_link_node_to_sink(): - assert(r.get_node(name) == None) - client.link_remote(name) - assert(r.get_node(name) == client) +def test_add_sink(node: ClientNode): + name = "demo.Counter" + registry = node.registry() + sink = registry.get_sink(name) + assert registry.get_sink(name) == sink + assert registry.get_node(name) == node + registry.remove_sink(sink) + assert registry.get_sink(name) == None + assert registry.get_node(name) == None -def test_unlink_node_from_sink(): - assert(r.get_node(name) == client) - client.unlink_remote(name) - assert(r.get_node(name) == None) +def test_remove_sink(node: ClientNode): + registry = node.registry() + sink = registry.get_sink(name) + registry.remove_sink(sink) + assert registry.get_sink(name) == None + assert registry.get_node(name) == None -def test_detach_node_from_all_sinks(): - client.link_remote(name) - assert(r.get_node(name) == client) - client.detach() - assert(r.get_node(name) == None) +def test_link_node_to_sink(node: ClientNode): + registry = node.registry() + assert registry.get_node(name) == node + node.link_remote(name) + assert registry.get_node(name) == node + + +def test_unlink_node_from_sink(node: ClientNode): + registry = node.registry() + assert registry.get_node(name) == node + node.unlink_remote(name) + assert registry.get_node(name) == None + + +def test_detach_node_from_all_sinks(node: ClientNode): + registry = node.registry() + node.link_remote(name) + assert registry.get_node(name) == node + node.detach() + assert registry.get_node(name) == None diff --git a/tests/test_comms.py b/tests/test_comms.py index 21982f7..6ed0134 100644 --- a/tests/test_comms.py +++ b/tests/test_comms.py @@ -1,75 +1,132 @@ -from olink.clientnode import ClientNode -from olink.remotenode import RemoteNode +from olink.client import ClientNode, ClientRegistry +from olink.remote import RemoteNode, RemoteRegistry from olink.mocks.mocksink import MockSink from olink.mocks.mocksource import MockSource +from typing import Tuple +import pytest -name = 'demo.Calc' -propName = 'demo.Calc/total' +name = "demo.Calc" +propName = "demo.Calc/total" propValue = 1 -invokeName = 'demo.Calc/add' +invokeName = "demo.Calc/add" invokeArgs = [1] -sigName = 'demo.Calc/down' +sigName = "demo.Calc/down" sigArgs = [5] -client = ClientNode() -remote = RemoteNode() -client.on_log(lambda level, msg: print(msg)) -remote.on_log(lambda level, msg: print(msg)) -client.on_write(lambda msg: remote.handle_message(msg)) -remote.on_write(lambda msg: client.handle_message(msg)) -sink = MockSink(name) -source = MockSource(name) -RemoteNode.register_source(source) +# remote_registry = RemoteRegistry() +# client_registry = ClientRegistry() +# client = ClientNode(client_registry) +# remote = RemoteNode(remote_registry) +# client.on_log(lambda level, msg: print(msg)) +# remote.on_log(lambda level, msg: print(msg)) +# client.on_write(lambda msg: remote.handle_message(msg)) +# remote.on_write(lambda msg: client.handle_message(msg)) +# sink = MockSink(name) +# source = MockSource(name) +# remote_registry.add_source(source) -def reset(): - sink.clear() - source.clear() +@pytest.fixture +def client(): + registry = ClientRegistry() + client = ClientNode(registry) + client.on_log(lambda level, msg: print(msg)) + sink = MockSink(name) + registry.add_sink(sink) + registry.add_node(name, client) + assert registry.get_sink(name) == sink + assert registry.get_node(name) == client + return client -def test_client_link(): +@pytest.fixture +def remote(): + registry = RemoteRegistry() + remote = RemoteNode(registry) + remote.on_log(lambda level, msg: print(msg)) + source = MockSource(name) + registry.add_source(source) + registry.add_node(name, remote) + assert registry.get_source(name) == source + assert len(registry.get_nodes(name)) == 1 + assert remote in registry.get_nodes(name) + return remote + + +@pytest.fixture +def conn(client, remote): + client.on_write(lambda msg: remote.handle_message(msg)) + remote.on_write(lambda msg: client.handle_message(msg)) + print("conn: remote node", hash(remote)) + return client, remote + + +def test_client_link(conn: Tuple[ClientNode, RemoteNode]): + client, _ = conn client.detach() - assert client.registry().get_node(name) == None + registry = client.registry() + sink = registry.get_sink(name) # type: MockSink + assert registry.get_node(name) == None client.link_remote(name) - assert client.registry().get_node(name) == client + assert registry.get_node(name) == client assert len(sink.events) == 1 - assert sink.events[0] == {'type': 'init', 'name': name, 'props': {}} + assert sink.events[0] == {"type": "init", "name": name, "props": {}} -def test_client_set_property(): - reset() +def test_client_set_property(conn: Tuple[ClientNode, RemoteNode]): + client, _ = conn + registry = client.registry() + sink = registry.get_sink(name) # type: MockSink client.link_remote(name) assert len(sink.events) == 1 client.set_remote_property(propName, propValue) assert len(sink.events) == 2 assert sink.events[1] == { - 'type': 'property_change', 'name': propName, 'value': propValue} + "type": "property_change", + "name": propName, + "value": propValue, + } -def test_client_invoke(): - reset() +def test_client_invoke(conn: Tuple[ClientNode, RemoteNode]): + client, _ = conn + registry = client.registry() + sink = registry.get_sink(name) # type: MockSink client.link_remote(name) assert len(sink.events) == 1 sink.invoke(invokeName, invokeArgs) assert len(sink.events) == 2 - assert sink.events[1] == {'type': 'invoke-reply', - 'name': invokeName, 'value': invokeName} + assert sink.events[1] == { + "type": "invoke-reply", + "name": invokeName, + "value": invokeName, + } -def test_remote_signal(): - reset() +def test_remote_signal(conn: Tuple[ClientNode, RemoteNode]): + client, remote = conn client.link_remote(name) + client_registry = client.registry() + sink = client_registry.get_sink(name) # type: MockSink assert len(sink.events) == 1 + remote_registry = remote.registry() + source = remote_registry.get_source(name) # type: MockSource source.notify_signal(sigName, sigArgs) assert len(sink.events) == 2 - assert sink.events[1] == {'type': 'signal', - 'name': sigName, 'args': sigArgs} + assert sink.events[1] == {"type": "signal", "name": sigName, "args": sigArgs} -def test_remote_set_property(): - reset() +def test_remote_set_property(conn: Tuple[ClientNode, RemoteNode]): + client, remote = conn + client_registry = client.registry() + sink = client_registry.get_sink(name) # type: MockSink + remote_registry = remote.registry() + source = remote_registry.get_source(name) # type: MockSource client.link_remote(name) assert len(sink.events) == 1 source.set_property(propName, propValue) assert len(sink.events) == 2 assert sink.events[1] == { - 'type': 'property_change', 'name': propName, 'value': propValue} + "type": "property_change", + "name": propName, + "value": propValue, + } diff --git a/tests/test_protocol.py b/tests/test_protocol.py index d29cbac..243171b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,8 +1,8 @@ from olink.core.protocol import Protocol from olink.core.types import MsgType -name = 'demo.Calc' -props = {'count': 1} +name = "demo.Calc" +props = {"count": 1} value = 1 id = 1 args = [1, 2] @@ -23,7 +23,7 @@ def test_messages(): msg = Protocol.set_property_message(name, value) assert msg == [MsgType.SET_PROPERTY, name, value] - msg = Protocol.property_change_message(name, value) + msg = Protocol.property_changed_message(name, value) assert msg == [MsgType.PROPERTY_CHANGE, name, value] msg = Protocol.invoke_message(id, name, args) diff --git a/tests/test_remotenode.py b/tests/test_remotenode.py index 51246fd..b9ad080 100644 --- a/tests/test_remotenode.py +++ b/tests/test_remotenode.py @@ -1,60 +1,60 @@ -from olink.remotenode import RemoteNode, get_remote_registry +from olink.remote import RemoteNode, RemoteRegistry from olink.mocks.mocksource import MockSource -name = 'demo.Counter' +name = "demo.Counter" source = MockSource(name) -remote = RemoteNode() -remote_registry = get_remote_registry() +registry = RemoteRegistry() +remote = RemoteNode(registry) def reset(): source.clear() - remote_registry.clear() + registry.clear() def test_add_source(): reset() - assert len(remote_registry.get_nodes(name)) == 0 - remote.register_source(source) - assert remote_registry.get_source(name) == source - assert len(remote_registry.get_nodes(name)) == 0 + assert len(registry.get_nodes(name)) == 0 + registry.add_source(source) + assert registry.get_source(name) == source + assert len(registry.get_nodes(name)) == 0 def test_remove_source(): reset() - RemoteNode.register_source(source) - assert remote_registry.get_source(name) == source - RemoteNode.unregister_source(source) - assert(remote_registry.get_source(name) == None) + registry.add_source(source) + assert registry.get_source(name) == source + registry.remove_source(source) + assert registry.get_source(name) == None def test_link_node_to_source(): reset() - RemoteNode.register_source(source) - assert remote_registry.get_nodes(name) == set() - remote_registry.add_node_to_source(name, remote) - assert remote_registry.get_nodes(name) == {remote} + registry.add_source(source) + assert registry.get_nodes(name) == set() + registry.add_node(name, remote) + assert registry.get_nodes(name) == {remote} def test_unlink_node_from_source(): reset() - RemoteNode.register_source(source) - remote_registry.add_node_to_source(name, remote) - assert remote_registry.get_nodes(name) == set([remote]) - remote_registry.remove_node_from_source(name, remote) - assert remote_registry.get_nodes(name) == set() + registry.add_source(source) + registry.add_node(name, remote) + assert remote in registry.get_nodes(name) + registry.remove_node_from_source(name, remote) + assert registry.get_nodes(name) == set() def test_detach_node_from_all_sources(): reset() - RemoteNode.register_source(source) - remote_registry.add_node_to_source(name, remote) - assert remote_registry.get_nodes(name) == set([remote]) + registry.add_source(source) + registry.add_node(name, remote) + assert registry.get_nodes(name) == set([remote]) remote.detach() - assert remote_registry.get_nodes(name) == set() + assert registry.get_nodes(name) == set() def test_get_registry(): reset() reg = remote.registry() - assert reg == remote_registry + assert reg == registry From aa014341152a108f06520a502de5182b4aa900d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Ryannel?= Date: Wed, 19 Apr 2023 16:43:59 +0200 Subject: [PATCH 4/5] not working. client awaits forever --- Taskfile.yml | 9 ++++ src/olink/client/node.py | 2 + src/olink/client/registry.py | 3 ++ src/olink/ws/__init__.py | 4 +- src/olink/ws/client.py | 76 +++++++++++++++++------------ src/olink/ws/server.py | 95 ++++++++++++++++++++++++------------ src/olink/ws/test_client.py | 1 + tests/test_ws_client.py | 51 +++++++++++++++++++ 8 files changed, 175 insertions(+), 66 deletions(-) create mode 100644 Taskfile.yml create mode 100644 tests/test_ws_client.py diff --git a/Taskfile.yml b/Taskfile.yml new file mode 100644 index 0000000..95eaa04 --- /dev/null +++ b/Taskfile.yml @@ -0,0 +1,9 @@ +version: 3 + +tasks: + test::debug: + cmds: + - pytest -s --log-cli-level=DEBUG ${@:1} + test: + cmds: + - pytest -s --log-cli-level=INFO ${@:1} diff --git a/src/olink/client/node.py b/src/olink/client/node.py index 6e36cd6..c8f1c94 100644 --- a/src/olink/client/node.py +++ b/src/olink/client/node.py @@ -3,6 +3,7 @@ from olink.core import LogLevel, MsgType, BaseNode, Protocol from .types import IObjectSink from .registry import ClientRegistry +import logging class InvokeReplyArg: @@ -47,6 +48,7 @@ def set_remote_property(self, name: str, value: Any) -> None: def link_node(self, name: str): # register this node to sink + logging.debug(f"ClientNode.linkNode: {name}") self.registry().add_node(name, self) def unlink_node(self, name: str) -> None: diff --git a/src/olink/client/registry.py b/src/olink/client/registry.py index ad34849..45cfb71 100644 --- a/src/olink/client/registry.py +++ b/src/olink/client/registry.py @@ -1,6 +1,7 @@ from .types import IObjectSink from olink.core import Name, Base, LogLevel from typing import Optional, TYPE_CHECKING +import logging if TYPE_CHECKING: from .node import ClientNode @@ -15,9 +16,11 @@ def __init__(self, sink=None): class ClientRegistry(Base): def __init__(self) -> None: super().__init__() + logging.debug("ClientRegistry.__init__") self.entries: dict[str, SinkToClientEntry] = {} def remove_node(self, node: "ClientNode"): + logging.debug("ClientRegistry.removeNode") # remove node from all sinks for entry in self.entries.values(): if entry.node is node: diff --git a/src/olink/ws/__init__.py b/src/olink/ws/__init__.py index fcafd5d..d9583aa 100644 --- a/src/olink/ws/__init__.py +++ b/src/olink/ws/__init__.py @@ -1,2 +1,2 @@ -from .client import Connection -from .server import Server, run_server \ No newline at end of file +from .client import Connection as Connection +from .server import Server as Server, run_server as run_server diff --git a/src/olink/ws/client.py b/src/olink/ws/client.py index f92be59..7670160 100644 --- a/src/olink/ws/client.py +++ b/src/olink/ws/client.py @@ -1,46 +1,58 @@ import asyncio -import websockets as ws +import websockets from olink.client import ClientNode +from olink.core import EventHook +import logging class Connection: - def __init__(self, node=ClientNode()): + def __init__(self, cancel: asyncio.Future, node: ClientNode): + self._cancel = cancel self._node = node + self._conn: websockets.WebSocketClientProtocol = None self._send_queue = asyncio.Queue() - self._recv_queue = asyncio.Queue() - self._conn: ws.WebSocketClientProtocol = None + self.on_recv = EventHook() - def send(self, msg): - self._send_queue.put_nowait(msg) + async def async_send(self, data): + await self._conn.send(data) - async def handle_send(self): - while self._conn is not None: + def send(self, data): + self._send_queue.put_nowait(data) + + async def _sender(self): + while True: msg = await self._send_queue.get() - data = self.serializer.serialize(msg) - await self._conn.send(data) - - async def handle_recv(self): - while self._conn is not None: - msg = await self._recv_queue.get() - self.emitter.emit(msg.object, msg) - - async def recv(self): - async for data in self._conn: - msg = self.serializer.deserialize(data) - await self._recv_queue.put(msg) - - async def connect(self, addr: str): - # connect to server - async for conn in ws.connect(addr): - try: - self._conn = conn - # start send and recv tasks - await asyncio.gather( - self.handle_send(), self.handle_recv(), self.recv() - ) - except ws.ConnectionClosed: - continue + if self._cancel.done() and self._send_queue.empty(): + await self._conn.close() + break + await self._conn.send(msg) + + async def _receiver(self): + try: + async for data in self._conn: + self.on_recv(data) + except websockets.ConnectionClosed: + pass + + async def connect(self, addr: str, done: asyncio.Future): + async with websockets.connect(addr) as conn: + logging.info("client connected") + self._conn = conn + receiver = asyncio.create_task(self._receiver()) + sender = asyncio.create_task(self._sender()) + logging.info("await client receiver, sender") + _, pending = await asyncio.wait( + [receiver, sender], return_when=asyncio.FIRST_COMPLETED + ) + logging.info("client receiver, sender done") + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass def cancel(self): + self._cancel.set_result(True) self._conn.close() self._conn = None diff --git a/src/olink/ws/server.py b/src/olink/ws/server.py index 290d3bd..7a2e6c6 100644 --- a/src/olink/ws/server.py +++ b/src/olink/ws/server.py @@ -1,65 +1,96 @@ -import websockets as ws +import websockets from typing import Any import asyncio -from olink.remote import RemoteNode +from olink.remote import RemoteNode, RemoteRegistry import logging class RemotePipe: - def __init__(self, conn: ws.WebSocketServerProtocol): + def __init__( + self, + cancel: asyncio.Future, + node: RemoteNode, + conn: websockets.WebSocketServerProtocol, + ): + self._cancel = cancel self._conn = conn - self._node = RemoteNode() + self._node = node self._node.on_write(self._send) self._send_queue = asyncio.Queue() - self._recv_queue = asyncio.Queue() def _send(self, data): self._send_queue.put_nowait(data) - async def handle_send(self): + async def _sender(self): while True: - data = await self._send_queue.get() - await self._conn.send(data) + msg = await self._send_queue.get() + if self._cancel.done() and self._send_queue.empty(): + await self._conn.close() + break + await self._conn.send(msg) - async def handle_recv(self): - while True: - data = await self._recv_queue.get() - self._node.handle_message(data) - - async def recv(self): - while True: - data = await self._conn.recv() - self._recv_queue.put_nowait(data) + async def _receiver(self): + try: + async for msg in self._conn: + self._node.handle_message(msg) + except websockets.ConnectionClosed: + logging.info("Connection closed") + pass async def run(self): - await asyncio.gather( - self.handle_send(), - self.handle_recv(), - self.recv(), + receiver = asyncio.create_task(self._receiver()) + sender = asyncio.create_task(self._sender()) + _, pending = await asyncio.wait( + [receiver, sender], return_when=asyncio.FIRST_COMPLETED ) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + logging.info("Task cancelled") + pass + + def cancel(self): + if self._cancel: + self._cancel.set_result(True) class Server: - def __init__(self) -> None: + def __init__(self, cancel: asyncio.Future) -> None: + logging.info("server init") + self._cancel = cancel self.pipes: list[RemotePipe] = [] + self._registry = RemoteRegistry() - async def handle_connection(self, socket: ws.WebSocketServerProtocol): - logging.info("New connection %s", socket) - pipe = RemotePipe(socket) + async def _handler(self, conn: websockets.WebSocketServerProtocol): + logging.info("server handle new connection %s", conn) + node = RemoteNode(self._registry) + pipe = RemotePipe(self._cancel, node, conn) self.pipes.append(pipe) await pipe.run() async def serve(self, host: str, port: int): - async with ws.serve(self.handle_connection, host, port): - await asyncio.Future() # run forever + logging.info("server serve") + async with websockets.serve(self._handler, host, port): + await self._cancel + + async def cancel(self): + logging.info("server cancel") + for pipe in self.pipes: + pipe.cancel() + self._cancel.set_result(True) -def run_server(host: str, port: int): - logging.info("Starting server on %s:%d", host, port) - server = Server() - asyncio.run(server.serve(host, port)) +async def run_server(cancel: asyncio.Future, host: str, port: int): + logging.info("run server on %s:%d", host, port) + server = Server(cancel) + logging.info("await serve") + await server.serve(host, port) + print("run_server done") + server.cancel() if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG) - run_server("localhost", 8152) + asyncio.run(run_server("localhost", 8152)) diff --git a/src/olink/ws/test_client.py b/src/olink/ws/test_client.py index 8f8e0ec..5b84cbb 100644 --- a/src/olink/ws/test_client.py +++ b/src/olink/ws/test_client.py @@ -21,5 +21,6 @@ async def handler(socket: ws.WebSocketServerProtocol): @pytest.mark.asyncio async def test_client_link(): + assert False queue = asyncio.Queue() await setup_server(queue, test_host, test_port) diff --git a/tests/test_ws_client.py b/tests/test_ws_client.py new file mode 100644 index 0000000..1a9b676 --- /dev/null +++ b/tests/test_ws_client.py @@ -0,0 +1,51 @@ +import asyncio +import pytest +from olink.ws import Connection, run_server +from olink.client import ClientNode, ClientRegistry +import logging + +test_host = "localhost" +test_port = 8152 +test_url = f"ws://{test_host}:{test_port}" +object_id = "demo.Calc" + + +async def delay(coro, delay: float): + await asyncio.sleep(delay) + await coro + + +@pytest.mark.asyncio +async def test_client_link(): + logging.info("test_client_link") + cancel = asyncio.Future() + server_task = asyncio.create_task(run_server(cancel, test_host, test_port)) + logging.info("server_task %s", server_task) + + async def run_client(cancel: asyncio.Future): + logging.info("run_client") + registry = ClientRegistry() + node = ClientNode(registry) + conn = Connection(cancel, node) + node.on_write(conn.send) + conn.on_recv += node.handle_message + logging.info("connecting to %s", test_url) + await conn.connect(test_url, cancel) + logging.info("connected") + node.link_node(object_id) + logging.info("client done") + + client_task = asyncio.create_task(delay(run_client(cancel), 0.5)) + logging.info("client_task %s", client_task) + + done, pending = await asyncio.wait( + [server_task, client_task], return_when=asyncio.FIRST_COMPLETED + ) + logging.info("done %s, pending %s", done, pending) + print(done, pending) + for task in pending: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass From 0ab2c675792ae7cb2dceeea45ef3b75649b15b4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BCrgen=20Ryannel?= Date: Thu, 20 Apr 2023 16:12:54 +0200 Subject: [PATCH 5/5] wip: not working --- src/olink/ws/__init__.py | 4 +- src/olink/ws/client.py | 58 ---------------------- src/olink/ws/conn.py | 79 ++++++++++++++++++++++++++++++ src/olink/ws/server.py | 101 +++++++++++---------------------------- tests/test_ws_client.py | 55 ++++++++++----------- 5 files changed, 135 insertions(+), 162 deletions(-) delete mode 100644 src/olink/ws/client.py create mode 100644 src/olink/ws/conn.py diff --git a/src/olink/ws/__init__.py b/src/olink/ws/__init__.py index d9583aa..5341ff3 100644 --- a/src/olink/ws/__init__.py +++ b/src/olink/ws/__init__.py @@ -1,2 +1,2 @@ -from .client import Connection as Connection -from .server import Server as Server, run_server as run_server +from .conn import Connection as Connection +from .server import Server as Server diff --git a/src/olink/ws/client.py b/src/olink/ws/client.py deleted file mode 100644 index 7670160..0000000 --- a/src/olink/ws/client.py +++ /dev/null @@ -1,58 +0,0 @@ -import asyncio -import websockets -from olink.client import ClientNode -from olink.core import EventHook -import logging - - -class Connection: - def __init__(self, cancel: asyncio.Future, node: ClientNode): - self._cancel = cancel - self._node = node - self._conn: websockets.WebSocketClientProtocol = None - self._send_queue = asyncio.Queue() - self.on_recv = EventHook() - - async def async_send(self, data): - await self._conn.send(data) - - def send(self, data): - self._send_queue.put_nowait(data) - - async def _sender(self): - while True: - msg = await self._send_queue.get() - if self._cancel.done() and self._send_queue.empty(): - await self._conn.close() - break - await self._conn.send(msg) - - async def _receiver(self): - try: - async for data in self._conn: - self.on_recv(data) - except websockets.ConnectionClosed: - pass - - async def connect(self, addr: str, done: asyncio.Future): - async with websockets.connect(addr) as conn: - logging.info("client connected") - self._conn = conn - receiver = asyncio.create_task(self._receiver()) - sender = asyncio.create_task(self._sender()) - logging.info("await client receiver, sender") - _, pending = await asyncio.wait( - [receiver, sender], return_when=asyncio.FIRST_COMPLETED - ) - logging.info("client receiver, sender done") - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass - - def cancel(self): - self._cancel.set_result(True) - self._conn.close() - self._conn = None diff --git a/src/olink/ws/conn.py b/src/olink/ws/conn.py new file mode 100644 index 0000000..4a431fc --- /dev/null +++ b/src/olink/ws/conn.py @@ -0,0 +1,79 @@ +import asyncio +from pyee.asyncio import AsyncIOEventEmitter +import logging +import websockets as ws + +log = logging.getLogger(__name__) + + +class Connection(AsyncIOEventEmitter): + def __init__( + self, cancel: asyncio.Event, socket: ws.WebSocketServerProtocol = None + ): + super().__init__() + self._cancel = cancel + self._socket = socket + self._send_queue = asyncio.Queue() + self._receiver_task = None + self._sender_task = None + if socket: # we are server + log.info("server conn init") + self._receiver_task = asyncio.create_task( + self._receiver(), name="server conn receiver" + ) + self._sender_task = asyncio.create_task( + self._sender(), name="server conn sender" + ) + + async def connect(self, addr: str): + log.info("client connected") + async with ws.connect(addr) as self._socket: + self._receiver_task = asyncio.create_task( + self._receiver(), name="client receiver" + ) + self._sender_task = asyncio.create_task( + self._sender(), name="client sender" + ) + + async def _sender(self): + while not self._cancel.is_set(): + try: + msg = await self._send_queue.get() + if not msg: + log.info("sender got None, closing") + break + await self._socket.send(msg) + self._send_queue.task_done() + except Exception as e: + log.info("connection sender closing: %s", e) + break + log.info("sender done") + + async def _receiver(self): + while not self._cancel.is_set(): + try: + msg = await self._socket.recv() + self.emit("message", msg) + except Exception: + log.info("Connection closed") + break + log.info("receiver done") + + def send(self, data): + self._send_queue.put_nowait(data) + + async def send_async(self, data): + await self._send_queue.put(data) + + async def cancel(self): + self._cancel.set() + if self._socket: + await self._socket.close() + self._socket = None + self._send_queue.put_nowait(None) + if self._receiver_task: + await self._receiver_task + if self._sender_task: + await self._sender_task + self._receiver_task = None + self._sender_task = None diff --git a/src/olink/ws/server.py b/src/olink/ws/server.py index 7a2e6c6..af3b6fd 100644 --- a/src/olink/ws/server.py +++ b/src/olink/ws/server.py @@ -1,96 +1,51 @@ -import websockets +import websockets as ws from typing import Any import asyncio from olink.remote import RemoteNode, RemoteRegistry import logging +from .conn import Connection -class RemotePipe: - def __init__( - self, - cancel: asyncio.Future, - node: RemoteNode, - conn: websockets.WebSocketServerProtocol, - ): - self._cancel = cancel - self._conn = conn - self._node = node - self._node.on_write(self._send) - self._send_queue = asyncio.Queue() - - def _send(self, data): - self._send_queue.put_nowait(data) +class RemoteHandler: + def __init__(self, conn: Connection, node: RemoteNode) -> None: + self.conn = conn + self.node = node + self.conn.on("message", self.on_recv_message) + self.node.on_write(self.on_send_message) - async def _sender(self): - while True: - msg = await self._send_queue.get() - if self._cancel.done() and self._send_queue.empty(): - await self._conn.close() - break - await self._conn.send(msg) + def on_recv_message(self, msg: str): + self.node.handle_message(msg) - async def _receiver(self): - try: - async for msg in self._conn: - self._node.handle_message(msg) - except websockets.ConnectionClosed: - logging.info("Connection closed") - pass + def on_send_message(self, msg: str): + self.conn.send(msg) - async def run(self): - receiver = asyncio.create_task(self._receiver()) - sender = asyncio.create_task(self._sender()) - _, pending = await asyncio.wait( - [receiver, sender], return_when=asyncio.FIRST_COMPLETED - ) - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - logging.info("Task cancelled") - pass - - def cancel(self): - if self._cancel: - self._cancel.set_result(True) + async def cancel(self): + await self.conn.cancel() + self.node.detach() class Server: - def __init__(self, cancel: asyncio.Future) -> None: + def __init__(self, cancel: asyncio.Event) -> None: logging.info("server init") self._cancel = cancel - self.pipes: list[RemotePipe] = [] + self._handlers: list[RemoteHandler] = [] self._registry = RemoteRegistry() - async def _handler(self, conn: websockets.WebSocketServerProtocol): - logging.info("server handle new connection %s", conn) + async def _handler(self, socket: ws.WebSocketServerProtocol): + logging.info("server handle new connection %s", socket) node = RemoteNode(self._registry) - pipe = RemotePipe(self._cancel, node, conn) - self.pipes.append(pipe) - await pipe.run() + conn = Connection(self._cancel, socket) + handler = RemoteHandler(conn, node) + self._handlers.append(handler) async def serve(self, host: str, port: int): logging.info("server serve") - async with websockets.serve(self._handler, host, port): - await self._cancel + async with ws.serve(self._handler, host, port): + await self._cancel.wait() + logging.info("server cancel wait done") async def cancel(self): logging.info("server cancel") - for pipe in self.pipes: - pipe.cancel() - self._cancel.set_result(True) - - -async def run_server(cancel: asyncio.Future, host: str, port: int): - logging.info("run server on %s:%d", host, port) - server = Server(cancel) - logging.info("await serve") - await server.serve(host, port) - print("run_server done") - server.cancel() - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - asyncio.run(run_server("localhost", 8152)) + for handler in self._handlers: + await handler.cancel() + self._cancel.set() diff --git a/tests/test_ws_client.py b/tests/test_ws_client.py index 1a9b676..54d421f 100644 --- a/tests/test_ws_client.py +++ b/tests/test_ws_client.py @@ -1,6 +1,6 @@ import asyncio import pytest -from olink.ws import Connection, run_server +from olink.ws import Connection, Server from olink.client import ClientNode, ClientRegistry import logging @@ -15,37 +15,34 @@ async def delay(coro, delay: float): await coro +async def delay_cancel(cancel: asyncio.Future, delay: float): + await asyncio.sleep(delay) + if not cancel.is_set(): + cancel.set() + + @pytest.mark.asyncio async def test_client_link(): logging.info("test_client_link") - cancel = asyncio.Future() - server_task = asyncio.create_task(run_server(cancel, test_host, test_port)) - logging.info("server_task %s", server_task) + cancel = asyncio.Event() + server = Server(cancel) + server_task = asyncio.create_task(server.serve(test_host, test_port), name="server") + + client_registry = ClientRegistry() + client_node = ClientNode(client_registry) + conn = Connection(cancel) + client_node.on_write(conn.send) + conn.on("message", client_node.handle_message) - async def run_client(cancel: asyncio.Future): + async def run_client(conn: Connection, cancel: asyncio.Event): logging.info("run_client") - registry = ClientRegistry() - node = ClientNode(registry) - conn = Connection(cancel, node) - node.on_write(conn.send) - conn.on_recv += node.handle_message logging.info("connecting to %s", test_url) - await conn.connect(test_url, cancel) - logging.info("connected") - node.link_node(object_id) - logging.info("client done") - - client_task = asyncio.create_task(delay(run_client(cancel), 0.5)) - logging.info("client_task %s", client_task) - - done, pending = await asyncio.wait( - [server_task, client_task], return_when=asyncio.FIRST_COMPLETED - ) - logging.info("done %s, pending %s", done, pending) - print(done, pending) - for task in pending: - task.cancel() - try: - await task - except asyncio.CancelledError: - pass + await conn.connect(test_url) + await cancel.wait() + + await asyncio.sleep(1) + client_task = asyncio.create_task(run_client(conn, cancel), name="client") + await delay_cancel(cancel, 0) + await server.cancel() + await conn.cancel() + await asyncio.wait([server_task, client_task])