diff --git a/dimos/core/__init__.py b/dimos/core/__init__.py new file mode 100644 index 0000000000..95f2ec3401 --- /dev/null +++ b/dimos/core/__init__.py @@ -0,0 +1,2 @@ +from dimos.core.base import In, Module, Out, RemoteIn, RemoteOut, module, rpc +from dimos.core.dimosdask import initialize diff --git a/dimos/core/base.py b/dimos/core/base.py new file mode 100644 index 0000000000..93f7adae4d --- /dev/null +++ b/dimos/core/base.py @@ -0,0 +1,497 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Core building blocks for the *actors2* graph-stream framework. + +Public surface +-------------- +• In[T], Out[T] – local data streams +• RemoteIn[T], RemoteOut[T] – cross-process proxies +• Module – user subclass that represents a logical unit +• @module – decorator that wires IO descriptors and RPCs +• @rpc – tag to mark remotable methods +• State – simple lifecycle enum (for pretty printing only) +""" + +from __future__ import annotations + +import enum +import inspect +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Protocol, + TypeVar, + get_args, + get_origin, + get_type_hints, +) + +from distributed.actor import Actor # only imported for type-checking + +from dimos.core import colors +from dimos.core.o3dpickle import register_picklers + +register_picklers() + +T = TypeVar("T") + + +# Goals +# ---------- +# streams should be able to know: +# - if they are in or out +# - which actor is their owner +# - streams can implement their own transport (how does this work?) +# - if they are connected to another stream, know all of the above for it +# +# Usage within actor +# ------------------ +# LocalIn.subscribe(print) +# LocalOut.publish("hello") +# +# Usage from outside +# myActor.inputStream.connect(otherActor.outputStream) +# myActor.outputStream.connect(otherActor.inputStream) +# + + +# --------------------------------------------------------------------------- +# Helper decorators +# --------------------------------------------------------------------------- + + +def rpc(fn: Callable[..., Any]) -> Callable[..., Any]: + """Mark *fn* as remotely callable.""" + + fn.__rpc__ = True # type: ignore[attr-defined] + return fn + + +# --------------------------------------------------------------------------- +# Protocols (work in progress) +# --------------------------------------------------------------------------- + + +class MultiprocessingProtocol(Protocol): + def deploy(self, target): ... + + +class TransportProtocol(Protocol[T]): + def broadcast(self, selfstream: Out, value: T): ... + + +class DirectTransportProtocol(Protocol[T]): + def direct_msg(self, selfstream: Out, target: RemoteIn, value: T) -> None: ... + + +Transport = TransportProtocol | DirectTransportProtocol + + +class DaskTransport(DirectTransportProtocol): + def msg(self, selfstream: Out[T], target: RemoteIn[T], value: T) -> None: ... + + +daskTransport = DaskTransport() # singleton instance for use in Out/RemoteOut + + +# --------------------------------------------------------------------------- +# Stream primitives +# --------------------------------------------------------------------------- + + +class State(enum.Enum): + DORMANT = "dormant" # descriptor defined but not bound + READY = "ready" # bound to owner but not yet connected + CONNECTED = "connected" # input bound to an output + FLOWING = "flowing" # runtime: data observed + + +class Stream(Generic[T]): + """Base class shared by *In* and *Out* streams.""" + + transport: Transport = daskTransport # default transport + + def __init__(self, typ: type[T], name: str, transport: Transport = None): + self.type: type[T] = typ + self.name: str = name + if transport: + self.transport = transport + + # ------------------------------------------------------------------ + # Descriptor plumbing – auto-fill name when used as class attr + # ------------------------------------------------------------------ + def __set_name__(self, owner: type, attr_name: str) -> None: # noqa: D401 + if not getattr(self, "name", ""): + self.name = attr_name + + # ------------------------------------------------------------------ + # String helpers ---------------------------------------------------- + # ------------------------------------------------------------------ + @property + def type_name(self) -> str: + return getattr(self.type, "__name__", repr(self.type)) + + def _color_fn(self) -> Callable[[str], str]: + if self.state == State.DORMANT: + return colors.orange + if self.state == State.READY: + return colors.blue + if self.state == State.CONNECTED: + return colors.green + return lambda s: s + + def __str__(self) -> str: # noqa: D401 + return self._color_fn()(f"{self.name}[{self.type_name}]") + + # ------------------------------------------------------------------ + # Lifecycle – subclasses implement .state + # ------------------------------------------------------------------ + @property + def state(self) -> State: # pragma: no cover – abstract + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Outputs (producers) +# --------------------------------------------------------------------------- + + +class BaseOut(Stream[T]): + """Common behaviour shared by *local* and *remote* outputs.""" + + def __init__(self, typ: type[T], name: str = "Out", owner: Any | None = None, **kwargs): + super().__init__(typ, name, **kwargs) + self.owner: Any | None = owner + + @property + def state(self) -> State: # noqa: D401 + return State.DORMANT if self.owner is None else State.READY + + # API surface ------------------------------------------------------- + def publish(self, value: T) -> None: # pragma: no cover – abstract + raise NotImplementedError + + def subscribe(self, inp: "In[T]") -> None: # pragma: no cover – abstract + raise NotImplementedError + + def __str__(self) -> str: # noqa: D401 + return ( + self.__class__.__name__ + + " " + + self._color_fn()(f"{self.name}[{self.type_name}]") + + " @ " + + str(self.owner) + ) + + +class Out(BaseOut[T]): + """Local *Out* – synchronous fan-out to subscribers.""" + + transport: Transport = daskTransport + + def __init__(self, typ: type[T], name: str = "Out", owner: Any | None = None): + super().__init__(typ, name, owner) + self._subscribers: List[In[T]] = [] + + def publish(self, value: T) -> None: # noqa: D401 + """Send *value* to all subscribers. + + • Local `In` → direct callback dispatch via ``_receive`` + • Remote `In` (its ``owner`` is a *distributed.Actor*) → perform a + synchronous RPC so the receiving process can enqueue the message. + """ + for inp in list(self._subscribers): + owner = getattr(inp, "owner", None) + + if isinstance(owner, Actor): + # Cross-process: schedule RPC on remote actor. + try: + getattr(owner, "receive_msg")(inp.name, value) # type: ignore[misc] + except Exception: # pylint: disable=broad-except + continue # swallow network issues during shutdown + else: + # In-process delivery. + inp._receive(value) + + def subscribe(self, inp: "In[T]") -> None: # noqa: D401 + if inp not in self._subscribers: + self._subscribers.append(inp) + + def __reduce__(self): # noqa: D401 + # if self.owner is None or not hasattr(self.owner, "ref"): + # raise ValueError(f"{self} Cannot serialise Out without an owner ref") + return ( + RemoteOut, + (self.type, self.name, self.owner.ref if hasattr(self.owner, "ref") else None), + ) + + +class RemoteOut(BaseOut[T]): + """Proxy for an *Out* that lives on a remote *distributed.Actor*.""" + + def __init__(self, typ: type[T], name: str, owner: Actor | None = None): + super().__init__(typ, name, owner) + + def subscribe(self, inp: "In[T]") -> None: # noqa: D401 + if self.owner is None: + raise RuntimeError("RemoteOut has no associated Actor; cannot subscribe") + fut = self.owner.subscribe(self.name, inp) + try: + fut.result() + except AttributeError: + pass # non-future – best effort + + +# --------------------------------------------------------------------------- +# Inputs (consumers) +# --------------------------------------------------------------------------- + + +class In(Stream[T]): + """Local *In* – pull side of the data flow.""" + + def __init__( + self, + typ: type[T], + name: str = "In", + owner: Any | None = None, + source: BaseOut[T] | None = None, + ) -> None: + super().__init__(typ, name) + self.owner: Any | None = owner + self.source: BaseOut[T] | None = source + self._callbacks: List[Callable[[T], None]] = [] + + # ------------------------------------------------------------------ + # Introspection helpers + # ------------------------------------------------------------------ + @property + def state(self) -> State: # noqa: D401 + return State.CONNECTED if self.source else State.DORMANT + + def __str__(self) -> str: # noqa: D401 + if self.state == State.CONNECTED and self.source is not None: + return f"IN {super().__str__()} <- {self.source}" + return f"IN {super().__str__()}" + + # ------------------------------------------------------------------ + # Connectivity API + # ------------------------------------------------------------------ + def bind(self, out_stream: BaseOut[T]) -> None: + if self.source is not None: + raise RuntimeError("Input already connected") + self.source = out_stream + out_stream.subscribe(self) + + # Backwards-compat alias + connect = bind # type: ignore[attr-defined] + + def subscribe(self, callback: Callable[[T], None]) -> None: # noqa: D401 + if self.source is None: + raise ValueError("Cannot subscribe to an unconnected In stream") + if not self._callbacks: + self.source.subscribe(self) + self._callbacks.append(callback) + + # ------------------------------------------------------------------ + # Internal helper – called by Out.publish + # ------------------------------------------------------------------ + def _receive(self, value: T) -> None: + for cb in list(self._callbacks): + cb(value) + + # ------------------------------------------------------------------ + # Pickling – becomes RemoteIn on the other side + # ------------------------------------------------------------------ + def __reduce__(self): # noqa: D401 + if self.owner is None or not hasattr(self.owner, "ref"): + raise ValueError("Cannot serialise In without an owner ref") + return (RemoteIn, (self.type, self.name, self.owner.ref)) + + +class RemoteIn(In[T]): + """Proxy for an *In* that lives on a remote actor.""" + + def __init__(self, typ: type[T], name: str, owner: Actor): + super().__init__(typ, name, owner, None) + + def __str__(self) -> str: # noqa: D401 + return f"{self.__class__.__name__} {super().__str__()} @ {self.owner}" + + def stream_connect(self, source: Out[Any]) -> None: + self.owner.stream_connect(self.name, source) + + +# --------------------------------------------------------------------------- +# Module infrastructure +# --------------------------------------------------------------------------- + + +class Module: # pylint: disable=too-few-public-methods + """Base-class for user logic blocks (actors).""" + + inputs: Dict[str, In[Any]] = {} + outputs: Dict[str, Out[Any]] = {} + rpcs: Dict[str, Callable[..., Any]] = {} + + # ------------------------------------------------------------------ + # Runtime helpers + # ------------------------------------------------------------------ + def stream_connect(self, input_name: str, source: Out[Any]) -> None: + inp = In(source.type, input_name, self, source) + self.inputs[input_name] = inp + setattr(self, input_name, inp) + + def subscribe(self, output_name: str, remote_input: In[Any]) -> None: # noqa: D401 + getattr(self, output_name).subscribe(remote_input) + + def receive_msg(self, input_name: str, msg: Any) -> None: # noqa: D401 + self.inputs[input_name]._receive(msg) + + def set_ref(self, ref: Actor) -> None: # noqa: D401 + self.ref = ref # created dynamically elsewhere + + def __str__(self) -> str: # noqa: D401 + return f"{self.__class__.__name__}-Local" + + @classmethod + def io(cls) -> str: # noqa: D401 + def _boundary(seq, first: str, mid: str, last: str): + seq = list(seq) + for idx, s in enumerate(seq): + if idx == 0: + yield first + s + elif idx == len(seq) - 1: + yield last + s + else: + yield mid + s + + def _box(name: str) -> str: + return "\n".join( + [ + "┌┴" + "─" * (len(name) + 1) + "┐", + f"│ {name} │", + "└┬" + "─" * (len(name) + 1) + "┘", + ] + ) + + inputs = list(_boundary(map(str, cls.inputs.values()), " ┌─ ", " ├─ ", " ├─ ")) + + # RPC signatures ------------------------------------------------- + rpc_lines: List[str] = [] + for n, fn in cls.rpcs.items(): + sig = inspect.signature(fn) + hints = get_type_hints(fn, include_extras=True) + params: List[str] = [] + for p in sig.parameters: + if p in ("self", "cls"): + continue + ann = hints.get(p, Any) + params.append(f"{p}: {getattr(ann, '__name__', repr(ann))}") + ret_ann = hints.get("return", Any) + rpc_lines.append( + f"{n}({', '.join(params)}) → {getattr(ret_ann, '__name__', repr(ret_ann))}" + ) + + rpcs = list(_boundary(rpc_lines, " ├─ ", " ├─ ", " └─ ")) + + outputs = list( + _boundary( + map(str, cls.outputs.values()), + " ├─ ", + " ├─ ", + " ├─ " if rpcs else " └─ ", + ) + ) + + if rpcs: + rpcs.insert(0, " │") + + return "\n".join(inputs + [_box(cls.__name__)] + outputs + rpcs) + + +# --------------------------------------------------------------------------- +# @module decorator – reflection heavy-lifting +# --------------------------------------------------------------------------- + + +def module(cls: type) -> type: # noqa: D401 + """Decorate *cls* to inject IO descriptors and RPC metadata.""" + + # Guarantee dicts are *per-class*, not shared between subclasses + cls.inputs = dict(getattr(cls, "inputs", {})) # type: ignore[attr-defined] + cls.outputs = dict(getattr(cls, "outputs", {})) # type: ignore[attr-defined] + cls.rpcs = dict(getattr(cls, "rpcs", {})) # type: ignore[attr-defined] + + # 1) Handle class-level annotations -------------------------------- + for name, ann in get_type_hints(cls, include_extras=True).items(): + origin = get_origin(ann) + if origin is Out: + inner, *_ = get_args(ann) or (Any,) + stream = Out(inner, name) + cls.outputs[name] = stream + setattr(cls, name, stream) + elif origin is In: + inner, *_ = get_args(ann) or (Any,) + stream = In(inner, name) + cls.inputs[name] = stream + setattr(cls, name, stream) + + # 2) Gather RPCs ---------------------------------------------------- + for n, obj in cls.__dict__.items(): + if callable(obj) and getattr(obj, "__rpc__", False): + cls.rpcs[n] = obj + + # 3) Wrap __init__ -------------------------------------------------- + original_init = cls.__init__ # type: ignore[attr-defined] + + def _init_wrapper(self, *args, **kwargs): # noqa: D401 – inner func + # (a) bind owners for pre-declared streams + for s in cls.outputs.values(): + s.owner = self + for s in cls.inputs.values(): + s.owner = self + + # (b) convert RemoteOut kwargs → connected In + new_kwargs = {} + for k, v in kwargs.items(): + if isinstance(v, RemoteOut): + inp = In(v.type, v.name, self, v) + cls.inputs[k] = inp + new_kwargs[k] = inp + else: + new_kwargs[k] = v + + # (c) delegate + original_init(self, *args, **new_kwargs) + + cls.__init__ = _init_wrapper # type: ignore[assignment] + + return cls + + +__all__ = [ + "In", + "Out", + "RemoteIn", + "RemoteOut", + "Module", + "module", + "rpc", + "State", +] diff --git a/dimos/core/colors.py b/dimos/core/colors.py new file mode 100644 index 0000000000..f137523e67 --- /dev/null +++ b/dimos/core/colors.py @@ -0,0 +1,43 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def green(text: str) -> str: + """Return the given text in green color.""" + return f"\033[92m{text}\033[0m" + + +def blue(text: str) -> str: + """Return the given text in blue color.""" + return f"\033[94m{text}\033[0m" + + +def red(text: str) -> str: + """Return the given text in red color.""" + return f"\033[91m{text}\033[0m" + + +def yellow(text: str) -> str: + """Return the given text in yellow color.""" + return f"\033[93m{text}\033[0m" + + +def cyan(text: str) -> str: + """Return the given text in cyan color.""" + return f"\033[96m{text}\033[0m" + + +def orange(text: str) -> str: + """Return the given text in orange color.""" + return f"\033[38;5;208m{text}\033[0m" diff --git a/dimos/core/o3dpickle.py b/dimos/core/o3dpickle.py new file mode 100644 index 0000000000..a18916a06c --- /dev/null +++ b/dimos/core/o3dpickle.py @@ -0,0 +1,38 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copyreg + +import numpy as np +import open3d as o3d + + +def reduce_external(obj): + # Convert Vector3dVector to numpy array for pickling + points_array = np.asarray(obj.points) + return (reconstruct_pointcloud, (points_array,)) + + +def reconstruct_pointcloud(points_array): + # Create new PointCloud and assign the points + pc = o3d.geometry.PointCloud() + pc.points = o3d.utility.Vector3dVector(points_array) + return pc + + +def register_picklers(): + # Register for the actual PointCloud class that gets instantiated + # We need to create a dummy PointCloud to get its actual class + _dummy_pc = o3d.geometry.PointCloud() + copyreg.pickle(_dummy_pc.__class__, reduce_external) diff --git a/dimos/core/test_base.py b/dimos/core/test_base.py new file mode 100644 index 0000000000..ca66f04287 --- /dev/null +++ b/dimos/core/test_base.py @@ -0,0 +1,145 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Event, Thread + +from dimos.multiprocess.actors2.base import In, Module, Out, RemoteOut, module, rpc +from dimos.multiprocess.actors2.base_dask import dimos +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.vector import Vector +from dimos.utils.testing import SensorReplay + +# never delete this +if dimos: + ... + + +@module +class RobotClient(Module): + odometry: Out[Odometry] + lidar: Out[LidarMessage] + mov: In[Vector] + + mov_msg_count = 0 + + def mov_callback(self, msg): + self.mov_msg_count += 1 + print("MOV REQ", msg) + + def __init__(self): + self.odometry = Out(Odometry, "odometry", self) + self._stop_event = Event() + self._thread = None + + def start(self): + self._thread = Thread(target=self.odomloop) + self._thread.start() + self.mov.subscribe(self.mov_callback) + + def odomloop(self): + odomdata = SensorReplay("raw_odometry_rotate_walk", autocast=Odometry.from_msg) + lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + + lidariter = lidardata.iterate() + self._stop_event.clear() + while not self._stop_event.is_set(): + for odom in odomdata.iterate(): + if self._stop_event.is_set(): + return + # print(odom) + odom.pubtime = time.perf_counter() + self.odometry.publish(odom) + + lidarmsg = next(lidariter) + lidarmsg.pubtime = time.perf_counter() + self.lidar.publish(lidarmsg) + time.sleep(0.1) + + def stop(self): + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) # Wait up to 1 second for clean shutdown + + +@module +class Navigation(Module): + mov: Out[Vector] + odom_msg_count = 0 + lidar_msg_count = 0 + + @rpc + def navigate_to(self, target: Vector) -> bool: ... + + def __init__( + self, + target_position: In[Vector], + lidar: In[LidarMessage], + odometry: In[Odometry], + ): + self.mov = Out(Vector, "mov", self) + self.target_position = target_position + self.lidar = lidar + self.odometry = odometry + + @rpc + def start(self): + def _odom(msg): + self.odom_msg_count += 1 + print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) + self.mov.publish(msg.pos) + + self.odometry.subscribe(_odom) + + def _lidar(msg): + self.lidar_msg_count += 1 + print("RCV:", (time.perf_counter() - msg.pubtime) * 1000, msg) + + self.lidar.subscribe(_lidar) + + +def test_deployment(dimos): + robot = dimos.deploy(RobotClient) + target_stream = RemoteOut[Vector](Vector, "map") + + print("\n") + print("lidar stream", robot.lidar, robot.lidar.owner) + print("target stream", target_stream) + print("odom stream", robot.odometry) + + nav = dimos.deploy( + Navigation, + target_position=target_stream, + lidar=robot.lidar, + odometry=robot.odometry, + ) + + print(robot.lidar) + robot.mov.connect(nav.mov) + + print("\n\n\n" + robot.io().result(), "\n") + print(nav.io().result(), "\n\n") + + robot.start().result() + nav.start().result() + time.sleep(1) + robot.stop().result() + print("robot.mov_msg_count", robot.mov_msg_count) + print("nav.odom_msg_count", nav.odom_msg_count) + print("nav.lidar_msg_count", nav.lidar_msg_count) + + assert robot.mov_msg_count >= 9 + assert nav.odom_msg_count >= 9 + assert nav.lidar_msg_count >= 9 diff --git a/dimos/core/test_o3dpickle.py b/dimos/core/test_o3dpickle.py new file mode 100644 index 0000000000..05d1cadde9 --- /dev/null +++ b/dimos/core/test_o3dpickle.py @@ -0,0 +1,36 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pickle + +from dimos.core.o3dpickle import register_picklers +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.utils.testing import SensorReplay + +register_picklers() + + +def test_enode_decode(): + lidardata = SensorReplay("office_lidar", autocast=LidarMessage.from_msg) + lidarmsg = next(lidardata.iterate()) + + binarypc = pickle.dumps(lidarmsg.pointcloud) + + # Test pickling and unpickling + binary = pickle.dumps(lidarmsg) + lidarmsg2 = pickle.loads(binary) + + # Verify the decoded message has the same properties + assert isinstance(lidarmsg2, LidarMessage) + assert len(lidarmsg2.pointcloud.points) == len(lidarmsg.pointcloud.points) diff --git a/dimos/robot/unitree_multiprocess/unitree_go2.py b/dimos/robot/unitree_multiprocess/unitree_go2.py new file mode 100644 index 0000000000..9c2637a752 --- /dev/null +++ b/dimos/robot/unitree_multiprocess/unitree_go2.py @@ -0,0 +1,48 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Literal, TypeAlias + +import numpy as np + +from dimos.core import In, Module, Out, RemoteIn, RemoteOut, module, rpc +from dimos.robot.unitree_webrtc.connection import WebRTCRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.vector import Vector + +VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] + + +class Robot(Module, WebRTCRobot): + mov: In[Vector] + lidar: Out[LidarMessage] + odometry: Out[Odometry] + video: Out[VideoMessage] + + def __init__(self, ip: str): + super().__init__(ip, mode="ai") + self.lidar = Out(LidarMessage, "lidar", self) + self.odometry = Out(Odometry, "odometry", self) + self.mov = In(Vector, "mov", self) + + def start(self): + self.connect() + + self.odom_stream().subscribe(self.odometry.publish) + self.lidar_stream().subscribe(self.lidar.publish) + self.video_stream().subscribe(self.video.publish) + self.mov.subscribe(self.move) diff --git a/dimos/robot/unitree_webrtc/connection.py b/dimos/robot/unitree_webrtc/connection.py index 16697c4378..d81bb2a6da 100644 --- a/dimos/robot/unitree_webrtc/connection.py +++ b/dimos/robot/unitree_webrtc/connection.py @@ -12,25 +12,29 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools import asyncio +import functools import threading -from typing import TypeAlias, Literal -from dimos.utils.reactive import backpressure, callback_to_observable -from dimos.types.vector import Vector -from dimos.types.position import Position -from dimos.robot.unitree_webrtc.type.lidar import LidarMessage -from dimos.robot.unitree_webrtc.type.odometry import Odometry -from go2_webrtc_driver.webrtc_driver import Go2WebRTCConnection, WebRTCConnectionMethod # type: ignore[import-not-found] -from go2_webrtc_driver.constants import RTC_TOPIC, VUI_COLOR, SPORT_CMD -from reactivex.subject import Subject -from reactivex.observable import Observable +from typing import Literal, TypeAlias + import numpy as np -from reactivex import operators as ops from aiortc import MediaStreamTrack -from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg -from dimos.robot.abstract_robot import AbstractRobot +from go2_webrtc_driver.constants import RTC_TOPIC, SPORT_CMD, VUI_COLOR +from go2_webrtc_driver.webrtc_driver import ( # type: ignore[import-not-found] + Go2WebRTCConnection, + WebRTCConnectionMethod, +) +from reactivex import operators as ops +from reactivex.observable import Observable +from reactivex.subject import Subject +from dimos.robot.abstract_robot import AbstractRobot +from dimos.robot.unitree_webrtc.type.lidar import LidarMessage +from dimos.robot.unitree_webrtc.type.lowstate import LowStateMsg +from dimos.robot.unitree_webrtc.type.odometry import Odometry +from dimos.types.position import Position +from dimos.types.vector import Vector +from dimos.utils.reactive import backpressure, callback_to_observable VideoMessage: TypeAlias = np.ndarray[tuple[int, int, Literal[3]], np.uint8] @@ -40,7 +44,6 @@ def __init__(self, ip: str, mode: str = "ai"): self.ip = ip self.mode = mode self.conn = Go2WebRTCConnection(WebRTCConnectionMethod.LocalSTA, ip=self.ip) - self.connect() def connect(self): self.loop = asyncio.new_event_loop() diff --git a/dimos/utils/threadpool.py b/dimos/utils/threadpool.py index cd2e7b16e5..52d1f97785 100644 --- a/dimos/utils/threadpool.py +++ b/dimos/utils/threadpool.py @@ -18,9 +18,11 @@ ReactiveX scheduler, ensuring consistent thread management across the application. """ -import os import multiprocessing +import os + from reactivex.scheduler import ThreadPoolScheduler + from .logging_config import logger @@ -37,7 +39,7 @@ def get_max_workers() -> int: # Create a ThreadPoolScheduler with a configurable number of workers. try: - max_workers = get_max_workers() + max_workers = 6 scheduler = ThreadPoolScheduler(max_workers=max_workers) logger.info(f"Using {max_workers} workers") except Exception as e: diff --git a/tests/run_unitree_mp.py b/tests/run_unitree_mp.py new file mode 100644 index 0000000000..b28b8c0e81 --- /dev/null +++ b/tests/run_unitree_mp.py @@ -0,0 +1,63 @@ +# Copyright 2025 Dimensional Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from threading import Event, Thread + +from dimos.core import Module, Out, initialize, module, rpc +from dimos.robot.unitree_multiprocess.unitree_go2 import Robot +from dimos.types.vector import Vector + + +class Mover(Module): + mov: Out[Vector] + _stop_event: Event + + def __init__(self): + self.mov = Out(Vector, "mov", self) + self._stop_event = Event() + + @rpc + def start(self): + self._thread = Thread(target=self.movloop) + self._thread.start() + + def movloop(self): + self._stop_event.clear() + while not self._stop_event.is_set(): + self.mov.publish(Vector(0.0, 0.0, 0.2)) + time.sleep(0.1) # Add a small delay to prevent excessive publishing + + @rpc + def stop(self): + self._stop_event.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=1.0) + + +def test_mover(): + dimos = initialize() + mover = dimos.deploy(Mover) + + robot = dimos.deploy(Robot, "192.168.1.1") + + robot.mov.connect(mover.mov) + + robot.start().result() + mover.start().result() + time.sleep(3) + + +if __name__ == "__main__": + test_mover()