Skip to content

Commit 5377f61

Browse files
committed
feat: add AuthOptionsMiddleware for handling OPTIONS requests
Introduces AuthOptionsMiddleware to inform clients of user capabilities based on their authentication status. This middleware processes OPTIONS requests, checks user permissions against defined endpoint requirements, and constructs appropriate CORS headers for responses. Additionally, it integrates the new middleware into the FastAPI application. Relates to: - opengeospatial/ogcapi-features#1005 - stac-api-extensions/transaction#15
1 parent 1a75550 commit 5377f61

File tree

4 files changed

+97
-1
lines changed

4 files changed

+97
-1
lines changed

src/stac_auth_proxy/app.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AddProcessTimeHeaderMiddleware,
1919
ApplyCql2FilterMiddleware,
2020
AuthenticationExtensionMiddleware,
21+
AuthOptionsMiddleware,
2122
BuildCql2FilterMiddleware,
2223
EnforceAuthMiddleware,
2324
OpenApiMiddleware,
@@ -149,6 +150,13 @@ async def lifespan(app: FastAPI):
149150
AddProcessTimeHeaderMiddleware,
150151
)
151152

153+
app.add_middleware(
154+
AuthOptionsMiddleware,
155+
public_endpoints=settings.public_endpoints,
156+
private_endpoints=settings.private_endpoints,
157+
default_public=settings.default_public,
158+
)
159+
152160
app.add_middleware(
153161
EnforceAuthMiddleware,
154162
public_endpoints=settings.public_endpoints,
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""Middleware to remove ROOT_PATH from incoming requests and update links in responses."""
2+
3+
import logging
4+
from dataclasses import dataclass
5+
6+
from starlette.requests import Request
7+
from starlette.responses import Response
8+
from starlette.types import ASGIApp, Receive, Scope, Send
9+
10+
from ..config import EndpointMethods
11+
from ..utils.requests import find_match
12+
13+
logger = logging.getLogger(__name__)
14+
15+
16+
@dataclass
17+
class AuthOptionsMiddleware:
18+
"""Middleware to enform client of users capabilities in response to OPTIONS request."""
19+
20+
app: ASGIApp
21+
private_endpoints: EndpointMethods
22+
public_endpoints: EndpointMethods
23+
default_public: bool
24+
state_key: str = "payload"
25+
26+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
27+
"""Check capabilities of the user."""
28+
print("HERE")
29+
if scope["type"] != "http":
30+
return await self.app(scope, receive, send)
31+
32+
if scope["method"] != "OPTIONS":
33+
return await self.app(scope, receive, send)
34+
35+
# Get endpoint requirements
36+
methods = ["GET", "POST", "PUT", "PATCH", "DELETE"]
37+
method_requirements = {}
38+
for method in methods:
39+
match = find_match(
40+
path=scope["path"],
41+
method=method,
42+
private_endpoints=self.private_endpoints,
43+
public_endpoints=self.public_endpoints,
44+
default_public=self.default_public,
45+
)
46+
method_requirements[method] = match
47+
48+
# Get user (maybe)
49+
request = Request(scope)
50+
assert hasattr(
51+
request.state, self.state_key
52+
), "Auth Payload not set in request state. Is state_key set correctly? Does the EnforceAuthMiddleware run before this middleware?"
53+
user = getattr(request.state, self.state_key, None)
54+
user_scopes = user.get("scope", "").split(" ") if user else []
55+
56+
# Get user capabilities
57+
valid_methods = []
58+
for method, match in method_requirements.items():
59+
# Is public
60+
if not match.is_private:
61+
valid_methods.append(method)
62+
continue
63+
64+
# Is private and user has all required scopes
65+
if user and all(scope in user_scopes for scope in match.required_scopes):
66+
valid_methods.append(method)
67+
continue
68+
69+
# Construct response
70+
headers = {
71+
"Allow": ", ".join(valid_methods),
72+
"Access-Control-Allow-Methods": ", ".join(valid_methods),
73+
"Access-Control-Allow-Headers": "Authorization, Content-Type",
74+
"Access-Control-Max-Age": "86400", # 24 hours
75+
}
76+
77+
# Add CORS origin if provided in request
78+
origin = request.headers.get("Origin")
79+
if origin:
80+
headers["Access-Control-Allow-Origin"] = origin
81+
82+
response = Response(
83+
content="",
84+
status_code=204,
85+
headers=headers,
86+
)
87+
return await response(scope, receive, send)

src/stac_auth_proxy/middleware/EnforceAuthMiddleware.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,6 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
9898
auto_error=match.is_private,
9999
required_scopes=match.required_scopes,
100100
)
101-
102101
except HTTPException as e:
103102
response = JSONResponse({"detail": e.detail}, status_code=e.status_code)
104103
return await response(scope, receive, send)

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .AddProcessTimeHeaderMiddleware import AddProcessTimeHeaderMiddleware
44
from .ApplyCql2FilterMiddleware import ApplyCql2FilterMiddleware
55
from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
6+
from .AuthOptionsMiddleware import AuthOptionsMiddleware
67
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
78
from .EnforceAuthMiddleware import EnforceAuthMiddleware
89
from .ProcessLinksMiddleware import ProcessLinksMiddleware
@@ -14,6 +15,7 @@
1415
"ApplyCql2FilterMiddleware",
1516
"AuthenticationExtensionMiddleware",
1617
"BuildCql2FilterMiddleware",
18+
"AuthOptionsMiddleware",
1719
"EnforceAuthMiddleware",
1820
"ProcessLinksMiddleware",
1921
"RemoveRootPathMiddleware",

0 commit comments

Comments
 (0)