39
39
Upper ,
40
40
)
41
41
42
+ from .lookups import is_constant_value
42
43
from .query_utils import process_lhs
43
44
44
45
MONGO_OPERATORS = {
@@ -84,18 +85,18 @@ def cast(self, compiler, connection, **extra): # noqa: ARG001
84
85
return lhs_mql
85
86
86
87
87
- def concat (self , compiler , connection ):
88
- return self .get_source_expressions ()[0 ].as_mql (compiler , connection )
88
+ def concat (self , compiler , connection , as_path = False ):
89
+ return self .get_source_expressions ()[0 ].as_mql (compiler , connection , as_path = as_path )
89
90
90
91
91
- def concat_pair (self , compiler , connection ):
92
+ def concat_pair (self , compiler , connection , as_path = False ): # noqa: ARG001
92
93
# null on either side results in null for expression, wrap with coalesce.
93
94
coalesced = self .coalesce ()
94
- return super (ConcatPair , coalesced ).as_mql (compiler , connection )
95
+ return super (ConcatPair , coalesced ).as_mql (compiler , connection , as_path = False )
95
96
96
97
97
- def cot (self , compiler , connection ):
98
- lhs_mql = process_lhs (self , compiler , connection )
98
+ def cot (self , compiler , connection , as_path = False ): # noqa: ARG001
99
+ lhs_mql = process_lhs (self , compiler , connection , as_path = False )
99
100
return {"$divide" : [1 , {"$tan" : lhs_mql }]}
100
101
101
102
@@ -117,8 +118,8 @@ def func(self, compiler, connection, **extra): # noqa: ARG001
117
118
return {f"${ operator } " : lhs_mql }
118
119
119
120
120
- def left (self , compiler , connection ):
121
- return self .get_substr ().as_mql (compiler , connection )
121
+ def left (self , compiler , connection , as_path = False ): # noqa: ARG001
122
+ return self .get_substr ().as_mql (compiler , connection , as_path = False )
122
123
123
124
124
125
def length (self , compiler , connection , as_path = False ): # noqa: ARG001
@@ -127,28 +128,35 @@ def length(self, compiler, connection, as_path=False): # noqa: ARG001
127
128
return {"$cond" : {"if" : {"$eq" : [lhs_mql , None ]}, "then" : None , "else" : {"$strLenCP" : lhs_mql }}}
128
129
129
130
130
- def log (self , compiler , connection ):
131
+ def log (self , compiler , connection , as_path = False ): # noqa: ARG001
131
132
# This function is usually log(base, num) but on MongoDB it's log(num, base).
132
133
clone = self .copy ()
133
134
clone .set_source_expressions (self .get_source_expressions ()[::- 1 ])
134
135
return func (clone , compiler , connection )
135
136
136
137
137
- def now (self , compiler , connection ): # noqa: ARG001
138
+ def now (self , compiler , connection , as_path = False ): # noqa: ARG001
138
139
return "$$NOW"
139
140
140
141
141
- def null_if (self , compiler , connection ):
142
+ def null_if (self , compiler , connection , as_path = False ): # noqa: ARG001
142
143
"""Return None if expr1==expr2 else expr1."""
143
- expr1 , expr2 = (expr .as_mql (compiler , connection ) for expr in self .get_source_expressions ())
144
+ expr1 , expr2 = (
145
+ expr .as_mql (compiler , connection , as_path = False ) for expr in self .get_source_expressions ()
146
+ )
144
147
return {"$cond" : {"if" : {"$eq" : [expr1 , expr2 ]}, "then" : None , "else" : expr1 }}
145
148
146
149
147
150
def preserve_null (operator ):
148
151
# If the argument is null, the function should return null, not
149
152
# $toLower/Upper's behavior of returning an empty string.
150
- def wrapped (self , compiler , connection ):
151
- lhs_mql = process_lhs (self , compiler , connection )
153
+ def wrapped (self , compiler , connection , as_path = False ):
154
+ if is_constant_value (self .lhs ) and as_path :
155
+ if self .lhs is None :
156
+ return None
157
+ lhs_mql = process_lhs (self , compiler , connection , as_path = True )
158
+ return lhs_mql .upper ()
159
+ lhs_mql = process_lhs (self , compiler , connection , as_path = False )
152
160
return {
153
161
"$expr" : {
154
162
"$cond" : {
@@ -162,24 +170,29 @@ def wrapped(self, compiler, connection):
162
170
return wrapped
163
171
164
172
165
- def replace (self , compiler , connection ):
166
- expression , text , replacement = process_lhs (self , compiler , connection )
173
+ def replace (self , compiler , connection , as_path = False ):
174
+ expression , text , replacement = process_lhs (self , compiler , connection , as_path = as_path )
167
175
return {"$replaceAll" : {"input" : expression , "find" : text , "replacement" : replacement }}
168
176
169
177
170
- def round_ (self , compiler , connection ):
178
+ def round_ (self , compiler , connection , as_path = False ): # noqa: ARG001
171
179
# Round needs its own function because it's a special case that inherits
172
180
# from Transform but has two arguments.
173
- return {"$round" : [expr .as_mql (compiler , connection ) for expr in self .get_source_expressions ()]}
181
+ return {
182
+ "$round" : [
183
+ expr .as_mql (compiler , connection , as_path = False )
184
+ for expr in self .get_source_expressions ()
185
+ ]
186
+ }
174
187
175
188
176
- def str_index (self , compiler , connection ):
189
+ def str_index (self , compiler , connection , as_path = False ): # noqa: ARG001
177
190
lhs = process_lhs (self , compiler , connection )
178
191
# StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB.
179
192
return {"$add" : [{"$indexOfCP" : lhs }, 1 ]}
180
193
181
194
182
- def substr (self , compiler , connection , ** extra ): # noqa: ARG001
195
+ def substr (self , compiler , connection , as_path = False ): # noqa: ARG001
183
196
lhs = process_lhs (self , compiler , connection )
184
197
# The starting index is zero-indexed on MongoDB rather than one-indexed.
185
198
lhs [1 ] = {"$add" : [lhs [1 ], - 1 ]}
@@ -191,14 +204,14 @@ def substr(self, compiler, connection, **extra): # noqa: ARG001
191
204
192
205
193
206
def trim (operator ):
194
- def wrapped (self , compiler , connection ):
207
+ def wrapped (self , compiler , connection , as_path = False ): # noqa: ARG001
195
208
lhs = process_lhs (self , compiler , connection )
196
209
return {f"${ operator } " : {"input" : lhs }}
197
210
198
211
return wrapped
199
212
200
213
201
- def trunc (self , compiler , connection , ** extra ): # noqa: ARG001
214
+ def trunc (self , compiler , connection , as_path = False ): # noqa: ARG001
202
215
lhs_mql = process_lhs (self , compiler , connection )
203
216
lhs_mql = {"date" : lhs_mql , "unit" : self .kind , "startOfWeek" : "mon" }
204
217
if timezone := self .get_tzname ():
@@ -257,11 +270,11 @@ def trunc_date(self, compiler, connection, **extra): # noqa: ARG001
257
270
}
258
271
259
272
260
- def trunc_time (self , compiler , connection ):
273
+ def trunc_time (self , compiler , connection , as_path = False ): # noqa: ARG001
261
274
tzname = self .get_tzname ()
262
275
if tzname and tzname != "UTC" :
263
276
raise NotSupportedError (f"TruncTime with tzinfo ({ tzname } ) isn't supported on MongoDB." )
264
- lhs_mql = process_lhs (self , compiler , connection )
277
+ lhs_mql = process_lhs (self , compiler , connection , as_path = False )
265
278
return {
266
279
"$dateFromString" : {
267
280
"dateString" : {
0 commit comments