Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions rock/actions/sandbox/sandbox_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class SandboxInfo(TypedDict, total=False):
create_time: str
start_time: str
stop_time: str
operator_type: str
2 changes: 1 addition & 1 deletion rock/admin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ async def lifespan(app: FastAPI):
nacos_provider=rock_config.nacos_provider,
k8s_config=rock_config.k8s,
)
operator = OperatorFactory.create_operator(operator_context)
operator = OperatorFactory.create_composite_operator(operator_context)

# init service
if rock_config.runtime.enable_auto_clear:
Expand Down
2 changes: 2 additions & 0 deletions rock/admin/proto/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class SandboxStartRequest(BaseModel):
"""Password for Docker registry authentication. When both username and password are provided, docker login will be performed before pulling the image."""
use_kata_runtime: bool = False
"""Whether to use kata container runtime (io.containerd.kata.v2) instead of --privileged mode."""
operator_type: str | None = None
"""The operator type to use for this sandbox (e.g., 'ray', 'k8s'). If None, uses the default operator."""


class SandboxCommand(Command):
Expand Down
8 changes: 8 additions & 0 deletions rock/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class RuntimeConfig:
python_env_path: str = field(default_factory=lambda: env_vars.ROCK_PYTHON_ENV_PATH)
envhub_db_url: str = field(default_factory=lambda: env_vars.ROCK_ENVHUB_DB_URL)
operator_type: str = "ray"
operator_types: list[str] = field(default_factory=list)
standard_spec: StandardSpec = field(default_factory=StandardSpec)
max_allowed_spec: StandardSpec = field(default_factory=lambda: StandardSpec(cpus=16, memory="64g"))
use_standard_spec_only: bool = False
Expand All @@ -142,6 +143,13 @@ def __post_init__(self) -> None:
if isinstance(self.max_allowed_spec, dict):
self.max_allowed_spec = StandardSpec(**self.max_allowed_spec)

# Backward compatibility: if operator_types is empty, populate from operator_type
if not self.operator_types:
self.operator_types = [self.operator_type]
# Keep operator_type as the default (first in the list)
if self.operator_types:
self.operator_type = self.operator_types[0]

if not self.python_env_path:
raise Exception(
"ROCK_PYTHON_ENV_PATH is not set, please specify the actual Python environment path "
Expand Down
3 changes: 3 additions & 0 deletions rock/deployments/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ class DockerDeploymentConfig(DeploymentConfig):
extended_params: dict[str, str] = Field(default_factory=dict)
"""Generic extension field for storing custom string key-value pairs."""

operator_type: str | None = None
"""The operator type to use for this sandbox (e.g., 'ray', 'k8s'). If None, uses the default operator."""

@model_validator(mode="before")
def validate_platform_args(cls, data: dict) -> dict:
"""Validate and extract platform arguments from docker_args.
Expand Down
149 changes: 149 additions & 0 deletions rock/sandbox/operator/composite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""Composite operator that delegates to multiple sub-operators based on operator type.

This module implements the Composite pattern: CompositeOperator inherits from
AbstractOperator and holds multiple concrete operators internally. It routes
each request to the appropriate sub-operator based on the operator_type field
in the deployment config (for submit) or in the Redis sandbox info (for
get_status/stop).

SandboxManager sees only a single AbstractOperator and requires no changes.
"""

from rock.actions.sandbox.sandbox_info import SandboxInfo
from rock.admin.core.redis_key import alive_sandbox_key
from rock.deployments.config import DeploymentConfig, DockerDeploymentConfig
from rock.logger import init_logger
from rock.sandbox.operator.abstract import AbstractOperator
from rock.utils.providers.redis_provider import RedisProvider

logger = init_logger(__name__)


class CompositeOperator(AbstractOperator):
"""Operator that holds multiple sub-operators and routes by operator_type.

When a sandbox is created via submit(), the operator_type from the
DeploymentConfig determines which sub-operator handles the request.
The chosen operator_type is recorded in the returned SandboxInfo so that
SandboxManager persists it to Redis.

For get_status() and stop(), the operator_type is looked up from Redis
to route to the correct sub-operator.
"""

def __init__(
self,
operators: dict[str, AbstractOperator],
default_operator_type: str,
):
"""Initialize CompositeOperator.

Args:
operators: Mapping from operator type name (e.g., "ray", "k8s")
to the corresponding AbstractOperator instance.
default_operator_type: The operator type to use when the request
does not specify one explicitly.
"""
if not operators:
raise ValueError("At least one operator must be provided")

normalized_default = default_operator_type.lower()
if normalized_default not in operators:
raise ValueError(
f"Default operator type '{default_operator_type}' not found "
f"in provided operators: {list(operators.keys())}"
)

self._operators = operators
self._default_operator_type = normalized_default
logger.info(
f"CompositeOperator initialized with operators: {list(operators.keys())}, "
f"default: '{normalized_default}'"
)

def set_redis_provider(self, redis_provider: RedisProvider):
"""Propagate the redis provider to all sub-operators."""
self._redis_provider = redis_provider
for operator in self._operators.values():
operator.set_redis_provider(redis_provider)

def _resolve_operator_for_config(self, config: DeploymentConfig) -> tuple[str, AbstractOperator]:
"""Resolve the sub-operator for a submit request based on config.

Returns:
A tuple of (resolved_operator_type, operator_instance).
"""
requested_type = None
if isinstance(config, DockerDeploymentConfig) and config.operator_type:
requested_type = config.operator_type.lower()

resolved_type = requested_type or self._default_operator_type
operator = self._operators.get(resolved_type)
if operator is None:
available = list(self._operators.keys())
raise ValueError(f"Unsupported operator type: '{resolved_type}'. Available types: {available}")
return resolved_type, operator

async def _resolve_operator_for_sandbox(self, sandbox_id: str) -> tuple[str, AbstractOperator]:
"""Resolve the sub-operator for an existing sandbox by looking up Redis.

Falls back to the default operator if Redis is unavailable or the
sandbox has no operator_type recorded.

Returns:
A tuple of (resolved_operator_type, operator_instance).
"""
resolved_type = self._default_operator_type

if self._redis_provider:
sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$")
if sandbox_status and len(sandbox_status) > 0:
stored_type = sandbox_status[0].get("operator_type")
if stored_type:
resolved_type = stored_type.lower()

operator = self._operators.get(resolved_type)
if operator is None:
logger.warning(
f"Operator type '{resolved_type}' for sandbox '{sandbox_id}' not found, "
f"falling back to default '{self._default_operator_type}'"
)
resolved_type = self._default_operator_type
operator = self._operators[resolved_type]

return resolved_type, operator

async def submit(self, config: DeploymentConfig, user_info: dict = {}) -> SandboxInfo:
"""Submit a sandbox creation request to the appropriate sub-operator.

The operator_type is determined from config.operator_type (if set) or
falls back to the default. The resolved operator_type is written into
the returned SandboxInfo so that SandboxManager persists it to Redis.
"""
resolved_type, operator = self._resolve_operator_for_config(config)
logger.info(
f"Routing submit for sandbox '{getattr(config, 'container_name', 'unknown')}' "
f"to operator '{resolved_type}'"
)

sandbox_info = await operator.submit(config, user_info)
sandbox_info["operator_type"] = resolved_type
return sandbox_info

async def get_status(self, sandbox_id: str) -> SandboxInfo:
"""Get sandbox status by routing to the operator that created it.

Ensures the returned SandboxInfo always contains operator_type so that
SandboxManager.get_status() does not lose it when writing back to Redis.
"""
resolved_type, operator = await self._resolve_operator_for_sandbox(sandbox_id)
logger.debug(f"Routing get_status for sandbox '{sandbox_id}' to operator '{resolved_type}'")
sandbox_info = await operator.get_status(sandbox_id)
sandbox_info["operator_type"] = resolved_type
return sandbox_info

async def stop(self, sandbox_id: str) -> bool:
"""Stop a sandbox by routing to the operator that created it."""
resolved_type, operator = await self._resolve_operator_for_sandbox(sandbox_id)
logger.info(f"Routing stop for sandbox '{sandbox_id}' to operator '{resolved_type}'")
return await operator.stop(sandbox_id)
80 changes: 74 additions & 6 deletions rock/sandbox/operator/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from rock.config import K8sConfig, RuntimeConfig
from rock.logger import init_logger
from rock.sandbox.operator.abstract import AbstractOperator
from rock.sandbox.operator.composite import CompositeOperator
from rock.sandbox.operator.k8s.operator import K8sOperator
from rock.sandbox.operator.ray import RayOperator
from rock.utils.providers.nacos_provider import NacosConfigProvider
Expand Down Expand Up @@ -40,10 +41,11 @@ class OperatorFactory:
"""

@staticmethod
def create_operator(context: OperatorContext) -> AbstractOperator:
"""Create an operator instance based on the runtime configuration.
def _create_single_operator(operator_type: str, context: OperatorContext) -> AbstractOperator:
"""Create a single operator instance by type.

Args:
operator_type: The operator type string (e.g., "ray", "k8s")
context: OperatorContext containing all necessary dependencies

Returns:
Expand All @@ -52,20 +54,86 @@ def create_operator(context: OperatorContext) -> AbstractOperator:
Raises:
ValueError: If operator_type is not supported or required dependencies are missing
"""
operator_type = context.runtime_config.operator_type.lower()
normalized_type = operator_type.lower()

if operator_type == "ray":
if normalized_type == "ray":
if context.ray_service is None:
raise ValueError("RayService is required for RayOperator")
logger.info("Creating RayOperator")
ray_operator = RayOperator(ray_service=context.ray_service, runtime_config=context.runtime_config)
if context.nacos_provider is not None:
ray_operator.set_nacos_provider(context.nacos_provider)
return ray_operator
elif operator_type == "k8s":
elif normalized_type == "k8s":
if context.k8s_config is None:
raise ValueError("K8sConfig is required for K8sOperator")
logger.info("Creating K8sOperator")
return K8sOperator(k8s_config=context.k8s_config)
else:
raise ValueError(f"Unsupported operator type: {operator_type}. " f"Supported types: ray, kubernetes")
raise ValueError(f"Unsupported operator type: {operator_type}. Supported types: ray, k8s")

@staticmethod
def create_operator(context: OperatorContext) -> AbstractOperator:
"""Create a single operator instance based on the default operator_type in runtime config.

Args:
context: OperatorContext containing all necessary dependencies

Returns:
AbstractOperator: The created operator instance
"""
return OperatorFactory._create_single_operator(context.runtime_config.operator_type, context)

@staticmethod
def create_operators(context: OperatorContext) -> dict[str, AbstractOperator]:
"""Create multiple operator instances based on operator_types list in runtime config.

Iterates over runtime_config.operator_types and creates an operator for each type.
The returned dict is keyed by the normalized operator type string.

Args:
context: OperatorContext containing all necessary dependencies

Returns:
dict[str, AbstractOperator]: Mapping from operator type to operator instance

Raises:
ValueError: If operator_types is empty or any type is unsupported
"""
operator_types = context.runtime_config.operator_types
if not operator_types:
raise ValueError("operator_types list is empty, at least one operator type must be configured")

operators: dict[str, AbstractOperator] = {}
for operator_type in operator_types:
normalized_type = operator_type.lower()
if normalized_type in operators:
logger.warning(f"Duplicate operator type '{normalized_type}' in config, skipping")
continue
operator = OperatorFactory._create_single_operator(normalized_type, context)
operators[normalized_type] = operator
logger.info(f"Created operator for type '{normalized_type}'")

logger.info(f"Initialized {len(operators)} operator(s): {list(operators.keys())}")
return operators

@staticmethod
def create_composite_operator(context: OperatorContext) -> CompositeOperator:
"""Create a CompositeOperator that wraps multiple sub-operators.

This is the recommended entry point for multi-operator setups. It reads
operator_types from the runtime config, creates each sub-operator, and
wraps them in a CompositeOperator that implements AbstractOperator.

The returned CompositeOperator can be passed directly to SandboxManager
as a single operator — no changes to SandboxManager are needed.

Args:
context: OperatorContext containing all necessary dependencies

Returns:
CompositeOperator: A composite operator wrapping all configured sub-operators
"""
operators = OperatorFactory.create_operators(context)
default_type = context.runtime_config.operator_type.lower()
return CompositeOperator(operators=operators, default_operator_type=default_type)
1 change: 1 addition & 0 deletions rock/sdk/sandbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ async def start(self):
"registry_username": self.config.registry_username,
"registry_password": self.config.registry_password,
"use_kata_runtime": self.config.use_kata_runtime,
"operator_type": self.config.operator_type,
}
try:
response = await HttpUtils.post(url, headers, data)
Expand Down
2 changes: 2 additions & 0 deletions rock/sdk/sandbox/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ class SandboxConfig(BaseConfig):
registry_username: str | None = None
registry_password: str | None = None
use_kata_runtime: bool = False
operator_type: str = "ray"
"""The operator type to use for this sandbox (e.g., 'ray', 'k8s'). If None, uses the server default."""


class SandboxGroupConfig(SandboxConfig):
Expand Down
63 changes: 63 additions & 0 deletions tests/unit/sdk/sandbox/test_sandbox_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from rock.sdk.sandbox.config import SandboxConfig, SandboxGroupConfig


def test_sandbox_config_operator_type_default_is_ray():
"""operator_type should default to 'ray' when not specified."""
config = SandboxConfig()
assert config.operator_type == "ray"


def test_sandbox_config_operator_type_set_explicitly():
"""operator_type should be stored when explicitly set."""
config = SandboxConfig(operator_type="k8s")
assert config.operator_type == "k8s"


def test_sandbox_config_operator_type_ray():
"""operator_type should accept 'ray' value."""
config = SandboxConfig(operator_type="ray")
assert config.operator_type == "ray"


def test_sandbox_config_operator_type_with_other_fields():
"""operator_type should coexist with other config fields."""
config = SandboxConfig(
image="ubuntu:22.04",
memory="16g",
cpus=4,
cluster="us-east",
operator_type="k8s",
)
assert config.operator_type == "k8s"
assert config.image == "ubuntu:22.04"
assert config.memory == "16g"
assert config.cpus == 4
assert config.cluster == "us-east"


def test_sandbox_config_operator_type_serialization():
"""operator_type should appear in model_dump output."""
config = SandboxConfig(operator_type="ray")
dumped = config.model_dump()
assert "operator_type" in dumped
assert dumped["operator_type"] == "ray"


def test_sandbox_config_operator_type_default_serialization():
"""Default operator_type='ray' should appear in model_dump output."""
config = SandboxConfig()
dumped = config.model_dump()
assert "operator_type" in dumped
assert dumped["operator_type"] == "ray"


def test_sandbox_group_config_has_operator_type():
"""SandboxGroupConfig should support operator_type field."""
config = SandboxGroupConfig(operator_type="ray", size=2)
assert config.operator_type == "ray"


def test_sandbox_group_config_operator_type_default_ray():
"""SandboxGroupConfig.operator_type should default to 'ray'."""
config = SandboxGroupConfig(size=2)
assert config.operator_type == "ray"
Loading
Loading