Skip to content

Commit 6b20c06

Browse files
committed
WIP.
1 parent b81c822 commit 6b20c06

File tree

11 files changed

+170
-89
lines changed

11 files changed

+170
-89
lines changed

django_mongodb_backend/base.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,17 @@ def _isnull_operator_match(a, b):
142142
"iregex": lambda a, b: regex_expr(a, b, insensitive=True),
143143
}
144144

145+
def range_match(a, b):
146+
## TODO: MAKE A TEST TO TEST WHEN BOTH ENDS ARE NONE. WHAT SHALL I RETURN?
147+
conditions = []
148+
if b[0] is not None:
149+
conditions.append({a: {"$gte": b[0]}})
150+
if b[1] is not None:
151+
conditions.append({a: {"$lte": b[1]}})
152+
if not conditions:
153+
return {"$literal": True}
154+
return {"$and": conditions}
155+
145156
mongo_operators_match = {
146157
"exact": lambda a, b: {a: b},
147158
"gt": lambda a, b: {a: {"$gt": b}},
@@ -156,12 +167,7 @@ def _isnull_operator_match(a, b):
156167
},
157168
"in": lambda a, b: {a: {"$in": list(b)}},
158169
"isnull": _isnull_operator_match,
159-
"range": lambda a, b: {
160-
"$and": [
161-
{"$or": [DatabaseWrapper._isnull_operator_match(b[0], True), {a: {"$gte": b[0]}}]},
162-
{"$or": [DatabaseWrapper._isnull_operator_match(b[1], True), {a: {"$lte": b[1]}}]},
163-
]
164-
},
170+
"range": range_match,
165171
"iexact": lambda a, b: regex_match(a, f"^{b}$", insensitive=True),
166172
"startswith": lambda a, b: regex_match(a, f"^{b}"),
167173
"istartswith": lambda a, b: regex_match(a, f"^{b}", insensitive=True),

django_mongodb_backend/compiler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -707,16 +707,16 @@ def get_project_fields(self, columns=None, ordering=None, force_expression=False
707707
# For brevity/simplicity, project {"field_name": 1}
708708
# instead of {"field_name": "$field_name"}.
709709
if isinstance(expr, Col) and name == expr.target.column and not force_expression
710-
else expr.as_mql(self, self.connection, as_expr=force_expression)
710+
else expr.as_mql(self, self.connection, as_path=False)
711711
)
712712
except EmptyResultSet:
713713
empty_result_set_value = getattr(expr, "empty_result_set_value", NotImplemented)
714714
value = (
715715
False if empty_result_set_value is NotImplemented else empty_result_set_value
716716
)
717-
fields[collection][name] = Value(value).as_mql(self, self.connection)
717+
fields[collection][name] = Value(value).as_mql(self, self.connection, as_path=False)
718718
except FullResultSet:
719-
fields[collection][name] = Value(True).as_mql(self, self.connection)
719+
fields[collection][name] = Value(True).as_mql(self, self.connection, as_path=False)
720720
# Annotations (stored in None) and the main collection's fields
721721
# should appear in the top-level of the fields dict.
722722
fields.update(fields.pop(None, {}))

django_mongodb_backend/expressions/builtins.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def case(self, compiler, connection, **extra): # noqa: ARG001
3434
for case in self.cases:
3535
case_mql = {}
3636
try:
37-
case_mql["case"] = case.as_mql(compiler, connection, as_expr=True)
37+
case_mql["case"] = case.as_mql(compiler, connection, as_path=False)
3838
except EmptyResultSet:
3939
continue
4040
except FullResultSet:
@@ -54,7 +54,7 @@ def case(self, compiler, connection, **extra): # noqa: ARG001
5454
}
5555

5656

57-
def col(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG001
57+
def col(self, compiler, connection, as_path=False): # noqa: ARG001
5858
# If the column is part of a subquery and belongs to one of the parent
5959
# queries, it will be stored for reference using $let in a $lookup stage.
6060
# If the query is built with `alias_cols=False`, treat the column as
@@ -72,16 +72,16 @@ def col(self, compiler, connection, as_path=False, as_expr=None): # noqa: ARG00
7272
# Add the column's collection's alias for columns in joined collections.
7373
has_alias = self.alias and self.alias != compiler.collection_name
7474
prefix = f"{self.alias}." if has_alias else ""
75-
if not as_path or as_expr:
75+
if not as_path:
7676
prefix = f"${prefix}"
7777
return f"{prefix}{self.target.column}"
7878

7979

80-
def col_pairs(self, compiler, connection):
80+
def col_pairs(self, compiler, connection, as_path=False):
8181
cols = self.get_cols()
8282
if len(cols) > 1:
8383
raise NotSupportedError("ColPairs is not supported.")
84-
return cols[0].as_mql(compiler, connection)
84+
return cols[0].as_mql(compiler, connection, as_path=as_path)
8585

8686

8787
def combined_expression(self, compiler, connection, **extra):
@@ -96,15 +96,15 @@ def expression_wrapper(self, compiler, connection, **extra):
9696
return self.expression.as_mql(compiler, connection, **extra)
9797

9898

99-
def negated_expression(self, compiler, connection):
100-
return {"$not": expression_wrapper(self, compiler, connection)}
99+
def negated_expression(self, compiler, connection, **extra):
100+
return {"$not": expression_wrapper(self, compiler, connection, **extra)}
101101

102102

103103
def order_by(self, compiler, connection):
104104
return self.expression.as_mql(compiler, connection)
105105

106106

107-
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False, as_expr=None):
107+
def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False):
108108
subquery_compiler = self.get_compiler(connection=connection)
109109
subquery_compiler.pre_sql_setup(with_col_aliases=False)
110110
field_name, expr = subquery_compiler.columns[0]
@@ -146,7 +146,7 @@ def query(self, compiler, connection, get_wrapping_pipeline=None, as_path=False,
146146
# Erase project_fields since the required value is projected above.
147147
subquery.project_fields = None
148148
compiler.subqueries.append(subquery)
149-
if as_path and not as_expr:
149+
if as_path:
150150
return f"{table_output}.{field_name}"
151151
return f"${table_output}.{field_name}"
152152

@@ -201,8 +201,10 @@ def when(self, compiler, connection, **extra):
201201
return self.condition.as_mql(compiler, connection, **extra)
202202

203203

204-
def value(self, compiler, connection, **extra): # noqa: ARG001
204+
def value(self, compiler, connection, as_path=False): # noqa: ARG001
205205
value = self.value
206+
if as_path:
207+
return value
206208
if isinstance(value, (list, int)):
207209
# Wrap lists & numbers in $literal to prevent ambiguity when Value
208210
# appears in $project.

django_mongodb_backend/expressions/search.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -933,10 +933,12 @@ def __str__(self):
933933
def __repr__(self):
934934
return f"SearchText({self.lhs}, {self.rhs})"
935935

936-
def as_mql(self, compiler, connection):
937-
lhs_mql = process_lhs(self, compiler, connection)
938-
value = process_rhs(self, compiler, connection)
939-
return {"$gte": [lhs_mql, value]}
936+
def as_mql(self, compiler, connection, as_path=False):
937+
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
938+
value = process_rhs(self, compiler, connection, as_path=as_path)
939+
if as_path:
940+
return {lhs_mql: {"$gte": value}}
941+
return {"$expr": {"$gte": [lhs_mql, value]}}
940942

941943

942944
CharField.register_lookup(SearchTextLookup)

django_mongodb_backend/fields/array.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.utils.translation import gettext_lazy as _
88

99
from ..forms import SimpleArrayField
10+
from ..lookups import is_constant_value, is_simple_column
1011
from ..query_utils import process_lhs, process_rhs
1112
from ..utils import prefix_validation_error
1213
from ..validators import ArrayMaxLengthValidator, LengthValidator
@@ -251,25 +252,38 @@ def __init__(self, lhs, rhs):
251252
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
252253
lookup_name = "contains"
253254

254-
def as_mql(self, compiler, connection):
255-
lhs_mql = process_lhs(self, compiler, connection)
256-
value = process_rhs(self, compiler, connection)
255+
def as_mql(self, compiler, connection, as_path=False):
256+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
257+
value = process_rhs(self, compiler, connection, as_path=False)
257258
return {
258-
"$and": [
259-
{"$ne": [lhs_mql, None]},
260-
{"$ne": [value, None]},
261-
{"$setIsSubset": [value, lhs_mql]},
262-
]
259+
"$expr": {
260+
"$and": [
261+
{"$ne": [lhs_mql, None]},
262+
{"$ne": [value, None]},
263+
{"$setIsSubset": [value, lhs_mql]},
264+
]
265+
}
263266
}
264267

265268

266269
@ArrayField.register_lookup
267270
class ArrayContainedBy(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
268271
lookup_name = "contained_by"
269272

270-
def as_mql(self, compiler, connection):
271-
lhs_mql = process_lhs(self, compiler, connection)
272-
value = process_rhs(self, compiler, connection)
273+
def as_mql(self, compiler, connection, as_path=False):
274+
if as_path and is_simple_column(self.lhs) and is_constant_value(self.rhs):
275+
lhs_mql = process_lhs(self, compiler, connection, as_path=as_path)
276+
value = process_rhs(self, compiler, connection, as_path=as_path)
277+
return {
278+
"$and": [
279+
# {lhs_mql: {"$ne": None}},
280+
{value: {"$ne": None}},
281+
{lhs_mql: {"$all": value}},
282+
]
283+
}
284+
285+
lhs_mql = process_lhs(self, compiler, connection, as_path=False)
286+
value = process_rhs(self, compiler, connection, as_path=False)
273287
return {
274288
"$and": [
275289
{"$ne": [lhs_mql, None]},

0 commit comments

Comments
 (0)