Skip to content

Commit 04f6307

Browse files
committed
WIP.
1 parent 5b74b5c commit 04f6307

File tree

11 files changed

+251
-201
lines changed

11 files changed

+251
-201
lines changed

django_mongodb_backend/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def _isnull_operator_match(a, b):
123123
"lte": lambda a, b: {
124124
"$and": [{"$lte": [a, b]}, DatabaseWrapper._isnull_operator(a, False)]
125125
},
126-
"in": lambda a, b: {"$in": [a, b]},
126+
"in": lambda a, b: {"$in": (a, b)},
127127
"isnull": _isnull_operator,
128128
"range": lambda a, b: {
129129
"$and": [
@@ -165,7 +165,7 @@ def range_match(a, b):
165165
"lte": lambda a, b: {
166166
"$and": [{a: {"$lte": b}}, DatabaseWrapper._isnull_operator_match(a, False)]
167167
},
168-
"in": lambda a, b: {a: {"$in": list(b)}},
168+
"in": lambda a, b: {a: {"$in": tuple(b)}},
169169
"isnull": _isnull_operator_match,
170170
"range": range_match,
171171
"iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True),

django_mongodb_backend/expressions/builtins.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -203,9 +203,7 @@ def when(self, compiler, connection, **extra):
203203

204204
def value(self, compiler, connection, as_path=False): # noqa: ARG001
205205
value = self.value
206-
if as_path:
207-
return value
208-
if isinstance(value, (list, int)):
206+
if isinstance(value, (list, int)) and not as_path:
209207
# Wrap lists & numbers in $literal to prevent ambiguity when Value
210208
# appears in $project.
211209
return {"$literal": value}

django_mongodb_backend/features.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,6 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures):
9090
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
9191
# GenericRelation.value_to_string() assumes integer pk.
9292
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
93-
# icontains doesn't work on ArrayField:
94-
# Unsupported conversion from array to string in $convert
95-
"model_fields_.test_arrayfield.QueryingTests.test_icontains",
9693
# ArrayField's contained_by lookup crashes with Exists: "both operands "
9794
# of $setIsSubset must be arrays. Second argument is of type: null"
9895
# https://jira.mongodb.org/browse/SERVER-99186

django_mongodb_backend/fields/array.py

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,11 @@ def formfield(self, **kwargs):
231231

232232

233233
class Array(Func):
234-
def as_mql(self, compiler, connection):
235-
return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]
234+
def as_mql(self, compiler, connection, as_path=False):
235+
return [
236+
expr.as_mql(compiler, connection, as_path=as_path)
237+
for expr in self.get_source_expressions()
238+
]
236239

237240

238241
class ArrayRHSMixin:
@@ -253,6 +256,12 @@ class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
253256
lookup_name = "contains"
254257

255258
def as_mql(self, compiler, connection, as_path=False):
259+
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
260+
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
261+
value = process_rhs(self, compiler, connection, as_path=as_path)
262+
if value is None:
263+
return False
264+
return {lhs_mql: {"$all": value}}
256265
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
257266
value = process_rhs(self, compiler, connection, as_path=False)
258267
return {
@@ -271,25 +280,16 @@ class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
271280
lookup_name = "contained_by"
272281

273282
def as_mql(self, compiler, connection, as_path=False):
274-
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
275-
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
276-
value = process_rhs(self, compiler, connection, as_path=as_path)
277-
return {
278-
"$and": [
279-
# {lhs_mql: {"$ne": None}},
280-
{value: {"$ne": None}},
281-
{lhs_mql: {"$all": value}},
282-
]
283-
}
284-
285283
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
286284
value = process_rhs(self, compiler, connection, as_path=False)
287285
return {
288-
"$and": [
289-
{"$ne": [lhs_mql, None]},
290-
{"$ne": [value, None]},
291-
{"$setIsSubset": [lhs_mql, value]},
292-
]
286+
"$expr": {
287+
"$and": [
288+
{"$ne": [lhs_mql, None]},
289+
{"$ne": [value, None]},
290+
{"$setIsSubset": [lhs_mql, value]},
291+
]
292+
}
293293
}
294294

295295

@@ -337,11 +337,21 @@ def get_subquery_wrapping_pipeline(self, compiler, connection, field_name, expr)
337337
},
338338
]
339339

340-
def as_mql(self, compiler, connection):
341-
lhs_mql = process_lhs(self, compiler, connection)
342-
value = process_rhs(self, compiler, connection)
340+
def as_mql(self, compiler, connection, as_path=False):
341+
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
342+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
343+
value = process_rhs(self, compiler, connection, as_path=True)
344+
return {lhs_mql: {"$in": value}}
345+
346+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
347+
value = process_rhs(self, compiler, connection, as_path=False)
343348
return {
344-
"$and": [{"$ne": [lhs_mql, None]}, {"$size": {"$setIntersection": [value, lhs_mql]}}]
349+
"$expr": {
350+
"$and": [
351+
{"$ne": [lhs_mql, None]},
352+
{"$size": {"$setIntersection": [value, lhs_mql]}},
353+
]
354+
}
345355
}
346356

347357

@@ -350,8 +360,8 @@ class ArrayLenTransform(Transform):
350360
lookup_name = "len"
351361
output_field = IntegerField()
352362

353-
def as_mql(self, compiler, connection):
354-
lhs_mql = process_lhs(self, compiler, connection)
363+
def as_mql(self, compiler, connection, as_path=False):
364+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
355365
return {"$cond": {"if": {"$isArray": lhs_mql}, "then": {"$size": lhs_mql}, "else": None}}
356366

357367

@@ -377,8 +387,10 @@ def __init__(self, index, base_field, *args, **kwargs):
377387
self.index = index
378388
self.base_field = base_field
379389

380-
def as_mql(self, compiler, connection):
381-
lhs_mql = process_lhs(self, compiler, connection)
390+
def as_mql(self, compiler, connection, as_path=False):
391+
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
392+
if as_path:
393+
return f"{lhs_mql}.{self.index}"
382394
return {"$arrayElemAt": [lhs_mql, self.index]}
383395

384396
@property
@@ -401,7 +413,7 @@ def __init__(self, start, end, *args, **kwargs):
401413
self.start = start
402414
self.end = end
403415

404-
def as_mql(self, compiler, connection):
416+
def as_mql(self, compiler, connection, as_path=False):
405417
lhs_mql = process_lhs(self, compiler, connection)
406418
return {"$slice": [lhs_mql, self.start, self.end]}
407419

django_mongodb_backend/fields/embedded_model_array.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _get_lookup(self, lookup_name):
7575
return lookup
7676

7777
class EmbeddedModelArrayFieldLookups(Lookup):
78-
def as_mql(self, compiler, connection):
78+
def as_mql(self, compiler, connection, as_path=False):
7979
raise ValueError(
8080
"Lookups aren't supported on EmbeddedModelArrayField. "
8181
"Try querying one of its embedded fields instead."
@@ -114,7 +114,7 @@ def get_lookup(self, name):
114114

115115

116116
class EmbeddedModelArrayFieldBuiltinLookup(Lookup):
117-
def process_rhs(self, compiler, connection):
117+
def process_rhs(self, compiler, connection, as_path=False):
118118
value = self.rhs
119119
if not self.get_db_prep_lookup_value_is_iterable:
120120
value = [value]
@@ -128,17 +128,17 @@ def process_rhs(self, compiler, connection):
128128
for v in value
129129
]
130130

131-
def as_mql(self, compiler, connection):
131+
def as_mql(self, compiler, connection, as_path=False):
132132
# Querying a subfield within the array elements (via nested
133133
# KeyTransform). Replicate MongoDB's implicit ANY-match by mapping over
134134
# the array and applying $in on the subfield.
135135
lhs_mql = process_lhs(self, compiler, connection)
136136
inner_lhs_mql = lhs_mql["$ifNull"][0]["$map"]["in"]
137137
values = process_rhs(self, compiler, connection)
138-
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators[self.lookup_name](
138+
lhs_mql["$ifNull"][0]["$map"]["in"] = connection.mongo_operators_expr[self.lookup_name](
139139
inner_lhs_mql, values
140140
)
141-
return {"$anyElementTrue": lhs_mql}
141+
return {"$expr": {"$anyElementTrue": lhs_mql}}
142142

143143

144144
@_EmbeddedModelArrayOutputField.register_lookup
@@ -275,7 +275,11 @@ def get_transform(self, name):
275275
f"{suggestion}"
276276
)
277277

278-
def as_mql(self, compiler, connection):
278+
def as_mql(self, compiler, connection, as_path=False):
279+
if as_path:
280+
inner_lhs_mql = self._lhs.as_mql(compiler, connection, as_path=True)
281+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
282+
return f"{inner_lhs_mql}.{lhs_mql}"
279283
inner_lhs_mql = self._lhs.as_mql(compiler, connection)
280284
lhs_mql = process_lhs(self, compiler, connection)
281285
return {

django_mongodb_backend/fields/polymorphic_embedded_model_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _get_lookup(self, lookup_name):
7070
return lookup
7171

7272
class EmbeddedModelArrayFieldLookups(Lookup):
73-
def as_mql(self, compiler, connection):
73+
def as_mql(self, compiler, connection, as_path=False):
7474
raise ValueError(
7575
"Lookups aren't supported on PolymorphicEmbeddedModelArrayField. "
7676
"Try querying one of its embedded fields instead."

django_mongodb_backend/functions.py

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
Upper,
4040
)
4141

42+
from .lookups import is_constant_value
4243
from .query_utils import process_lhs
4344

4445
MONGO_OPERATORS = {
@@ -84,18 +85,18 @@ def cast(self, compiler, connection, **extra): # noqa: ARG001
8485
return lhs_mql
8586

8687

87-
def concat(self, compiler, connection):
88-
return self.get_source_expressions()[0].as_mql(compiler, connection)
88+
def concat(self, compiler, connection, as_path=False):
89+
return self.get_source_expressions()[0].as_mql(compiler, connection, as_path=as_path)
8990

9091

91-
def concat_pair(self, compiler, connection):
92+
def concat_pair(self, compiler, connection, as_path=False): # noqa: ARG001
9293
# null on either side results in null for expression, wrap with coalesce.
9394
coalesced = self.coalesce()
94-
return super(ConcatPair, coalesced).as_mql(compiler, connection)
95+
return super(ConcatPair, coalesced).as_mql(compiler, connection, as_path=False)
9596

9697

97-
def cot(self, compiler, connection):
98-
lhs_mql = process_lhs(self, compiler, connection)
98+
def cot(self, compiler, connection, as_path=False): # noqa: ARG001
99+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
99100
return {"$divide": [1, {"$tan": lhs_mql}]}
100101

101102

@@ -117,8 +118,8 @@ def func(self, compiler, connection, **extra): # noqa: ARG001
117118
return {f"${operator}": lhs_mql}
118119

119120

120-
def left(self, compiler, connection):
121-
return self.get_substr().as_mql(compiler, connection)
121+
def left(self, compiler, connection, as_path=False): # noqa: ARG001
122+
return self.get_substr().as_mql(compiler, connection, as_path=False)
122123

123124

124125
def length(self, compiler, connection, as_path=False): # noqa: ARG001
@@ -127,28 +128,35 @@ def length(self, compiler, connection, as_path=False): # noqa: ARG001
127128
return {"$cond": {"if": {"$eq": [lhs_mql, None]}, "then": None, "else": {"$strLenCP": lhs_mql}}}
128129

129130

130-
def log(self, compiler, connection):
131+
def log(self, compiler, connection, as_path=False): # noqa: ARG001
131132
# This function is usually log(base, num) but on MongoDB it's log(num, base).
132133
clone = self.copy()
133134
clone.set_source_expressions(self.get_source_expressions()[::-1])
134135
return func(clone, compiler, connection)
135136

136137

137-
def now(self, compiler, connection): # noqa: ARG001
138+
def now(self, compiler, connection, as_path=False): # noqa: ARG001
138139
return "$$NOW"
139140

140141

141-
def null_if(self, compiler, connection):
142+
def null_if(self, compiler, connection, as_path=False): # noqa: ARG001
142143
"""Return None if expr1==expr2 else expr1."""
143-
expr1, expr2 = (expr.as_mql(compiler, connection) for expr in self.get_source_expressions())
144+
expr1, expr2 = (
145+
expr.as_mql(compiler, connection, as_path=False) for expr in self.get_source_expressions()
146+
)
144147
return {"$cond": {"if": {"$eq": [expr1, expr2]}, "then": None, "else": expr1}}
145148

146149

147150
def preserve_null(operator):
148151
# If the argument is null, the function should return null, not
149152
# $toLower/Upper's behavior of returning an empty string.
150-
def wrapped(self, compiler, connection):
151-
lhs_mql = process_lhs(self, compiler, connection)
153+
def wrapped(self, compiler, connection, as_path=False):
154+
if is_constant_value(self.lhs) and as_path:
155+
if self.lhs is None:
156+
return None
157+
lhs_mql = process_lhs(self, compiler, connection, as_path=True)
158+
return lhs_mql.upper()
159+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
152160
return {
153161
"$expr": {
154162
"$cond": {
@@ -162,24 +170,29 @@ def wrapped(self, compiler, connection):
162170
return wrapped
163171

164172

165-
def replace(self, compiler, connection):
166-
expression, text, replacement = process_lhs(self, compiler, connection)
173+
def replace(self, compiler, connection, as_path=False):
174+
expression, text, replacement = process_lhs(self, compiler, connection, as_path=as_path)
167175
return {"$replaceAll": {"input": expression, "find": text, "replacement": replacement}}
168176

169177

170-
def round_(self, compiler, connection):
178+
def round_(self, compiler, connection, as_path=False): # noqa: ARG001
171179
# Round needs its own function because it's a special case that inherits
172180
# from Transform but has two arguments.
173-
return {"$round": [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]}
181+
return {
182+
"$round": [
183+
expr.as_mql(compiler, connection, as_path=False)
184+
for expr in self.get_source_expressions()
185+
]
186+
}
174187

175188

176-
def str_index(self, compiler, connection):
189+
def str_index(self, compiler, connection, as_path=False): # noqa: ARG001
177190
lhs = process_lhs(self, compiler, connection)
178191
# StrIndex should be 0-indexed (not found) but it's -1-indexed on MongoDB.
179192
return {"$add": [{"$indexOfCP": lhs}, 1]}
180193

181194

182-
def substr(self, compiler, connection, **extra): # noqa: ARG001
195+
def substr(self, compiler, connection, as_path=False): # noqa: ARG001
183196
lhs = process_lhs(self, compiler, connection)
184197
# The starting index is zero-indexed on MongoDB rather than one-indexed.
185198
lhs[1] = {"$add": [lhs[1], -1]}
@@ -191,14 +204,14 @@ def substr(self, compiler, connection, **extra): # noqa: ARG001
191204

192205

193206
def trim(operator):
194-
def wrapped(self, compiler, connection):
207+
def wrapped(self, compiler, connection, as_path=False): # noqa: ARG001
195208
lhs = process_lhs(self, compiler, connection)
196209
return {f"${operator}": {"input": lhs}}
197210

198211
return wrapped
199212

200213

201-
def trunc(self, compiler, connection, **extra): # noqa: ARG001
214+
def trunc(self, compiler, connection, as_path=False): # noqa: ARG001
202215
lhs_mql = process_lhs(self, compiler, connection)
203216
lhs_mql = {"date": lhs_mql, "unit": self.kind, "startOfWeek": "mon"}
204217
if timezone := self.get_tzname():
@@ -257,11 +270,11 @@ def trunc_date(self, compiler, connection, **extra): # noqa: ARG001
257270
}
258271

259272

260-
def trunc_time(self, compiler, connection):
273+
def trunc_time(self, compiler, connection, as_path=False): # noqa: ARG001
261274
tzname = self.get_tzname()
262275
if tzname and tzname != "UTC":
263276
raise NotSupportedError(f"TruncTime with tzinfo ({tzname}) isn't supported on MongoDB.")
264-
lhs_mql = process_lhs(self, compiler, connection)
277+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
265278
return {
266279
"$dateFromString": {
267280
"dateString": {

0 commit comments

Comments
 (0)