diff --git a/mypy_django_plugin/lib/helpers.py b/mypy_django_plugin/lib/helpers.py index 07103162a..a89838f90 100644 --- a/mypy_django_plugin/lib/helpers.py +++ b/mypy_django_plugin/lib/helpers.py @@ -549,6 +549,9 @@ def resolve_string_attribute_value(attr_expr: Expression, django_context: "Djang if isinstance(attr_expr, StrExpr): return attr_expr.value + if isinstance(attr_expr, NameExpr) and isinstance(attr_expr.node, Var) and attr_expr.node.type is not None: + return get_literal_str_type(attr_expr.node.type) + # support extracting from settings, in general case it's unresolvable yet if isinstance(attr_expr, MemberExpr): member_name = attr_expr.name diff --git a/mypy_django_plugin/main.py b/mypy_django_plugin/main.py index 3d9ca48f1..b65efdfeb 100644 --- a/mypy_django_plugin/main.py +++ b/mypy_django_plugin/main.py @@ -169,6 +169,12 @@ def manager_and_queryset_method_hooks(self) -> dict[str, Callable[[MethodContext querysets.extract_prefetch_related_annotations, django_context=self.django_context ), "select_related": partial(querysets.validate_select_related, django_context=self.django_context), + "bulk_update": partial( + querysets.validate_bulk_update, django_context=self.django_context, method="bulk_update" + ), + "abulk_update": partial( + querysets.validate_bulk_update, django_context=self.django_context, method="abulk_update" + ), } def get_method_hook(self, fullname: str) -> Callable[[MethodContext], MypyType] | None: diff --git a/mypy_django_plugin/transformers/querysets.py b/mypy_django_plugin/transformers/querysets.py index 49771f242..826631b63 100644 --- a/mypy_django_plugin/transformers/querysets.py +++ b/mypy_django_plugin/transformers/querysets.py @@ -1,6 +1,7 @@ from collections.abc import Sequence +from typing import Literal -from django.core.exceptions import FieldError +from django.core.exceptions import FieldDoesNotExist, FieldError from django.db.models.base import Model from django.db.models.fields.related import RelatedField from django.db.models.fields.related_descriptors import ( @@ -12,7 +13,7 @@ from django.db.models.fields.reverse_related import ForeignObjectRel from mypy.checker import TypeChecker from mypy.errorcodes import NO_REDEF -from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression +from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, ARG_STAR, CallExpr, Expression, ListExpr, SetExpr, TupleExpr from mypy.plugin import FunctionContext, MethodContext from mypy.types import AnyType, Instance, LiteralType, ProperType, TupleType, TypedDictType, TypeOfAny, get_proper_type from mypy.types import Type as MypyType @@ -695,3 +696,58 @@ def validate_select_related(ctx: MethodContext, django_context: DjangoContext) - _validate_select_related_lookup(ctx, django_context, django_model.cls, lookup_value) return ctx.default_return_type + + +def _validate_bulk_update_field( + ctx: MethodContext, model_cls: type[Model], field_name: str, method: Literal["bulk_update", "abulk_update"] +) -> bool: + opts = model_cls._meta + try: + field = opts.get_field(field_name) + except FieldDoesNotExist as e: + ctx.api.fail(str(e), ctx.context) + return False + + if not field.concrete or field.many_to_many: + ctx.api.fail(f'"{method}()" can only be used with concrete fields. Got "{field_name}"', ctx.context) + return False + + all_pk_fields = set(opts.pk_fields) + for parent in opts.all_parents: + all_pk_fields.update(parent._meta.pk_fields) + + if field in all_pk_fields: + ctx.api.fail(f'"{method}()" cannot be used with primary key fields. Got "{field_name}"', ctx.context) + return False + + return True + + +def validate_bulk_update( + ctx: MethodContext, django_context: DjangoContext, method: Literal["bulk_update", "abulk_update"] +) -> MypyType: + """ + Type check the `fields` argument passed to `QuerySet.bulk_update(...)`. + + Extracted and adapted from `django.db.models.query.QuerySet.bulk_update` + Mirrors tests from `django/tests/queries/test_bulk_update.py` + """ + if not ( + isinstance(ctx.type, Instance) + and (django_model := helpers.get_model_info_from_qs_ctx(ctx, django_context)) is not None + and len(ctx.args) >= 2 + and ctx.args[1] + and isinstance((fields_args := ctx.args[1][0]), (ListExpr, TupleExpr, SetExpr)) + ): + return ctx.default_return_type + + if len(fields_args.items) == 0: + ctx.api.fail(f'Field names must be given to "{method}()"', ctx.context) + return ctx.default_return_type + + for field_arg in fields_args.items: + field_name = helpers.resolve_string_attribute_value(field_arg, django_context) + if field_name is not None: + _validate_bulk_update_field(ctx, django_model.cls, field_name, method) + + return ctx.default_return_type diff --git a/tests/typecheck/managers/querysets/test_bulk_update.yml b/tests/typecheck/managers/querysets/test_bulk_update.yml new file mode 100644 index 000000000..17b74bfdc --- /dev/null +++ b/tests/typecheck/managers/querysets/test_bulk_update.yml @@ -0,0 +1,224 @@ +- case: bulk_update_valid_fields + installed_apps: + - myapp + main: | + from myapp.models import Article, Author, Category + + # Valid single field updates + articles = Article.objects.all() + Article.objects.bulk_update(articles, ["title"]) + Article.objects.bulk_update(articles, ["content"]) + Article.objects.bulk_update(articles, ["published"]) + + # Valid multiple field updates + Article.objects.bulk_update(articles, ("title", "content")) + Article.objects.bulk_update(articles, {"title", "content", "published"}) + + # Valid foreign key field updates (by field name and attname) + Article.objects.bulk_update(articles, ["author"]) + Article.objects.bulk_update(articles, ["author_id"]) + Article.objects.bulk_update(articles, ["category"]) + Article.objects.bulk_update(articles, ["category_id"]) + + # Valid updates on different models + authors = Author.objects.all() + Author.objects.bulk_update(authors, ["name"]) + Author.objects.bulk_update(authors, ["email"]) + + categories = Category.objects.all() + Category.objects.bulk_update(categories, ["name"]) + Category.objects.bulk_update(categories, ["parent"]) + Category.objects.bulk_update(categories, ["parent_id"]) + + # Variables containing field names + field_name = "title" + Article.objects.bulk_update(articles, [field_name]) + + # Dynamic field lists + def get_fields() -> list[str]: + return ["title"] + + Article.objects.bulk_update(articles, get_fields()) + + async def test_async_bulk_update() -> None: + # Valid single field updates + articles = Article.objects.all() + await Article.objects.abulk_update(articles, {"published"}) + + # Valid multiple field updates + await Article.objects.abulk_update(articles, ("title", "content", "published")) + + # Valid foreign key field updates (by field name and attname) + await Article.objects.abulk_update(articles, ["category_id"]) + + # Valid updates on different models + authors = Author.objects.all() + await Author.objects.abulk_update(authors, ["email"]) + + categories = Category.objects.all() + await Category.objects.abulk_update(categories, ["parent_id"]) + + # Variables containing field names + field_name = "title" + await Article.objects.abulk_update(articles, [field_name]) + + # Dynamic field lists + def get_fields() -> list[str]: + return ["title"] + + await Article.objects.abulk_update(articles, get_fields()) + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + + class Category(models.Model): + name = models.CharField(max_length=100) + parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, blank=True) + + class Author(models.Model): + name = models.CharField(max_length=100) + email = models.EmailField() + + class Tag(models.Model): + name = models.CharField(max_length=50) + + class Article(models.Model): + title = models.CharField(max_length=200) + content = models.TextField() + published = models.BooleanField(default=False) + author = models.ForeignKey(Author, on_delete=models.CASCADE) + category = models.ForeignKey(Category, on_delete=models.CASCADE) + tags = models.ManyToManyField(Tag) + + + +- case: bulk_update_invalid_fields + installed_apps: + - myapp + main: | + from myapp.models import Article, Author, Category + from typing import Literal + + articles = Article.objects.all() + + # Empty fields list + Article.objects.bulk_update() # E: Missing positional arguments "objs", "fields" in call to "bulk_update" of "Manager" [call-arg] + Article.objects.bulk_update(articles, []) # E: Field names must be given to "bulk_update()" [misc] + + # Invalid field names (Django's FieldError) + Article.objects.bulk_update(articles, ["nonexistent"]) # E: Article has no field named 'nonexistent' [misc] + Article.objects.bulk_update(articles, ["invalid_field"]) # E: Article has no field named 'invalid_field' [misc] + + # Cannot update primary key fields + Article.objects.bulk_update(articles, ["id"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Mixed valid and invalid fields + Article.objects.bulk_update(articles, {"title", "nonexistent"}) # E: Article has no field named 'nonexistent' [misc] + Article.objects.bulk_update(articles, ("id", "title")) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Whitespace-only field names + Article.objects.bulk_update(articles, [""]) # E: Article has no field named '' [misc] + Article.objects.bulk_update(articles, [" "]) # E: Article has no field named ' ' [misc] + + # ManyToMany is not a concrete updatable field + Article.objects.bulk_update(articles, {"tags"}) # E: "bulk_update()" can only be used with concrete fields. Got "tags" [misc] + + # Nested field lookups are not supported + Article.objects.bulk_update(articles, ["author__name"]) # E: Article has no field named 'author__name' [misc] + Article.objects.bulk_update(articles, ["category__parent__name"]) # E: Article has no field named 'category__parent__name' [misc] + + # Multiple invalid fields + Article.objects.bulk_update(articles, ["nonexistent1", "nonexistent2"]) # E: Article has no field named 'nonexistent1' [misc] # E: Article has no field named 'nonexistent2' [misc] + + # Primary key with valid fields + Article.objects.bulk_update(articles, ["title", "id", "content"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Literal type variables are validated + invalid_field: Literal["nonexistent"] = "nonexistent" + Article.objects.bulk_update(articles, [invalid_field]) # E: Article has no field named 'nonexistent' [misc] + + pk_field: Literal["id"] = "id" + Article.objects.bulk_update(articles, [pk_field]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Test with different models + authors = Author.objects.all() + Author.objects.bulk_update(authors, ["id"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc] + Author.objects.bulk_update(authors, ["invalid"]) # E: Author has no field named 'invalid' [misc] + + categories = Category.objects.all() + Category.objects.bulk_update(categories, ["id"]) # E: "bulk_update()" cannot be used with primary key fields. Got "id" [misc] + Category.objects.bulk_update(categories, ["invalid"]) # E: Category has no field named 'invalid' [misc] + + # Async version + async def test_async_bulk_update_invalid() -> None: + articles = Article.objects.all() + + # Empty fields list + await Article.objects.abulk_update() # E: Missing positional arguments "objs", "fields" in call to "abulk_update" of "Manager" [call-arg] + await Article.objects.abulk_update(articles, []) # E: Field names must be given to "abulk_update()" [misc] + + # Invalid field names (Django's FieldError) + await Article.objects.abulk_update(articles, ["invalid_field"]) # E: Article has no field named 'invalid_field' [misc] + + # Cannot update primary key fields + await Article.objects.abulk_update(articles, ["id"]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Mixed valid and invalid fields + await Article.objects.abulk_update(articles, ["id", "title"]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Whitespace-only field names + await Article.objects.abulk_update(articles, [" "]) # E: Article has no field named ' ' [misc] + + # ManyToMany is not a concrete updatable field + await Article.objects.abulk_update(articles, ["tags"]) # E: "abulk_update()" can only be used with concrete fields. Got "tags" [misc] + + # Nested field lookups are not supported + await Article.objects.abulk_update(articles, ["author__name"]) # E: Article has no field named 'author__name' [misc] + await Article.objects.abulk_update(articles, ["category__parent__name"]) # E: Article has no field named 'category__parent__name' [misc] + + # Multiple invalid fields + await Article.objects.abulk_update(articles, ("nonexistent1", "nonexistent2")) # E: Article has no field named 'nonexistent1' [misc] # E: Article has no field named 'nonexistent2' [misc] + + # Primary key with valid fields + await Article.objects.abulk_update(articles, ["title", "id", "content"]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Literal type variables are validated + invalid_field: Literal["nonexistent"] = "nonexistent" + await Article.objects.abulk_update(articles, {invalid_field}) # E: Article has no field named 'nonexistent' [misc] + + pk_field: Literal["id"] = "id" + await Article.objects.abulk_update(articles, [pk_field]) # E: "abulk_update()" cannot be used with primary key fields. Got "id" [misc] + + # Test with different models + authors = Author.objects.all() + await Author.objects.abulk_update(authors, ["invalid"]) # E: Author has no field named 'invalid' [misc] + + categories = Category.objects.all() + await Category.objects.abulk_update(categories, ["invalid"]) # E: Category has no field named 'invalid' [misc] + + files: + - path: myapp/__init__.py + - path: myapp/models.py + content: | + from django.db import models + + class Category(models.Model): + name = models.CharField(max_length=100) + parent = models.ForeignKey('self', on_delete=models.CASCADE, null=True, blank=True) + + class Author(models.Model): + name = models.CharField(max_length=100) + email = models.EmailField() + + class Article(models.Model): + title = models.CharField(max_length=200) + content = models.TextField() + published = models.BooleanField(default=False) + author = models.ForeignKey(Author, on_delete=models.CASCADE) + category = models.ForeignKey(Category, on_delete=models.CASCADE) + tags = models.ManyToManyField('myapp.Tag') + + class Tag(models.Model): + name = models.CharField(max_length=50)