diff --git a/django_mongodb_backend/expressions/builtins.py b/django_mongodb_backend/expressions/builtins.py index 0bc939350..9096a20b9 100644 --- a/django_mongodb_backend/expressions/builtins.py +++ b/django_mongodb_backend/expressions/builtins.py @@ -23,6 +23,7 @@ Value, When, ) +from django.db.models.indexes import IndexExpression from django.db.models.sql import Query from ..query_utils import process_lhs @@ -95,12 +96,25 @@ def expression_wrapper(self, compiler, connection): return self.expression.as_mql(compiler, connection) +def index_expression(self, compiler, connection): + result = [] + for expr in self.get_source_expressions(): + if expr is None: + continue + for sub_expr in expr.get_source_expressions(): + try: + result.append(sub_expr.as_mql(compiler, connection, as_path=True)) + except FullResultSet: + result.append(Value(True).as_mql(compiler, connection)) + return result + + def negated_expression(self, compiler, connection): return {"$not": expression_wrapper(self, compiler, connection)} -def order_by(self, compiler, connection): - return self.expression.as_mql(compiler, connection) +def order_by(self, compiler, connection, **extra_args): + return self.expression.as_mql(compiler, connection, **extra_args) def query(self, compiler, connection, get_wrapping_pipeline=None): @@ -217,6 +231,7 @@ def register_expressions(): Exists.as_mql = exists ExpressionList.as_mql = process_lhs ExpressionWrapper.as_mql = expression_wrapper + IndexExpression.as_mql = index_expression NegatedExpression.as_mql = negated_expression OrderBy.as_mql = order_by Query.as_mql = query diff --git a/django_mongodb_backend/features.py b/django_mongodb_backend/features.py index 7e29c5003..47c93915d 100644 --- a/django_mongodb_backend/features.py +++ b/django_mongodb_backend/features.py @@ -102,6 +102,9 @@ class DatabaseFeatures(GISFeatures, BaseDatabaseFeatures): "model_fields.test_jsonfield.TestSaveLoad.test_bulk_update_custom_get_prep_value", # To debug: https://github.com/mongodb/django-mongodb-backend/issues/362 "constraints.tests.UniqueConstraintTests.test_validate_case_when", + # Simple expression index are supported + "schema.tests.SchemaTests.test_func_unique_constraint_unsupported", + "schema.tests.SchemaTests.test_func_index_unsupported", } # $bitAnd, #bitOr, and $bitXor are new in MongoDB 6.3. _django_test_expected_failures_bitwise = { diff --git a/django_mongodb_backend/fields/embedded_model.py b/django_mongodb_backend/fields/embedded_model.py index 951632363..2f010dcff 100644 --- a/django_mongodb_backend/fields/embedded_model.py +++ b/django_mongodb_backend/fields/embedded_model.py @@ -203,6 +203,9 @@ def as_mql(self, compiler, connection, as_path=False): def output_field(self): return self.ref_field + def db_type(self, connection): + return self.output_field.db_type(connection) + class KeyTransformFactory: def __init__(self, key_name, ref_field): diff --git a/django_mongodb_backend/indexes.py b/django_mongodb_backend/indexes.py index 99c1ef5f3..065b17b11 100644 --- a/django_mongodb_backend/indexes.py +++ b/django_mongodb_backend/indexes.py @@ -4,6 +4,7 @@ from django.core.checks import Error, Warning from django.db import NotSupportedError from django.db.models import FloatField, Index, IntegerField +from django.db.models.expressions import OrderBy from django.db.models.lookups import BuiltinLookup from django.db.models.sql.query import Query from django.db.models.sql.where import AND, XOR, WhereNode @@ -46,10 +47,30 @@ def builtin_lookup_idx(self, compiler, connection): def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False, column_prefix=""): """Return a pymongo IndexModel for this Django Index.""" + filter_expression = defaultdict(dict) + expressions_fields = [] if self.contains_expressions: - return None + query = Query(model=model, alias_cols=False) + compiler = query.get_compiler(connection=schema_editor.connection) + for expression in self.expressions: + query = Query(model=model, alias_cols=False) + field_ = expression.resolve_expression(query) + column = field_.as_mql(compiler, schema_editor.connection, as_path=True) + db_type = ( + field_.expression.db_type(schema_editor.connection) + if isinstance(field_, OrderBy) + else field_.output_field.db_type(schema_editor.connection) + ) + if unique: + filter_expression[column].update({"$type": db_type}) + order = ( + DESCENDING + if isinstance(expression, OrderBy) and expression.descending + else ASCENDING + ) + expressions_fields.append((column, order)) + kwargs = {} - filter_expression = defaultdict(dict) if self.condition: filter_expression.update(self._get_condition_mql(model, schema_editor)) if unique: @@ -80,7 +101,7 @@ def get_pymongo_index_model(self, model, schema_editor, field=None, unique=False for field_name, order in self.fields_orders ] ) - return IndexModel(index_orders, name=self.name, **kwargs) + return IndexModel(expressions_fields + index_orders, name=self.name, **kwargs) def where_node_idx(self, compiler, connection): diff --git a/django_mongodb_backend/schema.py b/django_mongodb_backend/schema.py index 9472db962..27b2bd0c4 100644 --- a/django_mongodb_backend/schema.py +++ b/django_mongodb_backend/schema.py @@ -1,5 +1,6 @@ from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.models import Index, UniqueConstraint +from django.db.models.expressions import F, OrderBy from pymongo.operations import SearchIndexModel from django_mongodb_backend.indexes import SearchIndex @@ -345,6 +346,36 @@ def _remove_field_index(self, model, field, column_prefix=""): ) collection.drop_index(index_names[0]) + def _check_supported_expressions(self, expressions): + for expression in expressions: + expression = expression.expression if isinstance(expression, OrderBy) else expression + if not isinstance(expression, F): + return False + return True + + def _unique_supported( + self, + condition=None, + deferrable=None, + include=None, + expressions=None, + nulls_distinct=None, + ): + return ( + (not condition or self.connection.features.supports_partial_indexes) + and (not deferrable or self.connection.features.supports_deferrable_unique_constraints) + and (not include or self.connection.features.supports_covering_indexes) + and ( + not expressions + or self._check_supported_expressions(expressions) + or self.connection.features.supports_expression_indexes + ) + and ( + nulls_distinct is None + or self.connection.features.supports_nulls_distinct_unique_constraints + ) + ) + @ignore_embedded_models def add_constraint(self, model, constraint, field=None, column_prefix="", parent_model=None): if isinstance(constraint, UniqueConstraint) and self._unique_supported( @@ -355,6 +386,7 @@ def add_constraint(self, model, constraint, field=None, column_prefix="", parent nulls_distinct=constraint.nulls_distinct, ): idx = Index( + *constraint.expressions, fields=constraint.fields, name=constraint.name, condition=constraint.condition, @@ -385,6 +417,7 @@ def remove_constraint(self, model, constraint): nulls_distinct=constraint.nulls_distinct, ): idx = Index( + *constraint.expressions, fields=constraint.fields, name=constraint.name, condition=constraint.condition, diff --git a/tests/schema_/test_embedded_model.py b/tests/schema_/test_embedded_model.py index c6c926031..5497b5bef 100644 --- a/tests/schema_/test_embedded_model.py +++ b/tests/schema_/test_embedded_model.py @@ -1,6 +1,7 @@ import itertools from django.db import connection, models +from django.db.models.expressions import F from django.test import TransactionTestCase, skipUnlessDBFeature from django.test.utils import isolate_apps @@ -519,6 +520,167 @@ class Meta: self.assertTableNotExists(Author) +class EmbeddedModelsTopLevelIndexTest(TestMixin, TransactionTestCase): + @isolate_apps("schema_") + def test_unique_together(self): + """Meta.unique_together defined at the top-level for embedded fields.""" + + class Address(EmbeddedModel): + unique_together_one = models.CharField(max_length=10) + unique_together_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Author(EmbeddedModel): + address = EmbeddedModelField(Address) + unique_together_three = models.CharField(max_length=10) + unique_together_four = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint( + F("author__unique_together_three").asc(), + F("author__unique_together_four").desc(), + name="unique_together_34", + ), + ( + models.UniqueConstraint( + F("author__address__unique_together_one"), + F("author__address__unique_together_two").asc(), + name="unique_together_12", + ) + ), + ] + + with connection.schema_editor() as editor: + editor.create_model(Book) + self.assertTableExists(Book) + # Embedded uniques are created from top-level definition. + self.assertEqual( + self.get_constraints_for_columns( + Book, ["author.unique_together_three", "author.unique_together_four"] + ), + ["unique_together_34"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_together_one", "author.address.unique_together_two"], + ), + ["unique_together_12"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_indexes(self): + """AddField/RemoveField + EmbeddedModelField + Meta.indexes at top-level.""" + + class Address(EmbeddedModel): + indexed_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Author(EmbeddedModel): + address = EmbeddedModelField(Address) + indexed_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + indexes = [ + models.Index(F("author__indexed_two").asc(), name="indexed_two"), + models.Index(F("author__address__indexed_one").asc(), name="indexed_one"), + ] + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded indexes are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.indexed_two"]), + ["indexed_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.indexed_one"], + ), + ["indexed_one"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + @isolate_apps("schema_") + def test_add_remove_field_constraints(self): + """AddField/RemoveField + EmbeddedModelField + Meta.constraints at top-level.""" + + class Address(EmbeddedModel): + unique_constraint_one = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Author(EmbeddedModel): + address = EmbeddedModelField(Address) + unique_constraint_two = models.CharField(max_length=10) + + class Meta: + app_label = "schema_" + + class Book(models.Model): + author = EmbeddedModelField(Author) + + class Meta: + app_label = "schema_" + constraints = [ + models.UniqueConstraint(F("author__unique_constraint_two"), name="unique_two"), + models.UniqueConstraint( + F("author__address__unique_constraint_one"), name="unique_one" + ), + ] + + new_field = EmbeddedModelField(Author) + new_field.set_attributes_from_name("author") + + with connection.schema_editor() as editor: + # Create the table and add the field. + editor.create_model(Book) + editor.add_field(Book, new_field) + # Embedded constraints are created. + self.assertEqual( + self.get_constraints_for_columns(Book, ["author.unique_constraint_two"]), + ["unique_two"], + ) + self.assertEqual( + self.get_constraints_for_columns( + Book, + ["author.address.unique_constraint_one"], + ), + ["unique_one"], + ) + editor.delete_model(Book) + self.assertTableNotExists(Book) + + class EmbeddedModelsIgnoredTests(TestMixin, TransactionTestCase): def test_embedded_not_created(self): """create_model() and delete_model() ignore EmbeddedModel."""