Skip to content

Commit de6a946

Browse files
authored
feat: support custom OpenAPI auth scheme (#51)
Currently, we inject an OIDC auth scheme into the OpenAPI spec. However, users may want to run the STAC Auth Proxy to apply auth for tokens that are generated elsewhere (e.g. tokens that can be validated with a JWKS, but are not generated via the `/token` or `/authorization` endpoint). As such, this PR enables the manual override of the auth scheme that we inject into the OpenAPI doc, configurable via env vars.
1 parent 2d04011 commit de6a946

File tree

5 files changed

+61
-7
lines changed

5 files changed

+61
-7
lines changed

README.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,19 @@ The application is configurable via environment variables.
123123
- **Type:** boolean
124124
- **Required:** No, defaults to `true`
125125
- **Example:** `false`, `1`, `True`
126+
- OpenAPI
126127
- **`OPENAPI_SPEC_ENDPOINT`**, path of OpenAPI specification, used for augmenting spec response with auth configuration
127128
- **Type:** string or null
128129
- **Required:** No, defaults to `null` (disabled)
129130
- **Example:** `/api`
131+
- **`OPENAPI_AUTH_SCHEME_NAME`**, name of the auth scheme to use in the OpenAPI spec
132+
- **Type:** string
133+
- **Required:** No, defaults to `oidcAuth`
134+
- **Example:** `jwtAuth`
135+
- **`OPENAPI_AUTH_SCHEME_OVERRIDE`**, override for the auth scheme in the OpenAPI spec
136+
- **Type:** JSON object
137+
- **Required:** No, defaults to `null` (disabled)
138+
- **Example:** `{"type": "http", "scheme": "bearer", "bearerFormat": "JWT", "description": "Paste your raw JWT here. This API uses Bearer token authorization.\n"}`
130139
- Filtering
131140
- **`ITEMS_FILTER_CLS`**, CQL2 expression generator for item-level filtering
132141
- **Type:** JSON object with class configuration
@@ -139,7 +148,7 @@ The application is configurable via environment variables.
139148
- **`ITEMS_FILTER_KWARGS`**, Keyword arguments for CQL2 expression generator
140149
- **Type:** Dictionary of keyword arguments used to initialize the class
141150
- **Required:** No, defaults to `{}`
142-
- **Example:** `{ "field_name": "properties.organization" }`
151+
- **Example:** `{"field_name": "properties.organization"}`
143152

144153
### Customization
145154

src/stac_auth_proxy/app.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ async def lifespan(app: FastAPI):
103103
public_endpoints=settings.public_endpoints,
104104
private_endpoints=settings.private_endpoints,
105105
default_public=settings.default_public,
106+
auth_scheme_name=settings.openapi_auth_scheme_name,
107+
auth_scheme_override=settings.openapi_auth_scheme_override,
106108
)
107109

108110
if settings.items_filter:

src/stac_auth_proxy/config.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,17 @@ class Settings(BaseSettings):
3838
oidc_discovery_url: HttpUrl
3939
oidc_discovery_internal_url: HttpUrl
4040

41+
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
4142
wait_for_upstream: bool = True
4243
check_conformance: bool = True
4344
enable_compression: bool = True
44-
enable_authentication_extension: bool = True
45-
healthz_prefix: str = Field(pattern=_PREFIX_PATTERN, default="/healthz")
45+
4646
openapi_spec_endpoint: Optional[str] = Field(pattern=_PREFIX_PATTERN, default=None)
47+
openapi_auth_scheme_name: str = "oidcAuth"
48+
openapi_auth_scheme_override: Optional[dict] = None
4749

4850
# Auth
51+
enable_authentication_extension: bool = True
4952
default_public: bool = False
5053
public_endpoints: EndpointMethodsNoScope = {
5154
r"^/api.html$": ["GET"],

src/stac_auth_proxy/middleware/UpdateOpenApiMiddleware.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import re
44
from dataclasses import dataclass
5-
from typing import Any
5+
from typing import Any, Optional
66

77
from starlette.datastructures import Headers
88
from starlette.requests import Request
@@ -23,7 +23,8 @@ class OpenApiMiddleware(JsonResponseMiddleware):
2323
private_endpoints: EndpointMethods
2424
public_endpoints: EndpointMethods
2525
default_public: bool
26-
oidc_auth_scheme_name: str = "oidcAuth"
26+
auth_scheme_name: str = "oidcAuth"
27+
auth_scheme_override: Optional[dict] = None
2728

2829
json_content_type_expr: str = r"application/(vnd\.oai\.openapi\+json?|json)"
2930

@@ -47,7 +48,7 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
4748
"""Augment the OpenAPI spec with auth information."""
4849
components = data.setdefault("components", {})
4950
securitySchemes = components.setdefault("securitySchemes", {})
50-
securitySchemes[self.oidc_auth_scheme_name] = {
51+
securitySchemes[self.auth_scheme_name] = self.auth_scheme_override or {
5152
"type": "openIdConnect",
5253
"openIdConnectUrl": self.oidc_config_url,
5354
}
@@ -62,6 +63,6 @@ def transform_json(self, data: dict[str, Any], request: Request) -> dict[str, An
6263
)
6364
if match.is_private:
6465
config.setdefault("security", []).append(
65-
{self.oidc_auth_scheme_name: match.required_scopes}
66+
{self.auth_scheme_name: match.required_scopes}
6667
)
6768
return data

tests/test_openapi.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,42 @@ def test_oidc_in_openapi_spec_public_endpoints(
151151
assert any(
152152
method.casefold() == m.casefold() for m in expected_auth[path]
153153
)
154+
155+
156+
def test_auth_scheme_name_override(source_api: FastAPI, source_api_server: str):
157+
"""When auth_scheme_name is overridden, the OpenAPI spec uses the custom name."""
158+
custom_name = "customAuth"
159+
app = app_factory(
160+
upstream_url=source_api_server,
161+
openapi_spec_endpoint=source_api.openapi_url,
162+
openapi_auth_scheme_name=custom_name,
163+
)
164+
client = TestClient(app)
165+
response = client.get(source_api.openapi_url)
166+
assert response.status_code == 200
167+
openapi = response.json()
168+
security_schemes = openapi.get("components", {}).get("securitySchemes", {})
169+
assert custom_name in security_schemes
170+
assert "oidcAuth" not in security_schemes
171+
172+
173+
def test_auth_scheme_override(source_api: FastAPI, source_api_server: str):
174+
"""When auth_scheme_override is provided, the OpenAPI spec uses the custom scheme."""
175+
custom_scheme = {
176+
"type": "http",
177+
"scheme": "bearer",
178+
"bearerFormat": "JWT",
179+
"description": "Custom JWT authentication",
180+
}
181+
app = app_factory(
182+
upstream_url=source_api_server,
183+
openapi_spec_endpoint=source_api.openapi_url,
184+
openapi_auth_scheme_override=custom_scheme,
185+
)
186+
client = TestClient(app)
187+
response = client.get(source_api.openapi_url)
188+
assert response.status_code == 200
189+
openapi = response.json()
190+
security_schemes = openapi.get("components", {}).get("securitySchemes", {})
191+
assert "oidcAuth" in security_schemes
192+
assert security_schemes["oidcAuth"] == custom_scheme

0 commit comments

Comments
 (0)