Skip to content

Commit 3232573

Browse files
committed
Buildout filter for item read (#45)
1 parent 9d8599f commit 3232573

File tree

8 files changed

+351
-196
lines changed

8 files changed

+351
-196
lines changed

README.md

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -191,21 +191,21 @@ If enabled, filters are intended to be applied to the following endpoints:
191191
- **Action:** Read Item
192192
- **Applied Filter:** `ITEMS_FILTER`
193193
- **Strategy:** Append body with generated CQL2 query.
194-
- `GET /collections/{collection_id}`
195-
- **Supported:** ❌[^23]
196-
- **Action:** Read Collection
197-
- **Applied Filter:** `COLLECTIONS_FILTER`
198-
- **Strategy:** Append query params with generated CQL2 query.
199194
- `GET /collections/{collection_id}/items`
200195
- **Supported:** ✅
201196
- **Action:** Read Item
202197
- **Applied Filter:** `ITEMS_FILTER`
203198
- **Strategy:** Append query params with generated CQL2 query.
204199
- `GET /collections/{collection_id}/items/{item_id}`
205-
- **Supported:** ❌[^25]
200+
- **Supported:**
206201
- **Action:** Read Item
207202
- **Applied Filter:** `ITEMS_FILTER`
208203
- **Strategy:** Validate response against CQL2 query.
204+
- `GET /collections/{collection_id}`
205+
- **Supported:** ❌[^23]
206+
- **Action:** Read Collection
207+
- **Applied Filter:** `COLLECTIONS_FILTER`
208+
- **Strategy:** Append query params with generated CQL2 query.
209209
- `POST /collections/`
210210
- **Supported:** ❌[^22]
211211
- **Action:** Create Collection
@@ -257,6 +257,5 @@ sequenceDiagram
257257
[^21]: https://github.com/developmentseed/stac-auth-proxy/issues/21
258258
[^22]: https://github.com/developmentseed/stac-auth-proxy/issues/22
259259
[^23]: https://github.com/developmentseed/stac-auth-proxy/issues/23
260-
[^25]: https://github.com/developmentseed/stac-auth-proxy/issues/25
261260
[^30]: https://github.com/developmentseed/stac-auth-proxy/issues/30
262261
[^37]: https://github.com/developmentseed/stac-auth-proxy/issues/37

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ classifiers = [
88
dependencies = [
99
"authlib>=1.3.2",
1010
"brotli>=1.1.0",
11-
"cql2>=0.3.5",
11+
"cql2>=0.3.6",
1212
"fastapi>=0.115.5",
1313
"httpx[http2]>=0.28.0",
1414
"jinja2>=3.1.4",
Lines changed: 121 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
"""Middleware to apply CQL2 filters."""
22

33
import json
4-
from dataclasses import dataclass
4+
import re
5+
from dataclasses import dataclass, field
6+
from functools import partial
57
from logging import getLogger
8+
from typing import Callable, Optional
69

10+
from cql2 import Expr
11+
from starlette.datastructures import MutableHeaders, State
712
from starlette.requests import Request
813
from starlette.types import ASGIApp, Message, Receive, Scope, Send
914

@@ -17,7 +22,6 @@ class ApplyCql2FilterMiddleware:
1722
"""Middleware to apply the Cql2Filter to the request."""
1823

1924
app: ASGIApp
20-
2125
state_key: str = "cql2_filter"
2226

2327
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
@@ -27,34 +31,123 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
2731

2832
request = Request(scope)
2933

30-
if request.method == "GET":
31-
cql2_filter = getattr(request.state, self.state_key, None)
32-
if cql2_filter:
33-
scope["query_string"] = filters.append_qs_filter(
34-
request.url.query, cql2_filter
35-
)
34+
get_cql2_filter: Callable[[], Optional[Expr]] = partial(
35+
getattr, request.state, self.state_key, None
36+
)
37+
38+
# Handle POST, PUT, PATCH
39+
if request.method in ["POST", "PUT", "PATCH"]:
40+
return await self.app(
41+
scope,
42+
Cql2RequestBodyAugmentor(
43+
receive=receive,
44+
state=request.state,
45+
get_cql2_filter=get_cql2_filter,
46+
),
47+
send,
48+
)
49+
50+
cql2_filter = get_cql2_filter()
51+
if not cql2_filter:
3652
return await self.app(scope, receive, send)
3753

38-
elif request.method in ["POST", "PUT", "PATCH"]:
39-
40-
async def receive_and_apply_filter() -> Message:
41-
message = await receive()
42-
if message["type"] != "http.request":
43-
return message
44-
45-
cql2_filter = getattr(request.state, self.state_key, None)
46-
if cql2_filter:
47-
try:
48-
body = json.loads(message.get("body", b"{}"))
49-
except json.JSONDecodeError as e:
50-
logger.warning("Failed to parse request body as JSON")
51-
# TODO: Return a 400 error
52-
raise e
54+
if re.match(r"^/collections/([^/]+)/items/([^/]+)$", request.url.path):
55+
return await self.app(
56+
scope,
57+
receive,
58+
Cql2ResponseBodyValidator(cql2_filter=cql2_filter, send=send),
59+
)
5360

54-
new_body = filters.append_body_filter(body, cql2_filter)
55-
message["body"] = json.dumps(new_body).encode("utf-8")
56-
return message
61+
scope["query_string"] = filters.append_qs_filter(request.url.query, cql2_filter)
62+
return await self.app(scope, receive, send)
5763

58-
return await self.app(scope, receive_and_apply_filter, send)
5964

60-
return await self.app(scope, receive, send)
65+
@dataclass(frozen=True)
66+
class Cql2RequestBodyAugmentor:
67+
"""Handler to augment the request body with a CQL2 filter."""
68+
69+
receive: Receive
70+
state: State
71+
get_cql2_filter: Callable[[], Optional[Expr]]
72+
73+
async def __call__(self) -> Message:
74+
"""Process a request body and augment with a CQL2 filter if available."""
75+
message = await self.receive()
76+
if message["type"] != "http.request":
77+
return message
78+
79+
# NOTE: Can only get cql2 filter _after_ calling self.receive()
80+
cql2_filter = self.get_cql2_filter()
81+
if not cql2_filter:
82+
return message
83+
84+
try:
85+
body = json.loads(message.get("body", b"{}"))
86+
except json.JSONDecodeError as e:
87+
logger.warning("Failed to parse request body as JSON")
88+
# TODO: Return a 400 error
89+
raise e
90+
91+
new_body = filters.append_body_filter(body, cql2_filter)
92+
message["body"] = json.dumps(new_body).encode("utf-8")
93+
return message
94+
95+
96+
@dataclass
97+
class Cql2ResponseBodyValidator:
98+
"""Handler to validate response body with CQL2."""
99+
100+
send: Send
101+
cql2_filter: Expr
102+
initial_message: Optional[Message] = field(init=False)
103+
body: bytes = field(init=False, default_factory=bytes)
104+
105+
async def __call__(self, message: Message) -> None:
106+
"""Process a response message and apply filtering if needed."""
107+
if message["type"] == "http.response.start":
108+
self.initial_message = message
109+
return
110+
111+
if message["type"] == "http.response.body":
112+
assert self.initial_message, "Initial message not set"
113+
114+
self.body += message["body"]
115+
if message.get("more_body"):
116+
return
117+
118+
try:
119+
body_json = json.loads(self.body)
120+
except json.JSONDecodeError:
121+
logger.warning("Failed to parse response body as JSON")
122+
await self._send_error_response(502, "Not found")
123+
return
124+
125+
logger.debug(
126+
"Applying %s filter to %s", self.cql2_filter.to_text(), body_json
127+
)
128+
if self.cql2_filter.matches(body_json):
129+
await self.send(self.initial_message)
130+
return await self.send(
131+
{
132+
"type": "http.response.body",
133+
"body": json.dumps(body_json).encode("utf-8"),
134+
"more_body": False,
135+
}
136+
)
137+
return await self._send_error_response(404, "Not found")
138+
139+
async def _send_error_response(self, status: int, message: str) -> None:
140+
"""Send an error response with the given status and message."""
141+
assert self.initial_message, "Initial message not set"
142+
error_body = json.dumps({"message": message}).encode("utf-8")
143+
headers = MutableHeaders(scope=self.initial_message)
144+
headers["content-length"] = str(len(error_body))
145+
self.initial_message["status"] = status
146+
await self.send(self.initial_message)
147+
await self.send(
148+
{
149+
"type": "http.response.body",
150+
"body": error_body,
151+
"more_body": False,
152+
}
153+
)

src/stac_auth_proxy/middleware/BuildCql2FilterMiddleware.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
"""Middleware to build the Cql2Filter."""
22

33
import json
4+
import re
45
from dataclasses import dataclass
56
from typing import Callable, Optional
67

78
from cql2 import Expr
89
from starlette.requests import Request
910
from starlette.types import ASGIApp, Message, Receive, Scope, Send
1011

11-
from ..utils import filters, requests
12+
from ..utils import requests
1213

1314

1415
@dataclass(frozen=True)
@@ -78,11 +79,10 @@ async def receive_build_filter() -> Message:
7879
def _get_filter(self, path: str) -> Optional[Callable[..., Expr]]:
7980
"""Get the CQL2 filter builder for the given path."""
8081
endpoint_filters = [
81-
(filters.is_collection_endpoint, self.collections_filter),
82-
(filters.is_item_endpoint, self.items_filter),
83-
(filters.is_search_endpoint, self.items_filter),
82+
(r"^/collections(/[^/]+)?$", self.collections_filter),
83+
(r"^(/collections/([^/]+)/items(/[^/]+)?$|/search$)", self.items_filter),
8484
]
85-
for check, builder in endpoint_filters:
86-
if check(path):
85+
for expr, builder in endpoint_filters:
86+
if re.match(expr, path):
8787
return builder
8888
return None

src/stac_auth_proxy/utils/filters.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Utility functions."""
22

33
import json
4-
import re
54
from typing import Optional
65
from urllib.parse import parse_qs
76

@@ -32,23 +31,6 @@ def append_body_filter(
3231
}
3332

3433

35-
def is_collection_endpoint(path: str) -> bool:
36-
"""Check if the path is a collection endpoint."""
37-
# TODO: Expand this to cover all cases where a collection filter should be applied
38-
return path == "/collections"
39-
40-
41-
def is_item_endpoint(path: str) -> bool:
42-
"""Check if the path is an item endpoint."""
43-
# TODO: Expand this to cover all cases where an item filter should be applied
44-
return bool(re.compile(r"^(/collections/([^/]+)/items$|/search)").match(path))
45-
46-
47-
def is_search_endpoint(path: str) -> bool:
48-
"""Check if the path is a search endpoint."""
49-
return path == "/search"
50-
51-
5234
def dict_to_query_string(params: dict) -> str:
5335
"""
5436
Convert a dictionary to a query string. Dict values are converted to JSON strings,

0 commit comments

Comments
 (0)