1
1
"""Middleware to apply CQL2 filters."""
2
2
3
3
import json
4
- from dataclasses import dataclass
4
+ import re
5
+ from dataclasses import dataclass , field
6
+ from functools import partial
5
7
from logging import getLogger
8
+ from typing import Callable , Optional
6
9
10
+ from cql2 import Expr
11
+ from starlette .datastructures import MutableHeaders , State
7
12
from starlette .requests import Request
8
13
from starlette .types import ASGIApp , Message , Receive , Scope , Send
9
14
@@ -17,7 +22,6 @@ class ApplyCql2FilterMiddleware:
17
22
"""Middleware to apply the Cql2Filter to the request."""
18
23
19
24
app : ASGIApp
20
-
21
25
state_key : str = "cql2_filter"
22
26
23
27
async def __call__ (self , scope : Scope , receive : Receive , send : Send ) -> None :
@@ -27,34 +31,123 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
27
31
28
32
request = Request (scope )
29
33
30
- if request .method == "GET" :
31
- cql2_filter = getattr (request .state , self .state_key , None )
32
- if cql2_filter :
33
- scope ["query_string" ] = filters .append_qs_filter (
34
- request .url .query , cql2_filter
35
- )
34
+ get_cql2_filter : Callable [[], Optional [Expr ]] = partial (
35
+ getattr , request .state , self .state_key , None
36
+ )
37
+
38
+ # Handle POST, PUT, PATCH
39
+ if request .method in ["POST" , "PUT" , "PATCH" ]:
40
+ return await self .app (
41
+ scope ,
42
+ Cql2RequestBodyAugmentor (
43
+ receive = receive ,
44
+ state = request .state ,
45
+ get_cql2_filter = get_cql2_filter ,
46
+ ),
47
+ send ,
48
+ )
49
+
50
+ cql2_filter = get_cql2_filter ()
51
+ if not cql2_filter :
36
52
return await self .app (scope , receive , send )
37
53
38
- elif request .method in ["POST" , "PUT" , "PATCH" ]:
39
-
40
- async def receive_and_apply_filter () -> Message :
41
- message = await receive ()
42
- if message ["type" ] != "http.request" :
43
- return message
44
-
45
- cql2_filter = getattr (request .state , self .state_key , None )
46
- if cql2_filter :
47
- try :
48
- body = json .loads (message .get ("body" , b"{}" ))
49
- except json .JSONDecodeError as e :
50
- logger .warning ("Failed to parse request body as JSON" )
51
- # TODO: Return a 400 error
52
- raise e
54
+ if re .match (r"^/collections/([^/]+)/items/([^/]+)$" , request .url .path ):
55
+ return await self .app (
56
+ scope ,
57
+ receive ,
58
+ Cql2ResponseBodyValidator (cql2_filter = cql2_filter , send = send ),
59
+ )
53
60
54
- new_body = filters .append_body_filter (body , cql2_filter )
55
- message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
56
- return message
61
+ scope ["query_string" ] = filters .append_qs_filter (request .url .query , cql2_filter )
62
+ return await self .app (scope , receive , send )
57
63
58
- return await self .app (scope , receive_and_apply_filter , send )
59
64
60
- return await self .app (scope , receive , send )
65
+ @dataclass (frozen = True )
66
+ class Cql2RequestBodyAugmentor :
67
+ """Handler to augment the request body with a CQL2 filter."""
68
+
69
+ receive : Receive
70
+ state : State
71
+ get_cql2_filter : Callable [[], Optional [Expr ]]
72
+
73
+ async def __call__ (self ) -> Message :
74
+ """Process a request body and augment with a CQL2 filter if available."""
75
+ message = await self .receive ()
76
+ if message ["type" ] != "http.request" :
77
+ return message
78
+
79
+ # NOTE: Can only get cql2 filter _after_ calling self.receive()
80
+ cql2_filter = self .get_cql2_filter ()
81
+ if not cql2_filter :
82
+ return message
83
+
84
+ try :
85
+ body = json .loads (message .get ("body" , b"{}" ))
86
+ except json .JSONDecodeError as e :
87
+ logger .warning ("Failed to parse request body as JSON" )
88
+ # TODO: Return a 400 error
89
+ raise e
90
+
91
+ new_body = filters .append_body_filter (body , cql2_filter )
92
+ message ["body" ] = json .dumps (new_body ).encode ("utf-8" )
93
+ return message
94
+
95
+
96
+ @dataclass
97
+ class Cql2ResponseBodyValidator :
98
+ """Handler to validate response body with CQL2."""
99
+
100
+ send : Send
101
+ cql2_filter : Expr
102
+ initial_message : Optional [Message ] = field (init = False )
103
+ body : bytes = field (init = False , default_factory = bytes )
104
+
105
+ async def __call__ (self , message : Message ) -> None :
106
+ """Process a response message and apply filtering if needed."""
107
+ if message ["type" ] == "http.response.start" :
108
+ self .initial_message = message
109
+ return
110
+
111
+ if message ["type" ] == "http.response.body" :
112
+ assert self .initial_message , "Initial message not set"
113
+
114
+ self .body += message ["body" ]
115
+ if message .get ("more_body" ):
116
+ return
117
+
118
+ try :
119
+ body_json = json .loads (self .body )
120
+ except json .JSONDecodeError :
121
+ logger .warning ("Failed to parse response body as JSON" )
122
+ await self ._send_error_response (502 , "Not found" )
123
+ return
124
+
125
+ logger .debug (
126
+ "Applying %s filter to %s" , self .cql2_filter .to_text (), body_json
127
+ )
128
+ if self .cql2_filter .matches (body_json ):
129
+ await self .send (self .initial_message )
130
+ return await self .send (
131
+ {
132
+ "type" : "http.response.body" ,
133
+ "body" : json .dumps (body_json ).encode ("utf-8" ),
134
+ "more_body" : False ,
135
+ }
136
+ )
137
+ return await self ._send_error_response (404 , "Not found" )
138
+
139
+ async def _send_error_response (self , status : int , message : str ) -> None :
140
+ """Send an error response with the given status and message."""
141
+ assert self .initial_message , "Initial message not set"
142
+ error_body = json .dumps ({"message" : message }).encode ("utf-8" )
143
+ headers = MutableHeaders (scope = self .initial_message )
144
+ headers ["content-length" ] = str (len (error_body ))
145
+ self .initial_message ["status" ] = status
146
+ await self .send (self .initial_message )
147
+ await self .send (
148
+ {
149
+ "type" : "http.response.body" ,
150
+ "body" : error_body ,
151
+ "more_body" : False ,
152
+ }
153
+ )
0 commit comments