|
1 | 1 | import unittest |
2 | | -from unittest.mock import patch |
3 | | - |
4 | 2 | from fastapi import FastAPI, Request |
5 | 3 | from fastapi.testclient import TestClient |
6 | 4 | from fastapi.websockets import WebSocket |
| 5 | +from unittest.mock import AsyncMock |
| 6 | +from unittest.mock import patch |
7 | 7 | from uuid import uuid4 |
8 | 8 |
|
9 | 9 | from mauth_client.authenticator import LocalAuthenticator |
@@ -220,3 +220,61 @@ def is_authentic_effect(self): |
220 | 220 | self.client.get("/sub_app/path") |
221 | 221 |
|
222 | 222 | self.assertEqual(request_url, "/sub_app/path") |
| 223 | + |
| 224 | + |
| 225 | +class TestMAuthASGIMiddlewareInLongLivedConnections(unittest.IsolatedAsyncioTestCase): |
| 226 | + def setUp(self): |
| 227 | + self.app = FastAPI() |
| 228 | + Config.APP_UUID = str(uuid4()) |
| 229 | + Config.MAUTH_URL = "https://mauth.com" |
| 230 | + Config.MAUTH_API_VERSION = "v1" |
| 231 | + Config.PRIVATE_KEY = "key" |
| 232 | + |
| 233 | + @patch.object(LocalAuthenticator, "is_authentic") |
| 234 | + async def test_fake_receive_delegates_to_original_after_body_consumed(self, is_authentic_mock): |
| 235 | + """Test that after body events are consumed, _fake_receive delegates to original receive""" |
| 236 | + is_authentic_mock.return_value = (True, 200, "") |
| 237 | + |
| 238 | + # Track that original receive was called after body events exhausted |
| 239 | + call_order = [] |
| 240 | + |
| 241 | + async def mock_app(scope, receive, send): |
| 242 | + # First receive should get body event |
| 243 | + event1 = await receive() |
| 244 | + call_order.append(("body", event1["type"])) |
| 245 | + |
| 246 | + # Second receive should delegate to original receive |
| 247 | + event2 = await receive() |
| 248 | + call_order.append(("disconnect", event2["type"])) |
| 249 | + |
| 250 | + await send({"type": "http.response.start", "status": 200, "headers": []}) |
| 251 | + await send({"type": "http.response.body", "body": b""}) |
| 252 | + |
| 253 | + middleware = MAuthASGIMiddleware(mock_app) |
| 254 | + |
| 255 | + # Mock receive that returns body then disconnect |
| 256 | + receive_calls = 0 |
| 257 | + |
| 258 | + async def mock_receive(): |
| 259 | + nonlocal receive_calls |
| 260 | + receive_calls += 1 |
| 261 | + if receive_calls == 1: |
| 262 | + return {"type": "http.request", "body": b"test", "more_body": False} |
| 263 | + return {"type": "http.disconnect"} |
| 264 | + |
| 265 | + send_mock = AsyncMock() |
| 266 | + scope = { |
| 267 | + "type": "http", |
| 268 | + "method": "POST", |
| 269 | + "path": "/test", |
| 270 | + "query_string": b"", |
| 271 | + "headers": [] |
| 272 | + } |
| 273 | + |
| 274 | + await middleware(scope, mock_receive, send_mock) |
| 275 | + |
| 276 | + # Verify events were received in correct order |
| 277 | + self.assertEqual(len(call_order), 2) |
| 278 | + self.assertEqual(call_order[0], ("body", "http.request")) |
| 279 | + self.assertEqual(call_order[1], ("disconnect", "http.disconnect")) |
| 280 | + self.assertEqual(receive_calls, 2) # Called once for auth, once from app |
0 commit comments