Skip to content

Commit 2cd2356

Browse files
committed
Fix ASGI event handling for long-lived connections
After body events are consumed for authentication, the middleware's _fake_receive function now delegates to the original receive callable instead of returning None. This allows downstream applications to properly receive lifecycle events like http.disconnect, enabling proper cleanup for SSE connections, streaming responses, and other long-lived HTTP connections. Adds test to verify that _fake_receive correctly delegates to original receive after body events are exhausted.
1 parent 21579c9 commit 2cd2356

File tree

2 files changed

+75
-9
lines changed

2 files changed

+75
-9
lines changed

mauth_client/middlewares/asgi.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ async def __call__(
6262
scope_copy[ENV_APP_UUID] = signed.app_uuid
6363
scope_copy[ENV_AUTHENTIC] = True
6464
scope_copy[ENV_PROTOCOL_VERSION] = signed.protocol_version()
65-
await self.app(scope_copy, self._fake_receive(events), send)
65+
await self.app(scope_copy, self._fake_receive(events, receive), send)
6666
else:
6767
await self._send_response(send, status, message)
6868

@@ -100,18 +100,26 @@ async def _send_response(self, send: ASGISendCallable, status: int, msg: str) ->
100100
"body": json.dumps(body).encode("utf-8"),
101101
})
102102

103-
def _fake_receive(self, events: List[ASGIReceiveEvent]) -> ASGIReceiveCallable:
103+
def _fake_receive(self, events: List[ASGIReceiveEvent],
104+
original_receive: ASGIReceiveCallable) -> ASGIReceiveCallable:
104105
"""
105-
Create a fake, async receive function using an iterator of the events
106-
we've already read. This will be passed to downstream middlewares/apps
107-
instead of the usual receive fn, so that they can also "receive" the
108-
body events.
106+
Create a fake receive function that replays cached body events.
107+
108+
After the middleware consumes request body events for authentication,
109+
this allows downstream apps to also "receive" those events. Once all
110+
cached events are exhausted, delegates to the original receive to
111+
properly forward lifecycle events (like http.disconnect).
112+
113+
This is essential for long-lived connections (SSE, streaming responses)
114+
that need to detect client disconnects.
109115
"""
110116
events_iter = iter(events)
111117

112118
async def _receive() -> ASGIReceiveEvent:
113119
try:
114120
return next(events_iter)
115121
except StopIteration:
116-
pass
122+
# After body events are consumed, delegate to original receive
123+
# This allows proper handling of disconnects for SSE connections
124+
return await original_receive()
117125
return _receive

tests/middlewares/asgi_test.py

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import unittest
2-
from unittest.mock import patch
3-
42
from fastapi import FastAPI, Request
53
from fastapi.testclient import TestClient
64
from fastapi.websockets import WebSocket
5+
from unittest.mock import AsyncMock
6+
from unittest.mock import patch
77
from uuid import uuid4
88

99
from mauth_client.authenticator import LocalAuthenticator
@@ -220,3 +220,61 @@ def is_authentic_effect(self):
220220
self.client.get("/sub_app/path")
221221

222222
self.assertEqual(request_url, "/sub_app/path")
223+
224+
225+
class TestMAuthASGIMiddlewareInLongLivedConnections(unittest.IsolatedAsyncioTestCase):
226+
def setUp(self):
227+
self.app = FastAPI()
228+
Config.APP_UUID = str(uuid4())
229+
Config.MAUTH_URL = "https://mauth.com"
230+
Config.MAUTH_API_VERSION = "v1"
231+
Config.PRIVATE_KEY = "key"
232+
233+
@patch.object(LocalAuthenticator, "is_authentic")
234+
async def test_fake_receive_delegates_to_original_after_body_consumed(self, is_authentic_mock):
235+
"""Test that after body events are consumed, _fake_receive delegates to original receive"""
236+
is_authentic_mock.return_value = (True, 200, "")
237+
238+
# Track that original receive was called after body events exhausted
239+
call_order = []
240+
241+
async def mock_app(scope, receive, send):
242+
# First receive should get body event
243+
event1 = await receive()
244+
call_order.append(("body", event1["type"]))
245+
246+
# Second receive should delegate to original receive
247+
event2 = await receive()
248+
call_order.append(("disconnect", event2["type"]))
249+
250+
await send({"type": "http.response.start", "status": 200, "headers": []})
251+
await send({"type": "http.response.body", "body": b""})
252+
253+
middleware = MAuthASGIMiddleware(mock_app)
254+
255+
# Mock receive that returns body then disconnect
256+
receive_calls = 0
257+
258+
async def mock_receive():
259+
nonlocal receive_calls
260+
receive_calls += 1
261+
if receive_calls == 1:
262+
return {"type": "http.request", "body": b"test", "more_body": False}
263+
return {"type": "http.disconnect"}
264+
265+
send_mock = AsyncMock()
266+
scope = {
267+
"type": "http",
268+
"method": "POST",
269+
"path": "/test",
270+
"query_string": b"",
271+
"headers": []
272+
}
273+
274+
await middleware(scope, mock_receive, send_mock)
275+
276+
# Verify events were received in correct order
277+
self.assertEqual(len(call_order), 2)
278+
self.assertEqual(call_order[0], ("body", "http.request"))
279+
self.assertEqual(call_order[1], ("disconnect", "http.disconnect"))
280+
self.assertEqual(receive_calls, 2) # Called once for auth, once from app

0 commit comments

Comments
 (0)