diff --git a/docs/_specs/websocket-tcp-proxy/plan.md b/docs/_specs/websocket-tcp-proxy/plan.md new file mode 100644 index 000000000..f54ff89d8 --- /dev/null +++ b/docs/_specs/websocket-tcp-proxy/plan.md @@ -0,0 +1,118 @@ +# WebSocket TCP 端口代理服务 - 实现计划 + +## 实现步骤 + +### 步骤 1: 在 rocklet 中添加端口转发端点 + +**文件**: `rock/rocklet/local_api.py` + +添加 WebSocket 端点: +```python +@sandbox_proxy_router.websocket("/portforward") +async def portforward(websocket: WebSocket, port: int): + """ + 容器内 TCP 端口代理端点。 + + 流程: + 1. 接受 WebSocket 连接 + 2. 验证端口 (1024-65535, 排除 22) + 3. 连接到 127.0.0.1:{port} + 4. 双向转发二进制数据 + """ +``` + +核心实现: +- 使用 `asyncio.open_connection("127.0.0.1", port)` 连接本地 TCP 端口 +- 创建两个异步任务进行双向数据转发 +- WebSocket 使用 `receive_bytes()` 和 `send_bytes()` 处理二进制数据 + +### 步骤 2: 在 SandboxProxyService 中添加辅助方法 + +**文件**: `rock/sandbox/service/sandbox_proxy_service.py` + +添加方法: +```python +def _validate_port(self, port: int) -> tuple[bool, str | None]: + """验证端口是否在允许范围内。""" + +def _get_rocklet_portforward_url(self, sandbox_status_dict: dict, port: int) -> str: + """获取 rocklet portforward 端点的 WebSocket URL。""" +``` + +URL 格式:`ws://{host_ip}:{mapped_port}/portforward?port={target_port}` + +### 步骤 3: 在 SandboxProxyService 中添加代理方法 + +**文件**: `rock/sandbox/service/sandbox_proxy_service.py` + +添加方法: +```python +async def websocket_to_tcp_proxy( + self, + client_websocket: WebSocket, + sandbox_id: str, + port: int, + tcp_connect_timeout: float = 10.0, + idle_timeout: float = 300.0, +) -> None: + """ + 将客户端 WebSocket 连接代理到 rocklet 的 portforward 端点。 + + 流程: + 1. 验证端口 + 2. 获取沙箱状态 + 3. 构建 rocklet portforward URL + 4. 连接到 rocklet WebSocket + 5. 双向转发二进制数据 + """ +``` + +核心实现: +- 使用 `websockets.connect()` 连接 rocklet +- 复用现有的 WebSocket 双向转发模式 + +### 步骤 4: 添加外部 WebSocket 路由端点 + +**文件**: `rock/admin/entrypoints/sandbox_proxy_api.py` + +添加路由: +```python +@sandbox_proxy_router.websocket("/sandboxes/{id}/portforward") +async def portforward(websocket: WebSocket, id: str, port: int): + """ + 外部 WebSocket TCP 端口代理端点。 + """ +``` + +## 架构说明 + +``` +客户端 ──WebSocket──▶ Proxy服务 ──WebSocket──▶ rocklet ──TCP──▶ 目标端口 + │ │ + 外部代理层 内部代理层 + (sandbox_proxy_api.py) (local_api.py) +``` + +**为什么需要两层?** + +Docker 容器只暴露预定义端口(如 8080),无法直接访问容器内动态启动的服务端口。rocklet 运行在容器内,可以直接访问 `127.0.0.1:{任意端口}`。 + +## 依赖 + +无需新增外部依赖: +- `asyncio` - TCP 连接和异步任务 +- `websockets` - WebSocket 客户端(已依赖) + +## 测试计划 + +### 单元测试 +1. 端口验证逻辑 (`_validate_port`) +2. rocklet portforward URL 构建 (`_get_rocklet_portforward_url`) +3. rocklet 端点存在性验证 + +### 集成测试 +1. 正常连接和断开 +2. 双向二进制数据传输 +3. 端口限制验证 +4. 超时处理 +5. 多客户端并发连接 \ No newline at end of file diff --git a/docs/_specs/websocket-tcp-proxy/spec.md b/docs/_specs/websocket-tcp-proxy/spec.md new file mode 100644 index 000000000..10a9ac30a --- /dev/null +++ b/docs/_specs/websocket-tcp-proxy/spec.md @@ -0,0 +1,118 @@ +# WebSocket TCP 端口代理服务 + +## 1. 概述 + +在 ROCK 沙箱代理服务层实现 **WebSocket → TCP** 端口代理,允许客户端通过 WebSocket 连接访问沙箱容器内监听的 TCP 端口。 + +## 2. 架构设计 + +### 2.1 架构图 + +``` +┌─────────────────┐ ┌─────────────────┐ ┌─────────────────────────────┐ +│ WebSocket 客户端 │────▶│ Proxy 服务 │────▶│ rocklet (容器内 8080 端口) │ +└─────────────────┘ └─────────────────┘ └─────────────────────────────┘ + /portforward?port=X │ + │ TCP 连接 + ▼ + ┌─────────────────┐ + │ 沙箱内 TCP 端口 │ + │ (如 3000, 8080) │ + └─────────────────┘ +``` + +### 2.2 两层代理架构 + +| 层级 | 组件 | 职责 | +|-----|------|------| +| **外部代理层** | `sandbox_proxy_service.py` | 接收客户端 WebSocket 连接,转发到 rocklet | +| **内部代理层** | `rocklet/local_api.py` | 在容器内连接目标 TCP 端口,实现实际代理 | + +### 2.3 为什么需要两层? + +Docker 容器默认只暴露预定义端口(如 8080),无法直接从外部访问容器内动态启动的服务端口。通过 rocklet 层代理,可以访问容器内任意 TCP 端口。 + +## 3. 功能规格 + +### 3.1 外部 API(客户端调用) + +| 项目 | 规格 | +|-----|------| +| **端点路径** | `WS /sandboxes/{sandbox_id}/portforward?port={port}` | +| **sandbox_id** | 沙箱实例的唯一标识符 | +| **port** | Query 参数,指定要代理的目标 TCP 端口号 | + +**示例:** +``` +ws://localhost:8080/sandboxes/abc123/portforward?port=3000 +``` + +### 3.2 内部 API(rocklet 端点) + +| 项目 | 规格 | +|-----|------| +| **端点路径** | `WS /portforward?port={port}` | +| **port** | Query 参数,指定容器内的目标 TCP 端口号 | +| **监听端口** | rocklet 监听 `Port.PROXY` (22555),通过 Docker 映射到宿主机随机端口 | + +### 3.3 数据传输 + +| 项目 | 规格 | +|-----|------| +| **协议** | WebSocket | +| **数据格式** | 二进制消息(Binary Frame) | +| **转发方式** | 原样透传,不修改数据内容 | + +### 3.4 端口安全限制 + +| 项目 | 规格 | +|-----|------| +| **允许范围** | 1024-65535 | +| **排除端口** | 22 (SSH) | +| **拒绝行为** | WebSocket 关闭码 1008 (Policy Violation) | + +### 3.5 连接生命周期 + +| 场景 | 行为 | +|-----|------| +| **连接建立** | 外部代理 → rocklet → TCP 端口,逐层建立连接 | +| **一对一映射** | 每个 WebSocket 连接对应一个独立的 TCP 连接 | +| **多客户端** | 多个 WebSocket 客户端连接同一端口,各自创建独立的 TCP 连接 | +| **TCP 断开** | 关闭 rocklet WebSocket → 关闭客户端 WebSocket | +| **客户端断开** | 关闭 rocklet WebSocket → 关闭 TCP 连接 | + +### 3.6 超时配置 + +| 超时类型 | 时长 | 说明 | +|---------|------|------| +| **连接超时** | 10 秒 | 建立到 rocklet 或 TCP 端口的连接最大等待时间 | +| **空闲超时** | 300 秒 | 连接无数据传输后自动关闭 | + +### 3.7 错误处理 + +| 错误场景 | 处理方式 | +|---------|---------| +| **端口不在允许范围** | WebSocket 关闭码 1008,提示端口不允许 | +| **端口被排除(如 22)** | WebSocket 关闭码 1008,提示端口不允许 | +| **沙箱不存在** | WebSocket 关闭码 1011,提示沙箱未启动 | +| **TCP 连接超时** | WebSocket 关闭码 1011,提示连接超时 | +| **TCP 连接被拒绝** | WebSocket 关闭码 1011,提示连接失败 | +| **传输中断** | WebSocket 关闭码 1011 | + +### 3.8 认证授权 + +不需要额外的认证/授权机制。 + +## 4. 实现位置 + +| 文件 | 修改内容 | +|-----|---------| +| `rock/rocklet/local_api.py` | 添加 `/portforward` WebSocket 端点,实现容器内 TCP 代理 | +| `rock/sandbox/service/sandbox_proxy_service.py` | 添加 `websocket_to_tcp_proxy()` 方法和 `_get_rocklet_portforward_url()` 方法 | +| `rock/admin/entrypoints/sandbox_proxy_api.py` | 添加外部 WebSocket 路由端点 | + +## 5. 参考设计 + +参考 Kubernetes port-forward 的 WebSocket 实现: +- URL: `/api/v1/namespaces/{namespace}/pods/{name}/portforward?ports={port}` +- 数据流:WebSocket 双向透传 TCP 数据 \ No newline at end of file diff --git a/rock/admin/entrypoints/sandbox_proxy_api.py b/rock/admin/entrypoints/sandbox_proxy_api.py index 2770fa7d2..3ca1e7f64 100644 --- a/rock/admin/entrypoints/sandbox_proxy_api.py +++ b/rock/admin/entrypoints/sandbox_proxy_api.py @@ -1,5 +1,4 @@ import asyncio -import logging from typing import Any from fastapi import APIRouter, Body, File, Form, Request, UploadFile, WebSocket, WebSocketDisconnect @@ -26,9 +25,12 @@ SandboxWriteFileRequest, ) from rock.admin.proto.response import BatchSandboxStatusResponse, SandboxListResponse +from rock.logger import init_logger from rock.sandbox.service.sandbox_proxy_service import SandboxProxyService from rock.utils import handle_exceptions +logger = init_logger(__name__) + sandbox_proxy_router = APIRouter() sandbox_proxy_service: SandboxProxyService @@ -114,13 +116,57 @@ async def list_sandboxes(request: Request) -> RockResponse[SandboxListResponse]: async def websocket_proxy(websocket: WebSocket, id: str, path: str = ""): await websocket.accept() sandbox_id = id - logging.info(f"Client connected to WebSocket proxy: {sandbox_id}, path: {path}") + logger.info(f"Client connected to WebSocket proxy: {sandbox_id}, path: {path}") try: await sandbox_proxy_service.websocket_proxy(websocket, sandbox_id, path) except WebSocketDisconnect: - logging.info(f"Client disconnected from WebSocket proxy: {sandbox_id}") + logger.info(f"Client disconnected from WebSocket proxy: {sandbox_id}") + except Exception as e: + logger.error(f"WebSocket proxy error: {e}") + await websocket.close(code=1011, reason=f"Proxy error: {str(e)}") + + +@sandbox_proxy_router.websocket("/sandboxes/{id}/portforward") +async def portforward(websocket: WebSocket, id: str, port: int): + """ + WebSocket TCP port forwarding endpoint. + + Proxies a WebSocket connection to a TCP port inside the sandbox. + + Args: + id: The sandbox identifier. + port: The target TCP port inside the sandbox. + + Errors: + - Port validation failure: Connection closed with error + - Sandbox not found: Connection closed with error + - TCP connection failure: Connection closed with error + """ + sandbox_id = id + client_host = websocket.client.host if websocket.client else "unknown" + client_port = websocket.client.port if websocket.client else "unknown" + logger.info( + f"[Portforward] Request received: sandbox={sandbox_id}, target_port={port}, " + f"client={client_host}:{client_port}, path={websocket.url.path}" + ) + + try: + logger.info(f"[Portforward] Accepting WebSocket connection: sandbox={sandbox_id}, target_port={port}") + await websocket.accept() + logger.info(f"[Portforward] WebSocket accepted, calling proxy service: sandbox={sandbox_id}, target_port={port}") + await sandbox_proxy_service.websocket_to_tcp_proxy(websocket, sandbox_id, port) + logger.info(f"[Portforward] Proxy service completed: sandbox={sandbox_id}, target_port={port}") + except ValueError as e: + logger.warning(f"[Portforward] Validation failed: sandbox={sandbox_id}, target_port={port}, error={e}") + await websocket.close(code=1008, reason=str(e)) + except WebSocketDisconnect as e: + logger.info(f"[Portforward] Client disconnected: sandbox={sandbox_id}, target_port={port}, code={e.code}") except Exception as e: - logging.error(f"WebSocket proxy error: {e}") + logger.error( + f"[Portforward] Unexpected error: sandbox={sandbox_id}, target_port={port}, " + f"error_type={type(e).__name__}, error={e}", + exc_info=True + ) await websocket.close(code=1011, reason=f"Proxy error: {str(e)}") diff --git a/rock/common/port_validation.py b/rock/common/port_validation.py new file mode 100644 index 000000000..17cb1fbc2 --- /dev/null +++ b/rock/common/port_validation.py @@ -0,0 +1,40 @@ +"""Port validation utilities for port forwarding.""" +from rock.logger import init_logger + +logger = init_logger(__name__) + +# Port forwarding constants +PORT_FORWARD_MIN_PORT = 1024 +PORT_FORWARD_MAX_PORT = 65535 +PORT_FORWARD_EXCLUDED_PORTS = {22} # SSH port + + +def validate_port_forward_port(port: int) -> tuple[bool, str | None]: + """Validate if the port is allowed for port forwarding. + + Args: + port: The port number to validate. + + Returns: + A tuple of (is_valid, error_message). + If valid, error_message is None. + """ + logger.debug( + f"[Portforward] Validating port: port={port}, " + f"min={PORT_FORWARD_MIN_PORT}, max={PORT_FORWARD_MAX_PORT}, " + f"excluded={PORT_FORWARD_EXCLUDED_PORTS}" + ) + if port < PORT_FORWARD_MIN_PORT: + error_msg = f"Port {port} is below minimum allowed port {PORT_FORWARD_MIN_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port > PORT_FORWARD_MAX_PORT: + error_msg = f"Port {port} is above maximum allowed port {PORT_FORWARD_MAX_PORT}" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + if port in PORT_FORWARD_EXCLUDED_PORTS: + error_msg = f"Port {port} is not allowed for port forwarding" + logger.warning(f"[Portforward] Port validation failed: {error_msg}") + return False, error_msg + logger.debug(f"[Portforward] Port validation passed: port={port}") + return True, None diff --git a/rock/rocklet/local_api.py b/rock/rocklet/local_api.py index 7513aad7d..97ebd6d02 100644 --- a/rock/rocklet/local_api.py +++ b/rock/rocklet/local_api.py @@ -1,9 +1,10 @@ +import asyncio import shutil import tempfile import zipfile from pathlib import Path -from fastapi import APIRouter, File, Form, UploadFile +from fastapi import APIRouter, File, Form, UploadFile, WebSocket, WebSocketDisconnect from rock.actions import ( CloseResponse, @@ -24,11 +25,19 @@ from rock.admin.proto.request import SandboxCreateSessionRequest as CreateSessionRequest from rock.admin.proto.request import SandboxReadFileRequest as ReadFileRequest from rock.admin.proto.request import SandboxWriteFileRequest as WriteFileRequest +from rock.common.port_validation import validate_port_forward_port +from rock.logger import init_logger from rock.rocklet.local_sandbox import LocalSandboxRuntime from rock.utils import get_executor +logger = init_logger(__name__) + local_router = APIRouter() +# Timeouts for port forwarding +TCP_CONNECT_TIMEOUT = 10 # seconds +IDLE_TIMEOUT = 300 # seconds + runtime = LocalSandboxRuntime(executor=get_executor()) @@ -130,3 +139,153 @@ async def env_close(request: EnvCloseRequest) -> EnvCloseResponse: @local_router.post("/env/list") async def env_list() -> EnvListResponse: return runtime.env_list() + + +@local_router.websocket("/portforward") +async def portforward(websocket: WebSocket, port: int): + """WebSocket endpoint for TCP port forwarding. + + This endpoint proxies WebSocket connections to a local TCP port, + allowing external clients to access TCP services running inside + the sandbox container. + + Args: + websocket: WebSocket connection from client + port: Target TCP port to connect to (query parameter) + """ + client_host = websocket.client.host if websocket.client else "unknown" + client_port = websocket.client.port if websocket.client else "unknown" + logger.info( + f"[Portforward] Request received: target_port={port}, " + f"client={client_host}:{client_port}, path={websocket.url.path}" + ) + + logger.info(f"[Portforward] Accepting WebSocket: target_port={port}") + await websocket.accept() + logger.info(f"[Portforward] WebSocket accepted: target_port={port}") + + # Validate port + logger.debug(f"[Portforward] Validating port: target_port={port}") + is_valid, error_message = validate_port_forward_port(port) + if not is_valid: + logger.warning(f"[Portforward] Port validation failed: target_port={port}, reason={error_message}") + await websocket.close(code=1008, reason=error_message) + return + logger.info(f"[Portforward] Port validation passed: target_port={port}") + + logger.info( + f"[Portforward] Connecting to local TCP: target_port={port}, " + f"address=127.0.0.1:{port}, timeout={TCP_CONNECT_TIMEOUT}s" + ) + + try: + # Connect to local TCP port + reader, writer = await asyncio.wait_for( + asyncio.open_connection("127.0.0.1", port), + timeout=TCP_CONNECT_TIMEOUT + ) + logger.info( + f"[Portforward] TCP connection established: target_port={port}, " + f"local_addr={writer.get_extra_info('sockname')}" + ) + except asyncio.TimeoutError: + logger.error( + f"[Portforward] TCP connection timeout: target_port={port}, " + f"timeout={TCP_CONNECT_TIMEOUT}s" + ) + await websocket.close(code=1011, reason=f"Connection to port {port} timed out") + return + except OSError as e: + logger.error( + f"[Portforward] TCP connection failed: target_port={port}, " + f"error_type={type(e).__name__}, errno={e.errno}, error={e}" + ) + await websocket.close(code=1011, reason=f"Failed to connect to port {port}: {e}") + return + except Exception as e: + logger.error( + f"[Portforward] Unexpected TCP error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + await websocket.close(code=1011, reason=f"Unexpected error: {e}") + return + + logger.info(f"[Portforward] Starting bidirectional forwarding: target_port={port}") + + ws_to_tcp_bytes = 0 + ws_to_tcp_msgs = 0 + tcp_to_ws_bytes = 0 + tcp_to_ws_msgs = 0 + + async def ws_to_tcp(): + """Forward messages from WebSocket to TCP.""" + nonlocal ws_to_tcp_bytes, ws_to_tcp_msgs + try: + while True: + data = await websocket.receive_bytes() + ws_to_tcp_msgs += 1 + ws_to_tcp_bytes += len(data) + writer.write(data) + await writer.drain() + logger.debug( + f"[Portforward] ws->tcp: target_port={port}, " + f"bytes={len(data)}, total_msgs={ws_to_tcp_msgs}, total_bytes={ws_to_tcp_bytes}" + ) + except WebSocketDisconnect as e: + logger.info( + f"[Portforward] ws->tcp: client disconnected: target_port={port}, code={e.code}" + ) + except Exception as e: + logger.debug( + f"[Portforward] ws->tcp error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + finally: + writer.close() + + async def tcp_to_ws(): + """Forward data from TCP to WebSocket.""" + nonlocal tcp_to_ws_bytes, tcp_to_ws_msgs + try: + while True: + data = await reader.read(4096) + if not data: + logger.info(f"[Portforward] tcp->ws: TCP connection closed by peer: target_port={port}") + break + tcp_to_ws_msgs += 1 + tcp_to_ws_bytes += len(data) + await websocket.send_bytes(data) + logger.debug( + f"[Portforward] tcp->ws: target_port={port}, " + f"bytes={len(data)}, total_msgs={tcp_to_ws_msgs}, total_bytes={tcp_to_ws_bytes}" + ) + except Exception as e: + logger.debug( + f"[Portforward] tcp->ws error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + finally: + try: + await websocket.close() + except Exception: + pass + + # Run both directions concurrently + try: + await asyncio.gather(ws_to_tcp(), tcp_to_ws()) + except Exception as e: + logger.debug( + f"[Portforward] Forwarding error: target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + finally: + writer.close() + try: + await writer.wait_closed() + except Exception: + pass + logger.info( + f"[Portforward] Connection closed: target_port={port}, " + f"ws->tcp: {ws_to_tcp_msgs} msgs, {ws_to_tcp_bytes} bytes, " + f"tcp->ws: {tcp_to_ws_msgs} msgs, {tcp_to_ws_bytes} bytes" + ) diff --git a/rock/sandbox/service/sandbox_proxy_service.py b/rock/sandbox/service/sandbox_proxy_service.py index 83560c0fb..5b1ef0ec9 100644 --- a/rock/sandbox/service/sandbox_proxy_service.py +++ b/rock/sandbox/service/sandbox_proxy_service.py @@ -38,6 +38,7 @@ from rock.config import OssConfig, ProxyServiceConfig, RockConfig from rock.deployments.constants import Port from rock.deployments.status import ServiceStatus +from rock.common.port_validation import validate_port_forward_port from rock.logger import init_logger from rock.sdk.common.exceptions import BadRequestRockError from rock.utils import EAGLE_EYE_TRACE_ID, trace_id_ctx_var @@ -232,6 +233,342 @@ async def websocket_proxy(self, client_websocket, sandbox_id: str, target_path: logger.error(f"WebSocket proxy error: {e}") await client_websocket.close(code=1011, reason=f"Proxy error: {str(e)}") + async def websocket_to_tcp_proxy( + self, + client_websocket, + sandbox_id: str, + port: int, + tcp_connect_timeout: float = 10.0, + idle_timeout: float = 300.0, + ) -> None: + """ + Proxy WebSocket connection to a TCP port inside the sandbox. + + This method forwards the connection to rocklet's /portforward endpoint, + which then connects to the actual TCP port inside the container. + + Args: + client_websocket: The WebSocket connection from the client. + sandbox_id: The sandbox identifier. + port: The target TCP port inside the sandbox. + tcp_connect_timeout: Timeout for establishing connection (default: 10s). + idle_timeout: Timeout for idle connection (default: 300s). + + Raises: + ValueError: If port validation fails. + Exception: If sandbox is not found or connection fails. + """ + logger.info(f"[Portforward] Starting proxy: sandbox={sandbox_id}, target_port={port}") + + # Validate port + logger.debug(f"[Portforward] Validating port: sandbox={sandbox_id}, target_port={port}") + is_valid, error_msg = validate_port_forward_port(port) + if not is_valid: + logger.warning(f"[Portforward] Port validation failed: sandbox={sandbox_id}, target_port={port}, reason={error_msg}") + raise ValueError(error_msg) + logger.info(f"[Portforward] Port validation passed: sandbox={sandbox_id}, target_port={port}") + + # Get sandbox status and rocklet portforward URL + logger.info(f"[Portforward] Fetching sandbox status: sandbox={sandbox_id}, target_port={port}") + try: + status_dicts = await self.get_service_status(sandbox_id) + logger.info( + f"[Portforward] Sandbox status retrieved: sandbox={sandbox_id}, target_port={port}, " + f"host_ip={status_dicts[0].get('host_ip')}, status_keys={list(status_dicts[0].keys())}" + ) + except Exception as e: + logger.error( + f"[Portforward] Failed to get sandbox status: sandbox={sandbox_id}, target_port={port}, " + f"error_type={type(e).__name__}, error={e}" + ) + raise + + target_url = self._get_rocklet_portforward_url(status_dicts[0], port) + logger.info( + f"[Portforward] Rocklet URL built: sandbox={sandbox_id}, target_port={port}, url={target_url}" + ) + + try: + # Connect to rocklet's portforward WebSocket endpoint + logger.info( + f"[Portforward] Connecting to rocklet: sandbox={sandbox_id}, target_port={port}, " + f"timeout={tcp_connect_timeout}s" + ) + async with websockets.connect( + target_url, + ping_interval=None, + ping_timeout=None, + open_timeout=tcp_connect_timeout, + ) as target_websocket: + logger.info( + f"[Portforward] Rocklet connection established: sandbox={sandbox_id}, target_port={port}" + ) + + # Create bidirectional forwarding tasks with close info tracking + logger.info( + f"[Portforward] Starting bidirectional forwarding: sandbox={sandbox_id}, target_port={port}" + ) + + # Track close frame info from rocklet + rocklet_close_info = {} + + client_to_target = asyncio.create_task( + self._forward_portforward_messages( + client_websocket, target_websocket, "client->rocklet", idle_timeout + ) + ) + target_to_client = asyncio.create_task( + self._forward_portforward_messages( + target_websocket, client_websocket, "rocklet->client", idle_timeout, + close_info=rocklet_close_info + ) + ) + + # Wait for any task to complete + done, pending = await asyncio.wait( + [client_to_target, target_to_client], return_when=asyncio.FIRST_COMPLETED + ) + + logger.info( + f"[Portforward] Forwarding task completed: sandbox={sandbox_id}, target_port={port}, " + f"completed={[t.get_name() for t in done]}" + ) + + # Cancel unfinished tasks + for task in pending: + logger.debug( + f"[Portforward] Cancelling pending task: sandbox={sandbox_id}, target_port={port}, " + f"task={task.get_name()}" + ) + task.cancel() + try: + await task # Wait for cancellation to complete + except asyncio.CancelledError: + pass + + # If rocklet closed the connection, forward close frame to client + if rocklet_close_info.get('direction') == 'rocklet->client': + close_code = rocklet_close_info.get('code', 1000) + close_reason = rocklet_close_info.get('reason', '') + logger.info( + f"[Portforward] Forwarding close to client: sandbox={sandbox_id}, target_port={port}, " + f"code={close_code}, reason={close_reason}" + ) + try: + await client_websocket.close(code=close_code, reason=close_reason) + except Exception as e: + logger.debug(f"[Portforward] Error closing client connection: {e}") + + except asyncio.TimeoutError as e: + logger.error( + f"[Portforward] Connection timeout: sandbox={sandbox_id}, target_port={port}, " + f"url={target_url}, timeout={tcp_connect_timeout}s" + ) + await client_websocket.close(code=1011, reason=f"Connection timeout to rocklet: {target_url}") + except websockets.exceptions.ConnectionClosed as e: + logger.warning( + f"[Portforward] Rocklet connection closed: sandbox={sandbox_id}, target_port={port}, " + f"code={e.code}, reason={e.reason}" + ) + await client_websocket.close(code=e.code, reason=e.reason or "") + except Exception as e: + logger.error( + f"[Portforward] Proxy error: sandbox={sandbox_id}, target_port={port}, " + f"error_type={type(e).__name__}, error={e}", + exc_info=True + ) + await client_websocket.close(code=1011, reason=f"Proxy error: {str(e)}") + + async def _forward_portforward_messages( + self, + source, + destination, + direction: str, + idle_timeout: float, + close_info: dict | None = None, + ): + """Forward binary messages between WebSocket connections with idle timeout. + + Handles both FastAPI WebSocket and websockets library ClientConnection objects: + - FastAPI WebSocket: receive_bytes(), send_bytes() + - websockets ClientConnection: recv(), send() + + When source connection closes, captures close frame info for forwarding. + + Args: + source: Source WebSocket connection + destination: Destination WebSocket connection + direction: Direction string for logging (e.g., "client->rocklet") + idle_timeout: Idle timeout in seconds + close_info: Optional dict to store close frame info (code, reason) + """ + logger.debug(f"[Portforward] Starting message forwarder: direction={direction}, idle_timeout={idle_timeout}s") + bytes_transferred = 0 + message_count = 0 + + # Detect the type of WebSocket objects + # FastAPI WebSocket has 'receive_bytes' method + # websockets ClientConnection has 'recv' method + source_is_fastapi = hasattr(source, 'receive_bytes') + dest_is_fastapi = hasattr(destination, 'send_bytes') + + logger.debug( + f"[Portforward] Connection types: direction={direction}, " + f"source_is_fastapi={source_is_fastapi}, dest_is_fastapi={dest_is_fastapi}" + ) + + try: + while True: + try: + # Receive data based on source type + if source_is_fastapi: + data = await asyncio.wait_for( + source.receive_bytes(), + timeout=idle_timeout, + ) + else: + # websockets library returns bytes or str + data = await asyncio.wait_for( + source.recv(), + timeout=idle_timeout, + ) + # Convert str to bytes if needed + if isinstance(data, str): + data = data.encode('utf-8') + + message_count += 1 + bytes_transferred += len(data) + + # Send data based on destination type + if dest_is_fastapi: + await destination.send_bytes(data) + else: + await destination.send(data) + + logger.debug( + f"[Portforward] Forwarded message: direction={direction}, " + f"msg_num={message_count}, bytes={len(data)}, total_bytes={bytes_transferred}" + ) + except asyncio.TimeoutError: + logger.info( + f"[Portforward] Idle timeout: direction={direction}, " + f"total_messages={message_count}, total_bytes={bytes_transferred}" + ) + break + except websockets.exceptions.ConnectionClosed as e: + # Capture close frame info for forwarding + close_code = e.code + close_reason = e.reason or "" + logger.info( + f"[Portforward] Source connection closed: direction={direction}, " + f"code={close_code}, reason={close_reason}, " + f"total_messages={message_count}, total_bytes={bytes_transferred}" + ) + if close_info is not None: + close_info['code'] = close_code + close_info['reason'] = close_reason + close_info['direction'] = direction + except Exception as e: + logger.info( + f"[Portforward] Forwarder stopped: direction={direction}, " + f"error_type={type(e).__name__}, error={e}, " + f"total_messages={message_count}, total_bytes={bytes_transferred}" + ) + + def _get_tcp_target_address(self, sandbox_status_dict: dict, port: int) -> tuple[str, int]: + """ + Get the target TCP address from sandbox status. + + Args: + sandbox_status_dict: The sandbox status dictionary. + port: The target port inside the sandbox. + + Returns: + A tuple of (host_ip, mapped_port). + """ + host_ip = sandbox_status_dict.get("host_ip") + service_status = ServiceStatus.from_dict(sandbox_status_dict) + # Use SERVER port mapping to access the sandbox + # The port inside the container is mapped to a host port + mapped_port = service_status.get_mapped_port(Port.SERVER) + # For now, we use the port as-is since we're connecting to the sandbox's network + # In a real scenario, we might need to use the mapped port or connect through the container network + return host_ip, port + + def _get_rocklet_portforward_url(self, sandbox_status_dict: dict, port: int) -> str: + """ + Get the WebSocket URL for rocklet's portforward endpoint. + + Args: + sandbox_status_dict: The sandbox status dictionary. + port: The target TCP port inside the sandbox. + + Returns: + WebSocket URL for rocklet's portforward endpoint. + """ + host_ip = sandbox_status_dict.get("host_ip") + service_status = ServiceStatus.from_dict(sandbox_status_dict) + # Use PROXY port mapping to access the rocklet service + # rocklet listens on Port.PROXY (22555) inside the container + mapped_port = service_status.get_mapped_port(Port.PROXY) + + logger.info( + f"[Portforward] Building rocklet URL: host_ip={host_ip}, " + f"container_port={Port.PROXY.value}, mapped_port={mapped_port}, target_port={port}" + ) + + if not host_ip: + logger.error(f"[Portforward] Missing host_ip in sandbox status: keys={list(sandbox_status_dict.keys())}") + if not mapped_port: + logger.error( + f"[Portforward] Missing mapped port for PROXY: " + f"available_ports={service_status.ports if hasattr(service_status, 'ports') else 'unknown'}" + ) + + url = f"ws://{host_ip}:{mapped_port}/portforward?port={port}" + logger.info(f"[Portforward] Generated rocklet URL: {url}") + return url + + async def _forward_websocket_to_tcp(self, websocket, writer, direction: str, idle_timeout: float): + """Forward data from WebSocket to TCP connection.""" + try: + while True: + try: + # Wait for data with idle timeout + data = await asyncio.wait_for( + websocket.receive_bytes(), + timeout=idle_timeout, + ) + writer.write(data) + await writer.drain() + logger.debug(f"Forwarded {len(data)} bytes {direction}") + except asyncio.TimeoutError: + logger.info(f"Idle timeout reached for {direction}") + break + except Exception as e: + logger.info(f"Connection closed in {direction}: {e}") + + async def _forward_tcp_to_websocket(self, reader, websocket, direction: str, idle_timeout: float): + """Forward data from TCP connection to WebSocket.""" + try: + while True: + try: + # Wait for data with idle timeout + data = await asyncio.wait_for( + reader.read(4096), + timeout=idle_timeout, + ) + if not data: + logger.info(f"TCP connection closed for {direction}") + break + await websocket.send_bytes(data) + logger.debug(f"Forwarded {len(data)} bytes {direction}") + except asyncio.TimeoutError: + logger.info(f"Idle timeout reached for {direction}") + break + except Exception as e: + logger.info(f"Connection closed in {direction}: {e}") + async def get_service_status(self, sandbox_id: str): sandbox_status_dicts = await self._redis_provider.json_get(alive_sandbox_key(sandbox_id), "$") if not sandbox_status_dicts or sandbox_status_dicts[0].get("host_ip") is None: diff --git a/tests/unit/common/test_port_validation.py b/tests/unit/common/test_port_validation.py new file mode 100644 index 000000000..b490da5b2 --- /dev/null +++ b/tests/unit/common/test_port_validation.py @@ -0,0 +1,80 @@ +"""Tests for common port validation utilities.""" +import pytest + +from rock.common.port_validation import ( + PORT_FORWARD_MIN_PORT, + PORT_FORWARD_MAX_PORT, + PORT_FORWARD_EXCLUDED_PORTS, + validate_port_forward_port, +) + + +class TestPortValidationConstants: + """Tests for port validation constants.""" + + def test_min_port_is_1024(self): + """Minimum allowed port should be 1024.""" + assert PORT_FORWARD_MIN_PORT == 1024 + + def test_max_port_is_65535(self): + """Maximum allowed port should be 65535.""" + assert PORT_FORWARD_MAX_PORT == 65535 + + def test_excluded_ports_contains_22(self): + """SSH port 22 should be excluded.""" + assert 22 in PORT_FORWARD_EXCLUDED_PORTS + + +class TestValidatePortForwardPort: + """Tests for validate_port_forward_port function.""" + + def test_accepts_valid_port_in_range(self): + """Valid port in allowed range should pass validation.""" + is_valid, error = validate_port_forward_port(8080) + assert is_valid is True + assert error is None + + def test_accepts_min_allowed_port(self): + """Minimum allowed port (1024) should pass validation.""" + is_valid, error = validate_port_forward_port(1024) + assert is_valid is True + assert error is None + + def test_accepts_max_allowed_port(self): + """Maximum allowed port (65535) should pass validation.""" + is_valid, error = validate_port_forward_port(65535) + assert is_valid is True + assert error is None + + def test_rejects_port_below_1024(self): + """Port below 1024 should be rejected.""" + is_valid, error = validate_port_forward_port(1023) + assert is_valid is False + assert error is not None + assert "1024" in error + + def test_rejects_port_above_65535(self): + """Port above 65535 should be rejected.""" + is_valid, error = validate_port_forward_port(65536) + assert is_valid is False + assert error is not None + assert "65535" in error + + def test_rejects_port_22(self): + """SSH port 22 should be rejected.""" + is_valid, error = validate_port_forward_port(22) + assert is_valid is False + assert error is not None + assert "22" in error + + def test_rejects_zero_port(self): + """Port 0 should be rejected.""" + is_valid, error = validate_port_forward_port(0) + assert is_valid is False + assert error is not None + + def test_rejects_negative_port(self): + """Negative port should be rejected.""" + is_valid, error = validate_port_forward_port(-1) + assert is_valid is False + assert error is not None diff --git a/tests/unit/rocklet/test_portforward.py b/tests/unit/rocklet/test_portforward.py new file mode 100644 index 000000000..d098153b2 --- /dev/null +++ b/tests/unit/rocklet/test_portforward.py @@ -0,0 +1,22 @@ +"""Tests for rocklet portforward WebSocket endpoint.""" +import pytest +from fastapi import FastAPI + +from rock.rocklet.local_api import local_router + + +@pytest.fixture +def app(): + """Create test FastAPI app with local_router.""" + app = FastAPI() + app.include_router(local_router, tags=["local"]) + return app + + +class TestPortForwardEndpoint: + """Tests for /portforward WebSocket endpoint.""" + + def test_portforward_endpoint_exists(self, app: FastAPI): + """Test that /portforward endpoint is registered.""" + routes = [route.path for route in app.routes] + assert "/portforward" in routes, "Portforward endpoint should be registered" \ No newline at end of file diff --git a/tests/unit/sandbox/test_websocket_tcp_proxy.py b/tests/unit/sandbox/test_websocket_tcp_proxy.py new file mode 100644 index 000000000..cae9f9570 --- /dev/null +++ b/tests/unit/sandbox/test_websocket_tcp_proxy.py @@ -0,0 +1,168 @@ +"""Tests for WebSocket TCP port forwarding proxy.""" +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI, WebSocket +from fastapi.testclient import TestClient + +from rock.admin.entrypoints.sandbox_proxy_api import sandbox_proxy_router, set_sandbox_proxy_service +from rock.common.port_validation import validate_port_forward_port + + +class TestValidatePort: + """Tests for port validation function.""" + + def test_validate_port_accepts_valid_port_in_range(self): + """Valid port in allowed range should pass validation.""" + is_valid, error = validate_port_forward_port(8080) + assert is_valid is True + assert error is None + + def test_validate_port_rejects_port_below_1024(self): + """Port below 1024 should be rejected.""" + is_valid, error = validate_port_forward_port(1023) + assert is_valid is False + assert "1024" in error + + def test_validate_port_rejects_port_22(self): + """SSH port 22 should be rejected even though it's in the excluded list.""" + is_valid, error = validate_port_forward_port(22) + assert is_valid is False + assert "22" in error + + def test_validate_port_rejects_port_above_65535(self): + """Port above 65535 should be rejected.""" + is_valid, error = validate_port_forward_port(65536) + assert is_valid is False + assert "65535" in error + + def test_validate_port_accepts_min_allowed_port(self): + """Minimum allowed port (1024) should pass validation.""" + is_valid, error = validate_port_forward_port(1024) + assert is_valid is True + assert error is None + + def test_validate_port_accepts_max_allowed_port(self): + """Maximum allowed port (65535) should pass validation.""" + is_valid, error = validate_port_forward_port(65535) + assert is_valid is True + assert error is None + + def test_validate_port_rejects_zero_port(self): + """Port 0 should be rejected.""" + is_valid, error = validate_port_forward_port(0) + assert is_valid is False + assert error is not None + + def test_validate_port_rejects_negative_port(self): + """Negative port should be rejected.""" + is_valid, error = validate_port_forward_port(-1) + assert is_valid is False + assert error is not None + + +class TestGetRockletPortforwardUrl: + """Tests for _get_rocklet_portforward_url method.""" + + def test_get_rocklet_portforward_url_returns_correct_format(self, sandbox_proxy_service): + """Should return correct WebSocket URL for rocklet portforward endpoint.""" + mock_status = { + "host_ip": "192.168.1.100", + "phases": { + "image_pull": {"status": "running", "message": "done"}, + "docker_run": {"status": "running", "message": "done"}, + }, + "port_mapping": {22555: 32768}, # PROXY port mapping (rocklet listens on 22555) + } + url = sandbox_proxy_service._get_rocklet_portforward_url(mock_status, 9000) + # Should connect to rocklet's PROXY port (22555 mapped to 32768) with /portforward path + assert url == "ws://192.168.1.100:32768/portforward?port=9000" + + +class TestWebsocketToTcpProxy: + """Tests for websocket_to_tcp_proxy method.""" + + @pytest.mark.asyncio + async def test_websocket_to_tcp_proxy_validates_port(self, sandbox_proxy_service): + """Should raise error for invalid port before attempting connection.""" + mock_websocket = MagicMock(spec=WebSocket) + + with pytest.raises(ValueError) as exc_info: + await sandbox_proxy_service.websocket_to_tcp_proxy( + client_websocket=mock_websocket, + sandbox_id="test-sandbox", + port=22, # Invalid port (SSH) + ) + assert "22" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_websocket_to_tcp_proxy_raises_for_nonexistent_sandbox(self, sandbox_proxy_service): + """Should raise error when sandbox does not exist.""" + mock_websocket = MagicMock(spec=WebSocket) + + with patch.object( + sandbox_proxy_service, "get_service_status", side_effect=Exception("sandbox not-started not started") + ): + with pytest.raises(Exception) as exc_info: + await sandbox_proxy_service.websocket_to_tcp_proxy( + client_websocket=mock_websocket, + sandbox_id="nonexistent-sandbox", + port=8080, + ) + assert "not started" in str(exc_info.value) + + +class TestPortForwardRoute: + """Tests for the portforward WebSocket route.""" + + @pytest.fixture + def app(self): + """Create a FastAPI app with the sandbox_proxy_router.""" + mock_service = MagicMock() + mock_service.websocket_to_tcp_proxy = AsyncMock() + set_sandbox_proxy_service(mock_service) + + app = FastAPI() + app.include_router(sandbox_proxy_router) + return app, mock_service + + def test_portforward_route_valid_port_calls_service(self, app): + """Should call websocket_to_tcp_proxy with correct parameters.""" + app, mock_service = app + client = TestClient(app) + + # Mock the service to complete immediately + mock_service.websocket_to_tcp_proxy.return_value = None + + # The WebSocket connection will be accepted and the service called + # Since the mock returns immediately, the connection should close + try: + with client.websocket_connect("/sandboxes/test-sandbox/portforward?port=8080") as websocket: + # Connection established, service should be called + pass + except Exception: + # Expected: connection may close due to mock behavior + pass + + # Verify the service was called with correct parameters + mock_service.websocket_to_tcp_proxy.assert_called_once() + call_args = mock_service.websocket_to_tcp_proxy.call_args + # sandbox_id is the second positional argument + assert call_args[0][1] == "test-sandbox" + # port is the third positional argument + assert call_args[0][2] == 8080 + + def test_portforward_route_invalid_port_rejected(self, app): + """Should reject invalid port (port 22) with close code 1008.""" + app, mock_service = app + mock_service.websocket_to_tcp_proxy.side_effect = ValueError("Port 22 is not allowed for port forwarding") + + client = TestClient(app) + + # The WebSocket connection will be closed with code 1008 + with client.websocket_connect("/sandboxes/test-sandbox/portforward?port=22") as websocket: + # Connection should be closed by server + pass + + # The service should have been called + mock_service.websocket_to_tcp_proxy.assert_called_once() \ No newline at end of file