diff --git a/mypy/checker.py b/mypy/checker.py index f90fc4be41f4..fc56d446f6d2 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -1603,6 +1603,34 @@ def check_func_def( self.binder = old_binder + # Restore the original generic signature for functions with constrained TypeVars. + # When TypeVars have value restrictions (e.g., TypeVar("T", str, bytes)), we expand + # them into concrete variants for checking. However, for decorator processing, we need + # to preserve the original polymorphic signature so that decorators can properly infer + # the constrained TypeVar instead of collapsing to the first variant. + if original_typ.variables and len(expanded) > 1: + has_constrained_tvars = any( + isinstance(tv, TypeVarType) and tv.values for tv in original_typ.variables + ) + if has_constrained_tvars: + restored_typ = original_typ + # If @types.coroutine or @asyncio.coroutine was applied, we need to preserve + # the AwaitableGenerator wrapper while also restoring the polymorphic signature. + if defn.is_awaitable_coroutine: + t = original_typ.ret_type + c = defn.is_coroutine + ty = self.get_generator_yield_type(t, c) + tc = self.get_generator_receive_type(t, c) + if c: + tr = self.get_coroutine_return_type(t) + else: + tr = self.get_generator_return_type(t, c) + ret_type = self.named_generic_type( + "typing.AwaitableGenerator", [ty, tc, tr, t] + ) + restored_typ = original_typ.copy_modified(ret_type=ret_type) + defn.type = restored_typ + def require_correct_self_argument(self, func: Type, defn: FuncDef) -> bool: func = get_proper_type(func) if not isinstance(func, CallableType): @@ -5477,6 +5505,26 @@ def visit_decorator(self, e: Decorator) -> None: def visit_decorator_inner( self, e: Decorator, allow_empty: bool = False, skip_first_item: bool = False ) -> None: + def build_typevar_map(dec: CallableType, sig: CallableType) -> dict[TypeVarId, Type]: + """Build TypeVar substitution map for matching constrained TypeVars.""" + result: dict[TypeVarId, Type] = {} + used_sig_tvs: set[TypeVarId] = set() + for dec_tv in dec.variables: + if not (isinstance(dec_tv, TypeVarType) and dec_tv.values): + continue + dec_constraints = frozenset(get_proper_type(v) for v in dec_tv.values) + for sig_tv in sig.variables: + if sig_tv.id in used_sig_tvs: + continue + if not (isinstance(sig_tv, TypeVarType) and sig_tv.values): + continue + sig_constraints = frozenset(get_proper_type(v) for v in sig_tv.values) + if dec_constraints == sig_constraints: + result[sig_tv.id] = dec_tv + used_sig_tvs.add(sig_tv.id) + break + return result + if self.recurse_into_functions: with self.tscope.function_scope(e.func): self.check_func_item(e.func, name=e.func.name, allow_empty=allow_empty) @@ -5508,6 +5556,18 @@ def visit_decorator_inner( object_type = self.lookup_type(d.expr) fullname = self.expr_checker.method_fullname(object_type, d.name) self.check_for_untyped_decorator(e.func, dec, d) + + # Before checking compatibility, unify matching constrained TypeVars + # between the decorator and decorated function to avoid spurious errors + adjusted_sig = sig + dec_proper = get_proper_type(dec) + sig_proper = get_proper_type(sig) + if isinstance(dec_proper, CallableType) and isinstance(sig_proper, CallableType): + typevar_map = build_typevar_map(dec_proper, sig_proper) + if typevar_map: + adjusted_sig = expand_type(sig, typevar_map) + + temp = self.temp_node(adjusted_sig, context=d) sig, t2 = self.expr_checker.check_call( dec, [temp], [nodes.ARG_POS], e, callable_name=fullname, object_type=object_type ) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 9990caaeb7a1..11398a0d548a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1461,6 +1461,82 @@ def handle_decorator_overload_call( return None return Overloaded(result), Overloaded(inferred_args) + def is_overload_decorator_constrained_call( + self, callee: Overloaded, args: list[Expression] + ) -> CallableType | None: + """Check if an overloaded callable is applied to a constrained TypeVar argument. + + Returns the constrained CallableType if detected, None otherwise. + This handles the reverse case of is_generic_decorator_overload_call: + instead of a generic decorator applied to an overloaded function, + this detects an overloaded decorator applied to a function with constrained TypeVars. + """ + if len(args) != 1: + return None # Only handle single-argument case (decorator pattern) + + # Check if overload items look like decorator signatures (Callable -> Callable) + for item in callee.items: + if len(item.arg_types) != 1: + return None + if not isinstance(get_proper_type(item.arg_types[0]), CallableType): + return None + if not isinstance(get_proper_type(item.ret_type), CallableType): + return None + + # Get the argument type + with self.chk.local_type_map: + with self.msg.filter_errors(): + arg_type = get_proper_type(self.accept(args[0], type_context=None)) + + if isinstance(arg_type, CallableType) and self.has_constrained_typevars(arg_type): + return arg_type + return None + + def has_constrained_typevars(self, t: Type) -> bool: + """Check if a type is a CallableType with constrained TypeVars.""" + t = get_proper_type(t) + if not isinstance(t, CallableType): + return False + return any(isinstance(tv, TypeVarType) and tv.values for tv in t.variables) + + def handle_overload_decorator_constrained_call( + self, callee: Overloaded, constrained_arg: CallableType, ctx: Context + ) -> tuple[Type, Type] | None: + """Handle application of overloaded decorator to constrained TypeVar function. + + Expands the constrained function to its variants, matches each against + the overload, and combines results into an Overloaded type. + """ + # Expand the constrained callable into its variants + variants = mypy.checker.expand_callable_variants(constrained_arg) + if len(variants) <= 1: + return None # No expansion needed, fall back to normal handling + + results: list[CallableType] = [] + inferred_args: list[CallableType] = [] + + for variant in variants: + arg = TempNode(typ=variant) + # Try matching this variant against the overloaded decorator + with self.msg.filter_errors() as err: + result, _ = self.check_call(callee, [arg], [ARG_POS], ctx) + if err.has_new_errors(): + # This variant doesn't match any overload + continue + p_result = get_proper_type(result) + if isinstance(p_result, CallableType): + results.append(p_result) + inferred_args.append(variant) + + if not results: + return None + + if len(results) == 1: + return results[0], inferred_args[0] + + # Combine results into an Overloaded type + return Overloaded(results), Overloaded(inferred_args) + def check_call_expr_with_callee_type( self, callee_type: Type, @@ -1596,6 +1672,14 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): + # Check for constrained TypeVar argument (reverse of decorator overload call) + constrained_arg = self.is_overload_decorator_constrained_call(callee, args) + if constrained_arg is not None: + result = self.handle_overload_decorator_constrained_call( + callee, constrained_arg, context + ) + if result is not None: + return result return self.check_overload_call( callee, args, arg_kinds, arg_names, callable_name, object_type, context ) @@ -2082,6 +2166,25 @@ def infer_function_type_arguments( Return a derived callable type that has the arguments applied. """ + + def find_matching_typevar( + arg_tv: TypeVarType, callee_vars: Sequence[TypeVarLikeType], used: set[int] + ) -> int | None: + """Find index of matching constrained TypeVar in callee, or None.""" + arg_constraints = frozenset(get_proper_type(v) for v in arg_tv.values) + for idx, cv in enumerate(callee_vars): + if idx in used: + continue + if not isinstance(cv, TypeVarType) or not cv.values: + continue + if frozenset(get_proper_type(v) for v in cv.values) == arg_constraints: + return idx + return None + + # Save the original callee_type variables before inference modifies them + original_callee_variables = list(callee_type.variables) + constrained_typevar_map: dict[int, TypeVarType] = {} + if self.chk.in_checked_function(): # Disable type errors during type inference. There may be errors # due to partial available context information at this time, but @@ -2113,6 +2216,27 @@ def infer_function_type_arguments( strict=self.chk.in_checked_function(), ) + # Preserve constrained TypeVars when passing functions with constrained TypeVars + # to decorators. When the argument has constrained TypeVars, save the decorator's + # corresponding TypeVar (from original_callee_variables) to apply at the end. + used_callee_var_indices: set[int] = set() + for arg_type in arg_types: + arg_type_proper = get_proper_type(arg_type) + if not isinstance(arg_type_proper, CallableType): + continue + for arg_typevar in arg_type_proper.variables: + if not (isinstance(arg_typevar, TypeVarType) and arg_typevar.values): + continue + # Found a constrained TypeVar in the argument - find matching decorator TypeVar + match_idx = find_matching_typevar( + arg_typevar, original_callee_variables, used_callee_var_indices + ) + if match_idx is not None: + callee_var = original_callee_variables[match_idx] + assert isinstance(callee_var, TypeVarType) + constrained_typevar_map[match_idx] = callee_var + used_callee_var_indices.add(match_idx) + if 2 in arg_pass_nums: # Second pass of type inference. (callee_type, inferred_args) = self.infer_function_type_arguments_pass2( @@ -2203,8 +2327,91 @@ def infer_function_type_arguments( # In dynamically typed functions use implicit 'Any' types for # type variables. inferred_args = [AnyType(TypeOfAny.unannotated)] * len(callee_type.variables) + + # Apply constrained TypeVars saved earlier + for callee_var_idx, constrained_tv in constrained_typevar_map.items(): + if callee_var_idx >= len(inferred_args): + continue + inferred = inferred_args[callee_var_idx] + # Replace if inference failed + if inferred is None or isinstance(get_proper_type(inferred), UninhabitedType): + inferred_args[callee_var_idx] = constrained_tv + continue + # Replace if collapsed to a single constraint + if not constrained_tv.values: + continue + inferred_proper = get_proper_type(inferred) + if any(inferred_proper == get_proper_type(c) for c in constrained_tv.values): + inferred_args[callee_var_idx] = constrained_tv + + # If we have constrained TypeVars, try polymorphic application + has_constrained = any(isinstance(a, TypeVarType) and a.values for a in inferred_args if a) + if constrained_typevar_map and has_constrained: + free_tvars = [a for a in inferred_args if isinstance(a, TypeVarType)] + poly_callee = self.apply_generic_arguments(callee_type, inferred_args, context) + poly_result = applytype.apply_poly(poly_callee, free_tvars) + if poly_result is not None: + poly_result = self._finalize_poly_result( + poly_result, arg_types, free_tvars, constrained_typevar_map + ) + freeze_all_type_vars(poly_result) + return poly_result + return self.apply_inferred_arguments(callee_type, inferred_args, context) + def _finalize_poly_result( + self, + poly_result: CallableType, + arg_types: list[Type], + free_tvars: list[TypeVarType], + constrained_typevar_map: dict[int, TypeVarType], + ) -> CallableType: + """Finalize polymorphic result by preserving arg names and substituting TypeVars.""" + ret_type = poly_result.ret_type + if not isinstance(ret_type, CallableType): + return poly_result + + first_arg_type = get_proper_type(arg_types[0]) if arg_types else None + if not isinstance(first_arg_type, CallableType): + return poly_result + + # Copy arg_names if arity matches + if first_arg_type.arg_names and len(ret_type.arg_types) == len(first_arg_type.arg_types): + ret_type = ret_type.copy_modified(arg_names=first_arg_type.arg_names) + + # Substitute decorated function's TypeVars with decorator's TypeVars + if first_arg_type.variables: + variables_map = self._build_typevar_substitution_map( + first_arg_type, free_tvars, constrained_typevar_map + ) + if variables_map: + ret_type = expand_type(ret_type, variables_map) + + return poly_result.copy_modified(ret_type=ret_type) + + def _build_typevar_substitution_map( + self, + first_arg_type: CallableType, + free_tvars: list[TypeVarType], + constrained_typevar_map: dict[int, TypeVarType], + ) -> dict[TypeVarId, Type]: + """Build substitution map from decorated function's TypeVars to decorator's TypeVars.""" + variables_map: dict[TypeVarId, Type] = {} + for callee_idx in constrained_typevar_map: + if callee_idx >= len(free_tvars): + continue + decorator_typevar = free_tvars[callee_idx] + if not decorator_typevar.values: + continue + decorator_constraints = frozenset(get_proper_type(v) for v in decorator_typevar.values) + for func_tv in first_arg_type.variables: + if not (isinstance(func_tv, TypeVarType) and func_tv.values): + continue + func_constraints = frozenset(get_proper_type(v) for v in func_tv.values) + if func_constraints == decorator_constraints: + variables_map[func_tv.id] = decorator_typevar + return variables_map + def infer_function_type_arguments_pass2( self, callee_type: CallableType, diff --git a/test-data/unit/check-async-await.test b/test-data/unit/check-async-await.test index e4bd4568f8c8..6b762a8437b9 100644 --- a/test-data/unit/check-async-await.test +++ b/test-data/unit/check-async-await.test @@ -1065,6 +1065,26 @@ reveal_type(coro) # N: Revealed type is "def () -> typing.AwaitableGenerator[bu [builtins fixtures/async_await.pyi] [typing fixtures/typing-async.pyi] +[case testCoroutineDecoratorWithConstrainedTypeVar] +# Test that @types.coroutine preserves both the AwaitableGenerator wrapper +# and the constrained TypeVar polymorphic signature. +from typing import Generator, TypeVar +from types import coroutine + +T = TypeVar("T", int, str) + +@coroutine +def process(data: T) -> Generator[T, None, T]: + yield data + return data + +# Expected: AwaitableGenerator with polymorphic T preserved +reveal_type(process) # N: Revealed type is "def [T in (builtins.int, builtins.str)] (data: T`-1) -> typing.AwaitableGenerator[T`-1, None, T`-1, typing.Generator[T`-1, None, T`-1]]" +reveal_type(process(1)) # N: Revealed type is "typing.AwaitableGenerator[builtins.int, None, builtins.int, typing.Generator[builtins.int, None, builtins.int]]" +reveal_type(process("hello")) # N: Revealed type is "typing.AwaitableGenerator[builtins.str, None, builtins.str, typing.Generator[builtins.str, None, builtins.str]]" +[builtins fixtures/async_await.pyi] +[typing fixtures/typing-async.pyi] + [case asyncIteratorInProtocol] from typing import AsyncIterator, Protocol diff --git a/test-data/unit/check-typevar-values.test b/test-data/unit/check-typevar-values.test index 0e8b3f74f2a7..6c135c4b3231 100644 --- a/test-data/unit/check-typevar-values.test +++ b/test-data/unit/check-typevar-values.test @@ -744,3 +744,97 @@ def fn(w: W) -> W: reveal_type(w) # N: Revealed type is "builtins.int" return w [builtins fixtures/isinstance.pyi] + +[case testDecoratorWithConstrainedTypeVarPreservesVariants] +# Test that decorators with constrained TypeVars preserve the polymorphic signature +# when applied using @decorator syntax. The decorator should maintain all constraint +# values rather than collapsing to a single concrete type. +from typing import TypeVar, Callable + +T = TypeVar("T", str, bytes) + +def decorator(fn: Callable[[T], T]) -> Callable[[T], T]: + def wrapped(data: T) -> T: + return fn(data) + return wrapped + +@decorator +def process(data: T) -> T: + return data + +# Expected: Polymorphic signature with both str and bytes constraints preserved +reveal_type(process) # N: Revealed type is "def [T in (builtins.str, builtins.bytes)] (data: T`-1) -> T`-1" +reveal_type(process("hello")) # N: Revealed type is "builtins.str" +reveal_type(process(b"world")) # N: Revealed type is "builtins.bytes" +process(123) # E: Value of type variable "T" of "process" cannot be "int" +[builtins fixtures/tuple.pyi] + +[case testDecoratorWithConstrainedTypeVarMultipleParams] +# Test with multiple constrained TypeVars - both should be preserved +from typing import TypeVar, Callable + +T = TypeVar("T", int, str) +U = TypeVar("U", float, bool) + +def decorator(fn: Callable[[T, U], T]) -> Callable[[T, U], T]: + return fn + +@decorator +def combine(x: T, y: U) -> T: + return x + +# Expected: Polymorphic signature with all constraints preserved +reveal_type(combine) # N: Revealed type is "def [T in (builtins.int, builtins.str), U in (builtins.float, builtins.bool)] (x: T`-1, y: U`-2) -> T`-1" +reveal_type(combine(1, 1.5)) # N: Revealed type is "builtins.int" +reveal_type(combine("a", True)) # N: Revealed type is "builtins.str" +combine(1, "bad") # E: Value of type variable "U" of "combine" cannot be "str" +[builtins fixtures/tuple.pyi] + +[case testTwoTypeVarsWithSameConstraintsNotMixedUp] +# Test that two TypeVars with identical constraints don't get mixed up. +# T and U both have (str, bytes) constraints but should remain distinct. +from typing import TypeVar, Callable + +T = TypeVar("T", str, bytes) +U = TypeVar("U", str, bytes) + +def decorator(fn: Callable[[T, U], tuple[T, U]]) -> Callable[[T, U], tuple[T, U]]: + return fn + +@decorator +def combine(x: T, y: U) -> tuple[T, U]: + return (x, y) + +# Expected: Both T and U are preserved as separate TypeVars +reveal_type(combine) # N: Revealed type is "def [T in (builtins.str, builtins.bytes), U in (builtins.str, builtins.bytes)] (x: T`-1, y: U`-2) -> tuple[T`-1, U`-2]" +# Both arguments can be different types +reveal_type(combine("hello", b"world")) # N: Revealed type is "tuple[builtins.str, builtins.bytes]" +reveal_type(combine(b"hello", "world")) # N: Revealed type is "tuple[builtins.bytes, builtins.str]" +# Same type for both is also valid +reveal_type(combine("a", "b")) # N: Revealed type is "tuple[builtins.str, builtins.str]" +[builtins fixtures/tuple.pyi] + +[case testOverloadedDecoratorWithConstrainedTypeVar] +# Test that overloaded decorators work correctly with constrained TypeVar functions. +# The decorator is overloaded for str and bytes, and the function has a constrained TypeVar. +from typing import TypeVar, Callable, overload + +T = TypeVar("T", str, bytes) + +@overload +def decorator(fn: Callable[[str], str]) -> Callable[[str], str]: ... +@overload +def decorator(fn: Callable[[bytes], bytes]) -> Callable[[bytes], bytes]: ... +def decorator(fn: Callable[[T], T]) -> Callable[[T], T]: + return fn + +@decorator +def process(data: T) -> T: + return data + +# Expected: Overloaded type with both str and bytes variants +reveal_type(process) # N: Revealed type is "Overload(def (builtins.str) -> builtins.str, def (builtins.bytes) -> builtins.bytes)" +reveal_type(process("hello")) # N: Revealed type is "builtins.str" +reveal_type(process(b"world")) # N: Revealed type is "builtins.bytes" +[builtins fixtures/tuple.pyi] +