diff --git a/packages/pyright-internal/src/analyzer/operations.ts b/packages/pyright-internal/src/analyzer/operations.ts index d4ed00b39d5c..02aea0fd4e59 100644 --- a/packages/pyright-internal/src/analyzer/operations.ts +++ b/packages/pyright-internal/src/analyzer/operations.ts @@ -765,6 +765,33 @@ export function getTypeOfUnaryOperation( return { type, isIncomplete, magicMethodDeprecationInfo: deprecatedInfo }; } +// Helper function to check if an expression is a simple name or `not ` with a literal bool type. +// We avoid narrowing for these cases because the variable could be reassigned. +function isBoolLiteralName(expr: ExpressionNode, exprType: Type, evaluator: TypeEvaluator): boolean { + // Check for simple name references + if (expr.nodeType === ParseNodeType.Name) { + return ( + isClassInstance(exprType) && + ClassType.isBuiltIn(exprType, 'bool') && + exprType.priv.literalValue !== undefined + ); + } + + // Check for `not ` expressions + if (expr.nodeType === ParseNodeType.UnaryOperation && expr.d.operator === OperatorType.Not) { + // Evaluate the inner expression's type (not the `not` expression's type). + // Note: makeTopLevelTypeVarsConcrete is needed here for the isClassInstance/ + // ClassType.isBuiltIn checks, unlike the outer call site which was removed + // because canBeTruthy/canBeFalsy do it internally. + const innerType = evaluator.makeTopLevelTypeVarsConcrete( + evaluator.getTypeOfExpression(expr.d.expr).type + ); + return isBoolLiteralName(expr.d.expr, innerType, evaluator); + } + + return false; +} + export function getTypeOfTernaryOperation( evaluator: TypeEvaluator, node: TernaryNode, @@ -778,7 +805,9 @@ export function getTypeOfTernaryOperation( return { type: UnknownType.create() }; } - evaluator.getTypeOfExpression(node.d.testExpr); + // Get the narrowed type of the test expression at this point in the code flow. + const testExprTypeResult = evaluator.getTypeOfExpression(node.d.testExpr); + const testExprType = testExprTypeResult.type; const typesToCombine: Type[] = []; let isIncomplete = false; @@ -790,7 +819,19 @@ export function getTypeOfTernaryOperation( fileInfo.definedConstants ); - if (constExprValue !== false && evaluator.isNodeReachable(node.d.ifExpr)) { + // Check if we should apply flow-sensitive narrowing. We avoid narrowing for + // simple name references with literal bool types because the variable could + // be reassigned, even though the type is a literal. This also applies to + // `not ` expressions to maintain consistency. + // Note: This guard is specific to ternary expressions. The and/or operators + // don't need this guard because they operate on already-evaluated types from + // their operands, not on types that may have been narrowed by upstream flow analysis. + const shouldApplyNarrowing = !isBoolLiteralName(node.d.testExpr, testExprType, evaluator); + + // Determine if the if-branch is reachable based on static evaluation, + // general reachability, and flow-sensitive type narrowing. + const testCanBeTruthy = shouldApplyNarrowing ? evaluator.canBeTruthy(testExprType) : true; + if (constExprValue !== false && evaluator.isNodeReachable(node.d.ifExpr) && testCanBeTruthy) { const ifType = evaluator.getTypeOfExpression(node.d.ifExpr, flags, inferenceContext); typesToCombine.push(ifType.type); if (ifType.isIncomplete) { @@ -801,7 +842,10 @@ export function getTypeOfTernaryOperation( } } - if (constExprValue !== true && evaluator.isNodeReachable(node.d.elseExpr)) { + // Determine if the else-branch is reachable based on static evaluation, + // general reachability, and flow-sensitive type narrowing. + const testCanBeFalsy = shouldApplyNarrowing ? evaluator.canBeFalsy(testExprType) : true; + if (constExprValue !== true && evaluator.isNodeReachable(node.d.elseExpr) && testCanBeFalsy) { const elseType = evaluator.getTypeOfExpression(node.d.elseExpr, flags, inferenceContext); typesToCombine.push(elseType.type); if (elseType.isIncomplete) { diff --git a/packages/pyright-internal/src/tests/samples/conditionalExpr1.py b/packages/pyright-internal/src/tests/samples/conditionalExpr1.py new file mode 100644 index 000000000000..8c61ac760070 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/conditionalExpr1.py @@ -0,0 +1,106 @@ +# This sample tests type narrowing for conditional expressions (ternary operator) +# where the condition is narrowed such that one branch is known to be unreachable. + +from typing import assert_type + + +class Node: + pass + + +class Wrapper: + def __init__(self, child: Node): + self.child = child + + +class SpecialNode(Node): + pass + + +def func1(plan: Wrapper | None): + # After the guard, plan is known to be truthy (not None) and + # plan.child is known to be SpecialNode. + if not (plan and isinstance(plan.child, SpecialNode)): + return + + # This should be fine - direct assignment works. + ts1: SpecialNode = plan.child + + # This should also be fine - the else branch (None) is unreachable + # because plan is known to be truthy in this context. + ts2: SpecialNode = plan.child if plan else None + + # Also verify the inferred type. + assert_type(plan.child if plan else None, SpecialNode) + + +def func2(val: int | None): + # After this guard, val is known to be truthy (not None and not 0). + if not val: + return + + # The else branch is unreachable since val is known to be truthy. + # Using different types to make the test meaningful. + ts1: int = val if val else "fallback" + assert_type(val if val else "fallback", int) + + +def func3(val: str | None): + # After this guard, val is known to be None (falsy). + if val: + return + + # The if branch is unreachable since val is known to be falsy (None). + # Using different types to make the test meaningful - without pruning, + # this would be str | None instead of just None. + ts1: None = "unreachable" if val else None + assert_type("unreachable" if val else None, None) + + +def func4(val: int | None): + # After this guard, val is known to be not None (but could still be 0, which is falsy). + if val is None: + return + + # The else branch is still reachable since val could be 0 (falsy). + # This test verifies that we don't over-narrow. + ts1: int | str = val if val else "zero" + assert_type(val if val else "zero", int | str) + + +def func5(val: int | None): + # After this guard, val is known to be truthy (not None and not 0). + if not val: + return + + # The else branch is unreachable since val is known to be truthy. + # Using different types to make the test meaningful. + ts1: int = val if val else "fallback" + assert_type(val if val else "fallback", int) + + +def func6(val: list[int] | None): + # After this guard, val is known to be not None. + if val is None: + return + + # However, val could still be an empty list (falsy), so the else branch + # is still reachable. Both branches should contribute to the type. + ts1: list[int] | str = val if val else "empty" + assert_type(val if val else "empty", list[int] | str) + + +def func_bool_literal(): + # Test that the bool literal guard prevents over-narrowing for mutable variables. + maybe = True + # Both branches should remain since maybe could be reassigned. + ts: int | str = 1 if maybe else "no" + assert_type(1 if maybe else "no", int | str) + + +def func_bool_literal_not(): + # Test that the bool literal guard also applies to `not` expressions. + flag = True + # Both branches should remain since flag could be reassigned. + ts: int | str = 1 if not flag else "yes" + assert_type(1 if not flag else "yes", int | str) diff --git a/packages/pyright-internal/src/tests/typeEvaluator5.test.ts b/packages/pyright-internal/src/tests/typeEvaluator5.test.ts index bf35f9c0e94b..b1fe3aa6ad5d 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator5.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator5.test.ts @@ -315,6 +315,11 @@ test('Conditional1', () => { TestUtils.validateResults(analysisResults, 15); }); +test('ConditionalExpr1', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['conditionalExpr1.py']); + TestUtils.validateResults(analysisResults, 0); +}); + test('TypePrinter1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['typePrinter1.py']); TestUtils.validateResults(analysisResults, 0);