From e164ec9ffb768c99330645a982db4570238bdd3d Mon Sep 17 00:00:00 2001 From: "kezhong.ykz" Date: Wed, 4 Mar 2026 16:59:20 +0800 Subject: [PATCH 01/11] docs: fix zh release note index.md (#573) --- .../version-1.3.x/Release Notes/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/version-1.3.x/Release Notes/index.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/version-1.3.x/Release Notes/index.md index 474503cad..56dfec168 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/version-1.3.x/Release Notes/index.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/version-1.3.x/Release Notes/index.md @@ -2,4 +2,4 @@ sidebar_position: 1 --- # 版本说明 -* [release v1.2.0](v1.2.0.md) \ No newline at end of file +* [release v1.3.0](v1.3.0.md) \ No newline at end of file From c364bcf77fcf053f3fd60adc407ec00ed611e0a7 Mon Sep 17 00:00:00 2001 From: Fangwen DAI <34794181+FangwenDave@users.noreply.github.com> Date: Thu, 5 Mar 2026 10:17:03 +0800 Subject: [PATCH 02/11] feat: support create standard spec sandbox #538 (#539) (#571) * support create standard_spec sandbox * rename enforce_standard_spec to use_standard_spec_only --- rock/config.py | 1 + rock/sandbox/sandbox_manager.py | 8 +++ tests/unit/sandbox/test_sandbox_manager.py | 82 ++++++++++++++++++++++ 3 files changed, 91 insertions(+) diff --git a/rock/config.py b/rock/config.py index 6066ad80d..4a28e0ccd 100644 --- a/rock/config.py +++ b/rock/config.py @@ -130,6 +130,7 @@ class RuntimeConfig: operator_type: str = "ray" standard_spec: StandardSpec = field(default_factory=StandardSpec) max_allowed_spec: StandardSpec = field(default_factory=lambda: StandardSpec(cpus=16, memory="64g")) + use_standard_spec_only: bool = False metrics_endpoint: str = "" user_defined_tags: dict = field(default_factory=dict) diff --git a/rock/sandbox/sandbox_manager.py b/rock/sandbox/sandbox_manager.py index 817f9126f..ba8417523 100644 --- a/rock/sandbox/sandbox_manager.py +++ b/rock/sandbox/sandbox_manager.py @@ -117,6 +117,14 @@ async def start_async( docker_deployment_config: DockerDeploymentConfig = await self.deployment_manager.init_config(config) sandbox_id = docker_deployment_config.container_name + if self.rock_config.runtime.use_standard_spec_only: + logger.info( + f"[{sandbox_id}] Using standard spec only: " + f"cpus={self.rock_config.runtime.standard_spec.cpus}, " + f"memory={self.rock_config.runtime.standard_spec.memory}" + ) + docker_deployment_config.cpus = self.rock_config.runtime.standard_spec.cpus + docker_deployment_config.memory = self.rock_config.runtime.standard_spec.memory sandbox_info: SandboxInfo = await self._operator.submit(docker_deployment_config, user_info) stop_time = str(int(time.time()) + docker_deployment_config.auto_clear_time * 60) auto_clear_time_dict = { diff --git a/tests/unit/sandbox/test_sandbox_manager.py b/tests/unit/sandbox/test_sandbox_manager.py index 8640f53f3..fb69a486f 100644 --- a/tests/unit/sandbox/test_sandbox_manager.py +++ b/tests/unit/sandbox/test_sandbox_manager.py @@ -189,3 +189,85 @@ async def test_get_actor_not_exist_raises_value_error(sandbox_manager): actor_name = sandbox_manager.deployment_manager.get_actor_name(sandbox_id) await sandbox_manager._ray_service.async_ray_get_actor(actor_name) assert exc_info.type == ValueError + + +@pytest.mark.need_ray +@pytest.mark.asyncio +async def test_use_standard_spec_only(sandbox_manager): + """Test that use_standard_spec_only forces sandbox to use standard spec.""" + # Enable use_standard_spec_only + sandbox_manager.rock_config.runtime.use_standard_spec_only = True + sandbox_manager.rock_config.runtime.standard_spec.cpus = 1 + sandbox_manager.rock_config.runtime.standard_spec.memory = "2g" + + # Try to create sandbox with different specs + config = DockerDeploymentConfig(cpus=4, memory="16g") + + response = await sandbox_manager.start_async(config) + sandbox_id = response.sandbox_id + + try: + # Wait for sandbox to be alive + await check_sandbox_status_until_alive(sandbox_manager, sandbox_id) + + # Get the actual deployment config used + actor_name = sandbox_manager.deployment_manager.get_actor_name(sandbox_id) + sandbox_actor = await sandbox_manager._ray_service.async_ray_get_actor(actor_name) + + # Get sandbox info to verify the actual specs + sandbox_info = await sandbox_manager._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote()) + + # Verify that standard spec was applied (not the requested 4 CPUs and 16g) + assert sandbox_info["cpus"] == 1, f"Expected cpus=1, but got {sandbox_info['cpus']}" + assert sandbox_info["memory"] == "2g", f"Expected memory='2g', but got {sandbox_info['memory']}" + + # Also verify sandbox is alive + is_alive = await sandbox_manager._is_actor_alive(sandbox_id) + assert is_alive + + finally: + # Cleanup + await sandbox_manager.stop(sandbox_id) + # Reset the flag + sandbox_manager.rock_config.runtime.use_standard_spec_only = False + + +@pytest.mark.need_ray +@pytest.mark.asyncio +async def test_use_standard_spec_only_disabled(sandbox_manager): + """Test that sandbox uses requested spec when use_standard_spec_only is disabled.""" + # Ensure use_standard_spec_only is disabled + sandbox_manager.rock_config.runtime.use_standard_spec_only = False + + # Create sandbox with custom specs (within allowed limits) + requested_cpus = 1 + requested_memory = "4g" + config = DockerDeploymentConfig(cpus=requested_cpus, memory=requested_memory) + + response = await sandbox_manager.start_async(config) + sandbox_id = response.sandbox_id + + try: + # Wait for sandbox to be alive + await check_sandbox_status_until_alive(sandbox_manager, sandbox_id) + + # Get the actual deployment config used + actor_name = sandbox_manager.deployment_manager.get_actor_name(sandbox_id) + sandbox_actor = await sandbox_manager._ray_service.async_ray_get_actor(actor_name) + + # Get sandbox info to verify the actual specs + sandbox_info = await sandbox_manager._ray_service.async_ray_get(sandbox_actor.sandbox_info.remote()) + + # Verify that requested spec was used (not standard spec) + assert sandbox_info["cpus"] == requested_cpus, f"Expected cpus={requested_cpus}, but got {sandbox_info['cpus']}" + assert ( + sandbox_info["memory"] == requested_memory + ), f"Expected memory='{requested_memory}', but got {sandbox_info['memory']}" + + # Also verify sandbox is alive + is_alive = await sandbox_manager._is_actor_alive(sandbox_id) + assert is_alive + + finally: + # Cleanup + await sandbox_manager.stop(sandbox_id) From e35aad82dcb86013eae2ec5e24884092d6850fd0 Mon Sep 17 00:00:00 2001 From: guoj14 Date: Thu, 5 Mar 2026 18:23:25 +0800 Subject: [PATCH 03/11] feat: update the parameter passing in the CI trigger script and result retrieval script (#581) * feat: update the parameter passing in the CI trigger script and result retrieval script * feat: update the parameter passing in the CI trigger script and result retrieval script --- .github/scripts/get-CI-result.sh | 7 ++++--- .github/scripts/trigger-CI.sh | 1 + .github/scripts/trigger-docs.sh | 1 + .github/workflows/CI-request-trigger.yml | 2 +- 4 files changed, 7 insertions(+), 4 deletions(-) diff --git a/.github/scripts/get-CI-result.sh b/.github/scripts/get-CI-result.sh index d10b7301c..19ee8a0c0 100644 --- a/.github/scripts/get-CI-result.sh +++ b/.github/scripts/get-CI-result.sh @@ -1,13 +1,14 @@ #!/bin/bash # 检查参数 -if [ $# -ne 2 ]; then - echo "Usage: $0 " +if [ $# -ne 3 ]; then + echo "Usage: $0 " exit 1 fi COMMIT_ID=$1 SECURITY=$2 +REPOSITORY=$3 # 设置最大等待时间 MAX_WAIT_TIME=7200 @@ -18,7 +19,7 @@ while true; do response=$(curl -s -H "Content-Type: application/json" \ -H "Authorization: Basic ${SECURITY}" \ - -d "{\"type\": \"RETRIEVE-TASK-STATUS\", \"commitId\": \"${COMMIT_ID}\"}" "http://get-tasend-back-twkvcdsbpj.cn-hangzhou.fcapp.run") + -d "{\"type\": \"RETRIEVE-TASK-STATUS\", \"repositoryUrl\": \"${REPOSITORY}\", \"commitId\": \"${COMMIT_ID}\"}" "http://get-tasend-back-twkvcdsbpj.cn-hangzhou.fcapp.run") echo "Response: $response" # 检查curl是否成功 diff --git a/.github/scripts/trigger-CI.sh b/.github/scripts/trigger-CI.sh index 42d79e10d..b6892be95 100644 --- a/.github/scripts/trigger-CI.sh +++ b/.github/scripts/trigger-CI.sh @@ -38,6 +38,7 @@ curl -v -H "Content-Type: application/json" \ \"type\": \"CREATE-TASK\", \"commitId\": \"${COMMIT_ID}\", \"repositoryUrl\": \"${REPO_URL}\", + \"prId\": \"${GITHUB_PR_ID}\", \"aone\": { \"projectId\": \"${PROJECT_ID}\", \"pipelineId\": \"${PIPELINE_ID}\"}, \"newBranch\": { \"name\": \"${BRANCH_NAME}\", \"ref\": \"${BRANCH_REF}\" }, \"params\": {\"cancel-in-progress\": \"${CANCEL_IN_PROGRESS}\", \"github_commit\":\"${GITHUB_COMMIT_ID}\", \"github_source_repo\": \"${GITHUB_SOURCE_REPO}\"} diff --git a/.github/scripts/trigger-docs.sh b/.github/scripts/trigger-docs.sh index beb7c0df9..00832ae86 100644 --- a/.github/scripts/trigger-docs.sh +++ b/.github/scripts/trigger-docs.sh @@ -41,6 +41,7 @@ curl -v -H "Content-Type: application/json" \ \"type\": \"CREATE-TASK\", \"commitId\": \"${COMMIT_ID}\", \"repositoryUrl\": \"${REPO_URL}\", + \"prId\": \"${GITHUB_PR_ID}\", \"aone\": { \"projectId\": \"${PROJECT_ID}\", \"pipelineId\": \"${PIPELINE_ID}\"}, \"newBranch\": { \"name\": \"${BRANCH_NAME}\", \"ref\": \"${BRANCH_REF}\" }, \"params\": {\"cancel-in-progress\": \"${CANCEL_IN_PROGRESS}\", \"github_commit\":\"${GITHUB_COMMIT_ID}\", \"github_source_repo\": \"${GITHUB_SOURCE_REPO}\", \"checkout_submodules\": \"${CHECKOUT_SUBMODULES}\", \"checkout_username\": \"${CHECK_USER_NAME}\", \"checkout_token\": \"${CHECK_TOKEN}\"} diff --git a/.github/workflows/CI-request-trigger.yml b/.github/workflows/CI-request-trigger.yml index da42050fe..319c205d2 100644 --- a/.github/workflows/CI-request-trigger.yml +++ b/.github/workflows/CI-request-trigger.yml @@ -46,4 +46,4 @@ jobs: run: | COMMIT_ID=$([ "${{ github.event_name }}" == "pull_request" ] && echo "${{ github.event.pull_request.head.sha }}" || echo "${{ github.sha }}") echo "Using Commit ID: $COMMIT_ID" - ./get-CI-result.sh "$COMMIT_ID" "${{ secrets.CI_SECRET }}" \ No newline at end of file + ./get-CI-result.sh "$COMMIT_ID" "${{ secrets.CI_SECRET }}" "${{ github.repository }}" \ No newline at end of file From ee9f9f80de83ca8a6773ee4b5adf363c0bf2ecc7 Mon Sep 17 00:00:00 2001 From: Timandes White Date: Sat, 28 Feb 2026 10:17:15 +0800 Subject: [PATCH 04/11] feat: add WebSocket TCP port forwarding for sandbox containers - Add /sandboxes/{id}/portforward WebSocket endpoint in proxy layer - Add /portforward WebSocket endpoint in rocklet for internal TCP proxy - Support port range 1024-65535, excluding port 22 (SSH) - Implement two-layer architecture: external proxy -> rocklet -> TCP port - Add comprehensive logging with target_port in all messages - Handle both FastAPI WebSocket and websockets library APIs - Add unit tests for port validation, URL building, and route handling --- .iflow/specs/websocket-tcp-proxy/plan.md | 118 +++++++ .iflow/specs/websocket-tcp-proxy/spec.md | 118 +++++++ rock/admin/entrypoints/sandbox_proxy_api.py | 54 ++- rock/rocklet/local_api.py | 192 +++++++++- rock/sandbox/service/sandbox_proxy_service.py | 327 ++++++++++++++++++ tests/unit/rocklet/test_portforward.py | 22 ++ .../unit/sandbox/test_websocket_tcp_proxy.py | 167 +++++++++ 7 files changed, 993 insertions(+), 5 deletions(-) create mode 100644 .iflow/specs/websocket-tcp-proxy/plan.md create mode 100644 .iflow/specs/websocket-tcp-proxy/spec.md create mode 100644 tests/unit/rocklet/test_portforward.py create mode 100644 tests/unit/sandbox/test_websocket_tcp_proxy.py diff --git a/.iflow/specs/websocket-tcp-proxy/plan.md b/.iflow/specs/websocket-tcp-proxy/plan.md new file mode 100644 index 000000000..f54ff89d8 --- /dev/null +++ b/.iflow/specs/websocket-tcp-proxy/plan.md @@ -0,0 +1,118 @@ +# WebSocket TCP 端口代理服务 - 实现计划 + +## 实现步骤 + +### 步骤 1: 在 rocklet 中添加端口转发端点 + +**文件**: `rock/rocklet/local_api.py` + +添加 WebSocket 端点: +```python +@sandbox_proxy_router.websocket("/portforward") +async def portforward(websocket: WebSocket, port: int): + """ + 容器内 TCP 端口代理端点。 + + 流程: + 1. 接受 WebSocket 连接 + 2. 验证端口 (1024-65535, 排除 22) + 3. 连接到 127.0.0.1:{port} + 4. 双向转发二进制数据 + """ +``` + +核心实现: +- 使用 `asyncio.open_connection("127.0.0.1", port)` 连接本地 TCP 端口 +- 创建两个异步任务进行双向数据转发 +- WebSocket 使用 `receive_bytes()` 和 `send_bytes()` 处理二进制数据 + +### 步骤 2: 在 SandboxProxyService 中添加辅助方法 + +**文件**: `rock/sandbox/service/sandbox_proxy_service.py` + +添加方法: +```python +def _validate_port(self, port: int) -> tuple[bool, str | None]: + """验证端口是否在允许范围内。""" + +def _get_rocklet_portforward_url(self, sandbox_status_dict: dict, port: int) -> str: + """获取 rocklet portforward 端点的 WebSocket URL。""" +``` + +URL 格式:`ws://{host_ip}:{mapped_port}/portforward?port={target_port}` + +### 步骤 3: 在 SandboxProxyService 中添加代理方法 + +**文件**: `rock/sandbox/service/sandbox_proxy_service.py` + +添加方法: +```python +async def websocket_to_tcp_proxy( + self, + client_websocket: WebSocket, + sandbox_id: str, + port: int, + tcp_connect_timeout: float = 10.0, + idle_timeout: float = 300.0, +) -> None: + """ + 将客户端 WebSocket 连接代理到 rocklet 的 portforward 端点。 + + 流程: + 1. 验证端口 + 2. 获取沙箱状态 + 3. 构建 rocklet portforward URL + 4. 连接到 rocklet WebSocket + 5. 双向转发二进制数据 + """ +``` + +核心实现: +- 使用 `websockets.connect()` 连接 rocklet +- 复用现有的 WebSocket 双向转发模式 + +### 步骤 4: 添加外部 WebSocket 路由端点 + +**文件**: `rock/admin/entrypoints/sandbox_proxy_api.py` + +添加路由: +```python +@sandbox_proxy_router.websocket("/sandboxes/{id}/portforward") +async def portforward(websocket: WebSocket, id: str, port: int): + """ + 外部 WebSocket TCP 端口代理端点。 + """ +``` + +## 架构说明 + +``` +客户端 ──WebSocket──▶ Proxy服务 ──WebSocket──▶ rocklet ──TCP──▶ 目标端口 + │ │ + 外部代理层 内部代理层 + (sandbox_proxy_api.py) (local_api.py) +``` + +**为什么需要两层?** + +Docker 容器只暴露预定义端口(如 8080),无法直接访问容器内动态启动的服务端口。rocklet 运行在容器内,可以直接访问 `127.0.0.1:{任意端口}`。 + +## 依赖 + +无需新增外部依赖: +- `asyncio` - TCP 连接和异步任务 +- `websockets` - WebSocket 客户端(已依赖) + +## 测试计划 + +### 单元测试 +1. 端口验证逻辑 (`_validate_port`) +2. rocklet portforward URL 构建 (`_get_rocklet_portforward_url`) +3. rocklet 端点存在性验证 + +### 集成测试 +1. 正常连接和断开 +2. 双向二进制数据传输 +3. 端口限制验证 +4. 超时处理 +5. 多客户端并发连接 \ No newline at end of file diff --git a/.iflow/specs/websocket-tcp-proxy/spec.md b/.iflow/specs/websocket-tcp-proxy/spec.md new file mode 100644 index 000000000..10a9ac30a --- /dev/null +++ b/.iflow/specs/websocket-tcp-proxy/spec.md @@ -0,0 +1,118 @@ +# WebSocket TCP 端口代理服务 + +## 1. 概述 + +在 ROCK 沙箱代理服务层实现 **WebSocket → TCP** 端口代理,允许客户端通过 WebSocket 连接访问沙箱容器内监听的 TCP 端口。 + +## 2. 架构设计 + +### 2.1 架构图 + +``` +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────────────┐ +│ WebSocket 客户端 │────▶│ Proxy 服务 │────▶│ rocklet (容器内 8080 端口) │ +└─────────────────┘ └─────────────────┘ └─────────────────────────────┘ + /portforward?port=X │ + │ TCP 连接 + ▼ + ┌─────────────────┐ + │ 沙箱内 TCP 端口 │ + │ (如 3000, 8080) │ + └─────────────────┘ +``` + +### 2.2 两层代理架构 + +| 层级 | 组件 | 职责 | +|-----|------|------| +| **外部代理层** | `sandbox_proxy_service.py` | 接收客户端 WebSocket 连接,转发到 rocklet | +| **内部代理层** | `rocklet/local_api.py` | 在容器内连接目标 TCP 端口,实现实际代理 | + +### 2.3 为什么需要两层? + +Docker 容器默认只暴露预定义端口(如 8080),无法直接从外部访问容器内动态启动的服务端口。通过 rocklet 层代理,可以访问容器内任意 TCP 端口。 + +## 3. 功能规格 + +### 3.1 外部 API(客户端调用) + +| 项目 | 规格 | +|-----|------| +| **端点路径** | `WS /sandboxes/{sandbox_id}/portforward?port={port}` | +| **sandbox_id** | 沙箱实例的唯一标识符 | +| **port** | Query 参数,指定要代理的目标 TCP 端口号 | + +**示例:** +``` +ws://localhost:8080/sandboxes/abc123/portforward?port=3000 +``` + +### 3.2 内部 API(rocklet 端点) + +| 项目 | 规格 | +|-----|------| +| **端点路径** | `WS /portforward?port={port}` | +| **port** | Query 参数,指定容器内的目标 TCP 端口号 | +| **监听端口** | rocklet 监听 `Port.PROXY` (22555),通过 Docker 映射到宿主机随机端口 | + +### 3.3 数据传输 + +| 项目 | 规格 | +|-----|------| +| **协议** | WebSocket | +| **数据格式** | 二进制消息(Binary Frame) | +| **转发方式** | 原样透传,不修改数据内容 | + +### 3.4 端口安全限制 + +| 项目 | 规格 | +|-----|------| +| **允许范围** | 1024-65535 | +| **排除端口** | 22 (SSH) | +| **拒绝行为** | WebSocket 关闭码 1008 (Policy Violation) | + +### 3.5 连接生命周期 + +| 场景 | 行为 | +|-----|------| +| **连接建立** | 外部代理 → rocklet → TCP 端口,逐层建立连接 | +| **一对一映射** | 每个 WebSocket 连接对应一个独立的 TCP 连接 | +| **多客户端** | 多个 WebSocket 客户端连接同一端口,各自创建独立的 TCP 连接 | +| **TCP 断开** | 关闭 rocklet WebSocket → 关闭客户端 WebSocket | +| **客户端断开** | 关闭 rocklet WebSocket → 关闭 TCP 连接 | + +### 3.6 超时配置 + +| 超时类型 | 时长 | 说明 | +|---------|------|------| +| **连接超时** | 10 秒 | 建立到 rocklet 或 TCP 端口的连接最大等待时间 | +| **空闲超时** | 300 秒 | 连接无数据传输后自动关闭 | + +### 3.7 错误处理 + +| 错误场景 | 处理方式 | +|---------|---------| +| **端口不在允许范围** | WebSocket 关闭码 1008,提示端口不允许 | +| **端口被排除(如 22)** | WebSocket 关闭码 1008,提示端口不允许 | +| **沙箱不存在** | WebSocket 关闭码 1011,提示沙箱未启动 | +| **TCP 连接超时** | WebSocket 关闭码 1011,提示连接超时 | +| **TCP 连接被拒绝** | WebSocket 关闭码 1011,提示连接失败 | +| **传输中断** | WebSocket 关闭码 1011 | + +### 3.8 认证授权 + +不需要额外的认证/授权机制。 + +## 4. 实现位置 + +| 文件 | 修改内容 | +|-----|---------| +| `rock/rocklet/local_api.py` | 添加 `/portforward` WebSocket 端点,实现容器内 TCP 代理 | +| `rock/sandbox/service/sandbox_proxy_service.py` | 添加 `websocket_to_tcp_proxy()` 方法和 `_get_rocklet_portforward_url()` 方法 | +| `rock/admin/entrypoints/sandbox_proxy_api.py` | 添加外部 WebSocket 路由端点 | + +## 5. 参考设计 + +参考 Kubernetes port-forward 的 WebSocket 实现: +- URL: `/api/v1/namespaces/{namespace}/pods/{name}/portforward?ports={port}` +- 数据流:WebSocket 双向透传 TCP 数据 \ No newline at end of file diff --git a/rock/admin/entrypoints/sandbox_proxy_api.py b/rock/admin/entrypoints/sandbox_proxy_api.py index 2770fa7d2..3ca1e7f64 100644 --- a/rock/admin/entrypoints/sandbox_proxy_api.py +++ b/rock/admin/entrypoints/sandbox_proxy_api.py @@ -1,5 +1,4 @@ import asyncio -import logging from typing import Any from fastapi import APIRouter, Body, File, Form, Request, UploadFile, WebSocket, WebSocketDisconnect @@ -26,9 +25,12 @@ SandboxWriteFileRequest, ) from rock.admin.proto.response import BatchSandboxStatusResponse, SandboxListResponse +from rock.logger import init_logger from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService from rock.utils import handle_exceptions +logger = init_logger(__name__) + sandbox_proxy_router = APIRouter() sandbox_proxy_service: SandboxProxyService @@ -114,13 +116,57 @@ async def list_sandboxes(request: Request) -> RockResponse[SandboxListResponse]: async def websocket_proxy(websocket: WebSocket, id: str, path: str = ""): await websocket.accept() sandbox_id = id - logging.info(f"Client connected to WebSocket proxy: {sandbox_id}, path: {path}") + logger.info(f"Client connected to WebSocket proxy: {sandbox_id}, path: {path}") try: await sandbox_proxy_service.websocket_proxy(websocket, sandbox_id, path) except WebSocketDisconnect: - logging.info(f"Client disconnected from WebSocket proxy: {sandbox_id}") + logger.info(f"Client disconnected from WebSocket proxy: {sandbox_id}") + except Exception as e: + logger.error(f"WebSocket proxy error: {e}") + await websocket.close(code=1011, reason=f"Proxy error: {str(e)}") + + +@sandbox_proxy_router.websocket("/sandboxes/{id}/portforward") +async def portforward(websocket: WebSocket, id: str, port: int): + """ + WebSocket TCP port forwarding endpoint. + + Proxies a WebSocket connection to a TCP port inside the sandbox. + + Args: + id: The sandbox identifier. + port: The target TCP port inside the sandbox. + + Errors: + - Port validation failure: Connection closed with error + - Sandbox not found: Connection closed with error + - TCP connection failure: Connection closed with error + """ + sandbox_id = id + client_host = websocket.client.host if websocket.client else "unknown" + client_port = websocket.client.port if websocket.client else "unknown" + logger.info( + f"[Portforward] Request received: sandbox={sandbox_id}, target_port={port}, " + f"client={client_host}:{client_port}, path={websocket.url.path}" + ) + + try: + logger.info(f"[Portforward] Accepting WebSocket connection: sandbox={sandbox_id}, target_port={port}") + await websocket.accept() + logger.info(f"[Portforward] WebSocket accepted, calling proxy service: sandbox={sandbox_id}, target_port={port}") + await sandbox_proxy_service.websocket_to_tcp_proxy(websocket, sandbox_id, port) + logger.info(f"[Portforward] Proxy service completed: sandbox={sandbox_id}, target_port={port}") + except ValueError as e: + logger.warning(f"[Portforward] Validation failed: sandbox={sandbox_id}, target_port={port}, error={e}") + await websocket.close(code=1008, reason=str(e)) + except WebSocketDisconnect as e: + logger.info(f"[Portforward] Client disconnected: sandbox={sandbox_id}, target_port={port}, code={e.code}") except Exception as e: - logging.error(f"WebSocket proxy error: {e}") + logger.error( + f"[Portforward] Unexpected error: sandbox={sandbox_id}, target_port={port}, " + f"error_type={type(e).__name__}, error={e}", + exc_info=True + ) await websocket.close(code=1011, reason=f"Proxy error: {str(e)}") diff --git a/rock/rocklet/local_api.py b/rock/rocklet/local_api.py index 7513aad7d..61a29096a 100644 --- a/rock/rocklet/local_api.py +++ b/rock/rocklet/local_api.py @@ -1,9 +1,10 @@ +import asyncio import shutil import tempfile import zipfile from pathlib import Path -from fastapi import APIRouter, File, Form, UploadFile +from fastapi import APIRouter, File, Form, UploadFile, WebSocket, WebSocketDisconnect from rock.actions import ( CloseResponse, @@ -24,11 +25,21 @@ from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest +from rock.logger import init_logger from rock.rocklet.local_sandbox import LocalSandboxRuntime from rock.utils import get_executor +logger = init_logger(__name__) + local_router = APIRouter() +# Constants for port forwarding +MIN_ALLOWED_PORT = 1024 +MAX_ALLOWED_PORT = 65535 +FORBIDDEN_PORTS = {22} # SSH port is not allowed +TCP_CONNECT_TIMEOUT = 10 # seconds +IDLE_TIMEOUT = 300 # seconds + runtime = LocalSandboxRuntime(executor=get_executor()) @@ -130,3 +141,182 @@ async def env_close(request: EnvCloseRequest) -> EnvCloseResponse: @local_router.post("/env/list") async def env_list() -> EnvListResponse: return runtime.env_list() + + +def _validate_port(port: int) -> tuple[bool, str]: + """Validate port number for port forwarding. + + Args: + port: Port number to validate + + Returns: + Tuple of (is_valid, error_message) + """ + logger.debug( + f"[Portforward] Validating port: port={port}, " + f"min={MIN_ALLOWED_PORT}, max={MAX_ALLOWED_PORT}, forbidden={FORBIDDEN_PORTS}" + ) + if port < MIN_ALLOWED_PORT: + error_msg = f"Port {port} is not allowed. Minimum allowed port is {MIN_ALLOWED_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port > MAX_ALLOWED_PORT: + error_msg = f"Port {port} is not allowed. Maximum allowed port is {MAX_ALLOWED_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port in FORBIDDEN_PORTS: + error_msg = f"Port {port} is not allowed for port forwarding" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + logger.debug(f"[Portforward] Port validation passed: port={port}") + return True, "" + + +@local_router.websocket("/portforward") +async def portforward(websocket: WebSocket, port: int): + """WebSocket endpoint for TCP port forwarding. + + This endpoint proxies WebSocket connections to a local TCP port, + allowing external clients to access TCP services running inside + the sandbox container. + + Args: + websocket: WebSocket connection from client + port: Target TCP port to connect to (query parameter) + """ + client_host = websocket.client.host if websocket.client else "unknown" + client_port = websocket.client.port if websocket.client else "unknown" + logger.info( + f"[Portforward] Request received: target_port={port}, " + f"client={client_host}:{client_port}, path={websocket.url.path}" + ) + + logger.info(f"[Portforward] Accepting WebSocket: target_port={port}") + await websocket.accept() + logger.info(f"[Portforward] WebSocket accepted: target_port={port}") + + # Validate port + logger.debug(f"[Portforward] Validating port: target_port={port}") + is_valid, error_message = _validate_port(port) + if not is_valid: + logger.warning(f"[Portforward] Port validation failed: target_port={port}, reason={error_message}") + await websocket.close(code=1008, reason=error_message) + return + logger.info(f"[Portforward] Port validation passed: target_port={port}") + + logger.info( + f"[Portforward] Connecting to local TCP: target_port={port}, " + f"address=127.0.0.1:{port}, timeout={TCP_CONNECT_TIMEOUT}s" + ) + + try: + # Connect to local TCP port + reader, writer = await asyncio.wait_for( + asyncio.open_connection("127.0.0.1", port), + timeout=TCP_CONNECT_TIMEOUT + ) + logger.info( + f"[Portforward] TCP connection established: target_port={port}, " + f"local_addr={writer.get_extra_info('sockname')}" + ) + except asyncio.TimeoutError: + logger.error( + f"[Portforward] TCP connection timeout: target_port={port}, " + f"timeout={TCP_CONNECT_TIMEOUT}s" + ) + await websocket.close(code=1011, reason=f"Connection to port {port} timed out") + return + except OSError as e: + logger.error( + f"[Portforward] TCP connection failed: target_port={port}, " + f"error_type={type(e).__name__}, errno={e.errno}, error={e}" + ) + await websocket.close(code=1011, reason=f"Failed to connect to port {port}: {e}") + return + except Exception as e: + logger.error( + f"[Portforward] Unexpected TCP error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + await websocket.close(code=1011, reason=f"Unexpected error: {e}") + return + + logger.info(f"[Portforward] Starting bidirectional forwarding: target_port={port}") + + ws_to_tcp_bytes = 0 + ws_to_tcp_msgs = 0 + tcp_to_ws_bytes = 0 + tcp_to_ws_msgs = 0 + + async def ws_to_tcp(): + """Forward messages from WebSocket to TCP.""" + nonlocal ws_to_tcp_bytes, ws_to_tcp_msgs + try: + while True: + data = await websocket.receive_bytes() + ws_to_tcp_msgs += 1 + ws_to_tcp_bytes += len(data) + writer.write(data) + await writer.drain() + logger.debug( + f"[Portforward] ws->tcp: target_port={port}, " + f"bytes={len(data)}, total_msgs={ws_to_tcp_msgs}, total_bytes={ws_to_tcp_bytes}" + ) + except WebSocketDisconnect as e: + logger.info( + f"[Portforward] ws->tcp: client disconnected: target_port={port}, code={e.code}" + ) + except Exception as e: + logger.debug( + f"[Portforward] ws->tcp error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + finally: + writer.close() + + async def tcp_to_ws(): + """Forward data from TCP to WebSocket.""" + nonlocal tcp_to_ws_bytes, tcp_to_ws_msgs + try: + while True: + data = await reader.read(4096) + if not data: + logger.info(f"[Portforward] tcp->ws: TCP connection closed by peer: target_port={port}") + break + tcp_to_ws_msgs += 1 + tcp_to_ws_bytes += len(data) + await websocket.send_bytes(data) + logger.debug( + f"[Portforward] tcp->ws: target_port={port}, " + f"bytes={len(data)}, total_msgs={tcp_to_ws_msgs}, total_bytes={tcp_to_ws_bytes}" + ) + except Exception as e: + logger.debug( + f"[Portforward] tcp->ws error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + finally: + try: + await websocket.close() + except Exception: + pass + + # Run both directions concurrently + try: + await asyncio.gather(ws_to_tcp(), tcp_to_ws()) + except Exception as e: + logger.debug( + f"[Portforward] Forwarding error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + logger.info( + f"[Portforward] Connection closed: target_port={port}, " + f"ws->tcp: {ws_to_tcp_msgs} msgs, {ws_to_tcp_bytes} bytes, " + f"tcp->ws: {tcp_to_ws_msgs} msgs, {tcp_to_ws_bytes} bytes" + ) diff --git a/rock/sandbox/service/sandbox_proxy_service.py b/rock/sandbox/service/sandbox_proxy_service.py index 250c5ce25..8421196d6 100644 --- a/rock/sandbox/service/sandbox_proxy_service.py +++ b/rock/sandbox/service/sandbox_proxy_service.py @@ -78,6 +78,42 @@ def __init__(self, rock_config: RockConfig, redis_provider: RedisProvider | None self._batch_get_status_max_count = rock_config.proxy_service.batch_get_status_max_count + # Port forwarding constants + PORT_FORWARD_MIN_PORT = 1024 + PORT_FORWARD_MAX_PORT = 65535 + PORT_FORWARD_EXCLUDED_PORTS = {22} # SSH port + + def _validate_port(self, port: int) -> tuple[bool, str | None]: + """ + Validate if the port is allowed for port forwarding. + + Args: + port: The port number to validate. + + Returns: + A tuple of (is_valid, error_message). + If valid, error_message is None. + """ + logger.debug( + f"[Portforward] Validating port: port={port}, " + f"min={self.PORT_FORWARD_MIN_PORT}, max={self.PORT_FORWARD_MAX_PORT}, " + f"excluded={self.PORT_FORWARD_EXCLUDED_PORTS}" + ) + if port < self.PORT_FORWARD_MIN_PORT: + error_msg = f"Port {port} is below minimum allowed port {self.PORT_FORWARD_MIN_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port > self.PORT_FORWARD_MAX_PORT: + error_msg = f"Port {port} is above maximum allowed port {self.PORT_FORWARD_MAX_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port in self.PORT_FORWARD_EXCLUDED_PORTS: + error_msg = f"Port {port} is not allowed for port forwarding" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + logger.debug(f"[Portforward] Port validation passed: port={port}") + return True, None + @monitor_sandbox_operation() async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse: sandbox_id = request.sandbox_id @@ -233,6 +269,297 @@ async def websocket_proxy(self, client_websocket, sandbox_id: str, target_path: logger.error(f"WebSocket proxy error: {e}") await client_websocket.close(code=1011, reason=f"Proxy error: {str(e)}") + async def websocket_to_tcp_proxy( + self, + client_websocket, + sandbox_id: str, + port: int, + tcp_connect_timeout: float = 10.0, + idle_timeout: float = 300.0, + ) -> None: + """ + Proxy WebSocket connection to a TCP port inside the sandbox. + + This method forwards the connection to rocklet's /portforward endpoint, + which then connects to the actual TCP port inside the container. + + Args: + client_websocket: The WebSocket connection from the client. + sandbox_id: The sandbox identifier. + port: The target TCP port inside the sandbox. + tcp_connect_timeout: Timeout for establishing connection (default: 10s). + idle_timeout: Timeout for idle connection (default: 300s). + + Raises: + ValueError: If port validation fails. + Exception: If sandbox is not found or connection fails. + """ + logger.info(f"[Portforward] Starting proxy: sandbox={sandbox_id}, target_port={port}") + + # Validate port + logger.debug(f"[Portforward] Validating port: sandbox={sandbox_id}, target_port={port}") + is_valid, error_msg = self._validate_port(port) + if not is_valid: + logger.warning(f"[Portforward] Port validation failed: sandbox={sandbox_id}, target_port={port}, reason={error_msg}") + raise ValueError(error_msg) + logger.info(f"[Portforward] Port validation passed: sandbox={sandbox_id}, target_port={port}") + + # Get sandbox status and rocklet portforward URL + logger.info(f"[Portforward] Fetching sandbox status: sandbox={sandbox_id}, target_port={port}") + try: + status_dicts = await self.get_service_status(sandbox_id) + logger.info( + f"[Portforward] Sandbox status retrieved: sandbox={sandbox_id}, target_port={port}, " + f"host_ip={status_dicts[0].get('host_ip')}, status_keys={list(status_dicts[0].keys())}" + ) + except Exception as e: + logger.error( + f"[Portforward] Failed to get sandbox status: sandbox={sandbox_id}, target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + raise + + target_url = self._get_rocklet_portforward_url(status_dicts[0], port) + logger.info( + f"[Portforward] Rocklet URL built: sandbox={sandbox_id}, target_port={port}, url={target_url}" + ) + + try: + # Connect to rocklet's portforward WebSocket endpoint + logger.info( + f"[Portforward] Connecting to rocklet: sandbox={sandbox_id}, target_port={port}, " + f"timeout={tcp_connect_timeout}s" + ) + async with websockets.connect( + target_url, + ping_interval=None, + ping_timeout=None, + open_timeout=tcp_connect_timeout, + ) as target_websocket: + logger.info( + f"[Portforward] Rocklet connection established: sandbox={sandbox_id}, target_port={port}" + ) + + # Create bidirectional forwarding tasks + logger.info( + f"[Portforward] Starting bidirectional forwarding: sandbox={sandbox_id}, target_port={port}" + ) + client_to_target = asyncio.create_task( + self._forward_portforward_messages( + client_websocket, target_websocket, "client->rocklet", idle_timeout + ) + ) + target_to_client = asyncio.create_task( + self._forward_portforward_messages( + target_websocket, client_websocket, "rocklet->client", idle_timeout + ) + ) + + # Wait for any task to complete + done, pending = await asyncio.wait( + [client_to_target, target_to_client], return_when=asyncio.FIRST_COMPLETED + ) + + logger.info( + f"[Portforward] Forwarding task completed: sandbox={sandbox_id}, target_port={port}, " + f"completed={[t.get_name() for t in done]}" + ) + + # Cancel unfinished tasks + for task in pending: + logger.debug( + f"[Portforward] Cancelling pending task: sandbox={sandbox_id}, target_port={port}, " + f"task={task.get_name()}" + ) + task.cancel() + + except asyncio.TimeoutError as e: + logger.error( + f"[Portforward] Connection timeout: sandbox={sandbox_id}, target_port={port}, " + f"url={target_url}, timeout={tcp_connect_timeout}s" + ) + await client_websocket.close(code=1011, reason=f"Connection timeout to rocklet: {target_url}") + except websockets.exceptions.ConnectionClosed as e: + logger.warning( + f"[Portforward] Rocklet connection closed: sandbox={sandbox_id}, target_port={port}, " + f"code={e.code}, reason={e.reason}" + ) + await client_websocket.close(code=1011, reason=f"Rocklet connection closed: {e.reason}") + except Exception as e: + logger.error( + f"[Portforward] Proxy error: sandbox={sandbox_id}, target_port={port}, " + f"error_type={type(e).__name__}, error={e}", + exc_info=True + ) + await client_websocket.close(code=1011, reason=f"Proxy error: {str(e)}") + + async def _forward_portforward_messages( + self, + source, + destination, + direction: str, + idle_timeout: float, + ): + """Forward binary messages between WebSocket connections with idle timeout. + + Handles both FastAPI WebSocket and websockets library ClientConnection objects: + - FastAPI WebSocket: receive_bytes(), send_bytes() + - websockets ClientConnection: recv(), send() + """ + logger.debug(f"[Portforward] Starting message forwarder: direction={direction}, idle_timeout={idle_timeout}s") + bytes_transferred = 0 + message_count = 0 + + # Detect the type of WebSocket objects + # FastAPI WebSocket has 'receive_bytes' method + # websockets ClientConnection has 'recv' method + source_is_fastapi = hasattr(source, 'receive_bytes') + dest_is_fastapi = hasattr(destination, 'send_bytes') + + logger.debug( + f"[Portforward] Connection types: direction={direction}, " + f"source_is_fastapi={source_is_fastapi}, dest_is_fastapi={dest_is_fastapi}" + ) + + try: + while True: + try: + # Receive data based on source type + if source_is_fastapi: + data = await asyncio.wait_for( + source.receive_bytes(), + timeout=idle_timeout, + ) + else: + # websockets library returns bytes or str + data = await asyncio.wait_for( + source.recv(), + timeout=idle_timeout, + ) + # Convert str to bytes if needed + if isinstance(data, str): + data = data.encode('utf-8') + + message_count += 1 + bytes_transferred += len(data) + + # Send data based on destination type + if dest_is_fastapi: + await destination.send_bytes(data) + else: + await destination.send(data) + + logger.debug( + f"[Portforward] Forwarded message: direction={direction}, " + f"msg_num={message_count}, bytes={len(data)}, total_bytes={bytes_transferred}" + ) + except asyncio.TimeoutError: + logger.info( + f"[Portforward] Idle timeout: direction={direction}, " + f"total_messages={message_count}, total_bytes={bytes_transferred}" + ) + break + except Exception as e: + logger.info( + f"[Portforward] Forwarder stopped: direction={direction}, " + f"error_type={type(e).__name__}, error={e}, " + f"total_messages={message_count}, total_bytes={bytes_transferred}" + ) + + def _get_tcp_target_address(self, sandbox_status_dict: dict, port: int) -> tuple[str, int]: + """ + Get the target TCP address from sandbox status. + + Args: + sandbox_status_dict: The sandbox status dictionary. + port: The target port inside the sandbox. + + Returns: + A tuple of (host_ip, mapped_port). + """ + host_ip = sandbox_status_dict.get("host_ip") + service_status = ServiceStatus.from_dict(sandbox_status_dict) + # Use SERVER port mapping to access the sandbox + # The port inside the container is mapped to a host port + mapped_port = service_status.get_mapped_port(Port.SERVER) + # For now, we use the port as-is since we're connecting to the sandbox's network + # In a real scenario, we might need to use the mapped port or connect through the container network + return host_ip, port + + def _get_rocklet_portforward_url(self, sandbox_status_dict: dict, port: int) -> str: + """ + Get the WebSocket URL for rocklet's portforward endpoint. + + Args: + sandbox_status_dict: The sandbox status dictionary. + port: The target TCP port inside the sandbox. + + Returns: + WebSocket URL for rocklet's portforward endpoint. + """ + host_ip = sandbox_status_dict.get("host_ip") + service_status = ServiceStatus.from_dict(sandbox_status_dict) + # Use PROXY port mapping to access the rocklet service + # rocklet listens on Port.PROXY (22555) inside the container + mapped_port = service_status.get_mapped_port(Port.PROXY) + + logger.info( + f"[Portforward] Building rocklet URL: host_ip={host_ip}, " + f"container_port={Port.PROXY.value}, mapped_port={mapped_port}, target_port={port}" + ) + + if not host_ip: + logger.error(f"[Portforward] Missing host_ip in sandbox status: keys={list(sandbox_status_dict.keys())}") + if not mapped_port: + logger.error( + f"[Portforward] Missing mapped port for PROXY: " + f"available_ports={service_status.ports if hasattr(service_status, 'ports') else 'unknown'}" + ) + + url = f"ws://{host_ip}:{mapped_port}/portforward?port={port}" + logger.info(f"[Portforward] Generated rocklet URL: {url}") + return url + + async def _forward_websocket_to_tcp(self, websocket, writer, direction: str, idle_timeout: float): + """Forward data from WebSocket to TCP connection.""" + try: + while True: + try: + # Wait for data with idle timeout + data = await asyncio.wait_for( + websocket.receive_bytes(), + timeout=idle_timeout, + ) + writer.write(data) + await writer.drain() + logger.debug(f"Forwarded {len(data)} bytes {direction}") + except asyncio.TimeoutError: + logger.info(f"Idle timeout reached for {direction}") + break + except Exception as e: + logger.info(f"Connection closed in {direction}: {e}") + + async def _forward_tcp_to_websocket(self, reader, websocket, direction: str, idle_timeout: float): + """Forward data from TCP connection to WebSocket.""" + try: + while True: + try: + # Wait for data with idle timeout + data = await asyncio.wait_for( + reader.read(4096), + timeout=idle_timeout, + ) + if not data: + logger.info(f"TCP connection closed for {direction}") + break + await websocket.send_bytes(data) + logger.debug(f"Forwarded {len(data)} bytes {direction}") + except asyncio.TimeoutError: + logger.info(f"Idle timeout reached for {direction}") + break + except Exception as e: + logger.info(f"Connection closed in {direction}: {e}") + async def get_service_status(self, sandbox_id: str): sandbox_status_dicts = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$") if not sandbox_status_dicts or sandbox_status_dicts[0].get("host_ip") is None: diff --git a/tests/unit/rocklet/test_portforward.py b/tests/unit/rocklet/test_portforward.py new file mode 100644 index 000000000..d098153b2 --- /dev/null +++ b/tests/unit/rocklet/test_portforward.py @@ -0,0 +1,22 @@ +"""Tests for rocklet portforward WebSocket endpoint.""" +import pytest +from fastapi import FastAPI + +from rock.rocklet.local_api import local_router + + +@pytest.fixture +def app(): + """Create test FastAPI app with local_router.""" + app = FastAPI() + app.include_router(local_router, tags=["local"]) + return app + + +class TestPortForwardEndpoint: + """Tests for /portforward WebSocket endpoint.""" + + def test_portforward_endpoint_exists(self, app: FastAPI): + """Test that /portforward endpoint is registered.""" + routes = [route.path for route in app.routes] + assert "/portforward" in routes, "Portforward endpoint should be registered" \ No newline at end of file diff --git a/tests/unit/sandbox/test_websocket_tcp_proxy.py b/tests/unit/sandbox/test_websocket_tcp_proxy.py new file mode 100644 index 000000000..33b7261eb --- /dev/null +++ b/tests/unit/sandbox/test_websocket_tcp_proxy.py @@ -0,0 +1,167 @@ +"""Tests for WebSocket TCP port forwarding proxy.""" +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI, WebSocket +from fastapi.testclient import TestClient + +from rock.admin.entrypoints.sandbox_proxy_api import sandbox_proxy_router, set_sandbox_proxy_service + + +class TestValidatePort: + """Tests for _validate_port method.""" + + def test_validate_port_accepts_valid_port_in_range(self, sandbox_proxy_service): + """Valid port in allowed range should pass validation.""" + is_valid, error = sandbox_proxy_service._validate_port(8080) + assert is_valid is True + assert error is None + + def test_validate_port_rejects_port_below_1024(self, sandbox_proxy_service): + """Port below 1024 should be rejected.""" + is_valid, error = sandbox_proxy_service._validate_port(1023) + assert is_valid is False + assert "1024" in error + + def test_validate_port_rejects_port_22(self, sandbox_proxy_service): + """SSH port 22 should be rejected even though it's in the excluded list.""" + is_valid, error = sandbox_proxy_service._validate_port(22) + assert is_valid is False + assert "22" in error + + def test_validate_port_rejects_port_above_65535(self, sandbox_proxy_service): + """Port above 65535 should be rejected.""" + is_valid, error = sandbox_proxy_service._validate_port(65536) + assert is_valid is False + assert "65535" in error + + def test_validate_port_accepts_min_allowed_port(self, sandbox_proxy_service): + """Minimum allowed port (1024) should pass validation.""" + is_valid, error = sandbox_proxy_service._validate_port(1024) + assert is_valid is True + assert error is None + + def test_validate_port_accepts_max_allowed_port(self, sandbox_proxy_service): + """Maximum allowed port (65535) should pass validation.""" + is_valid, error = sandbox_proxy_service._validate_port(65535) + assert is_valid is True + assert error is None + + def test_validate_port_rejects_zero_port(self, sandbox_proxy_service): + """Port 0 should be rejected.""" + is_valid, error = sandbox_proxy_service._validate_port(0) + assert is_valid is False + assert error is not None + + def test_validate_port_rejects_negative_port(self, sandbox_proxy_service): + """Negative port should be rejected.""" + is_valid, error = sandbox_proxy_service._validate_port(-1) + assert is_valid is False + assert error is not None + + +class TestGetRockletPortforwardUrl: + """Tests for _get_rocklet_portforward_url method.""" + + def test_get_rocklet_portforward_url_returns_correct_format(self, sandbox_proxy_service): + """Should return correct WebSocket URL for rocklet portforward endpoint.""" + mock_status = { + "host_ip": "192.168.1.100", + "phases": { + "image_pull": {"status": "running", "message": "done"}, + "docker_run": {"status": "running", "message": "done"}, + }, + "port_mapping": {22555: 32768}, # PROXY port mapping (rocklet listens on 22555) + } + url = sandbox_proxy_service._get_rocklet_portforward_url(mock_status, 9000) + # Should connect to rocklet's PROXY port (22555 mapped to 32768) with /portforward path + assert url == "ws://192.168.1.100:32768/portforward?port=9000" + + +class TestWebsocketToTcpProxy: + """Tests for websocket_to_tcp_proxy method.""" + + @pytest.mark.asyncio + async def test_websocket_to_tcp_proxy_validates_port(self, sandbox_proxy_service): + """Should raise error for invalid port before attempting connection.""" + mock_websocket = MagicMock(spec=WebSocket) + + with pytest.raises(ValueError) as exc_info: + await sandbox_proxy_service.websocket_to_tcp_proxy( + client_websocket=mock_websocket, + sandbox_id="test-sandbox", + port=22, # Invalid port (SSH) + ) + assert "22" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_websocket_to_tcp_proxy_raises_for_nonexistent_sandbox(self, sandbox_proxy_service): + """Should raise error when sandbox does not exist.""" + mock_websocket = MagicMock(spec=WebSocket) + + with patch.object( + sandbox_proxy_service, "get_service_status", side_effect=Exception("sandbox not-started not started") + ): + with pytest.raises(Exception) as exc_info: + await sandbox_proxy_service.websocket_to_tcp_proxy( + client_websocket=mock_websocket, + sandbox_id="nonexistent-sandbox", + port=8080, + ) + assert "not started" in str(exc_info.value) + + +class TestPortForwardRoute: + """Tests for the portforward WebSocket route.""" + + @pytest.fixture + def app(self): + """Create a FastAPI app with the sandbox_proxy_router.""" + mock_service = MagicMock() + mock_service.websocket_to_tcp_proxy = AsyncMock() + set_sandbox_proxy_service(mock_service) + + app = FastAPI() + app.include_router(sandbox_proxy_router) + return app, mock_service + + def test_portforward_route_valid_port_calls_service(self, app): + """Should call websocket_to_tcp_proxy with correct parameters.""" + app, mock_service = app + client = TestClient(app) + + # Mock the service to complete immediately + mock_service.websocket_to_tcp_proxy.return_value = None + + # The WebSocket connection will be accepted and the service called + # Since the mock returns immediately, the connection should close + try: + with client.websocket_connect("/sandboxes/test-sandbox/portforward?port=8080") as websocket: + # Connection established, service should be called + pass + except Exception: + # Expected: connection may close due to mock behavior + pass + + # Verify the service was called with correct parameters + mock_service.websocket_to_tcp_proxy.assert_called_once() + call_args = mock_service.websocket_to_tcp_proxy.call_args + # sandbox_id is the second positional argument + assert call_args[0][1] == "test-sandbox" + # port is the third positional argument + assert call_args[0][2] == 8080 + + def test_portforward_route_invalid_port_rejected(self, app): + """Should reject invalid port (port 22) with close code 1008.""" + app, mock_service = app + mock_service.websocket_to_tcp_proxy.side_effect = ValueError("Port 22 is not allowed for port forwarding") + + client = TestClient(app) + + # The WebSocket connection will be closed with code 1008 + with client.websocket_connect("/sandboxes/test-sandbox/portforward?port=22") as websocket: + # Connection should be closed by server + pass + + # The service should have been called + mock_service.websocket_to_tcp_proxy.assert_called_once() \ No newline at end of file From 45ef44d4467b315009aaf5fcb9fb44abec4f9d22 Mon Sep 17 00:00:00 2001 From: Timandes White Date: Sat, 28 Feb 2026 10:40:37 +0800 Subject: [PATCH 05/11] fix: forward WebSocket close frame from rocklet to client - Capture close frame (code, reason) when rocklet closes connection - Forward close frame to client instead of silently dropping - Properly wait for cancelled tasks to complete - Improve logging for connection close events --- rock/sandbox/service/sandbox_proxy_service.py | 51 +++++++++++++++++-- 1 file changed, 48 insertions(+), 3 deletions(-) diff --git a/rock/sandbox/service/sandbox_proxy_service.py b/rock/sandbox/service/sandbox_proxy_service.py index 8421196d6..af8fb92cd 100644 --- a/rock/sandbox/service/sandbox_proxy_service.py +++ b/rock/sandbox/service/sandbox_proxy_service.py @@ -340,10 +340,14 @@ async def websocket_to_tcp_proxy( f"[Portforward] Rocklet connection established: sandbox={sandbox_id}, target_port={port}" ) - # Create bidirectional forwarding tasks + # Create bidirectional forwarding tasks with close info tracking logger.info( f"[Portforward] Starting bidirectional forwarding: sandbox={sandbox_id}, target_port={port}" ) + + # Track close frame info from rocklet + rocklet_close_info = {} + client_to_target = asyncio.create_task( self._forward_portforward_messages( client_websocket, target_websocket, "client->rocklet", idle_timeout @@ -351,7 +355,8 @@ async def websocket_to_tcp_proxy( ) target_to_client = asyncio.create_task( self._forward_portforward_messages( - target_websocket, client_websocket, "rocklet->client", idle_timeout + target_websocket, client_websocket, "rocklet->client", idle_timeout, + close_info=rocklet_close_info ) ) @@ -372,6 +377,23 @@ async def websocket_to_tcp_proxy( f"task={task.get_name()}" ) task.cancel() + try: + await task # Wait for cancellation to complete + except asyncio.CancelledError: + pass + + # If rocklet closed the connection, forward close frame to client + if rocklet_close_info.get('direction') == 'rocklet->client': + close_code = rocklet_close_info.get('code', 1000) + close_reason = rocklet_close_info.get('reason', '') + logger.info( + f"[Portforward] Forwarding close to client: sandbox={sandbox_id}, target_port={port}, " + f"code={close_code}, reason={close_reason}" + ) + try: + await client_websocket.close(code=close_code, reason=close_reason) + except Exception as e: + logger.debug(f"[Portforward] Error closing client connection: {e}") except asyncio.TimeoutError as e: logger.error( @@ -384,7 +406,7 @@ async def websocket_to_tcp_proxy( f"[Portforward] Rocklet connection closed: sandbox={sandbox_id}, target_port={port}, " f"code={e.code}, reason={e.reason}" ) - await client_websocket.close(code=1011, reason=f"Rocklet connection closed: {e.reason}") + await client_websocket.close(code=e.code, reason=e.reason or "") except Exception as e: logger.error( f"[Portforward] Proxy error: sandbox={sandbox_id}, target_port={port}, " @@ -399,12 +421,22 @@ async def _forward_portforward_messages( destination, direction: str, idle_timeout: float, + close_info: dict | None = None, ): """Forward binary messages between WebSocket connections with idle timeout. Handles both FastAPI WebSocket and websockets library ClientConnection objects: - FastAPI WebSocket: receive_bytes(), send_bytes() - websockets ClientConnection: recv(), send() + + When source connection closes, captures close frame info for forwarding. + + Args: + source: Source WebSocket connection + destination: Destination WebSocket connection + direction: Direction string for logging (e.g., "client->rocklet") + idle_timeout: Idle timeout in seconds + close_info: Optional dict to store close frame info (code, reason) """ logger.debug(f"[Portforward] Starting message forwarder: direction={direction}, idle_timeout={idle_timeout}s") bytes_transferred = 0 @@ -459,6 +491,19 @@ async def _forward_portforward_messages( f"total_messages={message_count}, total_bytes={bytes_transferred}" ) break + except websockets.exceptions.ConnectionClosed as e: + # Capture close frame info for forwarding + close_code = e.code + close_reason = e.reason or "" + logger.info( + f"[Portforward] Source connection closed: direction={direction}, " + f"code={close_code}, reason={close_reason}, " + f"total_messages={message_count}, total_bytes={bytes_transferred}" + ) + if close_info is not None: + close_info['code'] = close_code + close_info['reason'] = close_reason + close_info['direction'] = direction except Exception as e: logger.info( f"[Portforward] Forwarder stopped: direction={direction}, " From f99fd3f7882c424aaeb43f33bea4ec3db11eb951 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9A=86=E5=AE=87?= Date: Fri, 6 Mar 2026 12:21:26 +0800 Subject: [PATCH 06/11] refactor: extract port validation to shared module Address PR #523 review feedback - duplicate _validate_port function exists in both sandbox_proxy_service.py and local_api.py. - Add rock/common/port_validation.py with shared validate_port_forward_port() - Update sandbox_proxy_service.py to use shared function - Update local_api.py to use shared function - Add unit tests for the shared module --- rock/common/port_validation.py | 40 ++++++++++ rock/rocklet/local_api.py | 37 +-------- rock/sandbox/service/sandbox_proxy_service.py | 39 +-------- tests/unit/common/test_port_validation.py | 80 +++++++++++++++++++ .../unit/sandbox/test_websocket_tcp_proxy.py | 35 ++++---- 5 files changed, 143 insertions(+), 88 deletions(-) create mode 100644 rock/common/port_validation.py create mode 100644 tests/unit/common/test_port_validation.py diff --git a/rock/common/port_validation.py b/rock/common/port_validation.py new file mode 100644 index 000000000..17cb1fbc2 --- /dev/null +++ b/rock/common/port_validation.py @@ -0,0 +1,40 @@ +"""Port validation utilities for port forwarding.""" +from rock.logger import init_logger + +logger = init_logger(__name__) + +# Port forwarding constants +PORT_FORWARD_MIN_PORT = 1024 +PORT_FORWARD_MAX_PORT = 65535 +PORT_FORWARD_EXCLUDED_PORTS = {22} # SSH port + + +def validate_port_forward_port(port: int) -> tuple[bool, str | None]: + """Validate if the port is allowed for port forwarding. + + Args: + port: The port number to validate. + + Returns: + A tuple of (is_valid, error_message). + If valid, error_message is None. + """ + logger.debug( + f"[Portforward] Validating port: port={port}, " + f"min={PORT_FORWARD_MIN_PORT}, max={PORT_FORWARD_MAX_PORT}, " + f"excluded={PORT_FORWARD_EXCLUDED_PORTS}" + ) + if port < PORT_FORWARD_MIN_PORT: + error_msg = f"Port {port} is below minimum allowed port {PORT_FORWARD_MIN_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port > PORT_FORWARD_MAX_PORT: + error_msg = f"Port {port} is above maximum allowed port {PORT_FORWARD_MAX_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port in PORT_FORWARD_EXCLUDED_PORTS: + error_msg = f"Port {port} is not allowed for port forwarding" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + logger.debug(f"[Portforward] Port validation passed: port={port}") + return True, None diff --git a/rock/rocklet/local_api.py b/rock/rocklet/local_api.py index 61a29096a..97ebd6d02 100644 --- a/rock/rocklet/local_api.py +++ b/rock/rocklet/local_api.py @@ -25,6 +25,7 @@ from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest +from rock.common.port_validation import validate_port_forward_port from rock.logger import init_logger from rock.rocklet.local_sandbox import LocalSandboxRuntime from rock.utils import get_executor @@ -33,10 +34,7 @@ local_router = APIRouter() -# Constants for port forwarding -MIN_ALLOWED_PORT = 1024 -MAX_ALLOWED_PORT = 65535 -FORBIDDEN_PORTS = {22} # SSH port is not allowed +# Timeouts for port forwarding TCP_CONNECT_TIMEOUT = 10 # seconds IDLE_TIMEOUT = 300 # seconds @@ -143,35 +141,6 @@ async def env_list() -> EnvListResponse: return runtime.env_list() -def _validate_port(port: int) -> tuple[bool, str]: - """Validate port number for port forwarding. - - Args: - port: Port number to validate - - Returns: - Tuple of (is_valid, error_message) - """ - logger.debug( - f"[Portforward] Validating port: port={port}, " - f"min={MIN_ALLOWED_PORT}, max={MAX_ALLOWED_PORT}, forbidden={FORBIDDEN_PORTS}" - ) - if port < MIN_ALLOWED_PORT: - error_msg = f"Port {port} is not allowed. Minimum allowed port is {MIN_ALLOWED_PORT}" - logger.warning(f"[Portforward] Port validation failed: {error_msg}") - return False, error_msg - if port > MAX_ALLOWED_PORT: - error_msg = f"Port {port} is not allowed. Maximum allowed port is {MAX_ALLOWED_PORT}" - logger.warning(f"[Portforward] Port validation failed: {error_msg}") - return False, error_msg - if port in FORBIDDEN_PORTS: - error_msg = f"Port {port} is not allowed for port forwarding" - logger.warning(f"[Portforward] Port validation failed: {error_msg}") - return False, error_msg - logger.debug(f"[Portforward] Port validation passed: port={port}") - return True, "" - - @local_router.websocket("/portforward") async def portforward(websocket: WebSocket, port: int): """WebSocket endpoint for TCP port forwarding. @@ -197,7 +166,7 @@ async def portforward(websocket: WebSocket, port: int): # Validate port logger.debug(f"[Portforward] Validating port: target_port={port}") - is_valid, error_message = _validate_port(port) + is_valid, error_message = validate_port_forward_port(port) if not is_valid: logger.warning(f"[Portforward] Port validation failed: target_port={port}, reason={error_message}") await websocket.close(code=1008, reason=error_message) diff --git a/rock/sandbox/service/sandbox_proxy_service.py b/rock/sandbox/service/sandbox_proxy_service.py index af8fb92cd..50ce31d8f 100644 --- a/rock/sandbox/service/sandbox_proxy_service.py +++ b/rock/sandbox/service/sandbox_proxy_service.py @@ -38,6 +38,7 @@ from rock.config import OssConfig, ProxyServiceConfig, RockConfig from rock.deployments.constants import Port from rock.deployments.status import ServiceStatus +from rock.common.port_validation import validate_port_forward_port from rock.logger import init_logger from rock.sdk.common.exceptions import BadRequestRockError from rock.utils import EAGLE_EYE_TRACE_ID, trace_id_ctx_var @@ -78,42 +79,6 @@ def __init__(self, rock_config: RockConfig, redis_provider: RedisProvider | None self._batch_get_status_max_count = rock_config.proxy_service.batch_get_status_max_count - # Port forwarding constants - PORT_FORWARD_MIN_PORT = 1024 - PORT_FORWARD_MAX_PORT = 65535 - PORT_FORWARD_EXCLUDED_PORTS = {22} # SSH port - - def _validate_port(self, port: int) -> tuple[bool, str | None]: - """ - Validate if the port is allowed for port forwarding. - - Args: - port: The port number to validate. - - Returns: - A tuple of (is_valid, error_message). - If valid, error_message is None. - """ - logger.debug( - f"[Portforward] Validating port: port={port}, " - f"min={self.PORT_FORWARD_MIN_PORT}, max={self.PORT_FORWARD_MAX_PORT}, " - f"excluded={self.PORT_FORWARD_EXCLUDED_PORTS}" - ) - if port < self.PORT_FORWARD_MIN_PORT: - error_msg = f"Port {port} is below minimum allowed port {self.PORT_FORWARD_MIN_PORT}" - logger.warning(f"[Portforward] Port validation failed: {error_msg}") - return False, error_msg - if port > self.PORT_FORWARD_MAX_PORT: - error_msg = f"Port {port} is above maximum allowed port {self.PORT_FORWARD_MAX_PORT}" - logger.warning(f"[Portforward] Port validation failed: {error_msg}") - return False, error_msg - if port in self.PORT_FORWARD_EXCLUDED_PORTS: - error_msg = f"Port {port} is not allowed for port forwarding" - logger.warning(f"[Portforward] Port validation failed: {error_msg}") - return False, error_msg - logger.debug(f"[Portforward] Port validation passed: port={port}") - return True, None - @monitor_sandbox_operation() async def create_session(self, request: CreateSessionRequest) -> CreateBashSessionResponse: sandbox_id = request.sandbox_id @@ -298,7 +263,7 @@ async def websocket_to_tcp_proxy( # Validate port logger.debug(f"[Portforward] Validating port: sandbox={sandbox_id}, target_port={port}") - is_valid, error_msg = self._validate_port(port) + is_valid, error_msg = validate_port_forward_port(port) if not is_valid: logger.warning(f"[Portforward] Port validation failed: sandbox={sandbox_id}, target_port={port}, reason={error_msg}") raise ValueError(error_msg) diff --git a/tests/unit/common/test_port_validation.py b/tests/unit/common/test_port_validation.py new file mode 100644 index 000000000..b490da5b2 --- /dev/null +++ b/tests/unit/common/test_port_validation.py @@ -0,0 +1,80 @@ +"""Tests for common port validation utilities.""" +import pytest + +from rock.common.port_validation import ( + PORT_FORWARD_MIN_PORT, + PORT_FORWARD_MAX_PORT, + PORT_FORWARD_EXCLUDED_PORTS, + validate_port_forward_port, +) + + +class TestPortValidationConstants: + """Tests for port validation constants.""" + + def test_min_port_is_1024(self): + """Minimum allowed port should be 1024.""" + assert PORT_FORWARD_MIN_PORT == 1024 + + def test_max_port_is_65535(self): + """Maximum allowed port should be 65535.""" + assert PORT_FORWARD_MAX_PORT == 65535 + + def test_excluded_ports_contains_22(self): + """SSH port 22 should be excluded.""" + assert 22 in PORT_FORWARD_EXCLUDED_PORTS + + +class TestValidatePortForwardPort: + """Tests for validate_port_forward_port function.""" + + def test_accepts_valid_port_in_range(self): + """Valid port in allowed range should pass validation.""" + is_valid, error = validate_port_forward_port(8080) + assert is_valid is True + assert error is None + + def test_accepts_min_allowed_port(self): + """Minimum allowed port (1024) should pass validation.""" + is_valid, error = validate_port_forward_port(1024) + assert is_valid is True + assert error is None + + def test_accepts_max_allowed_port(self): + """Maximum allowed port (65535) should pass validation.""" + is_valid, error = validate_port_forward_port(65535) + assert is_valid is True + assert error is None + + def test_rejects_port_below_1024(self): + """Port below 1024 should be rejected.""" + is_valid, error = validate_port_forward_port(1023) + assert is_valid is False + assert error is not None + assert "1024" in error + + def test_rejects_port_above_65535(self): + """Port above 65535 should be rejected.""" + is_valid, error = validate_port_forward_port(65536) + assert is_valid is False + assert error is not None + assert "65535" in error + + def test_rejects_port_22(self): + """SSH port 22 should be rejected.""" + is_valid, error = validate_port_forward_port(22) + assert is_valid is False + assert error is not None + assert "22" in error + + def test_rejects_zero_port(self): + """Port 0 should be rejected.""" + is_valid, error = validate_port_forward_port(0) + assert is_valid is False + assert error is not None + + def test_rejects_negative_port(self): + """Negative port should be rejected.""" + is_valid, error = validate_port_forward_port(-1) + assert is_valid is False + assert error is not None diff --git a/tests/unit/sandbox/test_websocket_tcp_proxy.py b/tests/unit/sandbox/test_websocket_tcp_proxy.py index 33b7261eb..cae9f9570 100644 --- a/tests/unit/sandbox/test_websocket_tcp_proxy.py +++ b/tests/unit/sandbox/test_websocket_tcp_proxy.py @@ -6,56 +6,57 @@ from fastapi.testclient import TestClient from rock.admin.entrypoints.sandbox_proxy_api import sandbox_proxy_router, set_sandbox_proxy_service +from rock.common.port_validation import validate_port_forward_port class TestValidatePort: - """Tests for _validate_port method.""" + """Tests for port validation function.""" - def test_validate_port_accepts_valid_port_in_range(self, sandbox_proxy_service): + def test_validate_port_accepts_valid_port_in_range(self): """Valid port in allowed range should pass validation.""" - is_valid, error = sandbox_proxy_service._validate_port(8080) + is_valid, error = validate_port_forward_port(8080) assert is_valid is True assert error is None - def test_validate_port_rejects_port_below_1024(self, sandbox_proxy_service): + def test_validate_port_rejects_port_below_1024(self): """Port below 1024 should be rejected.""" - is_valid, error = sandbox_proxy_service._validate_port(1023) + is_valid, error = validate_port_forward_port(1023) assert is_valid is False assert "1024" in error - def test_validate_port_rejects_port_22(self, sandbox_proxy_service): + def test_validate_port_rejects_port_22(self): """SSH port 22 should be rejected even though it's in the excluded list.""" - is_valid, error = sandbox_proxy_service._validate_port(22) + is_valid, error = validate_port_forward_port(22) assert is_valid is False assert "22" in error - def test_validate_port_rejects_port_above_65535(self, sandbox_proxy_service): + def test_validate_port_rejects_port_above_65535(self): """Port above 65535 should be rejected.""" - is_valid, error = sandbox_proxy_service._validate_port(65536) + is_valid, error = validate_port_forward_port(65536) assert is_valid is False assert "65535" in error - def test_validate_port_accepts_min_allowed_port(self, sandbox_proxy_service): + def test_validate_port_accepts_min_allowed_port(self): """Minimum allowed port (1024) should pass validation.""" - is_valid, error = sandbox_proxy_service._validate_port(1024) + is_valid, error = validate_port_forward_port(1024) assert is_valid is True assert error is None - def test_validate_port_accepts_max_allowed_port(self, sandbox_proxy_service): + def test_validate_port_accepts_max_allowed_port(self): """Maximum allowed port (65535) should pass validation.""" - is_valid, error = sandbox_proxy_service._validate_port(65535) + is_valid, error = validate_port_forward_port(65535) assert is_valid is True assert error is None - def test_validate_port_rejects_zero_port(self, sandbox_proxy_service): + def test_validate_port_rejects_zero_port(self): """Port 0 should be rejected.""" - is_valid, error = sandbox_proxy_service._validate_port(0) + is_valid, error = validate_port_forward_port(0) assert is_valid is False assert error is not None - def test_validate_port_rejects_negative_port(self, sandbox_proxy_service): + def test_validate_port_rejects_negative_port(self): """Negative port should be rejected.""" - is_valid, error = sandbox_proxy_service._validate_port(-1) + is_valid, error = validate_port_forward_port(-1) assert is_valid is False assert error is not None From de9ce6b8fa25e20dbb26fddf1b93529a597b7f36 Mon Sep 17 00:00:00 2001 From: Timandes White Date: Fri, 6 Mar 2026 18:06:56 +0800 Subject: [PATCH 07/11] chore: move specs to docs/_specs to hide from Docusaurus --- {.iflow/specs => docs/_specs}/websocket-tcp-proxy/plan.md | 0 {.iflow/specs => docs/_specs}/websocket-tcp-proxy/spec.md | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename {.iflow/specs => docs/_specs}/websocket-tcp-proxy/plan.md (100%) rename {.iflow/specs => docs/_specs}/websocket-tcp-proxy/spec.md (100%) diff --git a/.iflow/specs/websocket-tcp-proxy/plan.md b/docs/_specs/websocket-tcp-proxy/plan.md similarity index 100% rename from .iflow/specs/websocket-tcp-proxy/plan.md rename to docs/_specs/websocket-tcp-proxy/plan.md diff --git a/.iflow/specs/websocket-tcp-proxy/spec.md b/docs/_specs/websocket-tcp-proxy/spec.md similarity index 100% rename from .iflow/specs/websocket-tcp-proxy/spec.md rename to docs/_specs/websocket-tcp-proxy/spec.md From b3b77540b057b806e86dacc0f943d7a8890eb759 Mon Sep 17 00:00:00 2001 From: ShixinPeng <58160712+BCeZn@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:28:25 +0800 Subject: [PATCH 08/11] ci: remove CI request trigger workflow file (#588) --- .github/workflows/CI-request-trigger.yml | 49 ------------------------ 1 file changed, 49 deletions(-) delete mode 100644 .github/workflows/CI-request-trigger.yml diff --git a/.github/workflows/CI-request-trigger.yml b/.github/workflows/CI-request-trigger.yml deleted file mode 100644 index 319c205d2..000000000 --- a/.github/workflows/CI-request-trigger.yml +++ /dev/null @@ -1,49 +0,0 @@ -# This is a basic workflow to help you get started with Actions - -name: CI Request Trigger - -# Controls when the workflow will run -on: - pull_request: - branches: [ "master" ] - - # Allows you to run this workflow manually from the Actions tab - workflow_dispatch: - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number }} - cancel-in-progress: true - -# A workflow run is made up of one or more jobs that can run sequentially or in parallel -jobs: - # This workflow contains a single job called "build" - build: - # The type of runner that the job will run on - runs-on: ubuntu-latest - # work on CI script dir - defaults: - run: - working-directory: .github/scripts - # Steps represent a sequence of tasks that will be executed as part of the job - steps: - # Checks-out your repository under $GITHUB_WORKSPACE, so your job can access it - - uses: actions/checkout@v4 - - # Runs trigger CI - - name: Make the script files executable - run: chmod +x trigger-CI.sh get-CI-result.sh - - name: trigger a CI - run: | - COMMIT_ID=$([ "${{ github.event_name }}" == "pull_request" ] && echo "${{ github.event.pull_request.head.sha }}" || echo "${{ github.sha }}") - echo "Using Commit ID: $COMMIT_ID" - echo "$GITHUB_REF" - PR_ID=$(echo "$GITHUB_REF" | sed 's@refs/pull/\([0-9]\+\)/.*@\1@') - echo "PR ID is $PR_ID" - ./trigger-CI.sh "$COMMIT_ID" "${{ secrets.CI_SECRET }}" "${{ github.event.pull_request.head.repo.clone_url }}" "$PR_ID" - - # Runs get CI result - - name: Get CI result - run: | - COMMIT_ID=$([ "${{ github.event_name }}" == "pull_request" ] && echo "${{ github.event.pull_request.head.sha }}" || echo "${{ github.sha }}") - echo "Using Commit ID: $COMMIT_ID" - ./get-CI-result.sh "$COMMIT_ID" "${{ secrets.CI_SECRET }}" "${{ github.repository }}" \ No newline at end of file From bc80ba5d1ff49f8621e9b05721141574b3bd50e6 Mon Sep 17 00:00:00 2001 From: Fangwen DAI <34794181+FangwenDave@users.noreply.github.com> Date: Mon, 9 Mar 2026 22:05:21 +0800 Subject: [PATCH 09/11] feat: add lifespan rt metrics (#590) --- rock/sandbox/base_actor.py | 33 ++++++---- tests/unit/test_base_actor.py | 111 ++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+), 12 deletions(-) diff --git a/rock/sandbox/base_actor.py b/rock/sandbox/base_actor.py index 32712b322..3a49ab61e 100644 --- a/rock/sandbox/base_actor.py +++ b/rock/sandbox/base_actor.py @@ -37,6 +37,7 @@ class BaseActor: _namespace = "default" _metrics_endpoint = "" _user_defined_tags: dict = {} + _created_time: float = None def __init__( self, @@ -48,6 +49,7 @@ def __init__( self._gauges: dict[str, _Gauge] = {} if isinstance(config, DockerDeploymentConfig) and config.auto_clear_time: self._auto_clear_time_in_minutes = config.auto_clear_time + self._created_time = time.monotonic() self._stop_time = datetime.datetime.now() + datetime.timedelta(minutes=self._auto_clear_time_in_minutes) # Initialize the user and environment info - can be overridden by subclasses self._role = "test" @@ -103,6 +105,9 @@ def _init_monitor(self): self._gauges["net"] = self.meter.create_gauge( name="xrl_gateway.system.network", description="Network Usage", unit="1" ) + self._gauges["rt"] = self.meter.create_gauge( + name="xrl_gateway.system.lifespan_rt", description="Life Span Rt", unit="1" + ) async def _setup_monitor(self): if not env_vars.ROCK_MONITOR_ENABLE: @@ -152,19 +157,20 @@ async def _collect_sandbox_metrics(self, sandbox_id: str): return logger.debug(f"sandbox [{sandbox_id}] metrics = {metrics}") + attributes = { + "sandbox_id": sandbox_id, + "env": self._env, + "role": self._role, + "host": self.host, + "ip": self._ip, + "user_id": self._user_id, + "experiment_id": self._experiment_id, + "namespace": self._namespace, + } + if self._user_defined_tags is not None: + attributes.update(self._user_defined_tags) + if metrics.get("cpu") is not None: - attributes = { - "sandbox_id": sandbox_id, - "env": self._env, - "role": self._role, - "host": self.host, - "ip": self._ip, - } - if self._user_defined_tags is not None: - attributes.update(self._user_defined_tags) - attributes["user_id"] = self._user_id - attributes["experiment_id"] = self._experiment_id - attributes["namespace"] = self._namespace self._gauges["cpu"].set(metrics["cpu"], attributes=attributes) self._gauges["mem"].set(metrics["mem"], attributes=attributes) self._gauges["disk"].set(metrics["disk"], attributes=attributes) @@ -173,6 +179,9 @@ async def _collect_sandbox_metrics(self, sandbox_id: str): logger.debug(f"Successfully reported metrics for sandbox: {sandbox_id}") else: logger.warning(f"No metrics returned for sandbox: {sandbox_id}") + + life_span_rt = time.monotonic() - self._created_time + self._gauges["rt"].set(life_span_rt, attributes=attributes) single_sandbox_report_rt = time.perf_counter() - start logger.debug(f"Single sandbox report rt:{single_sandbox_report_rt:.4f}s") diff --git a/tests/unit/test_base_actor.py b/tests/unit/test_base_actor.py index daa0aa4d9..765686f1d 100644 --- a/tests/unit/test_base_actor.py +++ b/tests/unit/test_base_actor.py @@ -1,8 +1,12 @@ +import datetime +from unittest.mock import MagicMock + import pytest import ray from rock.deployments.config import LocalDeploymentConfig from rock.logger import init_logger +from rock.sandbox.base_actor import BaseActor from rock.sandbox.sandbox_actor import SandboxActor logger = init_logger(__name__) @@ -131,3 +135,110 @@ async def test_user_defined_tags_with_empty_dict(ray_init_shutdown): logger.info(f"Empty dict set successfully: {result}") finally: ray.kill(sandbox_actor) + + +class ConcreteBaseActor(BaseActor): + """Minimal concrete subclass used only for unit testing BaseActor.""" + + async def get_sandbox_statistics(self): + return {"cpu": 10.0, "mem": 20.0, "disk": 30.0, "net": 40.0} + + +def _make_actor() -> ConcreteBaseActor: + """Create a ConcreteBaseActor with lightweight mocked dependencies.""" + config = MagicMock() + config.container_name = "test-container" + config.auto_clear_time = None # skip DockerDeploymentConfig branch + + deployment = MagicMock() + deployment.__class__ = object # make isinstance(deployment, DockerDeployment) return False + + actor = ConcreteBaseActor(config, deployment) + actor.host = "127.0.0.1" + # Pre-populate all gauges with mocks so tests can override selectively + for key in ("cpu", "mem", "disk", "net", "rt"): + actor._gauges[key] = MagicMock() + return actor + + +@pytest.mark.asyncio +async def test_life_span_rt_gauge_is_set_during_metrics_collection(): + """life_span_rt gauge must be set with the elapsed timedelta after collection.""" + actor = _make_actor() + mock_rt_gauge = MagicMock() + actor._gauges["rt"] = mock_rt_gauge + + await actor._collect_sandbox_metrics("test-container") + + assert mock_rt_gauge.set.called, "life_span_rt gauge.set() was never called" + life_span_rt_value = mock_rt_gauge.set.call_args[0][0] + assert isinstance(life_span_rt_value, float), f"Expected float, got {type(life_span_rt_value)}" + assert life_span_rt_value >= 0, "life_span_rt must be non-negative" + + +@pytest.mark.asyncio +async def test_life_span_rt_increases_over_time(): + """life_span_rt reported on a second call must be >= the first call's value.""" + actor = _make_actor() + mock_rt_gauge = MagicMock() + actor._gauges["rt"] = mock_rt_gauge + + await actor._collect_sandbox_metrics("test-container") + first_rt: datetime.timedelta = mock_rt_gauge.set.call_args[0][0] + + await actor._collect_sandbox_metrics("test-container") + second_rt: datetime.timedelta = mock_rt_gauge.set.call_args[0][0] + + assert second_rt >= first_rt, f"life_span_rt should be non-decreasing: first={first_rt}, second={second_rt}" + + +@pytest.mark.asyncio +async def test_life_span_rt_attributes_contain_expected_keys(): + """Attributes passed to life_span_rt gauge must include all standard dimension keys.""" + actor = _make_actor() + actor._env = "prod" + actor._role = "worker" + actor._user_id = "user-42" + actor._experiment_id = "exp-7" + actor._namespace = "ns-test" + actor.host = "10.0.0.1" + + mock_rt_gauge = MagicMock() + actor._gauges["rt"] = mock_rt_gauge + + await actor._collect_sandbox_metrics("test-container") + + attributes = mock_rt_gauge.set.call_args[1]["attributes"] + expected_keys = {"sandbox_id", "env", "role", "host", "ip", "user_id", "experiment_id", "namespace"} + assert expected_keys.issubset(attributes.keys()), f"Missing attribute keys: {expected_keys - attributes.keys()}" + assert attributes["env"] == "prod" + assert attributes["role"] == "worker" + assert attributes["user_id"] == "user-42" + assert attributes["experiment_id"] == "exp-7" + assert attributes["namespace"] == "ns-test" + + +@pytest.mark.asyncio +async def test_life_span_rt_set_even_when_no_cpu_metrics(): + """life_span_rt must be reported even when get_sandbox_statistics returns no cpu data.""" + + class NoCpuActor(BaseActor): + async def get_sandbox_statistics(self): + return {} # cpu key absent + + config = MagicMock() + config.container_name = "no-cpu-container" + config.auto_clear_time = None + deployment = MagicMock() + deployment.__class__ = object + + actor = NoCpuActor(config, deployment) + actor.host = "127.0.0.1" + for key in ("cpu", "mem", "disk", "net", "rt"): + actor._gauges[key] = MagicMock() + + mock_rt_gauge = actor._gauges["rt"] + + await actor._collect_sandbox_metrics("no-cpu-container") + + assert mock_rt_gauge.set.called, "life_span_rt gauge.set() must be called even when cpu metrics are absent" From db878f122ef5a4b6b2803a8bb46fc15fe9802d58 Mon Sep 17 00:00:00 2001 From: Issac-Newton <1556820213@qq.com> Date: Mon, 9 Mar 2026 16:47:33 +0800 Subject: [PATCH 10/11] add enable_remove_container config --- rock/config.py | 1 + rock/deployments/manager.py | 1 + 2 files changed, 2 insertions(+) diff --git a/rock/config.py b/rock/config.py index 4a28e0ccd..23b5852ef 100644 --- a/rock/config.py +++ b/rock/config.py @@ -50,6 +50,7 @@ class SandboxConfig: actor_resource: str = "" actor_resource_num: float = 0.0 gateway_num: int = 1 + enable_remove_container: bool = True @dataclass diff --git a/rock/deployments/manager.py b/rock/deployments/manager.py index 858f46a85..b00cbdc76 100644 --- a/rock/deployments/manager.py +++ b/rock/deployments/manager.py @@ -40,6 +40,7 @@ async def init_config(self, config: DeploymentConfig) -> DockerDeploymentConfig: await self.rock_config.update() docker_deployment_config.actor_resource = self.rock_config.sandbox_config.actor_resource docker_deployment_config.actor_resource_num = self.rock_config.sandbox_config.actor_resource_num + docker_deployment_config.remove_container = self.rock_config.sandbox_config.enable_remove_container return docker_deployment_config def get_deployment(self, config: DeploymentConfig) -> AbstractDeployment: From fe1c7f060b11d412bc5ef9d6ca12296d420bda37 Mon Sep 17 00:00:00 2001 From: Issac-Newton <1556820213@qq.com> Date: Tue, 10 Mar 2026 17:59:48 +0800 Subject: [PATCH 11/11] rename: enable_remove_container to remove_container_enabled --- rock/config.py | 2 +- rock/deployments/manager.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/rock/config.py b/rock/config.py index 23b5852ef..b83154b32 100644 --- a/rock/config.py +++ b/rock/config.py @@ -50,7 +50,7 @@ class SandboxConfig: actor_resource: str = "" actor_resource_num: float = 0.0 gateway_num: int = 1 - enable_remove_container: bool = True + remove_container_enabled: bool = True @dataclass diff --git a/rock/deployments/manager.py b/rock/deployments/manager.py index b00cbdc76..d64820ee6 100644 --- a/rock/deployments/manager.py +++ b/rock/deployments/manager.py @@ -40,7 +40,7 @@ async def init_config(self, config: DeploymentConfig) -> DockerDeploymentConfig: await self.rock_config.update() docker_deployment_config.actor_resource = self.rock_config.sandbox_config.actor_resource docker_deployment_config.actor_resource_num = self.rock_config.sandbox_config.actor_resource_num - docker_deployment_config.remove_container = self.rock_config.sandbox_config.enable_remove_container + docker_deployment_config.remove_container = self.rock_config.sandbox_config.remove_container_enabled return docker_deployment_config def get_deployment(self, config: DeploymentConfig) -> AbstractDeployment: