From 06b4b56f77c8ac36682d56f1e6472cee2ffc6bda Mon Sep 17 00:00:00 2001 From: daifangwen Date: Wed, 4 Mar 2026 09:43:27 +0000 Subject: [PATCH 1/3] support multi operator --- rock/actions/sandbox/sandbox_info.py | 1 + rock/admin/main.py | 2 +- rock/admin/proto/request.py | 2 + rock/config.py | 8 + rock/deployments/config.py | 3 + rock/sandbox/operator/composite.py | 149 ++++++ rock/sandbox/operator/factory.py | 80 ++- rock/sdk/sandbox/client.py | 1 + rock/sdk/sandbox/config.py | 2 + tests/unit/test_composite_operator.py | 682 ++++++++++++++++++++++++++ tests/unit/test_sdk_operator_type.py | 276 +++++++++++ 11 files changed, 1199 insertions(+), 7 deletions(-) create mode 100644 rock/sandbox/operator/composite.py create mode 100644 tests/unit/test_composite_operator.py create mode 100644 tests/unit/test_sdk_operator_type.py diff --git a/rock/actions/sandbox/sandbox_info.py b/rock/actions/sandbox/sandbox_info.py index a6a28816a..b54716575 100644 --- a/rock/actions/sandbox/sandbox_info.py +++ b/rock/actions/sandbox/sandbox_info.py @@ -24,3 +24,4 @@ class SandboxInfo(TypedDict, total=False): create_time: str start_time: str stop_time: str + operator_type: str diff --git a/rock/admin/main.py b/rock/admin/main.py index 2cb242833..b5eed2485 100644 --- a/rock/admin/main.py +++ b/rock/admin/main.py @@ -81,7 +81,7 @@ async def lifespan(app: FastAPI): nacos_provider=rock_config.nacos_provider, k8s_config=rock_config.k8s, ) - operator = OperatorFactory.create_operator(operator_context) + operator = OperatorFactory.create_composite_operator(operator_context) # init service if rock_config.runtime.enable_auto_clear: diff --git a/rock/admin/proto/request.py b/rock/admin/proto/request.py index b7b45ce95..a3520b918 100644 --- a/rock/admin/proto/request.py +++ b/rock/admin/proto/request.py @@ -33,6 +33,8 @@ class SandboxStartRequest(BaseModel): """Password for Docker registry authentication. When both username and password are provided, docker login will be performed before pulling the image.""" use_kata_runtime: bool = False """Whether to use kata container runtime (io.containerd.kata.v2) instead of --privileged mode.""" + operator_type: str | None = None + """The operator type to use for this sandbox (e.g., 'ray', 'k8s'). If None, uses the default operator.""" class SandboxCommand(Command): diff --git a/rock/config.py b/rock/config.py index b83154b32..5988d5901 100644 --- a/rock/config.py +++ b/rock/config.py @@ -129,6 +129,7 @@ class RuntimeConfig: python_env_path: str = field(default_factory=lambda: env_vars.ROCK_PYTHON_ENV_PATH) envhub_db_url: str = field(default_factory=lambda: env_vars.ROCK_ENVHUB_DB_URL) operator_type: str = "ray" + operator_types: list[str] = field(default_factory=list) standard_spec: StandardSpec = field(default_factory=StandardSpec) max_allowed_spec: StandardSpec = field(default_factory=lambda: StandardSpec(cpus=16, memory="64g")) use_standard_spec_only: bool = False @@ -142,6 +143,13 @@ def __post_init__(self) -> None: if isinstance(self.max_allowed_spec, dict): self.max_allowed_spec = StandardSpec(**self.max_allowed_spec) + # Backward compatibility: if operator_types is empty, populate from operator_type + if not self.operator_types: + self.operator_types = [self.operator_type] + # Keep operator_type as the default (first in the list) + if self.operator_types: + self.operator_type = self.operator_types[0] + if not self.python_env_path: raise Exception( "ROCK_PYTHON_ENV_PATH is not set, please specify the actual Python environment path " diff --git a/rock/deployments/config.py b/rock/deployments/config.py index 7b8b66f9a..94a3be718 100644 --- a/rock/deployments/config.py +++ b/rock/deployments/config.py @@ -121,6 +121,9 @@ class DockerDeploymentConfig(DeploymentConfig): extended_params: dict[str, str] = Field(default_factory=dict) """Generic extension field for storing custom string key-value pairs.""" + operator_type: str | None = None + """The operator type to use for this sandbox (e.g., 'ray', 'k8s'). If None, uses the default operator.""" + @model_validator(mode="before") def validate_platform_args(cls, data: dict) -> dict: """Validate and extract platform arguments from docker_args. diff --git a/rock/sandbox/operator/composite.py b/rock/sandbox/operator/composite.py new file mode 100644 index 000000000..15bdc2254 --- /dev/null +++ b/rock/sandbox/operator/composite.py @@ -0,0 +1,149 @@ +"""Composite operator that delegates to multiple sub-operators based on operator type. + +This module implements the Composite pattern: CompositeOperator inherits from +AbstractOperator and holds multiple concrete operators internally. It routes +each request to the appropriate sub-operator based on the operator_type field +in the deployment config (for submit) or in the Redis sandbox info (for +get_status/stop). + +SandboxManager sees only a single AbstractOperator and requires no changes. +""" + +from rock.actions.sandbox.sandbox_info import SandboxInfo +from rock.admin.core.redis_key import alive_sandbox_key +from rock.deployments.config import DeploymentConfig, DockerDeploymentConfig +from rock.logger import init_logger +from rock.sandbox.operator.abstract import AbstractOperator +from rock.utils.providers.redis_provider import RedisProvider + +logger = init_logger(__name__) + + +class CompositeOperator(AbstractOperator): + """Operator that holds multiple sub-operators and routes by operator_type. + + When a sandbox is created via submit(), the operator_type from the + DeploymentConfig determines which sub-operator handles the request. + The chosen operator_type is recorded in the returned SandboxInfo so that + SandboxManager persists it to Redis. + + For get_status() and stop(), the operator_type is looked up from Redis + to route to the correct sub-operator. + """ + + def __init__( + self, + operators: dict[str, AbstractOperator], + default_operator_type: str, + ): + """Initialize CompositeOperator. + + Args: + operators: Mapping from operator type name (e.g., "ray", "k8s") + to the corresponding AbstractOperator instance. + default_operator_type: The operator type to use when the request + does not specify one explicitly. + """ + if not operators: + raise ValueError("At least one operator must be provided") + + normalized_default = default_operator_type.lower() + if normalized_default not in operators: + raise ValueError( + f"Default operator type '{default_operator_type}' not found " + f"in provided operators: {list(operators.keys())}" + ) + + self._operators = operators + self._default_operator_type = normalized_default + logger.info( + f"CompositeOperator initialized with operators: {list(operators.keys())}, " + f"default: '{normalized_default}'" + ) + + def set_redis_provider(self, redis_provider: RedisProvider): + """Propagate the redis provider to all sub-operators.""" + self._redis_provider = redis_provider + for operator in self._operators.values(): + operator.set_redis_provider(redis_provider) + + def _resolve_operator_for_config(self, config: DeploymentConfig) -> tuple[str, AbstractOperator]: + """Resolve the sub-operator for a submit request based on config. + + Returns: + A tuple of (resolved_operator_type, operator_instance). + """ + requested_type = None + if isinstance(config, DockerDeploymentConfig) and config.operator_type: + requested_type = config.operator_type.lower() + + resolved_type = requested_type or self._default_operator_type + operator = self._operators.get(resolved_type) + if operator is None: + available = list(self._operators.keys()) + raise ValueError(f"Unsupported operator type: '{resolved_type}'. Available types: {available}") + return resolved_type, operator + + async def _resolve_operator_for_sandbox(self, sandbox_id: str) -> tuple[str, AbstractOperator]: + """Resolve the sub-operator for an existing sandbox by looking up Redis. + + Falls back to the default operator if Redis is unavailable or the + sandbox has no operator_type recorded. + + Returns: + A tuple of (resolved_operator_type, operator_instance). + """ + resolved_type = self._default_operator_type + + if self._redis_provider: + sandbox_status = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$") + if sandbox_status and len(sandbox_status) > 0: + stored_type = sandbox_status[0].get("operator_type") + if stored_type: + resolved_type = stored_type.lower() + + operator = self._operators.get(resolved_type) + if operator is None: + logger.warning( + f"Operator type '{resolved_type}' for sandbox '{sandbox_id}' not found, " + f"falling back to default '{self._default_operator_type}'" + ) + resolved_type = self._default_operator_type + operator = self._operators[resolved_type] + + return resolved_type, operator + + async def submit(self, config: DeploymentConfig, user_info: dict = {}) -> SandboxInfo: + """Submit a sandbox creation request to the appropriate sub-operator. + + The operator_type is determined from config.operator_type (if set) or + falls back to the default. The resolved operator_type is written into + the returned SandboxInfo so that SandboxManager persists it to Redis. + """ + resolved_type, operator = self._resolve_operator_for_config(config) + logger.info( + f"Routing submit for sandbox '{getattr(config, 'container_name', 'unknown')}' " + f"to operator '{resolved_type}'" + ) + + sandbox_info = await operator.submit(config, user_info) + sandbox_info["operator_type"] = resolved_type + return sandbox_info + + async def get_status(self, sandbox_id: str) -> SandboxInfo: + """Get sandbox status by routing to the operator that created it. + + Ensures the returned SandboxInfo always contains operator_type so that + SandboxManager.get_status() does not lose it when writing back to Redis. + """ + resolved_type, operator = await self._resolve_operator_for_sandbox(sandbox_id) + logger.debug(f"Routing get_status for sandbox '{sandbox_id}' to operator '{resolved_type}'") + sandbox_info = await operator.get_status(sandbox_id) + sandbox_info["operator_type"] = resolved_type + return sandbox_info + + async def stop(self, sandbox_id: str) -> bool: + """Stop a sandbox by routing to the operator that created it.""" + resolved_type, operator = await self._resolve_operator_for_sandbox(sandbox_id) + logger.info(f"Routing stop for sandbox '{sandbox_id}' to operator '{resolved_type}'") + return await operator.stop(sandbox_id) diff --git a/rock/sandbox/operator/factory.py b/rock/sandbox/operator/factory.py index b70ae6555..6bfc96ff0 100644 --- a/rock/sandbox/operator/factory.py +++ b/rock/sandbox/operator/factory.py @@ -7,6 +7,7 @@ from rock.config import K8sConfig, RuntimeConfig from rock.logger import init_logger from rock.sandbox.operator.abstract import AbstractOperator +from rock.sandbox.operator.composite import CompositeOperator from rock.sandbox.operator.k8s.operator import K8sOperator from rock.sandbox.operator.ray import RayOperator from rock.utils.providers.nacos_provider import NacosConfigProvider @@ -40,10 +41,11 @@ class OperatorFactory: """ @staticmethod - def create_operator(context: OperatorContext) -> AbstractOperator: - """Create an operator instance based on the runtime configuration. + def _create_single_operator(operator_type: str, context: OperatorContext) -> AbstractOperator: + """Create a single operator instance by type. Args: + operator_type: The operator type string (e.g., "ray", "k8s") context: OperatorContext containing all necessary dependencies Returns: @@ -52,9 +54,9 @@ def create_operator(context: OperatorContext) -> AbstractOperator: Raises: ValueError: If operator_type is not supported or required dependencies are missing """ - operator_type = context.runtime_config.operator_type.lower() + normalized_type = operator_type.lower() - if operator_type == "ray": + if normalized_type == "ray": if context.ray_service is None: raise ValueError("RayService is required for RayOperator") logger.info("Creating RayOperator") @@ -62,10 +64,76 @@ def create_operator(context: OperatorContext) -> AbstractOperator: if context.nacos_provider is not None: ray_operator.set_nacos_provider(context.nacos_provider) return ray_operator - elif operator_type == "k8s": + elif normalized_type == "k8s": if context.k8s_config is None: raise ValueError("K8sConfig is required for K8sOperator") logger.info("Creating K8sOperator") return K8sOperator(k8s_config=context.k8s_config) else: - raise ValueError(f"Unsupported operator type: {operator_type}. " f"Supported types: ray, kubernetes") + raise ValueError(f"Unsupported operator type: {operator_type}. Supported types: ray, k8s") + + @staticmethod + def create_operator(context: OperatorContext) -> AbstractOperator: + """Create a single operator instance based on the default operator_type in runtime config. + + Args: + context: OperatorContext containing all necessary dependencies + + Returns: + AbstractOperator: The created operator instance + """ + return OperatorFactory._create_single_operator(context.runtime_config.operator_type, context) + + @staticmethod + def create_operators(context: OperatorContext) -> dict[str, AbstractOperator]: + """Create multiple operator instances based on operator_types list in runtime config. + + Iterates over runtime_config.operator_types and creates an operator for each type. + The returned dict is keyed by the normalized operator type string. + + Args: + context: OperatorContext containing all necessary dependencies + + Returns: + dict[str, AbstractOperator]: Mapping from operator type to operator instance + + Raises: + ValueError: If operator_types is empty or any type is unsupported + """ + operator_types = context.runtime_config.operator_types + if not operator_types: + raise ValueError("operator_types list is empty, at least one operator type must be configured") + + operators: dict[str, AbstractOperator] = {} + for operator_type in operator_types: + normalized_type = operator_type.lower() + if normalized_type in operators: + logger.warning(f"Duplicate operator type '{normalized_type}' in config, skipping") + continue + operator = OperatorFactory._create_single_operator(normalized_type, context) + operators[normalized_type] = operator + logger.info(f"Created operator for type '{normalized_type}'") + + logger.info(f"Initialized {len(operators)} operator(s): {list(operators.keys())}") + return operators + + @staticmethod + def create_composite_operator(context: OperatorContext) -> CompositeOperator: + """Create a CompositeOperator that wraps multiple sub-operators. + + This is the recommended entry point for multi-operator setups. It reads + operator_types from the runtime config, creates each sub-operator, and + wraps them in a CompositeOperator that implements AbstractOperator. + + The returned CompositeOperator can be passed directly to SandboxManager + as a single operator — no changes to SandboxManager are needed. + + Args: + context: OperatorContext containing all necessary dependencies + + Returns: + CompositeOperator: A composite operator wrapping all configured sub-operators + """ + operators = OperatorFactory.create_operators(context) + default_type = context.runtime_config.operator_type.lower() + return CompositeOperator(operators=operators, default_operator_type=default_type) diff --git a/rock/sdk/sandbox/client.py b/rock/sdk/sandbox/client.py index ed85f7b52..a12eea5cf 100644 --- a/rock/sdk/sandbox/client.py +++ b/rock/sdk/sandbox/client.py @@ -174,6 +174,7 @@ async def start(self): "registry_username": self.config.registry_username, "registry_password": self.config.registry_password, "use_kata_runtime": self.config.use_kata_runtime, + "operator_type": self.config.operator_type, } try: response = await HttpUtils.post(url, headers, data) diff --git a/rock/sdk/sandbox/config.py b/rock/sdk/sandbox/config.py index 46d97dfcc..7e31cc33a 100644 --- a/rock/sdk/sandbox/config.py +++ b/rock/sdk/sandbox/config.py @@ -42,6 +42,8 @@ class SandboxConfig(BaseConfig): registry_username: str | None = None registry_password: str | None = None use_kata_runtime: bool = False + operator_type: str = "ray" + """The operator type to use for this sandbox (e.g., 'ray', 'k8s'). If None, uses the server default.""" class SandboxGroupConfig(SandboxConfig): diff --git a/tests/unit/test_composite_operator.py b/tests/unit/test_composite_operator.py new file mode 100644 index 000000000..f6d95bdaf --- /dev/null +++ b/tests/unit/test_composite_operator.py @@ -0,0 +1,682 @@ +"""Unit tests for CompositeOperator and multi-operator support. + +Covers: +- RuntimeConfig operator_types backward compatibility +- CompositeOperator routing logic (submit / get_status / stop) +- CompositeOperator.set_redis_provider propagation to sub-operators +- get_status does NOT overwrite operator_type in Redis (critical) +- OperatorFactory.create_composite_operator +- SandboxStartRequest / DockerDeploymentConfig operator_type field +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fakeredis import aioredis + +from rock.actions.sandbox.response import State +from rock.actions.sandbox.sandbox_info import SandboxInfo +from rock.admin.core.redis_key import alive_sandbox_key +from rock.admin.proto.request import SandboxStartRequest +from rock.config import RuntimeConfig +from rock.deployments.config import DockerDeploymentConfig +from rock.sandbox.operator.abstract import AbstractOperator +from rock.sandbox.operator.composite import CompositeOperator +from rock.utils.providers.redis_provider import RedisProvider + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_mock_operator(operator_name: str = "mock") -> AbstractOperator: + """Create a mock AbstractOperator with async methods.""" + operator = AsyncMock(spec=AbstractOperator) + operator._operator_name = operator_name + operator.set_redis_provider = MagicMock() + operator.set_nacos_provider = MagicMock() + return operator + + +def _make_sandbox_info(**overrides) -> SandboxInfo: + """Create a minimal SandboxInfo dict for testing.""" + info: SandboxInfo = { + "sandbox_id": "test-sandbox-001", + "host_ip": "10.0.0.1", + "host_name": "test-host", + "image": "python:3.11", + "state": State.PENDING, + "cpus": 2, + "memory": "8g", + "phases": {}, + "port_mapping": {}, + } + info.update(overrides) + return info + + +@pytest.fixture +async def fake_redis_provider(): + """Create a RedisProvider backed by fakeredis.""" + provider = RedisProvider(host=None, port=None, password="") + provider.client = aioredis.FakeRedis(decode_responses=True) + yield provider + await provider.close_pool() + + +# =========================================================================== +# 1. RuntimeConfig operator_types backward compatibility +# =========================================================================== + + +class TestRuntimeConfigOperatorTypes: + """Test RuntimeConfig.operator_types field and backward compatibility.""" + + def test_default_operator_types_from_operator_type(self): + """When operator_types is empty, it should be populated from operator_type.""" + config = RuntimeConfig(operator_type="ray") + assert config.operator_types == ["ray"] + assert config.operator_type == "ray" + + def test_explicit_operator_types_list(self): + """When operator_types is explicitly set, it should be used as-is.""" + config = RuntimeConfig(operator_types=["ray", "k8s"]) + assert config.operator_types == ["ray", "k8s"] + # operator_type should be set to the first in the list + assert config.operator_type == "ray" + + def test_operator_types_overrides_operator_type(self): + """operator_types takes precedence; operator_type is synced to the first element.""" + config = RuntimeConfig(operator_type="ray", operator_types=["k8s", "ray"]) + assert config.operator_types == ["k8s", "ray"] + assert config.operator_type == "k8s" + + def test_single_operator_types(self): + """Single-element operator_types should work like the old operator_type.""" + config = RuntimeConfig(operator_types=["k8s"]) + assert config.operator_types == ["k8s"] + assert config.operator_type == "k8s" + + +# =========================================================================== +# 2. CompositeOperator initialization +# =========================================================================== + + +class TestCompositeOperatorInit: + """Test CompositeOperator construction and validation.""" + + def test_init_with_valid_operators(self): + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + assert composite._default_operator_type == "ray" + assert len(composite._operators) == 2 + + def test_init_with_empty_operators_raises(self): + with pytest.raises(ValueError, match="At least one operator"): + CompositeOperator(operators={}, default_operator_type="ray") + + def test_init_with_invalid_default_raises(self): + ray_op = _make_mock_operator("ray") + with pytest.raises(ValueError, match="not found in provided operators"): + CompositeOperator(operators={"ray": ray_op}, default_operator_type="k8s") + + def test_init_normalizes_default_type(self): + ray_op = _make_mock_operator("ray") + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="RAY", + ) + assert composite._default_operator_type == "ray" + + +# =========================================================================== +# 3. CompositeOperator.set_redis_provider propagation +# =========================================================================== + + +class TestCompositeOperatorRedisProviderPropagation: + """Test that set_redis_provider propagates to all sub-operators.""" + + def test_set_redis_provider_propagates_to_all(self): + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + + mock_redis = MagicMock(spec=RedisProvider) + composite.set_redis_provider(mock_redis) + + ray_op.set_redis_provider.assert_called_once_with(mock_redis) + k8s_op.set_redis_provider.assert_called_once_with(mock_redis) + assert composite._redis_provider is mock_redis + + +# =========================================================================== +# 4. CompositeOperator.submit routing +# =========================================================================== + + +class TestCompositeOperatorSubmit: + """Test submit() routes to the correct sub-operator.""" + + @pytest.mark.asyncio + async def test_submit_routes_to_specified_operator(self): + """When config.operator_type is set, submit should route to that operator.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + + ray_op.submit.return_value = _make_sandbox_info(sandbox_id="ray-sandbox") + k8s_op.submit.return_value = _make_sandbox_info(sandbox_id="k8s-sandbox") + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test-k8s", + operator_type="k8s", + ) + result = await composite.submit(config, {"user_id": "u1"}) + + k8s_op.submit.assert_awaited_once_with(config, {"user_id": "u1"}) + ray_op.submit.assert_not_awaited() + assert result["operator_type"] == "k8s" + + @pytest.mark.asyncio + async def test_submit_uses_default_when_no_operator_type(self): + """When config.operator_type is None, submit should use the default operator.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + + ray_op.submit.return_value = _make_sandbox_info(sandbox_id="ray-sandbox") + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test-default", + operator_type=None, + ) + result = await composite.submit(config, {}) + + ray_op.submit.assert_awaited_once() + k8s_op.submit.assert_not_awaited() + assert result["operator_type"] == "ray" + + @pytest.mark.asyncio + async def test_submit_sets_operator_type_in_sandbox_info(self): + """submit() must write operator_type into the returned SandboxInfo.""" + ray_op = _make_mock_operator("ray") + ray_op.submit.return_value = _make_sandbox_info() + + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + + config = DockerDeploymentConfig(image="python:3.11", container_name="test") + result = await composite.submit(config, {}) + + assert "operator_type" in result + assert result["operator_type"] == "ray" + + @pytest.mark.asyncio + async def test_submit_with_unsupported_operator_type_raises(self): + """submit() should raise ValueError for unsupported operator_type.""" + ray_op = _make_mock_operator("ray") + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test", + operator_type="docker_swarm", + ) + with pytest.raises(ValueError, match="Unsupported operator type"): + await composite.submit(config, {}) + + +# =========================================================================== +# 5. CompositeOperator.get_status routing via Redis +# =========================================================================== + + +class TestCompositeOperatorGetStatus: + """Test get_status() routes based on operator_type stored in Redis.""" + + @pytest.mark.asyncio + async def test_get_status_routes_by_redis_operator_type(self, fake_redis_provider): + """get_status should look up operator_type from Redis and route accordingly.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + + k8s_status = _make_sandbox_info(sandbox_id="sandbox-1", state=State.RUNNING) + k8s_op.get_status.return_value = k8s_status + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) + + # Pre-populate Redis with sandbox info that has operator_type="k8s" + sandbox_info_in_redis = _make_sandbox_info(sandbox_id="sandbox-1", operator_type="k8s") + await fake_redis_provider.json_set(alive_sandbox_key("sandbox-1"), "$", sandbox_info_in_redis) + + await composite.get_status("sandbox-1") + + k8s_op.get_status.assert_awaited_once_with("sandbox-1") + ray_op.get_status.assert_not_awaited() + + @pytest.mark.asyncio + async def test_get_status_falls_back_to_default_without_redis(self): + """Without Redis, get_status should fall back to the default operator.""" + ray_op = _make_mock_operator("ray") + ray_op.get_status.return_value = _make_sandbox_info(state=State.RUNNING) + + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + # No redis provider set + + await composite.get_status("sandbox-no-redis") + ray_op.get_status.assert_awaited_once_with("sandbox-no-redis") + + @pytest.mark.asyncio + async def test_get_status_falls_back_when_redis_has_no_operator_type(self, fake_redis_provider): + """If Redis entry has no operator_type, fall back to default.""" + ray_op = _make_mock_operator("ray") + ray_op.get_status.return_value = _make_sandbox_info(state=State.RUNNING) + + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) + + # Store sandbox info WITHOUT operator_type + sandbox_info_no_type = _make_sandbox_info(sandbox_id="sandbox-2") + await fake_redis_provider.json_set(alive_sandbox_key("sandbox-2"), "$", sandbox_info_no_type) + + await composite.get_status("sandbox-2") + ray_op.get_status.assert_awaited_once_with("sandbox-2") + + +# =========================================================================== +# 6. CompositeOperator.stop routing via Redis +# =========================================================================== + + +class TestCompositeOperatorStop: + """Test stop() routes based on operator_type stored in Redis.""" + + @pytest.mark.asyncio + async def test_stop_routes_by_redis_operator_type(self, fake_redis_provider): + """stop should look up operator_type from Redis and route accordingly.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + k8s_op.stop.return_value = True + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) + + sandbox_info_in_redis = _make_sandbox_info(sandbox_id="sandbox-stop", operator_type="k8s") + await fake_redis_provider.json_set(alive_sandbox_key("sandbox-stop"), "$", sandbox_info_in_redis) + + result = await composite.stop("sandbox-stop") + + k8s_op.stop.assert_awaited_once_with("sandbox-stop") + ray_op.stop.assert_not_awaited() + assert result is True + + @pytest.mark.asyncio + async def test_stop_falls_back_to_default_without_redis(self): + """Without Redis, stop should fall back to the default operator.""" + ray_op = _make_mock_operator("ray") + ray_op.stop.return_value = True + + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + + await composite.stop("sandbox-no-redis") + ray_op.stop.assert_awaited_once_with("sandbox-no-redis") + + +# =========================================================================== +# 7. CRITICAL: get_status does NOT overwrite operator_type in Redis +# =========================================================================== + + +class TestGetStatusPreservesOperatorTypeInRedis: + """Critical tests: verify that the full start_async -> get_status flow + does NOT lose or overwrite the operator_type field in Redis. + + This simulates the SandboxManager flow: + 1. CompositeOperator.submit() sets operator_type in SandboxInfo + 2. SandboxManager.start_async() writes SandboxInfo to Redis + 3. SandboxManager.get_status() calls operator.get_status() and writes back + 4. operator_type must still be present in Redis after step 3 + """ + + @pytest.mark.asyncio + async def test_operator_type_survives_submit_and_get_status_cycle(self, fake_redis_provider): + """Full cycle: submit writes operator_type, get_status preserves it.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + + # submit returns sandbox info (CompositeOperator will add operator_type) + k8s_op.submit.return_value = _make_sandbox_info(sandbox_id="cycle-test") + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) + + # Step 1: submit (simulates CompositeOperator.submit) + config = DockerDeploymentConfig( + image="python:3.11", + container_name="cycle-test", + operator_type="k8s", + ) + sandbox_info = await composite.submit(config, {}) + assert sandbox_info["operator_type"] == "k8s" + + # Step 2: simulate SandboxManager writing to Redis + await fake_redis_provider.json_set(alive_sandbox_key("cycle-test"), "$", sandbox_info) + + # Verify operator_type is in Redis + redis_data = await fake_redis_provider.json_get(alive_sandbox_key("cycle-test"), "$") + assert redis_data[0]["operator_type"] == "k8s" + + # Step 3: get_status returns info WITHOUT operator_type (like real sub-operators do) + k8s_op.get_status.return_value = _make_sandbox_info( + sandbox_id="cycle-test", + state=State.RUNNING, + # Note: no operator_type here, simulating real K8sOperator/RayOperator behavior + ) + status_info = await composite.get_status("cycle-test") + + # Step 4: simulate SandboxManager writing get_status result back to Redis + # This is what SandboxManager.get_status() does: + # sandbox_info = await self._operator.get_status(sandbox_id) + # await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) + await fake_redis_provider.json_set(alive_sandbox_key("cycle-test"), "$", status_info) + + # CRITICAL CHECK: operator_type must still be in Redis + # The sub-operator's get_status doesn't return operator_type, + # so if the full sandbox_info is overwritten, operator_type would be lost. + redis_data_after = await fake_redis_provider.json_get(alive_sandbox_key("cycle-test"), "$") + # This test verifies the CURRENT behavior. If operator_type is missing here, + # it means get_status overwrites it and we have a bug. + # + # In the current design, sub-operators (RayOperator, K8sOperator) merge + # redis_info with fresh status via redis_info.update(sandbox_info), which + # preserves operator_type because the fresh status dict doesn't contain it. + # However, SandboxManager.get_status() does a full json_set with the + # returned sandbox_info. If the sub-operator returns a dict without + # operator_type, it WILL be lost. + # + # Let's check what actually happens: + has_operator_type = "operator_type" in redis_data_after[0] + + if not has_operator_type: + # This means the current flow DOES lose operator_type. + # We need to verify this is the case and document it. + pytest.fail( + "operator_type was lost from Redis after get_status! " + "The sub-operator's get_status() did not include operator_type, " + "and SandboxManager overwrote Redis with the incomplete data." + ) + + @pytest.mark.asyncio + async def test_ray_operator_get_status_preserves_operator_type_via_redis_merge(self, fake_redis_provider): + """Simulate RayOperator.get_status() redis merge path. + + RayOperator.get_status() (non-rocklet path) does: + redis_info = await self.get_sandbox_info_from_redis(sandbox_id) + if redis_info: + redis_info.update(sandbox_info) # sandbox_info has no operator_type + return redis_info # redis_info still has operator_type + + This test verifies that the merge preserves operator_type. + """ + # Simulate what's in Redis (with operator_type) + redis_sandbox_info = _make_sandbox_info( + sandbox_id="ray-merge-test", + operator_type="ray", + user_id="user-1", + experiment_id="exp-1", + ) + + # Simulate what RayOperator gets from the actor (no operator_type) + actor_sandbox_info = _make_sandbox_info( + sandbox_id="ray-merge-test", + state=State.RUNNING, + ) + # Actor info typically doesn't have operator_type + assert "operator_type" not in actor_sandbox_info + + # Simulate the merge: redis_info.update(sandbox_info) + redis_sandbox_info.update(actor_sandbox_info) + + # operator_type should survive because actor_sandbox_info doesn't have it + assert redis_sandbox_info.get("operator_type") == "ray" + assert redis_sandbox_info.get("state") == State.RUNNING + + @pytest.mark.asyncio + async def test_k8s_operator_get_status_preserves_operator_type_via_redis_merge(self, fake_redis_provider): + """Simulate K8sOperator.get_status() redis merge path. + + K8sOperator.get_status() does: + sandbox_info = await self._provider.get_status(sandbox_id) + if self._redis_provider: + redis_info = await self._get_sandbox_info_from_redis(sandbox_id) + if redis_info: + redis_info.update(sandbox_info) + return redis_info + + This test verifies that the merge preserves operator_type. + """ + # Simulate what's in Redis (with operator_type) + redis_sandbox_info = _make_sandbox_info( + sandbox_id="k8s-merge-test", + operator_type="k8s", + user_id="user-1", + ) + + # Simulate what K8s provider returns (no operator_type) + provider_sandbox_info: SandboxInfo = { + "sandbox_id": "k8s-merge-test", + "host_ip": "10.0.0.2", + "state": State.RUNNING, + "phases": {}, + "port_mapping": {8000: 30001}, + } + assert "operator_type" not in provider_sandbox_info + + # Simulate the merge: redis_info.update(sandbox_info) + redis_sandbox_info.update(provider_sandbox_info) + + # operator_type should survive + assert redis_sandbox_info.get("operator_type") == "k8s" + assert redis_sandbox_info.get("state") == State.RUNNING + assert redis_sandbox_info.get("host_ip") == "10.0.0.2" + + @pytest.mark.asyncio + async def test_sandbox_manager_get_status_preserves_operator_type(self, fake_redis_provider): + """End-to-end: SandboxManager.get_status() must preserve operator_type in Redis. + + SandboxManager.get_status() does: + sandbox_info = await self._operator.get_status(sandbox_id) + await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) + + If the operator returns sandbox_info WITH operator_type (because sub-operators + merge from Redis), then the json_set will preserve it. + """ + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) + + # Pre-populate Redis with sandbox info including operator_type + initial_info = _make_sandbox_info( + sandbox_id="e2e-test", + operator_type="k8s", + user_id="user-1", + ) + await fake_redis_provider.json_set(alive_sandbox_key("e2e-test"), "$", initial_info) + + # K8sOperator.get_status() merges redis_info with provider status, + # so the returned dict should still contain operator_type + merged_status = _make_sandbox_info( + sandbox_id="e2e-test", + operator_type="k8s", # preserved from Redis merge + user_id="user-1", + state=State.RUNNING, + host_ip="10.0.0.5", + ) + k8s_op.get_status.return_value = merged_status + + # CompositeOperator.get_status routes to k8s_op + result = await composite.get_status("e2e-test") + + # Simulate SandboxManager writing back to Redis + await fake_redis_provider.json_set(alive_sandbox_key("e2e-test"), "$", result) + + # Verify operator_type is preserved + final_redis = await fake_redis_provider.json_get(alive_sandbox_key("e2e-test"), "$") + assert final_redis[0]["operator_type"] == "k8s" + assert final_redis[0]["state"] == State.RUNNING + + +# =========================================================================== +# 8. OperatorFactory.create_composite_operator +# =========================================================================== + + +class TestOperatorFactoryCreateComposite: + """Test OperatorFactory.create_composite_operator method.""" + + def test_create_composite_operator_single_type(self): + """create_composite_operator with a single operator type.""" + from rock.sandbox.operator.factory import OperatorContext, OperatorFactory + + runtime_config = RuntimeConfig(operator_types=["ray"]) + ray_service = MagicMock() + + context = OperatorContext( + runtime_config=runtime_config, + ray_service=ray_service, + ) + + with patch("rock.sandbox.operator.factory.OperatorFactory._create_single_operator") as mock_create: + mock_ray_op = _make_mock_operator("ray") + mock_create.return_value = mock_ray_op + + composite = OperatorFactory.create_composite_operator(context) + + assert isinstance(composite, CompositeOperator) + assert composite._default_operator_type == "ray" + assert "ray" in composite._operators + + def test_create_composite_operator_multiple_types(self): + """create_composite_operator with multiple operator types.""" + from rock.sandbox.operator.factory import OperatorContext, OperatorFactory + + runtime_config = RuntimeConfig(operator_types=["ray", "k8s"]) + ray_service = MagicMock() + + context = OperatorContext( + runtime_config=runtime_config, + ray_service=ray_service, + ) + + call_count = 0 + + def side_effect(op_type, ctx): + nonlocal call_count + call_count += 1 + return _make_mock_operator(op_type) + + with patch( + "rock.sandbox.operator.factory.OperatorFactory._create_single_operator", + side_effect=side_effect, + ): + composite = OperatorFactory.create_composite_operator(context) + + assert isinstance(composite, CompositeOperator) + assert composite._default_operator_type == "ray" + assert "ray" in composite._operators + assert "k8s" in composite._operators + assert call_count == 2 + + +# =========================================================================== +# 9. SandboxStartRequest and DockerDeploymentConfig operator_type field +# =========================================================================== + + +class TestOperatorTypeFieldInModels: + """Test that operator_type field exists and works in request/config models.""" + + def test_sandbox_start_request_has_operator_type(self): + request = SandboxStartRequest( + image="python:3.11", + operator_type="k8s", + ) + assert request.operator_type == "k8s" + + def test_sandbox_start_request_operator_type_default_none(self): + request = SandboxStartRequest(image="python:3.11") + assert request.operator_type is None + + def test_docker_deployment_config_has_operator_type(self): + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test", + operator_type="ray", + ) + assert config.operator_type == "ray" + + def test_docker_deployment_config_operator_type_default_none(self): + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test", + ) + assert config.operator_type is None + + def test_docker_deployment_config_from_request_preserves_operator_type(self): + """DockerDeploymentConfig.from_request should carry over operator_type.""" + request = SandboxStartRequest( + image="python:3.11", + sandbox_id="test-sandbox", + operator_type="k8s", + ) + config = DockerDeploymentConfig.from_request(request) + assert config.operator_type == "k8s" + assert config.container_name == "test-sandbox" diff --git a/tests/unit/test_sdk_operator_type.py b/tests/unit/test_sdk_operator_type.py new file mode 100644 index 000000000..42c5d1658 --- /dev/null +++ b/tests/unit/test_sdk_operator_type.py @@ -0,0 +1,276 @@ +"""Unit tests for SDK operator_type support. + +Covers: +- SandboxConfig.operator_type field (default, explicit value) +- Sandbox.__init__ stores operator_type from config +- Sandbox.start() includes operator_type in the request payload +- SandboxGroup propagates operator_type to all child sandboxes +""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from rock.sdk.sandbox.client import Sandbox, SandboxGroup +from rock.sdk.sandbox.config import SandboxConfig, SandboxGroupConfig + +# =========================================================================== +# 1. SandboxConfig operator_type field +# =========================================================================== + + +class TestSandboxConfigOperatorType: + """Test SandboxConfig.operator_type field behavior.""" + + def test_operator_type_default_is_none(self): + """operator_type should default to None when not specified.""" + config = SandboxConfig() + assert config.operator_type is None + + def test_operator_type_set_explicitly(self): + """operator_type should be stored when explicitly set.""" + config = SandboxConfig(operator_type="k8s") + assert config.operator_type == "k8s" + + def test_operator_type_ray(self): + """operator_type should accept 'ray' value.""" + config = SandboxConfig(operator_type="ray") + assert config.operator_type == "ray" + + def test_operator_type_with_other_fields(self): + """operator_type should coexist with other config fields.""" + config = SandboxConfig( + image="ubuntu:22.04", + memory="16g", + cpus=4, + cluster="us-east", + operator_type="k8s", + ) + assert config.operator_type == "k8s" + assert config.image == "ubuntu:22.04" + assert config.memory == "16g" + assert config.cpus == 4 + assert config.cluster == "us-east" + + def test_operator_type_serialization(self): + """operator_type should appear in model_dump output.""" + config = SandboxConfig(operator_type="ray") + dumped = config.model_dump() + assert "operator_type" in dumped + assert dumped["operator_type"] == "ray" + + def test_operator_type_none_serialization(self): + """operator_type=None should appear in model_dump output.""" + config = SandboxConfig() + dumped = config.model_dump() + assert "operator_type" in dumped + assert dumped["operator_type"] is None + + +# =========================================================================== +# 2. Sandbox.__init__ with operator_type +# =========================================================================== + + +class TestSandboxInitOperatorType: + """Test that Sandbox stores operator_type from SandboxConfig.""" + + def test_sandbox_stores_operator_type_from_config(self): + """Sandbox should store the config with operator_type.""" + config = SandboxConfig(operator_type="k8s") + sandbox = Sandbox(config) + assert sandbox.config.operator_type == "k8s" + + def test_sandbox_stores_none_operator_type(self): + """Sandbox should store None operator_type when not specified.""" + config = SandboxConfig() + sandbox = Sandbox(config) + assert sandbox.config.operator_type is None + + +# =========================================================================== +# 3. Sandbox.start() includes operator_type in request +# =========================================================================== + + +class TestSandboxStartOperatorType: + """Test that Sandbox.start() sends operator_type in the POST payload.""" + + @pytest.mark.asyncio + async def test_start_sends_operator_type_in_payload(self): + """start() should include operator_type in the request data.""" + config = SandboxConfig(operator_type="k8s", startup_timeout=5) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-001", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + # Mock get_status to return alive immediately + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + # Verify the POST was called with operator_type in data + mock_post.assert_called_once() + call_args = mock_post.call_args + posted_data = call_args[0][2] # third positional arg is data + assert "operator_type" in posted_data + assert posted_data["operator_type"] == "k8s" + + @pytest.mark.asyncio + async def test_start_sends_none_operator_type_when_not_set(self): + """start() should send operator_type=None when not configured.""" + config = SandboxConfig(startup_timeout=5) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-002", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + call_args = mock_post.call_args + posted_data = call_args[0][2] + assert "operator_type" in posted_data + assert posted_data["operator_type"] is None + + @pytest.mark.asyncio + async def test_start_sends_ray_operator_type(self): + """start() should correctly send operator_type='ray'.""" + config = SandboxConfig(operator_type="ray", startup_timeout=5) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-003", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + call_args = mock_post.call_args + posted_data = call_args[0][2] + assert posted_data["operator_type"] == "ray" + + @pytest.mark.asyncio + async def test_start_payload_contains_all_expected_fields(self): + """start() payload should contain operator_type alongside all other fields.""" + config = SandboxConfig( + image="ubuntu:22.04", + memory="16g", + cpus=4, + operator_type="k8s", + startup_timeout=5, + ) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-004", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + call_args = mock_post.call_args + posted_data = call_args[0][2] + + # Verify all expected fields are present + assert posted_data["image"] == "ubuntu:22.04" + assert posted_data["memory"] == "16g" + assert posted_data["cpus"] == 4 + assert posted_data["operator_type"] == "k8s" + assert "use_kata_runtime" in posted_data + assert "registry_username" in posted_data + assert "registry_password" in posted_data + + +# =========================================================================== +# 4. SandboxGroup propagates operator_type +# =========================================================================== + + +class TestSandboxGroupOperatorType: + """Test that SandboxGroup propagates operator_type to child sandboxes.""" + + def test_group_propagates_operator_type_to_children(self): + """All sandboxes in a group should inherit operator_type from config.""" + config = SandboxGroupConfig( + size=3, + operator_type="k8s", + ) + group = SandboxGroup(config) + + assert len(group.sandbox_list) == 3 + for sandbox in group.sandbox_list: + assert sandbox.config.operator_type == "k8s" + + def test_group_propagates_none_operator_type(self): + """All sandboxes in a group should have None operator_type when not set.""" + config = SandboxGroupConfig(size=2) + group = SandboxGroup(config) + + for sandbox in group.sandbox_list: + assert sandbox.config.operator_type is None + + +# =========================================================================== +# 5. SandboxGroupConfig inherits operator_type +# =========================================================================== + + +class TestSandboxGroupConfigOperatorType: + """Test SandboxGroupConfig inherits operator_type from SandboxConfig.""" + + def test_group_config_has_operator_type(self): + """SandboxGroupConfig should support operator_type field.""" + config = SandboxGroupConfig(operator_type="ray", size=2) + assert config.operator_type == "ray" + + def test_group_config_operator_type_default_none(self): + """SandboxGroupConfig.operator_type should default to None.""" + config = SandboxGroupConfig(size=2) + assert config.operator_type is None From c478f013b4829dc64bdb79b12c550e7817bb09f7 Mon Sep 17 00:00:00 2001 From: daifangwen Date: Wed, 4 Mar 2026 11:47:45 +0000 Subject: [PATCH 2/3] fix ut --- tests/unit/test_sdk_operator_type.py | 36 ++++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/unit/test_sdk_operator_type.py b/tests/unit/test_sdk_operator_type.py index 42c5d1658..359ed4118 100644 --- a/tests/unit/test_sdk_operator_type.py +++ b/tests/unit/test_sdk_operator_type.py @@ -22,10 +22,10 @@ class TestSandboxConfigOperatorType: """Test SandboxConfig.operator_type field behavior.""" - def test_operator_type_default_is_none(self): - """operator_type should default to None when not specified.""" + def test_operator_type_default_is_ray(self): + """operator_type should default to 'ray' when not specified.""" config = SandboxConfig() - assert config.operator_type is None + assert config.operator_type == "ray" def test_operator_type_set_explicitly(self): """operator_type should be stored when explicitly set.""" @@ -59,12 +59,12 @@ def test_operator_type_serialization(self): assert "operator_type" in dumped assert dumped["operator_type"] == "ray" - def test_operator_type_none_serialization(self): - """operator_type=None should appear in model_dump output.""" + def test_operator_type_default_serialization(self): + """Default operator_type='ray' should appear in model_dump output.""" config = SandboxConfig() dumped = config.model_dump() assert "operator_type" in dumped - assert dumped["operator_type"] is None + assert dumped["operator_type"] == "ray" # =========================================================================== @@ -81,11 +81,11 @@ def test_sandbox_stores_operator_type_from_config(self): sandbox = Sandbox(config) assert sandbox.config.operator_type == "k8s" - def test_sandbox_stores_none_operator_type(self): - """Sandbox should store None operator_type when not specified.""" + def test_sandbox_stores_default_operator_type(self): + """Sandbox should store default operator_type='ray' when not specified.""" config = SandboxConfig() sandbox = Sandbox(config) - assert sandbox.config.operator_type is None + assert sandbox.config.operator_type == "ray" # =========================================================================== @@ -129,8 +129,8 @@ async def test_start_sends_operator_type_in_payload(self): assert posted_data["operator_type"] == "k8s" @pytest.mark.asyncio - async def test_start_sends_none_operator_type_when_not_set(self): - """start() should send operator_type=None when not configured.""" + async def test_start_sends_default_operator_type_when_not_set(self): + """start() should send operator_type='ray' when using default config.""" config = SandboxConfig(startup_timeout=5) sandbox = Sandbox(config) @@ -155,7 +155,7 @@ async def test_start_sends_none_operator_type_when_not_set(self): call_args = mock_post.call_args posted_data = call_args[0][2] assert "operator_type" in posted_data - assert posted_data["operator_type"] is None + assert posted_data["operator_type"] == "ray" @pytest.mark.asyncio async def test_start_sends_ray_operator_type(self): @@ -248,13 +248,13 @@ def test_group_propagates_operator_type_to_children(self): for sandbox in group.sandbox_list: assert sandbox.config.operator_type == "k8s" - def test_group_propagates_none_operator_type(self): - """All sandboxes in a group should have None operator_type when not set.""" + def test_group_propagates_default_operator_type(self): + """All sandboxes in a group should have default operator_type='ray' when not set.""" config = SandboxGroupConfig(size=2) group = SandboxGroup(config) for sandbox in group.sandbox_list: - assert sandbox.config.operator_type is None + assert sandbox.config.operator_type == "ray" # =========================================================================== @@ -270,7 +270,7 @@ def test_group_config_has_operator_type(self): config = SandboxGroupConfig(operator_type="ray", size=2) assert config.operator_type == "ray" - def test_group_config_operator_type_default_none(self): - """SandboxGroupConfig.operator_type should default to None.""" + def test_group_config_operator_type_default_ray(self): + """SandboxGroupConfig.operator_type should default to 'ray'.""" config = SandboxGroupConfig(size=2) - assert config.operator_type is None + assert config.operator_type == "ray" From 9181ce19bae223a74c5417725ae2e9112f0b3e6a Mon Sep 17 00:00:00 2001 From: daifangwen Date: Thu, 5 Mar 2026 07:01:21 +0000 Subject: [PATCH 3/3] optimize ut --- tests/unit/sdk/sandbox/test_sandbox_config.py | 63 ++ tests/unit/sdk/sandbox/test_sandbox_sdk.py | 171 ++++ tests/unit/test_composite_operator.py | 956 ++++++++---------- tests/unit/test_sdk_operator_type.py | 276 ----- 4 files changed, 664 insertions(+), 802 deletions(-) create mode 100644 tests/unit/sdk/sandbox/test_sandbox_config.py create mode 100644 tests/unit/sdk/sandbox/test_sandbox_sdk.py delete mode 100644 tests/unit/test_sdk_operator_type.py diff --git a/tests/unit/sdk/sandbox/test_sandbox_config.py b/tests/unit/sdk/sandbox/test_sandbox_config.py new file mode 100644 index 000000000..a8911ebfd --- /dev/null +++ b/tests/unit/sdk/sandbox/test_sandbox_config.py @@ -0,0 +1,63 @@ +from rock.sdk.sandbox.config import SandboxConfig, SandboxGroupConfig + + +def test_sandbox_config_operator_type_default_is_ray(): + """operator_type should default to 'ray' when not specified.""" + config = SandboxConfig() + assert config.operator_type == "ray" + + +def test_sandbox_config_operator_type_set_explicitly(): + """operator_type should be stored when explicitly set.""" + config = SandboxConfig(operator_type="k8s") + assert config.operator_type == "k8s" + + +def test_sandbox_config_operator_type_ray(): + """operator_type should accept 'ray' value.""" + config = SandboxConfig(operator_type="ray") + assert config.operator_type == "ray" + + +def test_sandbox_config_operator_type_with_other_fields(): + """operator_type should coexist with other config fields.""" + config = SandboxConfig( + image="ubuntu:22.04", + memory="16g", + cpus=4, + cluster="us-east", + operator_type="k8s", + ) + assert config.operator_type == "k8s" + assert config.image == "ubuntu:22.04" + assert config.memory == "16g" + assert config.cpus == 4 + assert config.cluster == "us-east" + + +def test_sandbox_config_operator_type_serialization(): + """operator_type should appear in model_dump output.""" + config = SandboxConfig(operator_type="ray") + dumped = config.model_dump() + assert "operator_type" in dumped + assert dumped["operator_type"] == "ray" + + +def test_sandbox_config_operator_type_default_serialization(): + """Default operator_type='ray' should appear in model_dump output.""" + config = SandboxConfig() + dumped = config.model_dump() + assert "operator_type" in dumped + assert dumped["operator_type"] == "ray" + + +def test_sandbox_group_config_has_operator_type(): + """SandboxGroupConfig should support operator_type field.""" + config = SandboxGroupConfig(operator_type="ray", size=2) + assert config.operator_type == "ray" + + +def test_sandbox_group_config_operator_type_default_ray(): + """SandboxGroupConfig.operator_type should default to 'ray'.""" + config = SandboxGroupConfig(size=2) + assert config.operator_type == "ray" diff --git a/tests/unit/sdk/sandbox/test_sandbox_sdk.py b/tests/unit/sdk/sandbox/test_sandbox_sdk.py new file mode 100644 index 000000000..434138f6e --- /dev/null +++ b/tests/unit/sdk/sandbox/test_sandbox_sdk.py @@ -0,0 +1,171 @@ +from unittest.mock import AsyncMock, patch + +import pytest + +from rock.sdk.sandbox.client import Sandbox, SandboxGroup +from rock.sdk.sandbox.config import SandboxConfig, SandboxGroupConfig + + +def test_sandbox_stores_operator_type_from_config(): + """Sandbox should store the config with operator_type.""" + config = SandboxConfig(operator_type="k8s") + sandbox = Sandbox(config) + assert sandbox.config.operator_type == "k8s" + + +def test_sandbox_stores_default_operator_type(): + """Sandbox should store default operator_type='ray' when not specified.""" + config = SandboxConfig() + sandbox = Sandbox(config) + assert sandbox.config.operator_type == "ray" + + +@pytest.mark.asyncio +async def test_start_sends_operator_type_in_payload(): + """start() should include operator_type in the request data.""" + config = SandboxConfig(operator_type="k8s", startup_timeout=5) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-001", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + mock_post.assert_called_once() + call_args = mock_post.call_args + posted_data = call_args[0][2] + assert "operator_type" in posted_data + assert posted_data["operator_type"] == "k8s" + + +@pytest.mark.asyncio +async def test_start_sends_default_operator_type_when_not_set(): + """start() should send operator_type='ray' when using default config.""" + config = SandboxConfig(startup_timeout=5) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-002", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + call_args = mock_post.call_args + posted_data = call_args[0][2] + assert "operator_type" in posted_data + assert posted_data["operator_type"] == "ray" + + +@pytest.mark.asyncio +async def test_start_sends_ray_operator_type(): + """start() should correctly send operator_type='ray'.""" + config = SandboxConfig(operator_type="ray", startup_timeout=5) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-003", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + call_args = mock_post.call_args + posted_data = call_args[0][2] + assert posted_data["operator_type"] == "ray" + + +@pytest.mark.asyncio +async def test_start_payload_contains_all_expected_fields(): + """start() payload should contain operator_type alongside all other fields.""" + config = SandboxConfig( + image="ubuntu:22.04", + memory="16g", + cpus=4, + operator_type="k8s", + startup_timeout=5, + ) + sandbox = Sandbox(config) + + mock_response = { + "status": "Success", + "result": { + "sandbox_id": "test-sandbox-004", + "host_name": "test-host", + "host_ip": "10.0.0.1", + }, + } + + mock_status = AsyncMock() + mock_status.is_alive = True + + with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( + sandbox, "get_status", return_value=mock_status + ): + mock_post.return_value = mock_response + await sandbox.start() + + call_args = mock_post.call_args + posted_data = call_args[0][2] + + assert posted_data["image"] == "ubuntu:22.04" + assert posted_data["memory"] == "16g" + assert posted_data["cpus"] == 4 + assert posted_data["operator_type"] == "k8s" + assert "use_kata_runtime" in posted_data + assert "registry_username" in posted_data + assert "registry_password" in posted_data + + +def test_group_propagates_operator_type_to_children(): + """All sandboxes in a group should inherit operator_type from config.""" + config = SandboxGroupConfig(size=3, operator_type="k8s") + group = SandboxGroup(config) + + assert len(group.sandbox_list) == 3 + for sandbox in group.sandbox_list: + assert sandbox.config.operator_type == "k8s" + + +def test_group_propagates_default_operator_type(): + """All sandboxes in a group should have default operator_type='ray' when not set.""" + config = SandboxGroupConfig(size=2) + group = SandboxGroup(config) + + for sandbox in group.sandbox_list: + assert sandbox.config.operator_type == "ray" diff --git a/tests/unit/test_composite_operator.py b/tests/unit/test_composite_operator.py index f6d95bdaf..6636b0513 100644 --- a/tests/unit/test_composite_operator.py +++ b/tests/unit/test_composite_operator.py @@ -24,10 +24,6 @@ from rock.sandbox.operator.composite import CompositeOperator from rock.utils.providers.redis_provider import RedisProvider -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - def _make_mock_operator(operator_name: str = "mock") -> AbstractOperator: """Create a mock AbstractOperator with async methods.""" @@ -64,619 +60,527 @@ async def fake_redis_provider(): await provider.close_pool() -# =========================================================================== -# 1. RuntimeConfig operator_types backward compatibility -# =========================================================================== - - -class TestRuntimeConfigOperatorTypes: - """Test RuntimeConfig.operator_types field and backward compatibility.""" - - def test_default_operator_types_from_operator_type(self): - """When operator_types is empty, it should be populated from operator_type.""" - config = RuntimeConfig(operator_type="ray") - assert config.operator_types == ["ray"] - assert config.operator_type == "ray" - - def test_explicit_operator_types_list(self): - """When operator_types is explicitly set, it should be used as-is.""" - config = RuntimeConfig(operator_types=["ray", "k8s"]) - assert config.operator_types == ["ray", "k8s"] - # operator_type should be set to the first in the list - assert config.operator_type == "ray" - - def test_operator_types_overrides_operator_type(self): - """operator_types takes precedence; operator_type is synced to the first element.""" - config = RuntimeConfig(operator_type="ray", operator_types=["k8s", "ray"]) - assert config.operator_types == ["k8s", "ray"] - assert config.operator_type == "k8s" - - def test_single_operator_types(self): - """Single-element operator_types should work like the old operator_type.""" - config = RuntimeConfig(operator_types=["k8s"]) - assert config.operator_types == ["k8s"] - assert config.operator_type == "k8s" - - -# =========================================================================== -# 2. CompositeOperator initialization -# =========================================================================== - - -class TestCompositeOperatorInit: - """Test CompositeOperator construction and validation.""" - - def test_init_with_valid_operators(self): - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) - assert composite._default_operator_type == "ray" - assert len(composite._operators) == 2 - - def test_init_with_empty_operators_raises(self): - with pytest.raises(ValueError, match="At least one operator"): - CompositeOperator(operators={}, default_operator_type="ray") - - def test_init_with_invalid_default_raises(self): - ray_op = _make_mock_operator("ray") - with pytest.raises(ValueError, match="not found in provided operators"): - CompositeOperator(operators={"ray": ray_op}, default_operator_type="k8s") - - def test_init_normalizes_default_type(self): - ray_op = _make_mock_operator("ray") - composite = CompositeOperator( - operators={"ray": ray_op}, - default_operator_type="RAY", - ) - assert composite._default_operator_type == "ray" - - -# =========================================================================== -# 3. CompositeOperator.set_redis_provider propagation -# =========================================================================== - - -class TestCompositeOperatorRedisProviderPropagation: - """Test that set_redis_provider propagates to all sub-operators.""" - - def test_set_redis_provider_propagates_to_all(self): - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) - - mock_redis = MagicMock(spec=RedisProvider) - composite.set_redis_provider(mock_redis) - - ray_op.set_redis_provider.assert_called_once_with(mock_redis) - k8s_op.set_redis_provider.assert_called_once_with(mock_redis) - assert composite._redis_provider is mock_redis +def test_runtime_config_default_operator_types_from_operator_type(): + """When operator_types is empty, it should be populated from operator_type.""" + config = RuntimeConfig(operator_type="ray") + assert config.operator_types == ["ray"] + assert config.operator_type == "ray" -# =========================================================================== -# 4. CompositeOperator.submit routing -# =========================================================================== +def test_runtime_config_explicit_operator_types_list(): + """When operator_types is explicitly set, it should be used as-is.""" + config = RuntimeConfig(operator_types=["ray", "k8s"]) + assert config.operator_types == ["ray", "k8s"] + assert config.operator_type == "ray" -class TestCompositeOperatorSubmit: - """Test submit() routes to the correct sub-operator.""" +def test_runtime_config_operator_types_overrides_operator_type(): + """operator_types takes precedence; operator_type is synced to the first element.""" + config = RuntimeConfig(operator_type="ray", operator_types=["k8s", "ray"]) + assert config.operator_types == ["k8s", "ray"] + assert config.operator_type == "k8s" - @pytest.mark.asyncio - async def test_submit_routes_to_specified_operator(self): - """When config.operator_type is set, submit should route to that operator.""" - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") - ray_op.submit.return_value = _make_sandbox_info(sandbox_id="ray-sandbox") - k8s_op.submit.return_value = _make_sandbox_info(sandbox_id="k8s-sandbox") +def test_runtime_config_single_operator_types(): + """Single-element operator_types should work like the old operator_type.""" + config = RuntimeConfig(operator_types=["k8s"]) + assert config.operator_types == ["k8s"] + assert config.operator_type == "k8s" - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) - - config = DockerDeploymentConfig( - image="python:3.11", - container_name="test-k8s", - operator_type="k8s", - ) - result = await composite.submit(config, {"user_id": "u1"}) - - k8s_op.submit.assert_awaited_once_with(config, {"user_id": "u1"}) - ray_op.submit.assert_not_awaited() - assert result["operator_type"] == "k8s" - @pytest.mark.asyncio - async def test_submit_uses_default_when_no_operator_type(self): - """When config.operator_type is None, submit should use the default operator.""" - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") +def test_composite_init_with_valid_operators(): + """CompositeOperator should initialize correctly with valid operators.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + assert composite._default_operator_type == "ray" + assert len(composite._operators) == 2 - ray_op.submit.return_value = _make_sandbox_info(sandbox_id="ray-sandbox") - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) +def test_composite_init_with_empty_operators_raises(): + """CompositeOperator should raise ValueError with empty operators dict.""" + with pytest.raises(ValueError, match="At least one operator"): + CompositeOperator(operators={}, default_operator_type="ray") - config = DockerDeploymentConfig( - image="python:3.11", - container_name="test-default", - operator_type=None, - ) - result = await composite.submit(config, {}) - ray_op.submit.assert_awaited_once() - k8s_op.submit.assert_not_awaited() - assert result["operator_type"] == "ray" +def test_composite_init_with_invalid_default_raises(): + """CompositeOperator should raise ValueError when default type is not in operators.""" + ray_op = _make_mock_operator("ray") + with pytest.raises(ValueError, match="not found in provided operators"): + CompositeOperator(operators={"ray": ray_op}, default_operator_type="k8s") - @pytest.mark.asyncio - async def test_submit_sets_operator_type_in_sandbox_info(self): - """submit() must write operator_type into the returned SandboxInfo.""" - ray_op = _make_mock_operator("ray") - ray_op.submit.return_value = _make_sandbox_info() - composite = CompositeOperator( - operators={"ray": ray_op}, - default_operator_type="ray", - ) +def test_composite_init_normalizes_default_type(): + """CompositeOperator should normalize default_operator_type to lowercase.""" + ray_op = _make_mock_operator("ray") + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="RAY", + ) + assert composite._default_operator_type == "ray" + + +def test_set_redis_provider_propagates_to_all(): + """set_redis_provider should propagate to all sub-operators.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + + mock_redis = MagicMock(spec=RedisProvider) + composite.set_redis_provider(mock_redis) + + ray_op.set_redis_provider.assert_called_once_with(mock_redis) + k8s_op.set_redis_provider.assert_called_once_with(mock_redis) + assert composite._redis_provider is mock_redis + + +@pytest.mark.asyncio +async def test_submit_routes_to_specified_operator(): + """When config.operator_type is set, submit should route to that operator.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + + ray_op.submit.return_value = _make_sandbox_info(sandbox_id="ray-sandbox") + k8s_op.submit.return_value = _make_sandbox_info(sandbox_id="k8s-sandbox") - config = DockerDeploymentConfig(image="python:3.11", container_name="test") - result = await composite.submit(config, {}) + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) - assert "operator_type" in result - assert result["operator_type"] == "ray" + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test-k8s", + operator_type="k8s", + ) + result = await composite.submit(config, {"user_id": "u1"}) + + k8s_op.submit.assert_awaited_once_with(config, {"user_id": "u1"}) + ray_op.submit.assert_not_awaited() + assert result["operator_type"] == "k8s" - @pytest.mark.asyncio - async def test_submit_with_unsupported_operator_type_raises(self): - """submit() should raise ValueError for unsupported operator_type.""" - ray_op = _make_mock_operator("ray") - composite = CompositeOperator( - operators={"ray": ray_op}, - default_operator_type="ray", - ) - config = DockerDeploymentConfig( - image="python:3.11", - container_name="test", - operator_type="docker_swarm", - ) - with pytest.raises(ValueError, match="Unsupported operator type"): - await composite.submit(config, {}) +@pytest.mark.asyncio +async def test_submit_uses_default_when_no_operator_type(): + """When config.operator_type is None, submit should use the default operator.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + ray_op.submit.return_value = _make_sandbox_info(sandbox_id="ray-sandbox") + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test-default", + operator_type=None, + ) + result = await composite.submit(config, {}) -# =========================================================================== -# 5. CompositeOperator.get_status routing via Redis -# =========================================================================== + ray_op.submit.assert_awaited_once() + k8s_op.submit.assert_not_awaited() + assert result["operator_type"] == "ray" + + +@pytest.mark.asyncio +async def test_submit_sets_operator_type_in_sandbox_info(): + """submit() must write operator_type into the returned SandboxInfo.""" + ray_op = _make_mock_operator("ray") + ray_op.submit.return_value = _make_sandbox_info() + + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + + config = DockerDeploymentConfig(image="python:3.11", container_name="test") + result = await composite.submit(config, {}) + assert "operator_type" in result + assert result["operator_type"] == "ray" -class TestCompositeOperatorGetStatus: - """Test get_status() routes based on operator_type stored in Redis.""" - @pytest.mark.asyncio - async def test_get_status_routes_by_redis_operator_type(self, fake_redis_provider): - """get_status should look up operator_type from Redis and route accordingly.""" - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") +@pytest.mark.asyncio +async def test_submit_with_unsupported_operator_type_raises(): + """submit() should raise ValueError for unsupported operator_type.""" + ray_op = _make_mock_operator("ray") + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test", + operator_type="docker_swarm", + ) + with pytest.raises(ValueError, match="Unsupported operator type"): + await composite.submit(config, {}) - k8s_status = _make_sandbox_info(sandbox_id="sandbox-1", state=State.RUNNING) - k8s_op.get_status.return_value = k8s_status - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) - composite.set_redis_provider(fake_redis_provider) +@pytest.mark.asyncio +async def test_get_status_routes_by_redis_operator_type(fake_redis_provider): + """get_status should look up operator_type from Redis and route accordingly.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") - # Pre-populate Redis with sandbox info that has operator_type="k8s" - sandbox_info_in_redis = _make_sandbox_info(sandbox_id="sandbox-1", operator_type="k8s") - await fake_redis_provider.json_set(alive_sandbox_key("sandbox-1"), "$", sandbox_info_in_redis) + k8s_status = _make_sandbox_info(sandbox_id="sandbox-1", state=State.RUNNING) + k8s_op.get_status.return_value = k8s_status - await composite.get_status("sandbox-1") + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) - k8s_op.get_status.assert_awaited_once_with("sandbox-1") - ray_op.get_status.assert_not_awaited() + sandbox_info_in_redis = _make_sandbox_info(sandbox_id="sandbox-1", operator_type="k8s") + await fake_redis_provider.json_set(alive_sandbox_key("sandbox-1"), "$", sandbox_info_in_redis) - @pytest.mark.asyncio - async def test_get_status_falls_back_to_default_without_redis(self): - """Without Redis, get_status should fall back to the default operator.""" - ray_op = _make_mock_operator("ray") - ray_op.get_status.return_value = _make_sandbox_info(state=State.RUNNING) + await composite.get_status("sandbox-1") - composite = CompositeOperator( - operators={"ray": ray_op}, - default_operator_type="ray", - ) - # No redis provider set + k8s_op.get_status.assert_awaited_once_with("sandbox-1") + ray_op.get_status.assert_not_awaited() - await composite.get_status("sandbox-no-redis") - ray_op.get_status.assert_awaited_once_with("sandbox-no-redis") - @pytest.mark.asyncio - async def test_get_status_falls_back_when_redis_has_no_operator_type(self, fake_redis_provider): - """If Redis entry has no operator_type, fall back to default.""" - ray_op = _make_mock_operator("ray") - ray_op.get_status.return_value = _make_sandbox_info(state=State.RUNNING) +@pytest.mark.asyncio +async def test_get_status_falls_back_to_default_without_redis(): + """Without Redis, get_status should fall back to the default operator.""" + ray_op = _make_mock_operator("ray") + ray_op.get_status.return_value = _make_sandbox_info(state=State.RUNNING) - composite = CompositeOperator( - operators={"ray": ray_op}, - default_operator_type="ray", - ) - composite.set_redis_provider(fake_redis_provider) + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) - # Store sandbox info WITHOUT operator_type - sandbox_info_no_type = _make_sandbox_info(sandbox_id="sandbox-2") - await fake_redis_provider.json_set(alive_sandbox_key("sandbox-2"), "$", sandbox_info_no_type) + await composite.get_status("sandbox-no-redis") + ray_op.get_status.assert_awaited_once_with("sandbox-no-redis") - await composite.get_status("sandbox-2") - ray_op.get_status.assert_awaited_once_with("sandbox-2") +@pytest.mark.asyncio +async def test_get_status_falls_back_when_redis_has_no_operator_type(fake_redis_provider): + """If Redis entry has no operator_type, fall back to default.""" + ray_op = _make_mock_operator("ray") + ray_op.get_status.return_value = _make_sandbox_info(state=State.RUNNING) -# =========================================================================== -# 6. CompositeOperator.stop routing via Redis -# =========================================================================== + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) + sandbox_info_no_type = _make_sandbox_info(sandbox_id="sandbox-2") + await fake_redis_provider.json_set(alive_sandbox_key("sandbox-2"), "$", sandbox_info_no_type) -class TestCompositeOperatorStop: - """Test stop() routes based on operator_type stored in Redis.""" + await composite.get_status("sandbox-2") + ray_op.get_status.assert_awaited_once_with("sandbox-2") - @pytest.mark.asyncio - async def test_stop_routes_by_redis_operator_type(self, fake_redis_provider): - """stop should look up operator_type from Redis and route accordingly.""" - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") - k8s_op.stop.return_value = True - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) - composite.set_redis_provider(fake_redis_provider) +@pytest.mark.asyncio +async def test_stop_routes_by_redis_operator_type(fake_redis_provider): + """stop should look up operator_type from Redis and route accordingly.""" + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + k8s_op.stop.return_value = True - sandbox_info_in_redis = _make_sandbox_info(sandbox_id="sandbox-stop", operator_type="k8s") - await fake_redis_provider.json_set(alive_sandbox_key("sandbox-stop"), "$", sandbox_info_in_redis) + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) - result = await composite.stop("sandbox-stop") + sandbox_info_in_redis = _make_sandbox_info(sandbox_id="sandbox-stop", operator_type="k8s") + await fake_redis_provider.json_set(alive_sandbox_key("sandbox-stop"), "$", sandbox_info_in_redis) - k8s_op.stop.assert_awaited_once_with("sandbox-stop") - ray_op.stop.assert_not_awaited() - assert result is True + result = await composite.stop("sandbox-stop") - @pytest.mark.asyncio - async def test_stop_falls_back_to_default_without_redis(self): - """Without Redis, stop should fall back to the default operator.""" - ray_op = _make_mock_operator("ray") - ray_op.stop.return_value = True + k8s_op.stop.assert_awaited_once_with("sandbox-stop") + ray_op.stop.assert_not_awaited() + assert result is True - composite = CompositeOperator( - operators={"ray": ray_op}, - default_operator_type="ray", - ) - await composite.stop("sandbox-no-redis") - ray_op.stop.assert_awaited_once_with("sandbox-no-redis") +@pytest.mark.asyncio +async def test_stop_falls_back_to_default_without_redis(): + """Without Redis, stop should fall back to the default operator.""" + ray_op = _make_mock_operator("ray") + ray_op.stop.return_value = True + composite = CompositeOperator( + operators={"ray": ray_op}, + default_operator_type="ray", + ) -# =========================================================================== -# 7. CRITICAL: get_status does NOT overwrite operator_type in Redis -# =========================================================================== + await composite.stop("sandbox-no-redis") + ray_op.stop.assert_awaited_once_with("sandbox-no-redis") -class TestGetStatusPreservesOperatorTypeInRedis: - """Critical tests: verify that the full start_async -> get_status flow - does NOT lose or overwrite the operator_type field in Redis. +@pytest.mark.asyncio +async def test_operator_type_survives_submit_and_get_status_cycle(fake_redis_provider): + """Critical: full cycle submit -> Redis write -> get_status -> Redis write + must preserve operator_type in Redis. - This simulates the SandboxManager flow: + Simulates the SandboxManager flow: 1. CompositeOperator.submit() sets operator_type in SandboxInfo 2. SandboxManager.start_async() writes SandboxInfo to Redis 3. SandboxManager.get_status() calls operator.get_status() and writes back 4. operator_type must still be present in Redis after step 3 """ + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") + + k8s_op.submit.return_value = _make_sandbox_info(sandbox_id="cycle-test") + + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) + + config = DockerDeploymentConfig( + image="python:3.11", + container_name="cycle-test", + operator_type="k8s", + ) + sandbox_info = await composite.submit(config, {}) + assert sandbox_info["operator_type"] == "k8s" + + await fake_redis_provider.json_set(alive_sandbox_key("cycle-test"), "$", sandbox_info) + + redis_data = await fake_redis_provider.json_get(alive_sandbox_key("cycle-test"), "$") + assert redis_data[0]["operator_type"] == "k8s" + + k8s_op.get_status.return_value = _make_sandbox_info( + sandbox_id="cycle-test", + state=State.RUNNING, + ) + status_info = await composite.get_status("cycle-test") + + await fake_redis_provider.json_set(alive_sandbox_key("cycle-test"), "$", status_info) + + redis_data_after = await fake_redis_provider.json_get(alive_sandbox_key("cycle-test"), "$") + has_operator_type = "operator_type" in redis_data_after[0] + + if not has_operator_type: + pytest.fail( + "operator_type was lost from Redis after get_status! " + "The sub-operator's get_status() did not include operator_type, " + "and SandboxManager overwrote Redis with the incomplete data." + ) - @pytest.mark.asyncio - async def test_operator_type_survives_submit_and_get_status_cycle(self, fake_redis_provider): - """Full cycle: submit writes operator_type, get_status preserves it.""" - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") - - # submit returns sandbox info (CompositeOperator will add operator_type) - k8s_op.submit.return_value = _make_sandbox_info(sandbox_id="cycle-test") - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) - composite.set_redis_provider(fake_redis_provider) +@pytest.mark.asyncio +async def test_ray_operator_get_status_preserves_operator_type_via_redis_merge(): + """Simulate RayOperator.get_status() redis merge path. - # Step 1: submit (simulates CompositeOperator.submit) - config = DockerDeploymentConfig( - image="python:3.11", - container_name="cycle-test", - operator_type="k8s", - ) - sandbox_info = await composite.submit(config, {}) - assert sandbox_info["operator_type"] == "k8s" + RayOperator.get_status() (non-rocklet path) does: + redis_info = await self.get_sandbox_info_from_redis(sandbox_id) + if redis_info: + redis_info.update(sandbox_info) + return redis_info - # Step 2: simulate SandboxManager writing to Redis - await fake_redis_provider.json_set(alive_sandbox_key("cycle-test"), "$", sandbox_info) + The merge preserves operator_type because actor_sandbox_info doesn't contain it. + """ + redis_sandbox_info = _make_sandbox_info( + sandbox_id="ray-merge-test", + operator_type="ray", + user_id="user-1", + experiment_id="exp-1", + ) - # Verify operator_type is in Redis - redis_data = await fake_redis_provider.json_get(alive_sandbox_key("cycle-test"), "$") - assert redis_data[0]["operator_type"] == "k8s" + actor_sandbox_info = _make_sandbox_info( + sandbox_id="ray-merge-test", + state=State.RUNNING, + ) + assert "operator_type" not in actor_sandbox_info - # Step 3: get_status returns info WITHOUT operator_type (like real sub-operators do) - k8s_op.get_status.return_value = _make_sandbox_info( - sandbox_id="cycle-test", - state=State.RUNNING, - # Note: no operator_type here, simulating real K8sOperator/RayOperator behavior - ) - status_info = await composite.get_status("cycle-test") - - # Step 4: simulate SandboxManager writing get_status result back to Redis - # This is what SandboxManager.get_status() does: - # sandbox_info = await self._operator.get_status(sandbox_id) - # await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) - await fake_redis_provider.json_set(alive_sandbox_key("cycle-test"), "$", status_info) - - # CRITICAL CHECK: operator_type must still be in Redis - # The sub-operator's get_status doesn't return operator_type, - # so if the full sandbox_info is overwritten, operator_type would be lost. - redis_data_after = await fake_redis_provider.json_get(alive_sandbox_key("cycle-test"), "$") - # This test verifies the CURRENT behavior. If operator_type is missing here, - # it means get_status overwrites it and we have a bug. - # - # In the current design, sub-operators (RayOperator, K8sOperator) merge - # redis_info with fresh status via redis_info.update(sandbox_info), which - # preserves operator_type because the fresh status dict doesn't contain it. - # However, SandboxManager.get_status() does a full json_set with the - # returned sandbox_info. If the sub-operator returns a dict without - # operator_type, it WILL be lost. - # - # Let's check what actually happens: - has_operator_type = "operator_type" in redis_data_after[0] - - if not has_operator_type: - # This means the current flow DOES lose operator_type. - # We need to verify this is the case and document it. - pytest.fail( - "operator_type was lost from Redis after get_status! " - "The sub-operator's get_status() did not include operator_type, " - "and SandboxManager overwrote Redis with the incomplete data." - ) - - @pytest.mark.asyncio - async def test_ray_operator_get_status_preserves_operator_type_via_redis_merge(self, fake_redis_provider): - """Simulate RayOperator.get_status() redis merge path. - - RayOperator.get_status() (non-rocklet path) does: - redis_info = await self.get_sandbox_info_from_redis(sandbox_id) - if redis_info: - redis_info.update(sandbox_info) # sandbox_info has no operator_type - return redis_info # redis_info still has operator_type - - This test verifies that the merge preserves operator_type. - """ - # Simulate what's in Redis (with operator_type) - redis_sandbox_info = _make_sandbox_info( - sandbox_id="ray-merge-test", - operator_type="ray", - user_id="user-1", - experiment_id="exp-1", - ) + redis_sandbox_info.update(actor_sandbox_info) - # Simulate what RayOperator gets from the actor (no operator_type) - actor_sandbox_info = _make_sandbox_info( - sandbox_id="ray-merge-test", - state=State.RUNNING, - ) - # Actor info typically doesn't have operator_type - assert "operator_type" not in actor_sandbox_info - - # Simulate the merge: redis_info.update(sandbox_info) - redis_sandbox_info.update(actor_sandbox_info) - - # operator_type should survive because actor_sandbox_info doesn't have it - assert redis_sandbox_info.get("operator_type") == "ray" - assert redis_sandbox_info.get("state") == State.RUNNING - - @pytest.mark.asyncio - async def test_k8s_operator_get_status_preserves_operator_type_via_redis_merge(self, fake_redis_provider): - """Simulate K8sOperator.get_status() redis merge path. - - K8sOperator.get_status() does: - sandbox_info = await self._provider.get_status(sandbox_id) - if self._redis_provider: - redis_info = await self._get_sandbox_info_from_redis(sandbox_id) - if redis_info: - redis_info.update(sandbox_info) - return redis_info - - This test verifies that the merge preserves operator_type. - """ - # Simulate what's in Redis (with operator_type) - redis_sandbox_info = _make_sandbox_info( - sandbox_id="k8s-merge-test", - operator_type="k8s", - user_id="user-1", - ) + assert redis_sandbox_info.get("operator_type") == "ray" + assert redis_sandbox_info.get("state") == State.RUNNING - # Simulate what K8s provider returns (no operator_type) - provider_sandbox_info: SandboxInfo = { - "sandbox_id": "k8s-merge-test", - "host_ip": "10.0.0.2", - "state": State.RUNNING, - "phases": {}, - "port_mapping": {8000: 30001}, - } - assert "operator_type" not in provider_sandbox_info - - # Simulate the merge: redis_info.update(sandbox_info) - redis_sandbox_info.update(provider_sandbox_info) - - # operator_type should survive - assert redis_sandbox_info.get("operator_type") == "k8s" - assert redis_sandbox_info.get("state") == State.RUNNING - assert redis_sandbox_info.get("host_ip") == "10.0.0.2" - - @pytest.mark.asyncio - async def test_sandbox_manager_get_status_preserves_operator_type(self, fake_redis_provider): - """End-to-end: SandboxManager.get_status() must preserve operator_type in Redis. - - SandboxManager.get_status() does: - sandbox_info = await self._operator.get_status(sandbox_id) - await self._redis_provider.json_set(alive_sandbox_key(sandbox_id), "$", sandbox_info) - - If the operator returns sandbox_info WITH operator_type (because sub-operators - merge from Redis), then the json_set will preserve it. - """ - ray_op = _make_mock_operator("ray") - k8s_op = _make_mock_operator("k8s") - - composite = CompositeOperator( - operators={"ray": ray_op, "k8s": k8s_op}, - default_operator_type="ray", - ) - composite.set_redis_provider(fake_redis_provider) - # Pre-populate Redis with sandbox info including operator_type - initial_info = _make_sandbox_info( - sandbox_id="e2e-test", - operator_type="k8s", - user_id="user-1", - ) - await fake_redis_provider.json_set(alive_sandbox_key("e2e-test"), "$", initial_info) - - # K8sOperator.get_status() merges redis_info with provider status, - # so the returned dict should still contain operator_type - merged_status = _make_sandbox_info( - sandbox_id="e2e-test", - operator_type="k8s", # preserved from Redis merge - user_id="user-1", - state=State.RUNNING, - host_ip="10.0.0.5", - ) - k8s_op.get_status.return_value = merged_status +@pytest.mark.asyncio +async def test_k8s_operator_get_status_preserves_operator_type_via_redis_merge(): + """Simulate K8sOperator.get_status() redis merge path. - # CompositeOperator.get_status routes to k8s_op - result = await composite.get_status("e2e-test") + K8sOperator.get_status() does: + sandbox_info = await self._provider.get_status(sandbox_id) + if self._redis_provider: + redis_info = await self._get_sandbox_info_from_redis(sandbox_id) + if redis_info: + redis_info.update(sandbox_info) + return redis_info - # Simulate SandboxManager writing back to Redis - await fake_redis_provider.json_set(alive_sandbox_key("e2e-test"), "$", result) + The merge preserves operator_type because provider_sandbox_info doesn't contain it. + """ + redis_sandbox_info = _make_sandbox_info( + sandbox_id="k8s-merge-test", + operator_type="k8s", + user_id="user-1", + ) + + provider_sandbox_info: SandboxInfo = { + "sandbox_id": "k8s-merge-test", + "host_ip": "10.0.0.2", + "state": State.RUNNING, + "phases": {}, + "port_mapping": {8000: 30001}, + } + assert "operator_type" not in provider_sandbox_info - # Verify operator_type is preserved - final_redis = await fake_redis_provider.json_get(alive_sandbox_key("e2e-test"), "$") - assert final_redis[0]["operator_type"] == "k8s" - assert final_redis[0]["state"] == State.RUNNING + redis_sandbox_info.update(provider_sandbox_info) + assert redis_sandbox_info.get("operator_type") == "k8s" + assert redis_sandbox_info.get("state") == State.RUNNING + assert redis_sandbox_info.get("host_ip") == "10.0.0.2" -# =========================================================================== -# 8. OperatorFactory.create_composite_operator -# =========================================================================== +@pytest.mark.asyncio +async def test_sandbox_manager_get_status_preserves_operator_type(fake_redis_provider): + """End-to-end: SandboxManager.get_status() must preserve operator_type in Redis. -class TestOperatorFactoryCreateComposite: - """Test OperatorFactory.create_composite_operator method.""" + If the operator returns sandbox_info WITH operator_type (because sub-operators + merge from Redis), then the json_set will preserve it. + """ + ray_op = _make_mock_operator("ray") + k8s_op = _make_mock_operator("k8s") - def test_create_composite_operator_single_type(self): - """create_composite_operator with a single operator type.""" - from rock.sandbox.operator.factory import OperatorContext, OperatorFactory + composite = CompositeOperator( + operators={"ray": ray_op, "k8s": k8s_op}, + default_operator_type="ray", + ) + composite.set_redis_provider(fake_redis_provider) - runtime_config = RuntimeConfig(operator_types=["ray"]) - ray_service = MagicMock() + initial_info = _make_sandbox_info( + sandbox_id="e2e-test", + operator_type="k8s", + user_id="user-1", + ) + await fake_redis_provider.json_set(alive_sandbox_key("e2e-test"), "$", initial_info) - context = OperatorContext( - runtime_config=runtime_config, - ray_service=ray_service, - ) + merged_status = _make_sandbox_info( + sandbox_id="e2e-test", + operator_type="k8s", + user_id="user-1", + state=State.RUNNING, + host_ip="10.0.0.5", + ) + k8s_op.get_status.return_value = merged_status - with patch("rock.sandbox.operator.factory.OperatorFactory._create_single_operator") as mock_create: - mock_ray_op = _make_mock_operator("ray") - mock_create.return_value = mock_ray_op + result = await composite.get_status("e2e-test") - composite = OperatorFactory.create_composite_operator(context) + await fake_redis_provider.json_set(alive_sandbox_key("e2e-test"), "$", result) - assert isinstance(composite, CompositeOperator) - assert composite._default_operator_type == "ray" - assert "ray" in composite._operators + final_redis = await fake_redis_provider.json_get(alive_sandbox_key("e2e-test"), "$") + assert final_redis[0]["operator_type"] == "k8s" + assert final_redis[0]["state"] == State.RUNNING - def test_create_composite_operator_multiple_types(self): - """create_composite_operator with multiple operator types.""" - from rock.sandbox.operator.factory import OperatorContext, OperatorFactory - runtime_config = RuntimeConfig(operator_types=["ray", "k8s"]) - ray_service = MagicMock() +def test_create_composite_operator_single_type(): + """create_composite_operator with a single operator type.""" + from rock.sandbox.operator.factory import OperatorContext, OperatorFactory - context = OperatorContext( - runtime_config=runtime_config, - ray_service=ray_service, - ) + runtime_config = RuntimeConfig(operator_types=["ray"]) + ray_service = MagicMock() - call_count = 0 + context = OperatorContext( + runtime_config=runtime_config, + ray_service=ray_service, + ) - def side_effect(op_type, ctx): - nonlocal call_count - call_count += 1 - return _make_mock_operator(op_type) + with patch("rock.sandbox.operator.factory.OperatorFactory._create_single_operator") as mock_create: + mock_ray_op = _make_mock_operator("ray") + mock_create.return_value = mock_ray_op - with patch( - "rock.sandbox.operator.factory.OperatorFactory._create_single_operator", - side_effect=side_effect, - ): - composite = OperatorFactory.create_composite_operator(context) + composite = OperatorFactory.create_composite_operator(context) - assert isinstance(composite, CompositeOperator) - assert composite._default_operator_type == "ray" - assert "ray" in composite._operators - assert "k8s" in composite._operators - assert call_count == 2 + assert isinstance(composite, CompositeOperator) + assert composite._default_operator_type == "ray" + assert "ray" in composite._operators -# =========================================================================== -# 9. SandboxStartRequest and DockerDeploymentConfig operator_type field -# =========================================================================== +def test_create_composite_operator_multiple_types(): + """create_composite_operator with multiple operator types.""" + from rock.sandbox.operator.factory import OperatorContext, OperatorFactory + runtime_config = RuntimeConfig(operator_types=["ray", "k8s"]) + ray_service = MagicMock() -class TestOperatorTypeFieldInModels: - """Test that operator_type field exists and works in request/config models.""" + context = OperatorContext( + runtime_config=runtime_config, + ray_service=ray_service, + ) - def test_sandbox_start_request_has_operator_type(self): - request = SandboxStartRequest( - image="python:3.11", - operator_type="k8s", - ) - assert request.operator_type == "k8s" + call_count = 0 - def test_sandbox_start_request_operator_type_default_none(self): - request = SandboxStartRequest(image="python:3.11") - assert request.operator_type is None + def side_effect(op_type, ctx): + nonlocal call_count + call_count += 1 + return _make_mock_operator(op_type) - def test_docker_deployment_config_has_operator_type(self): - config = DockerDeploymentConfig( - image="python:3.11", - container_name="test", - operator_type="ray", - ) - assert config.operator_type == "ray" + with patch( + "rock.sandbox.operator.factory.OperatorFactory._create_single_operator", + side_effect=side_effect, + ): + composite = OperatorFactory.create_composite_operator(context) - def test_docker_deployment_config_operator_type_default_none(self): - config = DockerDeploymentConfig( - image="python:3.11", - container_name="test", - ) - assert config.operator_type is None - - def test_docker_deployment_config_from_request_preserves_operator_type(self): - """DockerDeploymentConfig.from_request should carry over operator_type.""" - request = SandboxStartRequest( - image="python:3.11", - sandbox_id="test-sandbox", - operator_type="k8s", - ) - config = DockerDeploymentConfig.from_request(request) - assert config.operator_type == "k8s" - assert config.container_name == "test-sandbox" + assert isinstance(composite, CompositeOperator) + assert composite._default_operator_type == "ray" + assert "ray" in composite._operators + assert "k8s" in composite._operators + assert call_count == 2 + + +def test_sandbox_start_request_has_operator_type(): + """SandboxStartRequest should accept and store operator_type.""" + request = SandboxStartRequest( + image="python:3.11", + operator_type="k8s", + ) + assert request.operator_type == "k8s" + + +def test_sandbox_start_request_operator_type_default_none(): + """SandboxStartRequest.operator_type should default to None.""" + request = SandboxStartRequest(image="python:3.11") + assert request.operator_type is None + + +def test_docker_deployment_config_has_operator_type(): + """DockerDeploymentConfig should accept and store operator_type.""" + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test", + operator_type="ray", + ) + assert config.operator_type == "ray" + + +def test_docker_deployment_config_operator_type_default_none(): + """DockerDeploymentConfig.operator_type should default to None.""" + config = DockerDeploymentConfig( + image="python:3.11", + container_name="test", + ) + assert config.operator_type is None + + +def test_docker_deployment_config_from_request_preserves_operator_type(): + """DockerDeploymentConfig.from_request should carry over operator_type.""" + request = SandboxStartRequest( + image="python:3.11", + sandbox_id="test-sandbox", + operator_type="k8s", + ) + config = DockerDeploymentConfig.from_request(request) + assert config.operator_type == "k8s" + assert config.container_name == "test-sandbox" diff --git a/tests/unit/test_sdk_operator_type.py b/tests/unit/test_sdk_operator_type.py deleted file mode 100644 index 359ed4118..000000000 --- a/tests/unit/test_sdk_operator_type.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Unit tests for SDK operator_type support. - -Covers: -- SandboxConfig.operator_type field (default, explicit value) -- Sandbox.__init__ stores operator_type from config -- Sandbox.start() includes operator_type in the request payload -- SandboxGroup propagates operator_type to all child sandboxes -""" - -from unittest.mock import AsyncMock, patch - -import pytest - -from rock.sdk.sandbox.client import Sandbox, SandboxGroup -from rock.sdk.sandbox.config import SandboxConfig, SandboxGroupConfig - -# =========================================================================== -# 1. SandboxConfig operator_type field -# =========================================================================== - - -class TestSandboxConfigOperatorType: - """Test SandboxConfig.operator_type field behavior.""" - - def test_operator_type_default_is_ray(self): - """operator_type should default to 'ray' when not specified.""" - config = SandboxConfig() - assert config.operator_type == "ray" - - def test_operator_type_set_explicitly(self): - """operator_type should be stored when explicitly set.""" - config = SandboxConfig(operator_type="k8s") - assert config.operator_type == "k8s" - - def test_operator_type_ray(self): - """operator_type should accept 'ray' value.""" - config = SandboxConfig(operator_type="ray") - assert config.operator_type == "ray" - - def test_operator_type_with_other_fields(self): - """operator_type should coexist with other config fields.""" - config = SandboxConfig( - image="ubuntu:22.04", - memory="16g", - cpus=4, - cluster="us-east", - operator_type="k8s", - ) - assert config.operator_type == "k8s" - assert config.image == "ubuntu:22.04" - assert config.memory == "16g" - assert config.cpus == 4 - assert config.cluster == "us-east" - - def test_operator_type_serialization(self): - """operator_type should appear in model_dump output.""" - config = SandboxConfig(operator_type="ray") - dumped = config.model_dump() - assert "operator_type" in dumped - assert dumped["operator_type"] == "ray" - - def test_operator_type_default_serialization(self): - """Default operator_type='ray' should appear in model_dump output.""" - config = SandboxConfig() - dumped = config.model_dump() - assert "operator_type" in dumped - assert dumped["operator_type"] == "ray" - - -# =========================================================================== -# 2. Sandbox.__init__ with operator_type -# =========================================================================== - - -class TestSandboxInitOperatorType: - """Test that Sandbox stores operator_type from SandboxConfig.""" - - def test_sandbox_stores_operator_type_from_config(self): - """Sandbox should store the config with operator_type.""" - config = SandboxConfig(operator_type="k8s") - sandbox = Sandbox(config) - assert sandbox.config.operator_type == "k8s" - - def test_sandbox_stores_default_operator_type(self): - """Sandbox should store default operator_type='ray' when not specified.""" - config = SandboxConfig() - sandbox = Sandbox(config) - assert sandbox.config.operator_type == "ray" - - -# =========================================================================== -# 3. Sandbox.start() includes operator_type in request -# =========================================================================== - - -class TestSandboxStartOperatorType: - """Test that Sandbox.start() sends operator_type in the POST payload.""" - - @pytest.mark.asyncio - async def test_start_sends_operator_type_in_payload(self): - """start() should include operator_type in the request data.""" - config = SandboxConfig(operator_type="k8s", startup_timeout=5) - sandbox = Sandbox(config) - - mock_response = { - "status": "Success", - "result": { - "sandbox_id": "test-sandbox-001", - "host_name": "test-host", - "host_ip": "10.0.0.1", - }, - } - - # Mock get_status to return alive immediately - mock_status = AsyncMock() - mock_status.is_alive = True - - with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( - sandbox, "get_status", return_value=mock_status - ): - mock_post.return_value = mock_response - await sandbox.start() - - # Verify the POST was called with operator_type in data - mock_post.assert_called_once() - call_args = mock_post.call_args - posted_data = call_args[0][2] # third positional arg is data - assert "operator_type" in posted_data - assert posted_data["operator_type"] == "k8s" - - @pytest.mark.asyncio - async def test_start_sends_default_operator_type_when_not_set(self): - """start() should send operator_type='ray' when using default config.""" - config = SandboxConfig(startup_timeout=5) - sandbox = Sandbox(config) - - mock_response = { - "status": "Success", - "result": { - "sandbox_id": "test-sandbox-002", - "host_name": "test-host", - "host_ip": "10.0.0.1", - }, - } - - mock_status = AsyncMock() - mock_status.is_alive = True - - with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( - sandbox, "get_status", return_value=mock_status - ): - mock_post.return_value = mock_response - await sandbox.start() - - call_args = mock_post.call_args - posted_data = call_args[0][2] - assert "operator_type" in posted_data - assert posted_data["operator_type"] == "ray" - - @pytest.mark.asyncio - async def test_start_sends_ray_operator_type(self): - """start() should correctly send operator_type='ray'.""" - config = SandboxConfig(operator_type="ray", startup_timeout=5) - sandbox = Sandbox(config) - - mock_response = { - "status": "Success", - "result": { - "sandbox_id": "test-sandbox-003", - "host_name": "test-host", - "host_ip": "10.0.0.1", - }, - } - - mock_status = AsyncMock() - mock_status.is_alive = True - - with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( - sandbox, "get_status", return_value=mock_status - ): - mock_post.return_value = mock_response - await sandbox.start() - - call_args = mock_post.call_args - posted_data = call_args[0][2] - assert posted_data["operator_type"] == "ray" - - @pytest.mark.asyncio - async def test_start_payload_contains_all_expected_fields(self): - """start() payload should contain operator_type alongside all other fields.""" - config = SandboxConfig( - image="ubuntu:22.04", - memory="16g", - cpus=4, - operator_type="k8s", - startup_timeout=5, - ) - sandbox = Sandbox(config) - - mock_response = { - "status": "Success", - "result": { - "sandbox_id": "test-sandbox-004", - "host_name": "test-host", - "host_ip": "10.0.0.1", - }, - } - - mock_status = AsyncMock() - mock_status.is_alive = True - - with patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock) as mock_post, patch.object( - sandbox, "get_status", return_value=mock_status - ): - mock_post.return_value = mock_response - await sandbox.start() - - call_args = mock_post.call_args - posted_data = call_args[0][2] - - # Verify all expected fields are present - assert posted_data["image"] == "ubuntu:22.04" - assert posted_data["memory"] == "16g" - assert posted_data["cpus"] == 4 - assert posted_data["operator_type"] == "k8s" - assert "use_kata_runtime" in posted_data - assert "registry_username" in posted_data - assert "registry_password" in posted_data - - -# =========================================================================== -# 4. SandboxGroup propagates operator_type -# =========================================================================== - - -class TestSandboxGroupOperatorType: - """Test that SandboxGroup propagates operator_type to child sandboxes.""" - - def test_group_propagates_operator_type_to_children(self): - """All sandboxes in a group should inherit operator_type from config.""" - config = SandboxGroupConfig( - size=3, - operator_type="k8s", - ) - group = SandboxGroup(config) - - assert len(group.sandbox_list) == 3 - for sandbox in group.sandbox_list: - assert sandbox.config.operator_type == "k8s" - - def test_group_propagates_default_operator_type(self): - """All sandboxes in a group should have default operator_type='ray' when not set.""" - config = SandboxGroupConfig(size=2) - group = SandboxGroup(config) - - for sandbox in group.sandbox_list: - assert sandbox.config.operator_type == "ray" - - -# =========================================================================== -# 5. SandboxGroupConfig inherits operator_type -# =========================================================================== - - -class TestSandboxGroupConfigOperatorType: - """Test SandboxGroupConfig inherits operator_type from SandboxConfig.""" - - def test_group_config_has_operator_type(self): - """SandboxGroupConfig should support operator_type field.""" - config = SandboxGroupConfig(operator_type="ray", size=2) - assert config.operator_type == "ray" - - def test_group_config_operator_type_default_ray(self): - """SandboxGroupConfig.operator_type should default to 'ray'.""" - config = SandboxGroupConfig(size=2) - assert config.operator_type == "ray"