Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion rock/sdk/sandbox/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ async def start(self):
self._host_ip = response.get("result").get("host_ip")

start_time = time.time()
poll_count = 0
while time.time() - start_time < self.config.startup_timeout:
sandbox_info = await self.get_status()
logging.debug(f"Get status response: {sandbox_info}")
Expand All @@ -200,11 +201,39 @@ async def start(self):
error_msg = await self._parse_error_message_from_status(sandbox_info.status)
if error_msg:
raise InternalServerRockError(f"Failed to start sandbox because {error_msg}, sandbox: {str(self)}")
await asyncio.sleep(3)
poll_count += 1
interval = self._calculate_poll_interval(poll_count, enable_backoff=True)
await asyncio.sleep(interval)
raise InternalServerRockError(
f"Failed to start sandbox within {self.config.startup_timeout}s, sandbox: {str(self)}"
)

@staticmethod
def _calculate_poll_interval(
poll_count: int,
enable_backoff: bool = True,
base_interval: int = 3,
max_interval: int = 15,
backoff_threshold: int = 5,
backoff_step: int = 2,
) -> int:
"""Calculate the polling interval with optional gradual backoff.

Args:
poll_count: Current poll iteration number (1-based).
enable_backoff: Whether to enable gradual backoff after the threshold.
base_interval: Base polling interval in seconds.
max_interval: Maximum polling interval in seconds.
backoff_threshold: Number of polls after which backoff begins.
backoff_step: Seconds to add per poll beyond the threshold.

Returns:
The polling interval in seconds.
"""
if not enable_backoff or poll_count < backoff_threshold:
return base_interval
return min(base_interval + (poll_count - backoff_threshold + 1) * backoff_step, max_interval)

async def is_alive(self) -> IsAliveResponse:
try:
status_response = await self.get_status()
Expand Down
152 changes: 152 additions & 0 deletions tests/unit/sdk/test_client_start_backoff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
from unittest.mock import AsyncMock, patch

import pytest

from rock.actions.sandbox.response import SandboxStatusResponse
from rock.sdk.common.exceptions import InternalServerRockError
from rock.sdk.sandbox.client import Sandbox
from rock.sdk.sandbox.config import SandboxConfig

_START_POST_RESPONSE = {
"status": "Success",
"result": {"sandbox_id": "test-sandbox-id", "host_name": "host1", "host_ip": "1.2.3.4"},
}


def _create_sandbox() -> Sandbox:
config = SandboxConfig(image="python:3.11", startup_timeout=300, base_url="http://localhost:8080")
return Sandbox(config)


def _make_status(is_alive: bool, status: dict | None = None) -> SandboxStatusResponse:
return SandboxStatusResponse(sandbox_id="test-sandbox-id", status=status or {}, is_alive=is_alive)


async def _run_start_with_polls(alive_after: int) -> list[int]:
"""Run sandbox.start() where get_status returns alive after `alive_after` polls. Returns recorded sleep intervals."""
sandbox = _create_sandbox()
not_alive = _make_status(is_alive=False)
alive = _make_status(is_alive=True)
call_count = 0

async def mock_get_status():
nonlocal call_count
call_count += 1
return alive if call_count >= alive_after else not_alive

sleep_intervals = []

async def mock_sleep(seconds):
sleep_intervals.append(seconds)

with (
patch.object(sandbox, "get_status", side_effect=mock_get_status),
patch.object(sandbox, "_parse_error_message_from_status", new_callable=AsyncMock, return_value=None),
patch("rock.sdk.sandbox.client.asyncio.sleep", side_effect=mock_sleep),
patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock, return_value=_START_POST_RESPONSE),
):
await sandbox.start()

return sleep_intervals


# --- _calculate_poll_interval tests ---


def test_interval_returns_base_when_backoff_disabled():
for poll_count in range(1, 20):
assert Sandbox._calculate_poll_interval(poll_count, enable_backoff=False) == 3


def test_interval_returns_base_before_threshold():
for poll_count in range(1, 5):
assert Sandbox._calculate_poll_interval(poll_count, enable_backoff=True) == 3


def test_interval_backoff_starts_at_threshold():
assert Sandbox._calculate_poll_interval(5, enable_backoff=True) == 5


def test_interval_increases_gradually():
expected = {5: 5, 6: 7, 7: 9, 8: 11, 9: 13}
for poll_count, expected_interval in expected.items():
assert Sandbox._calculate_poll_interval(poll_count, enable_backoff=True) == expected_interval


def test_interval_caps_at_max():
for poll_count in range(10, 50):
assert Sandbox._calculate_poll_interval(poll_count, enable_backoff=True) <= 15


def test_interval_custom_parameters():
result = Sandbox._calculate_poll_interval(
poll_count=8,
enable_backoff=True,
base_interval=5,
max_interval=20,
backoff_threshold=3,
backoff_step=3,
)
assert result == 20


def test_interval_exact_sequence():
expected = [3, 3, 3, 3, 5, 7, 9, 11, 13, 15, 15, 15]
actual = [Sandbox._calculate_poll_interval(i, enable_backoff=True) for i in range(1, 13)]
assert actual == expected


# --- start() integration tests ---


@pytest.mark.asyncio
async def test_start_succeeds_on_first_poll():
sandbox = _create_sandbox()
alive = _make_status(is_alive=True)

with (
patch.object(sandbox, "get_status", new_callable=AsyncMock, return_value=alive),
patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock, return_value=_START_POST_RESPONSE),
):
await sandbox.start()

assert sandbox.sandbox_id == "test-sandbox-id"


@pytest.mark.asyncio
async def test_start_intervals_before_threshold():
sleep_intervals = await _run_start_with_polls(alive_after=4)
assert sleep_intervals == [3, 3, 3]


@pytest.mark.asyncio
async def test_start_intervals_with_backoff():
sleep_intervals = await _run_start_with_polls(alive_after=9)
assert sleep_intervals == [3, 3, 3, 3, 5, 7, 9, 11]


@pytest.mark.asyncio
async def test_start_intervals_cap_at_max():
sleep_intervals = await _run_start_with_polls(alive_after=15)
assert all(interval <= 15 for interval in sleep_intervals)
assert 15 in sleep_intervals


@pytest.mark.asyncio
async def test_start_intervals_full_sequence():
sleep_intervals = await _run_start_with_polls(alive_after=13)
assert sleep_intervals == [3, 3, 3, 3, 5, 7, 9, 11, 13, 15, 15, 15]


@pytest.mark.asyncio
async def test_start_raises_on_error_status():
sandbox = _create_sandbox()
failed_status = {"build": {"status": "failed", "message": "image pull failed"}}
not_alive_with_error = _make_status(is_alive=False, status=failed_status)

with (
patch.object(sandbox, "get_status", new_callable=AsyncMock, return_value=not_alive_with_error),
patch("rock.utils.http.HttpUtils.post", new_callable=AsyncMock, return_value=_START_POST_RESPONSE),
):
with pytest.raises(InternalServerRockError, match="image pull failed"):
await sandbox.start()
Loading