Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 60 additions & 0 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1603,6 +1603,34 @@ def check_func_def(

self.binder = old_binder

# Restore the original generic signature for functions with constrained TypeVars.
# When TypeVars have value restrictions (e.g., TypeVar("T", str, bytes)), we expand
# them into concrete variants for checking. However, for decorator processing, we need
# to preserve the original polymorphic signature so that decorators can properly infer
# the constrained TypeVar instead of collapsing to the first variant.
if original_typ.variables and len(expanded) > 1:
has_constrained_tvars = any(
isinstance(tv, TypeVarType) and tv.values for tv in original_typ.variables
)
if has_constrained_tvars:
restored_typ = original_typ
# If @types.coroutine or @asyncio.coroutine was applied, we need to preserve
# the AwaitableGenerator wrapper while also restoring the polymorphic signature.
if defn.is_awaitable_coroutine:
t = original_typ.ret_type
c = defn.is_coroutine
ty = self.get_generator_yield_type(t, c)
tc = self.get_generator_receive_type(t, c)
if c:
tr = self.get_coroutine_return_type(t)
else:
tr = self.get_generator_return_type(t, c)
ret_type = self.named_generic_type(
"typing.AwaitableGenerator", [ty, tc, tr, t]
)
restored_typ = original_typ.copy_modified(ret_type=ret_type)
defn.type = restored_typ

def require_correct_self_argument(self, func: Type, defn: FuncDef) -> bool:
func = get_proper_type(func)
if not isinstance(func, CallableType):
Expand Down Expand Up @@ -5477,6 +5505,26 @@ def visit_decorator(self, e: Decorator) -> None:
def visit_decorator_inner(
self, e: Decorator, allow_empty: bool = False, skip_first_item: bool = False
) -> None:
def build_typevar_map(dec: CallableType, sig: CallableType) -> dict[TypeVarId, Type]:
"""Build TypeVar substitution map for matching constrained TypeVars."""
result: dict[TypeVarId, Type] = {}
used_sig_tvs: set[TypeVarId] = set()
for dec_tv in dec.variables:
if not (isinstance(dec_tv, TypeVarType) and dec_tv.values):
continue
dec_constraints = frozenset(get_proper_type(v) for v in dec_tv.values)
for sig_tv in sig.variables:
if sig_tv.id in used_sig_tvs:
continue
if not (isinstance(sig_tv, TypeVarType) and sig_tv.values):
continue
sig_constraints = frozenset(get_proper_type(v) for v in sig_tv.values)
if dec_constraints == sig_constraints:
result[sig_tv.id] = dec_tv
used_sig_tvs.add(sig_tv.id)
break
return result

if self.recurse_into_functions:
with self.tscope.function_scope(e.func):
self.check_func_item(e.func, name=e.func.name, allow_empty=allow_empty)
Expand Down Expand Up @@ -5508,6 +5556,18 @@ def visit_decorator_inner(
object_type = self.lookup_type(d.expr)
fullname = self.expr_checker.method_fullname(object_type, d.name)
self.check_for_untyped_decorator(e.func, dec, d)

# Before checking compatibility, unify matching constrained TypeVars
# between the decorator and decorated function to avoid spurious errors
adjusted_sig = sig
dec_proper = get_proper_type(dec)
sig_proper = get_proper_type(sig)
if isinstance(dec_proper, CallableType) and isinstance(sig_proper, CallableType):
typevar_map = build_typevar_map(dec_proper, sig_proper)
if typevar_map:
adjusted_sig = expand_type(sig, typevar_map)

temp = self.temp_node(adjusted_sig, context=d)
sig, t2 = self.expr_checker.check_call(
dec, [temp], [nodes.ARG_POS], e, callable_name=fullname, object_type=object_type
)
Expand Down
207 changes: 207 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,6 +1461,82 @@ def handle_decorator_overload_call(
return None
return Overloaded(result), Overloaded(inferred_args)

def is_overload_decorator_constrained_call(
self, callee: Overloaded, args: list[Expression]
) -> CallableType | None:
"""Check if an overloaded callable is applied to a constrained TypeVar argument.

Returns the constrained CallableType if detected, None otherwise.
This handles the reverse case of is_generic_decorator_overload_call:
instead of a generic decorator applied to an overloaded function,
this detects an overloaded decorator applied to a function with constrained TypeVars.
"""
if len(args) != 1:
return None # Only handle single-argument case (decorator pattern)

# Check if overload items look like decorator signatures (Callable -> Callable)
for item in callee.items:
if len(item.arg_types) != 1:
return None
if not isinstance(get_proper_type(item.arg_types[0]), CallableType):
return None
if not isinstance(get_proper_type(item.ret_type), CallableType):
return None

# Get the argument type
with self.chk.local_type_map:
with self.msg.filter_errors():
arg_type = get_proper_type(self.accept(args[0], type_context=None))

if isinstance(arg_type, CallableType) and self.has_constrained_typevars(arg_type):
return arg_type
return None

def has_constrained_typevars(self, t: Type) -> bool:
"""Check if a type is a CallableType with constrained TypeVars."""
t = get_proper_type(t)
if not isinstance(t, CallableType):
return False
return any(isinstance(tv, TypeVarType) and tv.values for tv in t.variables)

def handle_overload_decorator_constrained_call(
self, callee: Overloaded, constrained_arg: CallableType, ctx: Context
) -> tuple[Type, Type] | None:
"""Handle application of overloaded decorator to constrained TypeVar function.

Expands the constrained function to its variants, matches each against
the overload, and combines results into an Overloaded type.
"""
# Expand the constrained callable into its variants
variants = mypy.checker.expand_callable_variants(constrained_arg)
if len(variants) <= 1:
return None # No expansion needed, fall back to normal handling

results: list[CallableType] = []
inferred_args: list[CallableType] = []

for variant in variants:
arg = TempNode(typ=variant)
# Try matching this variant against the overloaded decorator
with self.msg.filter_errors() as err:
result, _ = self.check_call(callee, [arg], [ARG_POS], ctx)
if err.has_new_errors():
# This variant doesn't match any overload
continue
p_result = get_proper_type(result)
if isinstance(p_result, CallableType):
results.append(p_result)
inferred_args.append(variant)

if not results:
return None

if len(results) == 1:
return results[0], inferred_args[0]

# Combine results into an Overloaded type
return Overloaded(results), Overloaded(inferred_args)

def check_call_expr_with_callee_type(
self,
callee_type: Type,
Expand Down Expand Up @@ -1596,6 +1672,14 @@ def check_call(
object_type,
)
elif isinstance(callee, Overloaded):
# Check for constrained TypeVar argument (reverse of decorator overload call)
constrained_arg = self.is_overload_decorator_constrained_call(callee, args)
if constrained_arg is not None:
result = self.handle_overload_decorator_constrained_call(
callee, constrained_arg, context
)
if result is not None:
return result
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
Expand Down Expand Up @@ -2082,6 +2166,25 @@ def infer_function_type_arguments(

Return a derived callable type that has the arguments applied.
"""

def find_matching_typevar(
arg_tv: TypeVarType, callee_vars: Sequence[TypeVarLikeType], used: set[int]
) -> int | None:
"""Find index of matching constrained TypeVar in callee, or None."""
arg_constraints = frozenset(get_proper_type(v) for v in arg_tv.values)
for idx, cv in enumerate(callee_vars):
if idx in used:
continue
if not isinstance(cv, TypeVarType) or not cv.values:
continue
if frozenset(get_proper_type(v) for v in cv.values) == arg_constraints:
return idx
return None

# Save the original callee_type variables before inference modifies them
original_callee_variables = list(callee_type.variables)
constrained_typevar_map: dict[int, TypeVarType] = {}

if self.chk.in_checked_function():
# Disable type errors during type inference. There may be errors
# due to partial available context information at this time, but
Expand Down Expand Up @@ -2113,6 +2216,27 @@ def infer_function_type_arguments(
strict=self.chk.in_checked_function(),
)

# Preserve constrained TypeVars when passing functions with constrained TypeVars
# to decorators. When the argument has constrained TypeVars, save the decorator's
# corresponding TypeVar (from original_callee_variables) to apply at the end.
used_callee_var_indices: set[int] = set()
for arg_type in arg_types:
arg_type_proper = get_proper_type(arg_type)
if not isinstance(arg_type_proper, CallableType):
continue
for arg_typevar in arg_type_proper.variables:
if not (isinstance(arg_typevar, TypeVarType) and arg_typevar.values):
continue
# Found a constrained TypeVar in the argument - find matching decorator TypeVar
match_idx = find_matching_typevar(
arg_typevar, original_callee_variables, used_callee_var_indices
)
if match_idx is not None:
callee_var = original_callee_variables[match_idx]
assert isinstance(callee_var, TypeVarType)
constrained_typevar_map[match_idx] = callee_var
used_callee_var_indices.add(match_idx)

if 2 in arg_pass_nums:
# Second pass of type inference.
(callee_type, inferred_args) = self.infer_function_type_arguments_pass2(
Expand Down Expand Up @@ -2203,8 +2327,91 @@ def infer_function_type_arguments(
# In dynamically typed functions use implicit 'Any' types for
# type variables.
inferred_args = [AnyType(TypeOfAny.unannotated)] * len(callee_type.variables)

# Apply constrained TypeVars saved earlier
for callee_var_idx, constrained_tv in constrained_typevar_map.items():
if callee_var_idx >= len(inferred_args):
continue
inferred = inferred_args[callee_var_idx]
# Replace if inference failed
if inferred is None or isinstance(get_proper_type(inferred), UninhabitedType):
inferred_args[callee_var_idx] = constrained_tv
continue
# Replace if collapsed to a single constraint
if not constrained_tv.values:
continue
inferred_proper = get_proper_type(inferred)
if any(inferred_proper == get_proper_type(c) for c in constrained_tv.values):
inferred_args[callee_var_idx] = constrained_tv

# If we have constrained TypeVars, try polymorphic application
has_constrained = any(isinstance(a, TypeVarType) and a.values for a in inferred_args if a)
if constrained_typevar_map and has_constrained:
free_tvars = [a for a in inferred_args if isinstance(a, TypeVarType)]
poly_callee = self.apply_generic_arguments(callee_type, inferred_args, context)
poly_result = applytype.apply_poly(poly_callee, free_tvars)
if poly_result is not None:
poly_result = self._finalize_poly_result(
poly_result, arg_types, free_tvars, constrained_typevar_map
)
freeze_all_type_vars(poly_result)
return poly_result

return self.apply_inferred_arguments(callee_type, inferred_args, context)

def _finalize_poly_result(
self,
poly_result: CallableType,
arg_types: list[Type],
free_tvars: list[TypeVarType],
constrained_typevar_map: dict[int, TypeVarType],
) -> CallableType:
"""Finalize polymorphic result by preserving arg names and substituting TypeVars."""
ret_type = poly_result.ret_type
if not isinstance(ret_type, CallableType):
return poly_result

first_arg_type = get_proper_type(arg_types[0]) if arg_types else None
if not isinstance(first_arg_type, CallableType):
return poly_result

# Copy arg_names if arity matches
if first_arg_type.arg_names and len(ret_type.arg_types) == len(first_arg_type.arg_types):
ret_type = ret_type.copy_modified(arg_names=first_arg_type.arg_names)

# Substitute decorated function's TypeVars with decorator's TypeVars
if first_arg_type.variables:
variables_map = self._build_typevar_substitution_map(
first_arg_type, free_tvars, constrained_typevar_map
)
if variables_map:
ret_type = expand_type(ret_type, variables_map)

return poly_result.copy_modified(ret_type=ret_type)

def _build_typevar_substitution_map(
self,
first_arg_type: CallableType,
free_tvars: list[TypeVarType],
constrained_typevar_map: dict[int, TypeVarType],
) -> dict[TypeVarId, Type]:
"""Build substitution map from decorated function's TypeVars to decorator's TypeVars."""
variables_map: dict[TypeVarId, Type] = {}
for callee_idx in constrained_typevar_map:
if callee_idx >= len(free_tvars):
continue
decorator_typevar = free_tvars[callee_idx]
if not decorator_typevar.values:
continue
decorator_constraints = frozenset(get_proper_type(v) for v in decorator_typevar.values)
for func_tv in first_arg_type.variables:
if not (isinstance(func_tv, TypeVarType) and func_tv.values):
continue
func_constraints = frozenset(get_proper_type(v) for v in func_tv.values)
if func_constraints == decorator_constraints:
variables_map[func_tv.id] = decorator_typevar
return variables_map

def infer_function_type_arguments_pass2(
self,
callee_type: CallableType,
Expand Down
20 changes: 20 additions & 0 deletions test-data/unit/check-async-await.test
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,26 @@ reveal_type(coro) # N: Revealed type is "def () -> typing.AwaitableGenerator[bu
[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]

[case testCoroutineDecoratorWithConstrainedTypeVar]
# Test that @types.coroutine preserves both the AwaitableGenerator wrapper
# and the constrained TypeVar polymorphic signature.
from typing import Generator, TypeVar
from types import coroutine

T = TypeVar("T", int, str)

@coroutine
def process(data: T) -> Generator[T, None, T]:
yield data
return data

# Expected: AwaitableGenerator with polymorphic T preserved
reveal_type(process) # N: Revealed type is "def [T in (builtins.int, builtins.str)] (data: T`-1) -> typing.AwaitableGenerator[T`-1, None, T`-1, typing.Generator[T`-1, None, T`-1]]"
reveal_type(process(1)) # N: Revealed type is "typing.AwaitableGenerator[builtins.int, None, builtins.int, typing.Generator[builtins.int, None, builtins.int]]"
reveal_type(process("hello")) # N: Revealed type is "typing.AwaitableGenerator[builtins.str, None, builtins.str, typing.Generator[builtins.str, None, builtins.str]]"
[builtins fixtures/async_await.pyi]
[typing fixtures/typing-async.pyi]

[case asyncIteratorInProtocol]
from typing import AsyncIterator, Protocol

Expand Down
Loading