Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/curly-pumpkins-kick.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@e2b/code-interpreter-template': patch
---

Add retry
5 changes: 5 additions & 0 deletions .changeset/wicked-mirrors-punch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@e2b/code-interpreter-python': patch
---

Fix issue with secure False
12 changes: 10 additions & 2 deletions python/e2b_code_interpreter/code_interpreter_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,10 @@ async def run_code(
request_timeout = request_timeout or self.connection_config.request_timeout
context_id = context.id if context else None

headers: Dict[str, str] = {}
if self._envd_access_token:
headers = {"X-Access-Token": self._envd_access_token}

try:
async with self._client.stream(
"POST",
Expand All @@ -201,7 +205,7 @@ async def run_code(
"language": language,
"env_vars": envs,
},
headers={"X-Access-Token": self._envd_access_token},
headers=headers,
timeout=(request_timeout, timeout, request_timeout, request_timeout),
) as response:
err = await aextract_exception(response)
Expand Down Expand Up @@ -249,10 +253,14 @@ async def create_code_context(
if cwd:
data["cwd"] = cwd

headers: Dict[str, str] = {}
if self._envd_access_token:
headers = {"X-Access-Token": self._envd_access_token}

try:
response = await self._client.post(
f"{self._jupyter_url}/contexts",
headers={"X-Access-Token": self._envd_access_token},
headers=headers,
json=data,
timeout=request_timeout or self.connection_config.request_timeout,
)
Expand Down
12 changes: 10 additions & 2 deletions python/e2b_code_interpreter/code_interpreter_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def run_code(
request_timeout = request_timeout or self.connection_config.request_timeout
context_id = context.id if context else None

headers: Dict[str, str] = {}
if self._envd_access_token:
headers = {"X-Access-Token": self._envd_access_token}

try:
with self._client.stream(
"POST",
Expand All @@ -198,7 +202,7 @@ def run_code(
"language": language,
"env_vars": envs,
},
headers={"X-Access-Token": self._envd_access_token},
headers=headers,
timeout=(request_timeout, timeout, request_timeout, request_timeout),
) as response:
err = extract_exception(response)
Expand Down Expand Up @@ -246,11 +250,15 @@ def create_code_context(
if cwd:
data["cwd"] = cwd

headers: Dict[str, str] = {}
if self._envd_access_token:
headers = {"X-Access-Token": self._envd_access_token}

try:
response = self._client.post(
f"{self._jupyter_url}/contexts",
json=data,
headers={"X-Access-Token": self._envd_access_token},
headers=headers,
timeout=request_timeout or self.connection_config.request_timeout,
)

Expand Down
69 changes: 61 additions & 8 deletions template/server/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
)
from pydantic import StrictStr
from websockets.client import WebSocketClientProtocol, connect
from websockets.exceptions import (
ConnectionClosedError,
WebSocketException,
)

from api.models.error import Error
from api.models.logs import Stdout, Stderr
Expand All @@ -27,6 +31,9 @@

logger = logging.getLogger(__name__)

MAX_RECONNECT_RETRIES = 3
PING_TIMEOUT = 30


class Execution:
def __init__(self, in_background: bool = False):
Expand Down Expand Up @@ -61,6 +68,15 @@ def __init__(self, context_id: str, session_id: str, language: str, cwd: str):
self._executions: Dict[str, Execution] = {}
self._lock = asyncio.Lock()

async def reconnect(self):
if self._ws is not None:
await self._ws.close(reason="Reconnecting")

if self._receive_task is not None:
await self._receive_task

await self.connect()

async def connect(self):
logger.debug(f"WebSocket connecting to {self.url}")

Expand All @@ -69,6 +85,7 @@ async def connect(self):

self._ws = await connect(
self.url,
ping_timeout=PING_TIMEOUT,
max_size=None,
max_queue=None,
logger=ws_logger,
Expand Down Expand Up @@ -274,9 +291,6 @@ async def execute(
env_vars: Dict[StrictStr, str],
access_token: str,
):
message_id = str(uuid.uuid4())
self._executions[message_id] = Execution()

if self._ws is None:
raise Exception("WebSocket not connected")

Expand Down Expand Up @@ -313,13 +327,40 @@ async def execute(
)
complete_code = f"{indented_env_code}\n{complete_code}"

logger.info(
f"Sending code for the execution ({message_id}): {complete_code}"
)
request = self._get_execute_request(message_id, complete_code, False)
message_id = str(uuid.uuid4())
execution = Execution()
self._executions[message_id] = execution

# Send the code for execution
await self._ws.send(request)
# Initial request and retries
for i in range(1 + MAX_RECONNECT_RETRIES):
try:
logger.info(
f"Sending code for the execution ({message_id}): {complete_code}"
)
request = self._get_execute_request(
message_id, complete_code, False
)
await self._ws.send(request)
break
except (ConnectionClosedError, WebSocketException) as e:
# Keep the last result, even if error
if i < MAX_RECONNECT_RETRIES:
logger.warning(
f"WebSocket connection lost while sending execution request, {i + 1}. reconnecting...: {str(e)}"
)
await self.reconnect()
else:
# The retry didn't help, request wasn't sent successfully
logger.error("Failed to send execution request")
await execution.queue.put(
Error(
name="WebSocketError",
value="Failed to send execution request",
traceback="",
)
)
await execution.queue.put(UnexpectedEndOfExecution())

# Stream the results
async for item in self._wait_for_result(message_id):
Expand All @@ -343,6 +384,18 @@ async def _receive_message(self):
await self._process_message(json.loads(message))
except Exception as e:
logger.error(f"WebSocket received error while receiving messages: {str(e)}")
finally:
# To prevent infinite hang, we need to cancel all ongoing execution as we could lost results during the reconnect
# Thanks to the locking, there can be either no ongoing execution or just one.
for key, execution in self._executions.items():
await execution.queue.put(
Error(
name="WebSocketError",
value="The connections was lost, rerun the code to get the results",
traceback="",
)
)
await execution.queue.put(UnexpectedEndOfExecution())

async def _process_message(self, data: dict):
"""
Expand Down