1313import urllib .parse
1414from typing import Iterable , List , Optional , Union
1515
16+ import outcome
1617import trio
1718import trio .abc
1819from wsproto import ConnectionType , WSConnection
3536 # pylint doesn't care about the version_info check, so need to ignore the warning
3637 from exceptiongroup import BaseExceptionGroup # pylint: disable=redefined-builtin
3738
38- _TRIO_MULTI_ERROR = tuple (map (int , trio .__version__ .split ('.' )[:2 ])) < (0 , 22 )
39+ _IS_TRIO_MULTI_ERROR = tuple (map (int , trio .__version__ .split ('.' )[:2 ])) < (0 , 22 )
40+
41+ if _IS_TRIO_MULTI_ERROR :
42+ _TRIO_EXC_GROUP_TYPE = trio .MultiError # type: ignore[attr-defined] # pylint: disable=no-member
43+ else :
44+ _TRIO_EXC_GROUP_TYPE = BaseExceptionGroup # pylint: disable=possibly-used-before-assignment
3945
4046CONN_TIMEOUT = 60 # default connect & disconnect timeout, in seconds
4147MESSAGE_QUEUE_SIZE = 1
4450logger = logging .getLogger ('trio-websocket' )
4551
4652
53+ class TrioWebsocketInternalError (Exception ):
54+ """Raised as a fallback when open_websocket is unable to unwind an exceptiongroup
55+ into a single preferred exception. This should never happen, if it does then
56+ underlying assumptions about the internal code are incorrect.
57+ """
58+
59+
4760def _ignore_cancel (exc ):
4861 return None if isinstance (exc , trio .Cancelled ) else exc
4962
@@ -70,7 +83,7 @@ def __exit__(self, ty, value, tb):
7083 if value is None or not self ._armed :
7184 return False
7285
73- if _TRIO_MULTI_ERROR : # pragma: no cover
86+ if _IS_TRIO_MULTI_ERROR : # pragma: no cover
7487 filtered_exception = trio .MultiError .filter (_ignore_cancel , value ) # pylint: disable=no-member
7588 elif isinstance (value , BaseExceptionGroup ): # pylint: disable=possibly-used-before-assignment
7689 filtered_exception = value .subgroup (lambda exc : not isinstance (exc , trio .Cancelled ))
@@ -125,10 +138,33 @@ async def open_websocket(
125138 client-side timeout (:exc:`ConnectionTimeout`, :exc:`DisconnectionTimeout`),
126139 or server rejection (:exc:`ConnectionRejected`) during handshakes.
127140 '''
128- async with trio .open_nursery () as new_nursery :
141+
142+ # This context manager tries very very hard not to raise an exceptiongroup
143+ # in order to be as transparent as possible for the end user.
144+ # In the trivial case, this means that if user code inside the cm raises
145+ # we make sure that it doesn't get wrapped.
146+
147+ # If opening the connection fails, then we will raise that exception. User
148+ # code is never executed, so we will never have multiple exceptions.
149+
150+ # After opening the connection, we spawn _reader_task in the background and
151+ # yield to user code. If only one of those raise a non-cancelled exception
152+ # we will raise that non-cancelled exception.
153+ # If we get multiple cancelled, we raise the user's cancelled.
154+ # If both raise exceptions, we raise the user code's exception with the entire
155+ # exception group as the __cause__.
156+ # If we somehow get multiple exceptions, but no user exception, then we raise
157+ # TrioWebsocketInternalError.
158+
159+ # If closing the connection fails, then that will be raised as the top
160+ # exception in the last `finally`. If we encountered exceptions in user code
161+ # or in reader task then they will be set as the `__cause__`.
162+
163+
164+ async def _open_connection (nursery : trio .Nursery ) -> WebSocketConnection :
129165 try :
130166 with trio .fail_after (connect_timeout ):
131- connection = await connect_websocket (new_nursery , host , port ,
167+ return await connect_websocket (nursery , host , port ,
132168 resource , use_ssl = use_ssl , subprotocols = subprotocols ,
133169 extra_headers = extra_headers ,
134170 message_queue_size = message_queue_size ,
@@ -137,14 +173,85 @@ async def open_websocket(
137173 raise ConnectionTimeout from None
138174 except OSError as e :
139175 raise HandshakeError from e
176+
177+ async def _close_connection (connection : WebSocketConnection ) -> None :
140178 try :
141- yield connection
142- finally :
143- try :
144- with trio .fail_after (disconnect_timeout ):
145- await connection .aclose ()
146- except trio .TooSlowError :
147- raise DisconnectionTimeout from None
179+ with trio .fail_after (disconnect_timeout ):
180+ await connection .aclose ()
181+ except trio .TooSlowError :
182+ raise DisconnectionTimeout from None
183+
184+ connection : WebSocketConnection | None = None
185+ close_result : outcome .Maybe [None ] | None = None
186+ user_error = None
187+
188+ try :
189+ async with trio .open_nursery () as new_nursery :
190+ result = await outcome .acapture (_open_connection , new_nursery )
191+
192+ if isinstance (result , outcome .Value ):
193+ connection = result .unwrap ()
194+ try :
195+ yield connection
196+ except BaseException as e :
197+ user_error = e
198+ raise
199+ finally :
200+ close_result = await outcome .acapture (_close_connection , connection )
201+ # This exception handler should only be entered if either:
202+ # 1. The _reader_task started in connect_websocket raises
203+ # 2. User code raises an exception
204+ # I.e. open/close_connection are not included
205+ except _TRIO_EXC_GROUP_TYPE as e :
206+ # user_error, or exception bubbling up from _reader_task
207+ if len (e .exceptions ) == 1 :
208+ raise e .exceptions [0 ]
209+
210+ # contains at most 1 non-cancelled exceptions
211+ exception_to_raise : BaseException | None = None
212+ for sub_exc in e .exceptions :
213+ if not isinstance (sub_exc , trio .Cancelled ):
214+ if exception_to_raise is not None :
215+ # multiple non-cancelled
216+ break
217+ exception_to_raise = sub_exc
218+ else :
219+ if exception_to_raise is None :
220+ # all exceptions are cancelled
221+ # prefer raising the one from the user, for traceback reasons
222+ if user_error is not None :
223+ # no reason to raise from e, just to include a bunch of extra
224+ # cancelleds.
225+ raise user_error # pylint: disable=raise-missing-from
226+ # multiple internal Cancelled is not possible afaik
227+ raise e .exceptions [0 ] # pragma: no cover # pylint: disable=raise-missing-from
228+ raise exception_to_raise
229+
230+ # if we have any KeyboardInterrupt in the group, make sure to raise it.
231+ for sub_exc in e .exceptions :
232+ if isinstance (sub_exc , KeyboardInterrupt ):
233+ raise sub_exc from e
234+
235+ # Both user code and internal code raised non-cancelled exceptions.
236+ # We "hide" the internal exception(s) in the __cause__ and surface
237+ # the user_error.
238+ if user_error is not None :
239+ raise user_error from e
240+
241+ raise TrioWebsocketInternalError (
242+ "The trio-websocket API is not expected to raise multiple exceptions. "
243+ "Please report this as a bug to "
244+ "https://github.com/python-trio/trio-websocket"
245+ ) from e # pragma: no cover
246+
247+ finally :
248+ if close_result is not None :
249+ close_result .unwrap ()
250+
251+
252+ # error setting up, unwrap that exception
253+ if connection is None :
254+ result .unwrap ()
148255
149256
150257async def connect_websocket (nursery , host , port , resource , * , use_ssl ,
0 commit comments