Skip to content

Commit 34ba6b7

Browse files
authored
feat: configurable root_path (#50)
Enable the proxy to be served from a configurable non-root path (e.g. `/stac`). This is entirely independent of the path from which the upstream API is served. This involves: * Processing the `href` of any link in the response matching the upstream host, removing the `upstream_url` path and adding the `root_path` if it exists * Removing the `root_path` from the request before it is sent to the upstream server * Adding a servers field in OpenAPI spec with `root_path` if it exists
1 parent 2961a48 commit 34ba6b7

11 files changed

+488
-18
lines changed

src/stac_auth_proxy/app.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
BuildCql2FilterMiddleware,
2222
EnforceAuthMiddleware,
2323
OpenApiMiddleware,
24+
ProcessLinksMiddleware,
25+
RemoveRootPathMiddleware,
2426
)
2527
from .utils.lifespan import check_conformance, check_server_health
2628

@@ -67,11 +69,15 @@ async def lifespan(app: FastAPI):
6769
app = FastAPI(
6870
openapi_url=None, # Disable OpenAPI schema endpoint, we want to serve upstream's schema
6971
lifespan=lifespan,
72+
root_path=settings.root_path,
7073
)
74+
if app.root_path:
75+
logger.debug("Mounted app at %s", app.root_path)
7176

7277
#
7378
# Handlers (place catch-all proxy handler last)
7479
#
80+
7581
if settings.healthz_prefix:
7682
app.include_router(
7783
HealthzHandler(upstream_url=str(settings.upstream_url)).router,
@@ -90,6 +96,7 @@ async def lifespan(app: FastAPI):
9096
#
9197
# Middleware (order is important, last added = first to run)
9298
#
99+
93100
if settings.enable_authentication_extension:
94101
app.add_middleware(
95102
AuthenticationExtensionMiddleware,
@@ -106,6 +113,7 @@ async def lifespan(app: FastAPI):
106113
public_endpoints=settings.public_endpoints,
107114
private_endpoints=settings.private_endpoints,
108115
default_public=settings.default_public,
116+
root_path=settings.root_path,
109117
auth_scheme_name=settings.openapi_auth_scheme_name,
110118
auth_scheme_override=settings.openapi_auth_scheme_override,
111119
)
@@ -119,11 +127,6 @@ async def lifespan(app: FastAPI):
119127
items_filter=settings.items_filter(),
120128
)
121129

122-
if settings.enable_compression:
123-
app.add_middleware(
124-
CompressionMiddleware,
125-
)
126-
127130
app.add_middleware(
128131
AddProcessTimeHeaderMiddleware,
129132
)
@@ -136,4 +139,22 @@ async def lifespan(app: FastAPI):
136139
oidc_config_url=settings.oidc_discovery_internal_url,
137140
)
138141

142+
if settings.root_path or settings.upstream_url.path != "/":
143+
app.add_middleware(
144+
ProcessLinksMiddleware,
145+
upstream_url=str(settings.upstream_url),
146+
root_path=settings.root_path,
147+
)
148+
149+
if settings.root_path:
150+
app.add_middleware(
151+
RemoveRootPathMiddleware,
152+
root_path=settings.root_path,
153+
)
154+
155+
if settings.enable_compression:
156+
app.add_middleware(
157+
CompressionMiddleware,
158+
)
159+
139160
return app

src/stac_auth_proxy/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ class Settings(BaseSettings):
3838
oidc_discovery_url: HttpUrl
3939
oidc_discovery_internal_url: HttpUrl
4040

41+
root_path: str = ""
4142
override_host: bool = True
4243
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
4344
wait_for_upstream: bool = True

src/stac_auth_proxy/middleware/AuthenticationExtensionMiddleware.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import logging
44
import re
55
from dataclasses import dataclass, field
6-
from itertools import chain
76
from typing import Any
87
from urllib.parse import urlparse
98

@@ -14,6 +13,7 @@
1413
from ..config import EndpointMethods
1514
from ..utils.middleware import JsonResponseMiddleware
1615
from ..utils.requests import find_match
16+
from ..utils.stac import get_links
1717

1818
logger = logging.getLogger(__name__)
1919

@@ -101,18 +101,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
101101
# auth:refs
102102
# ---
103103
# Annotate links with "auth:refs": [auth_scheme]
104-
links = chain(
105-
# Item/Collection
106-
data.get("links", []),
107-
# Collections/Items/Search
108-
(
109-
link
110-
for prop in ["features", "collections"]
111-
for object_with_links in data.get(prop, [])
112-
for link in object_with_links.get("links", [])
113-
),
114-
)
115-
for link in links:
104+
for link in get_links(data):
116105
if "href" not in link:
117106
logger.warning("Link %s has no href", link)
118107
continue
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
"""Middleware to remove the application root path from incoming requests and update links in responses."""
2+
3+
import logging
4+
import re
5+
from dataclasses import dataclass
6+
from typing import Any, Optional
7+
from urllib.parse import urlparse, urlunparse
8+
9+
from starlette.datastructures import Headers
10+
from starlette.requests import Request
11+
from starlette.types import ASGIApp, Scope
12+
13+
from ..utils.middleware import JsonResponseMiddleware
14+
from ..utils.stac import get_links
15+
16+
logger = logging.getLogger(__name__)
17+
18+
19+
@dataclass
20+
class ProcessLinksMiddleware(JsonResponseMiddleware):
21+
"""
22+
Middleware to update links in responses, removing the upstream_url path and adding
23+
the root_path if it exists.
24+
"""
25+
26+
app: ASGIApp
27+
upstream_url: str
28+
root_path: Optional[str] = None
29+
30+
json_content_type_expr: str = r"application/(geo\+)?json"
31+
32+
def should_transform_response(self, request: Request, scope: Scope) -> bool:
33+
"""Only transform responses with JSON content type."""
34+
return bool(
35+
re.match(
36+
self.json_content_type_expr,
37+
Headers(scope=scope).get("content-type", ""),
38+
)
39+
)
40+
41+
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
42+
"""Update links in the response to include root_path."""
43+
for link in get_links(data):
44+
href = link.get("href")
45+
if not href:
46+
continue
47+
48+
try:
49+
parsed_link = urlparse(href)
50+
51+
# Ignore links that are not for this proxy
52+
if parsed_link.netloc != request.headers.get("host"):
53+
continue
54+
55+
# Remove the upstream_url path from the link if it exists
56+
if urlparse(self.upstream_url).path != "/":
57+
parsed_link = parsed_link._replace(
58+
path=parsed_link.path[len(urlparse(self.upstream_url).path) :]
59+
)
60+
61+
# Add the root_path to the link if it exists
62+
if self.root_path:
63+
parsed_link = parsed_link._replace(
64+
path=f"{self.root_path}{parsed_link.path}"
65+
)
66+
67+
link["href"] = urlunparse(parsed_link)
68+
except Exception as e:
69+
logger.error(
70+
"Failed to parse link href %r, (ignoring): %s", href, str(e)
71+
)
72+
73+
return data
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
"""Middleware to remove ROOT_PATH from incoming requests and update links in responses."""
2+
3+
import logging
4+
from dataclasses import dataclass
5+
6+
from starlette.responses import Response
7+
from starlette.types import ASGIApp, Receive, Scope, Send
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
@dataclass
13+
class RemoveRootPathMiddleware:
14+
"""
15+
Middleware to remove the root path of the request before it is sent to the upstream
16+
server.
17+
18+
IMPORTANT: This middleware must be placed early in the middleware chain (ie late in
19+
the order of declaration) so that it trims the root_path from the request path before
20+
any middleware that may need to use the request path (e.g. EnforceAuthMiddleware).
21+
"""
22+
23+
app: ASGIApp
24+
root_path: str
25+
26+
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
27+
"""Remove ROOT_PATH from the request path if it exists."""
28+
if scope["type"] != "http":
29+
return await self.app(scope, receive, send)
30+
31+
# If root_path is set and path doesn't start with it, return 404
32+
if self.root_path and not scope["path"].startswith(self.root_path):
33+
response = Response("Not Found", status_code=404)
34+
logger.error(
35+
f"Root path {self.root_path!r} not found in path {scope['path']!r}"
36+
)
37+
await response(scope, receive, send)
38+
return
39+
40+
# Remove root_path if it exists at the start of the path
41+
if scope["path"].startswith(self.root_path):
42+
scope["raw_path"] = scope["path"].encode()
43+
scope["path"] = scope["path"][len(self.root_path) :] or "/"
44+
45+
return await self.app(scope, receive, send)

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class OpenApiMiddleware(JsonResponseMiddleware):
2323
private_endpoints: EndpointMethods
2424
public_endpoints: EndpointMethods
2525
default_public: bool
26+
root_path: str = ""
2627
auth_scheme_name: str = "oidcAuth"
2728
auth_scheme_override: Optional[dict] = None
2829

@@ -46,12 +47,19 @@ def should_transform_response(self, request: Request, scope: Scope) -> bool:
4647

4748
def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, Any]:
4849
"""Augment the OpenAPI spec with auth information."""
50+
# Add servers field with root path if root_path is set
51+
if self.root_path:
52+
data["servers"] = [{"url": self.root_path}]
53+
54+
# Add security scheme
4955
components = data.setdefault("components", {})
5056
securitySchemes = components.setdefault("securitySchemes", {})
5157
securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or {
5258
"type": "openIdConnect",
5359
"openIdConnectUrl": self.oidc_config_url,
5460
}
61+
62+
# Add security to private endpoints
5563
for path, method_config in data["paths"].items():
5664
for method, config in method_config.items():
5765
match = find_match(

src/stac_auth_proxy/middleware/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from .AuthenticationExtensionMiddleware import AuthenticationExtensionMiddleware
66
from .BuildCql2FilterMiddleware import BuildCql2FilterMiddleware
77
from .EnforceAuthMiddleware import EnforceAuthMiddleware
8+
from .ProcessLinksMiddleware import ProcessLinksMiddleware
9+
from .RemoveRootPathMiddleware import RemoveRootPathMiddleware
810
from .UpdateOpenApiMiddleware import OpenApiMiddleware
911

1012
__all__ = [
@@ -13,5 +15,7 @@
1315
"AuthenticationExtensionMiddleware",
1416
"BuildCql2FilterMiddleware",
1517
"EnforceAuthMiddleware",
18+
"ProcessLinksMiddleware",
19+
"RemoveRootPathMiddleware",
1620
"OpenApiMiddleware",
1721
]

src/stac_auth_proxy/utils/stac.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
"""STAC-specific utilities."""
2+
3+
from itertools import chain
4+
5+
6+
def get_links(data: dict) -> chain[dict]:
7+
"""Get all links from a STAC response."""
8+
return chain(
9+
# Item/Collection
10+
data.get("links", []),
11+
# Collections/Items/Search
12+
(
13+
link
14+
for prop in ["features", "collections"]
15+
for object_with_links in data.get(prop, [])
16+
for link in object_with_links.get("links", [])
17+
),
18+
)

tests/test_openapi.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,3 +190,33 @@ def test_auth_scheme_override(source_api: FastAPI, source_api_server: str):
190190
security_schemes = openapi.get("components", {}).get("securitySchemes", {})
191191
assert "oidcAuth" in security_schemes
192192
assert security_schemes["oidcAuth"] == custom_scheme
193+
194+
195+
def test_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: str):
196+
"""When root_path is set, the OpenAPI spec includes the root path in the servers field."""
197+
root_path = "/api/v1"
198+
app = app_factory(
199+
upstream_url=source_api_server,
200+
openapi_spec_endpoint=source_api.openapi_url,
201+
root_path=root_path,
202+
)
203+
client = TestClient(app)
204+
response = client.get(root_path + source_api.openapi_url)
205+
assert response.status_code == 200
206+
openapi = response.json()
207+
assert "servers" in openapi
208+
assert openapi["servers"] == [{"url": root_path}]
209+
210+
211+
def test_no_root_path_in_openapi_spec(source_api: FastAPI, source_api_server: str):
212+
"""When root_path is not set, the OpenAPI spec does not include a servers field."""
213+
app = app_factory(
214+
upstream_url=source_api_server,
215+
openapi_spec_endpoint=source_api.openapi_url,
216+
root_path="", # Empty string means no root path
217+
)
218+
client = TestClient(app)
219+
response = client.get(source_api.openapi_url)
220+
assert response.status_code == 200
221+
openapi = response.json()
222+
assert "servers" not in openapi

0 commit comments

Comments
 (0)