|
| 1 | +"""Middleware to remove ROOT_PATH from incoming requests and update links in responses.""" |
| 2 | + |
| 3 | +import logging |
| 4 | +from dataclasses import dataclass |
| 5 | + |
| 6 | +from starlette.requests import Request |
| 7 | +from starlette.responses import Response |
| 8 | +from starlette.types import ASGIApp, Receive, Scope, Send |
| 9 | + |
| 10 | +from ..config import EndpointMethods |
| 11 | +from ..utils.requests import find_match |
| 12 | + |
| 13 | +logger = logging.getLogger(__name__) |
| 14 | + |
| 15 | + |
| 16 | +@dataclass |
| 17 | +class AuthOptionsMiddleware: |
| 18 | + """Middleware to enform client of users capabilities in response to OPTIONS request.""" |
| 19 | + |
| 20 | + app: ASGIApp |
| 21 | + private_endpoints: EndpointMethods |
| 22 | + public_endpoints: EndpointMethods |
| 23 | + default_public: bool |
| 24 | + state_key: str = "payload" |
| 25 | + |
| 26 | + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: |
| 27 | + """Check capabilities of the user.""" |
| 28 | + print("HERE") |
| 29 | + if scope["type"] != "http": |
| 30 | + return await self.app(scope, receive, send) |
| 31 | + |
| 32 | + if scope["method"] != "OPTIONS": |
| 33 | + return await self.app(scope, receive, send) |
| 34 | + |
| 35 | + # Get endpoint requirements |
| 36 | + methods = ["GET", "POST", "PUT", "PATCH", "DELETE"] |
| 37 | + method_requirements = {} |
| 38 | + for method in methods: |
| 39 | + match = find_match( |
| 40 | + path=scope["path"], |
| 41 | + method=method, |
| 42 | + private_endpoints=self.private_endpoints, |
| 43 | + public_endpoints=self.public_endpoints, |
| 44 | + default_public=self.default_public, |
| 45 | + ) |
| 46 | + method_requirements[method] = match |
| 47 | + |
| 48 | + # Get user (maybe) |
| 49 | + request = Request(scope) |
| 50 | + assert hasattr( |
| 51 | + request.state, self.state_key |
| 52 | + ), "Auth Payload not set in request state. Is state_key set correctly? Does the EnforceAuthMiddleware run before this middleware?" |
| 53 | + user = getattr(request.state, self.state_key, None) |
| 54 | + user_scopes = user.get("scope", "").split(" ") if user else [] |
| 55 | + |
| 56 | + # Get user capabilities |
| 57 | + valid_methods = [] |
| 58 | + for method, match in method_requirements.items(): |
| 59 | + # Is public |
| 60 | + if not match.is_private: |
| 61 | + valid_methods.append(method) |
| 62 | + continue |
| 63 | + |
| 64 | + # Is private and user has all required scopes |
| 65 | + if user and all(scope in user_scopes for scope in match.required_scopes): |
| 66 | + valid_methods.append(method) |
| 67 | + continue |
| 68 | + |
| 69 | + # Construct response |
| 70 | + headers = { |
| 71 | + "Allow": ", ".join(valid_methods), |
| 72 | + "Access-Control-Allow-Methods": ", ".join(valid_methods), |
| 73 | + "Access-Control-Allow-Headers": "Authorization, Content-Type", |
| 74 | + "Access-Control-Max-Age": "86400", # 24 hours |
| 75 | + } |
| 76 | + |
| 77 | + # Add CORS origin if provided in request |
| 78 | + origin = request.headers.get("Origin") |
| 79 | + if origin: |
| 80 | + headers["Access-Control-Allow-Origin"] = origin |
| 81 | + |
| 82 | + response = Response( |
| 83 | + content="", |
| 84 | + status_code=204, |
| 85 | + headers=headers, |
| 86 | + ) |
| 87 | + return await response(scope, receive, send) |
0 commit comments