Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion dimos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path

try:
# Not a dependency, just the best way to get config path if available.
from gi.repository import GLib # type: ignore[import-untyped,import-not-found]
except ImportError:
CONFIG_DIR = Path(os.environ.get("XDG_CONFIG_HOME", Path.home() / ".config"))
STATE_DIR = Path(os.environ.get("XDG_STATE_HOME", Path.home() / ".local" / "state")) / "dimos"
else:
CONFIG_DIR = Path(GLib.get_user_config_dir())
STATE_DIR = Path(GLib.get_user_state_dir()) / "dimos"

DIMOS_PROJECT_ROOT = Path(__file__).parent.parent

DIMOS_LOG_DIR = DIMOS_PROJECT_ROOT / "logs"
if (DIMOS_PROJECT_ROOT / ".git").exists():
# Running from Git repository
LOG_DIR = DIMOS_PROJECT_ROOT / "logs"
else:
# Running from an installed package - use XDG_STATE_HOME
LOG_DIR = STATE_DIR / "logs"

"""
Constants for shared memory
Expand Down
27 changes: 19 additions & 8 deletions dimos/core/blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from collections import defaultdict
from collections.abc import Callable, Mapping
from collections.abc import Callable, Mapping, MutableMapping
from dataclasses import dataclass, field, replace
from functools import cached_property, reduce
import operator
Expand All @@ -22,6 +22,8 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, Union, cast, get_args, get_origin, get_type_hints

from pydantic import BaseModel, create_model

if TYPE_CHECKING:
from dimos.protocol.service.system_configurator.base import SystemConfigurator

Expand Down Expand Up @@ -164,6 +166,11 @@ def create(cls, module: type[ModuleBase], **kwargs: Any) -> "Blueprint":
def disabled_modules(self, *modules: type[ModuleBase]) -> "Blueprint":
return replace(self, disabled_modules_tuple=self.disabled_modules_tuple + modules)

def config(self) -> type[BaseModel]:
configs = {b.module.name: (b.module.default_config | None, None) for b in self.blueprints}
configs["g"] = (GlobalConfig | None, None)
return create_model("BlueprintConfig", __config__={"extra": "forbid"}, **configs) # type: ignore[call-overload,no-any-return]

def transports(self, transports: dict[tuple[str, type], Any]) -> "Blueprint":
return replace(self, transport_map=MappingProxyType({**self.transport_map, **transports}))

Expand Down Expand Up @@ -290,13 +297,16 @@ def _verify_no_name_conflicts(self) -> None:
raise ValueError("\n".join(error_lines))

def _deploy_all_modules(
self, module_coordinator: ModuleCoordinator, global_config: GlobalConfig
self,
module_coordinator: ModuleCoordinator,
global_config: GlobalConfig,
blueprint_args: Mapping[str, Mapping[str, Any]],
) -> None:
module_specs: list[ModuleSpec] = []
for blueprint in self._active_blueprints:
module_specs.append((blueprint.module, global_config, blueprint.kwargs))
module_specs.append((blueprint.module, global_config, blueprint.kwargs.copy()))

module_coordinator.deploy_parallel(module_specs)
module_coordinator.deploy_parallel(module_specs, blueprint_args)

def _connect_streams(self, module_coordinator: ModuleCoordinator) -> None:
# dict when given (final/remapped) stream name+type, provides a list of modules + original (non-remapped) stream names
Expand Down Expand Up @@ -444,12 +454,13 @@ def _connect_module_refs(self, module_coordinator: ModuleCoordinator) -> None:

def build(
self,
cli_config_overrides: Mapping[str, Any] | None = None,
blueprint_args: MutableMapping[str, Any] | None = None,
) -> ModuleCoordinator:
logger.info("Building the blueprint")
global_config.update(**dict(self.global_config_overrides))
if cli_config_overrides:
global_config.update(**dict(cli_config_overrides))
blueprint_args = blueprint_args or {}
if "g" in blueprint_args:
global_config.update(**blueprint_args.pop("g"))

self._run_configurators()
self._check_requirements()
Expand All @@ -460,7 +471,7 @@ def build(
module_coordinator.start()

# all module constructors are called here (each of them setup their own)
self._deploy_all_modules(module_coordinator, global_config)
self._deploy_all_modules(module_coordinator, global_config, blueprint_args)
self._connect_streams(module_coordinator)
self._connect_module_refs(module_coordinator)

Expand Down
5 changes: 5 additions & 0 deletions dimos/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ def __init__(self, config_args: dict[str, Any]):
except ValueError:
...

@classproperty
def name(self) -> str:
"""Name for this module to be used for blueprint configs."""
return self.__name__.lower() # type: ignore[attr-defined,no-any-return]

@property
def frame_id(self) -> str:
base = self.config.frame_id or self.__class__.__name__
Expand Down
7 changes: 5 additions & 2 deletions dimos/core/module_coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

from collections.abc import Mapping
import threading
from typing import TYPE_CHECKING, Any, TypeAlias

Expand Down Expand Up @@ -99,7 +100,9 @@ def deploy(
self._deployed_modules[module_class] = deployed_module # type: ignore[assignment]
return deployed_module # type: ignore[return-value]

def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]:
def deploy_parallel(
self, module_specs: list[ModuleSpec], blueprint_args: Mapping[str, Mapping[str, Any]]
) -> list[ModuleProxy]:
if not self._managers:
raise ValueError("Not started")

Expand All @@ -115,7 +118,7 @@ def deploy_parallel(self, module_specs: list[ModuleSpec]) -> list[ModuleProxy]:
results: list[Any] = [None] * len(module_specs)

def _deploy_group(dep: str) -> None:
deployed = self._managers[dep].deploy_parallel(specs_by_deployment[dep])
deployed = self._managers[dep].deploy_parallel(specs_by_deployment[dep], blueprint_args)
for index, module in zip(indices_by_deployment[dep], deployed, strict=True):
results[index] = module

Expand Down
13 changes: 2 additions & 11 deletions dimos/core/run_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,12 @@
import signal
import time

from dimos.constants import STATE_DIR
from dimos.utils.logging_config import setup_logger

logger = setup_logger()


def _get_state_dir() -> Path:
"""XDG_STATE_HOME compliant state directory for dimos."""
xdg = os.environ.get("XDG_STATE_HOME")
if xdg:
return Path(xdg) / "dimos"
return Path.home() / ".local" / "state" / "dimos"


REGISTRY_DIR = _get_state_dir() / "runs"
LOG_BASE_DIR = _get_state_dir() / "logs"
REGISTRY_DIR = STATE_DIR / "runs"


@dataclass
Expand Down
37 changes: 26 additions & 11 deletions dimos/core/test_blueprints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from types import MappingProxyType
from typing import Protocol

from pydantic import ValidationError
import pytest

from dimos.core._test_future_annotations_helper import (
Expand All @@ -38,9 +40,11 @@
from dimos.spec.utils import Spec

# Disable Rerun for tests (prevents viewer spawn and gRPC flush errors)
_BUILD_WITHOUT_RERUN = {
"cli_config_overrides": {"viewer": "none"},
}
_BUILD_WITHOUT_RERUN = MappingProxyType(
{
"g": {"viewer": "none"},
}
)


class Scratch:
Expand Down Expand Up @@ -141,6 +145,17 @@ def test_autoconnect() -> None:
)


def test_config() -> None:
blueprint = autoconnect(ModuleA.blueprint(), ModuleB.blueprint())
config = blueprint.config()
assert config.model_fields.keys() == {"modulea", "moduleb", "g"}
assert config.model_fields["modulea"].annotation == ModuleA.default_config | None
assert config.model_fields["moduleb"].annotation == ModuleB.default_config | None

with pytest.raises(ValidationError, match="invalid_key"):
config(module_a={"invalid_key": 5})


def test_transports() -> None:
custom_transport = LCMTransport("/custom_topic", Data1)
blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint()).transports(
Expand All @@ -166,7 +181,7 @@ def test_global_config() -> None:
def test_build_happy_path() -> None:
blueprint_set = autoconnect(ModuleA.blueprint(), ModuleB.blueprint(), ModuleC.blueprint())

coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
assert isinstance(coordinator, ModuleCoordinator)
Expand Down Expand Up @@ -295,7 +310,7 @@ def test_remapping() -> None:
assert ("color_image", Data1) not in blueprint_set._all_name_types

# Build and verify streams work
coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
source_instance = coordinator.get_instance(SourceModule)
Expand Down Expand Up @@ -345,7 +360,7 @@ def test_future_annotations_autoconnect() -> None:

blueprint_set = autoconnect(FutureModuleOut.blueprint(), FutureModuleIn.blueprint())

coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
out_instance = coordinator.get_instance(FutureModuleOut)
Expand Down Expand Up @@ -437,7 +452,7 @@ def test_module_ref_direct() -> None:
coordinator = autoconnect(
Calculator1.blueprint(),
Mod1.blueprint(),
).build(**_BUILD_WITHOUT_RERUN)
).build(_BUILD_WITHOUT_RERUN.copy())

try:
mod1 = coordinator.get_instance(Mod1)
Expand All @@ -453,7 +468,7 @@ def test_module_ref_spec() -> None:
coordinator = autoconnect(
Calculator1.blueprint(),
Mod2.blueprint(),
).build(**_BUILD_WITHOUT_RERUN)
).build(_BUILD_WITHOUT_RERUN.copy())

try:
mod2 = coordinator.get_instance(Mod2)
Expand All @@ -470,7 +485,7 @@ def test_disabled_modules_are_skipped_during_build() -> None:
ModuleA.blueprint(), ModuleB.blueprint(), ModuleC.blueprint()
).disabled_modules(ModuleC)

coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
assert coordinator.get_instance(ModuleA) is not None
Expand All @@ -488,7 +503,7 @@ def test_disabled_module_ref_gets_noop_proxy() -> None:
Mod2.blueprint(),
).disabled_modules(Calculator1)

coordinator = blueprint_set.build(**_BUILD_WITHOUT_RERUN)
coordinator = blueprint_set.build(_BUILD_WITHOUT_RERUN.copy())

try:
mod2 = coordinator.get_instance(Mod2)
Expand Down Expand Up @@ -528,7 +543,7 @@ def test_module_ref_remap_ambiguous() -> None:
(Mod2, "calc", Calculator1),
]
)
.build(**_BUILD_WITHOUT_RERUN)
.build(_BUILD_WITHOUT_RERUN.copy())
)

try:
Expand Down
3 changes: 2 additions & 1 deletion dimos/core/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,8 @@ def test_worker_manager_parallel_deployment(create_worker_manager):
(SimpleModule, global_config, {}),
(AnotherModule, global_config, {}),
(ThirdModule, global_config, {}),
]
],
{},
)

assert len(modules) == 3
Expand Down
2 changes: 1 addition & 1 deletion dimos/core/tests/test_docker_deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def test_deploy_parallel_deploys_docker_module(self, mock_proxy_cls, dimos_clust
specs = [
(FakeDockerModule, (), {}),
]
results = dimos_cluster.deploy_parallel(specs)
results = dimos_cluster.deploy_parallel(specs, {})

mock_proxy_cls.assert_called_once()
assert results[0] is mock_dm
Expand Down
Loading
Loading