@@ -35,7 +35,7 @@ def __init__(self, token: str, url: str = TRANSPORT_DEFAULT_WS_URL):
35
35
self ._recv_task : Optional [asyncio .Task [None ]] = None
36
36
self ._closed = False
37
37
38
- async def connect (self ) -> dict [str , Any ]:
38
+ async def connect (self , throw_error : bool = True ) -> dict [str , Any ]:
39
39
self ._ws = await websockets .connect (
40
40
self ._url ,
41
41
extra_headers = {
@@ -49,7 +49,7 @@ async def connect(self) -> dict[str, Any]:
49
49
hello_msg = cast (HelloMessage , hello )
50
50
self ._closed = False
51
51
# Start background message processor
52
- self ._recv_task = asyncio .create_task (self ._background_recv ())
52
+ self ._recv_task = asyncio .create_task (self ._background_recv (throw_error ))
53
53
return {"version" : hello_msg ["appVersion" ]}
54
54
55
55
async def close (self ) -> None :
@@ -109,7 +109,7 @@ async def request(self, command: str, params: dict[str, Any]) -> ResponseMessage
109
109
finally :
110
110
del self ._response_futures [msg_id ]
111
111
112
- async def _background_recv (self ) -> None :
112
+ async def _background_recv (self , throw_error : bool = True ) -> None :
113
113
try :
114
114
while not self ._closed and self ._ws is not None :
115
115
msg : IncomingMessage = await self ._recv ()
@@ -127,6 +127,16 @@ async def _background_recv(self) -> None:
127
127
except Exception as e :
128
128
warnings .warn (f"Background recv error: { e } " , RuntimeWarning )
129
129
130
+ if throw_error :
131
+ self ._closed = True
132
+ # Cancel all pending response futures
133
+ for future in self ._response_futures .values ():
134
+ if not future .done ():
135
+ future .set_exception (e )
136
+ if self ._ws :
137
+ await self ._ws .close ()
138
+ raise
139
+
130
140
async def _recv (self ) -> IncomingMessage :
131
141
if self ._ws is None :
132
142
raise WokwiError ("Not connected" )
0 commit comments