From da4944125cab1ab709bf61fd586ac9a374875b9b Mon Sep 17 00:00:00 2001 From: Antigravity Agent Date: Wed, 4 Feb 2026 17:26:23 +0400 Subject: [PATCH] Fix issue #875: Reorder IfExp fields to match evaluation order --- libcst/_exceptions.py | 1 - libcst/_nodes/expression.py | 7 +- libcst/_nodes/statement.py | 1 - libcst/_nodes/tests/test_cst_node.py | 6 +- libcst/_nodes/tests/test_ifexp.py | 4 +- libcst/_nodes/tests/test_module.py | 1 - libcst/_parser/conversions/expression.py | 14 +- libcst/_parser/conversions/statement.py | 30 +- libcst/_parser/parso/tests/test_tokenize.py | 38 +-- libcst/_parser/tests/test_parse_errors.py | 72 ++--- libcst/_typed_visitor.py | 9 +- libcst/_typed_visitor_base.py | 1 - .../codemod/commands/add_trailing_commas.py | 7 +- .../tests/test_fix_pyre_directives.py | 12 +- .../visitors/_apply_type_annotations.py | 2 - .../visitors/tests/test_gather_comments.py | 12 +- .../test_gather_string_annotation_names.py | 42 +-- .../tests/test_gather_unused_imports.py | 54 ++-- libcst/display/graphviz.py | 9 +- libcst/display/tests/test_dump_graphviz.py | 8 +- libcst/helpers/tests/test_node_fields.py | 1 - libcst/helpers/tests/test_template.py | 12 +- libcst/matchers/__init__.py | 5 +- libcst/matchers/_return_types.py | 2 - libcst/matchers/tests/test_decorators.py | 150 +++------ .../tests/test_matchers_with_metadata.py | 24 +- libcst/metadata/accessor_provider.py | 1 - .../metadata/tests/test_accessor_provider.py | 1 - libcst/metadata/tests/test_name_provider.py | 68 ++-- libcst/metadata/tests/test_scope_provider.py | 300 ++++++------------ libcst/tests/__main__.py | 1 - libcst/tests/test_add_slots.py | 1 - libcst/tests/test_exceptions.py | 18 +- libcst/tests/test_roundtrip.py | 1 - 34 files changed, 307 insertions(+), 608 deletions(-) diff --git a/libcst/_exceptions.py b/libcst/_exceptions.py index 4d3dd3865..33b9dab8c 100644 --- a/libcst/_exceptions.py +++ b/libcst/_exceptions.py @@ -8,7 +8,6 @@ from libcst._tabs import expand_tabs - _NEWLINE_CHARS: str = "\r\n" diff --git a/libcst/_nodes/expression.py b/libcst/_nodes/expression.py index eb95d9b3c..a66fde91e 100644 --- a/libcst/_nodes/expression.py +++ b/libcst/_nodes/expression.py @@ -18,7 +18,6 @@ from typing import Callable, Generator, Literal, Optional, Sequence, Union from libcst import CSTLogicError - from libcst._add_slots import add_slots from libcst._maybe_sentinel import MaybeSentinel from libcst._nodes.base import CSTCodegenError, CSTNode, CSTValidationError @@ -2741,12 +2740,12 @@ class IfExp(BaseExpression): If statements are provided by :class:`If` and :class:`Else` nodes. """ - #: The test to perform. - test: BaseExpression - #: The expression to evaluate when the test is true. body: BaseExpression + #: The test to perform. + test: BaseExpression + #: The expression to evaluate when the test is false. orelse: BaseExpression diff --git a/libcst/_nodes/statement.py b/libcst/_nodes/statement.py index 5546f1432..b921d2810 100644 --- a/libcst/_nodes/statement.py +++ b/libcst/_nodes/statement.py @@ -10,7 +10,6 @@ from typing import Literal, Optional, Pattern, Sequence, Union from libcst import CSTLogicError - from libcst._add_slots import add_slots from libcst._maybe_sentinel import MaybeSentinel from libcst._nodes.base import CSTNode, CSTValidationError diff --git a/libcst/_nodes/tests/test_cst_node.py b/libcst/_nodes/tests/test_cst_node.py index 8cc30dc63..d3eb7a50b 100644 --- a/libcst/_nodes/tests/test_cst_node.py +++ b/libcst/_nodes/tests/test_cst_node.py @@ -146,8 +146,7 @@ def test_repr(self) -> None: ), ) ), - dedent( - """ + dedent(""" SimpleStatementLine( body=[ Pass( @@ -188,8 +187,7 @@ def test_repr(self) -> None: ), ), ) - """ - ).strip(), + """).strip(), ) def test_visit(self) -> None: diff --git a/libcst/_nodes/tests/test_ifexp.py b/libcst/_nodes/tests/test_ifexp.py index dd260ef34..8eee00631 100644 --- a/libcst/_nodes/tests/test_ifexp.py +++ b/libcst/_nodes/tests/test_ifexp.py @@ -114,8 +114,8 @@ def test_valid( ( ( lambda: cst.IfExp( - cst.Name("bar"), cst.Name("foo"), + cst.Name("bar"), cst.Name("baz"), lpar=(cst.LeftParen(),), ), @@ -123,8 +123,8 @@ def test_valid( ), ( lambda: cst.IfExp( - cst.Name("bar"), cst.Name("foo"), + cst.Name("bar"), cst.Name("baz"), rpar=(cst.RightParen(),), ), diff --git a/libcst/_nodes/tests/test_module.py b/libcst/_nodes/tests/test_module.py index 40de8f8e5..e4d404716 100644 --- a/libcst/_nodes/tests/test_module.py +++ b/libcst/_nodes/tests/test_module.py @@ -8,7 +8,6 @@ import libcst as cst from libcst import parse_module, parse_statement from libcst._nodes.tests.base import CSTNodeTest - from libcst.metadata import CodeRange, MetadataWrapper, PositionProvider from libcst.testing.utils import data_provider diff --git a/libcst/_parser/conversions/expression.py b/libcst/_parser/conversions/expression.py index 79d7ad783..1c5c906ba 100644 --- a/libcst/_parser/conversions/expression.py +++ b/libcst/_parser/conversions/expression.py @@ -159,7 +159,7 @@ def convert_expression_input( config: ParserConfig, children: typing.Sequence[typing.Any] ) -> typing.Any: - (child, endmarker) = children + child, endmarker = children # HACK: UGLY! REMOVE THIS SOON! # Unwrap WithLeadingWhitespace if it exists. It shouldn't exist by this point, but # testlist isn't fully implemented, and we currently leak these partial objects. @@ -177,7 +177,7 @@ def convert_namedexpr_test( return test # Convert all of the operations that have no precedence in a loop - (walrus, value) = assignment + walrus, value = assignment return WithLeadingWhitespace( NamedExpr( target=test.value, @@ -201,7 +201,7 @@ def convert_test( (child,) = children return child else: - (body, if_token, test, else_token, orelse) = children + body, if_token, test, else_token, orelse = children return WithLeadingWhitespace( IfExp( body=body.value, @@ -718,7 +718,7 @@ def convert_trailer_arglist( def convert_trailer_subscriptlist( config: ParserConfig, children: typing.Sequence[typing.Any] ) -> typing.Any: - (lbracket, subscriptlist, rbracket) = children + lbracket, subscriptlist, rbracket = children return SubscriptPartial( lbracket=LeftSquareBracket( whitespace_after=parse_parenthesizable_whitespace( @@ -1556,7 +1556,7 @@ def convert_comp_for( (sync_comp_for,) = children return sync_comp_for else: - (async_tok, sync_comp_for) = children + async_tok, sync_comp_for = children return sync_comp_for.with_changes( # asynchronous steals the `CompFor`'s `whitespace_before`. asynchronous=Asynchronous(whitespace_after=sync_comp_for.whitespace_before), @@ -1594,7 +1594,7 @@ def convert_yield_expr( yield_node = Yield(value=None) else: # Yielding explicit value - (yield_token, yield_arg) = children + yield_token, yield_arg = children yield_node = Yield( value=yield_arg.value, whitespace_after_yield=parse_parenthesizable_whitespace( @@ -1617,7 +1617,7 @@ def convert_yield_arg( return child else: # Its a yield from - (from_token, test) = children + from_token, test = children return WithLeadingWhitespace( From( diff --git a/libcst/_parser/conversions/statement.py b/libcst/_parser/conversions/statement.py index f96c6ea21..6e9c0be27 100644 --- a/libcst/_parser/conversions/statement.py +++ b/libcst/_parser/conversions/statement.py @@ -121,7 +121,7 @@ @with_production("stmt_input", "stmt ENDMARKER") def convert_stmt_input(config: ParserConfig, children: Sequence[Any]) -> Any: - (child, endmarker) = children + child, endmarker = children return child @@ -359,7 +359,7 @@ def convert_pass_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: @with_production("del_stmt", "'del' exprlist") def convert_del_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: - (del_name, exprlist) = children + del_name, exprlist = children return WithLeadingWhitespace( Del( target=exprlist.value, @@ -393,7 +393,7 @@ def convert_return_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: keyword.whitespace_before, ) else: - (keyword, testlist) = children + keyword, testlist = children return WithLeadingWhitespace( Return( value=testlist.value, @@ -635,12 +635,12 @@ def convert_raise_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: exc = None cause = None elif len(children) == 2: - (raise_token, test) = children + raise_token, test = children whitespace_after_raise = parse_simple_whitespace(config, test.whitespace_before) exc = test.value cause = None elif len(children) == 4: - (raise_token, test, from_token, source) = children + raise_token, test, from_token, source = children whitespace_after_raise = parse_simple_whitespace(config, test.whitespace_before) exc = test.value cause = From( @@ -685,7 +685,7 @@ def _construct_nameitems(config: ParserConfig, names: Sequence[Any]) -> List[Nam @with_production("global_stmt", "'global' NAME (',' NAME)*") def convert_global_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: - (global_token, *names) = children + global_token, *names = children return WithLeadingWhitespace( Global( names=tuple(_construct_nameitems(config, names)), @@ -699,7 +699,7 @@ def convert_global_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: @with_production("nonlocal_stmt", "'nonlocal' NAME (',' NAME)*") def convert_nonlocal_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: - (nonlocal_token, *names) = children + nonlocal_token, *names = children return WithLeadingWhitespace( Nonlocal( names=tuple(_construct_nameitems(config, names)), @@ -714,7 +714,7 @@ def convert_nonlocal_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: @with_production("assert_stmt", "'assert' test [',' test]") def convert_assert_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: if len(children) == 2: - (assert_token, test) = children + assert_token, test = children assert_node = Assert( whitespace_after_assert=parse_simple_whitespace( config, test.whitespace_before @@ -723,7 +723,7 @@ def convert_assert_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: msg=None, ) else: - (assert_token, test, comma_token, msg) = children + assert_token, test, comma_token, msg = children assert_node = Assert( whitespace_after_assert=parse_simple_whitespace( config, test.whitespace_before @@ -814,7 +814,7 @@ def convert_while_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: while_token, test, while_colon_token, while_suite, *else_block = children if len(else_block) > 0: - (else_token, else_colon_token, else_suite) = else_block + else_token, else_colon_token, else_suite = else_block orelse = Else( leading_lines=parse_empty_lines(config, else_token.whitespace_before), whitespace_before_colon=parse_simple_whitespace( @@ -854,7 +854,7 @@ def convert_for_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: ) = children if len(else_block) > 0: - (else_token, else_colon_token, else_suite) = else_block + else_token, else_colon_token, else_suite = else_block orelse = Else( leading_lines=parse_empty_lines(config, else_token.whitespace_before), whitespace_before_colon=parse_simple_whitespace( @@ -958,14 +958,14 @@ def convert_except_clause(config: ParserConfig, children: Sequence[Any]) -> Any: test = None name = None elif len(children) == 2: - (except_token, test_node) = children + except_token, test_node = children whitespace_after_except = parse_simple_whitespace( config, except_token.whitespace_after ) test = test_node.value name = None else: - (except_token, test_node, as_token, name_token) = children + except_token, test_node, as_token, name_token = children whitespace_after_except = parse_simple_whitespace( config, except_token.whitespace_after ) @@ -993,7 +993,7 @@ def convert_except_clause(config: ParserConfig, children: Sequence[Any]) -> Any: ) @with_production("with_stmt", "'with' with_item ':' suite", version="<3.1") def convert_with_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: - (with_token, *items, colon_token, suite) = children + with_token, *items, colon_token, suite = children item_nodes: List[WithItem] = [] for with_item, maybe_comma in grouper(items, 2): @@ -1031,7 +1031,7 @@ def convert_with_stmt(config: ParserConfig, children: Sequence[Any]) -> Any: @with_production("with_item", "test ['as' expr]") def convert_with_item(config: ParserConfig, children: Sequence[Any]) -> Any: if len(children) == 3: - (test, as_token, expr_node) = children + test, as_token, expr_node = children test_node = test.value asname = AsName( whitespace_before_as=parse_simple_whitespace( diff --git a/libcst/_parser/parso/tests/test_tokenize.py b/libcst/_parser/parso/tests/test_tokenize.py index c8180047e..84dd35d0b 100644 --- a/libcst/_parser/parso/tests/test_tokenize.py +++ b/libcst/_parser/parso/tests/test_tokenize.py @@ -68,14 +68,12 @@ def test_simple_with_whitespace(self): def test_function_whitespace(self): # Test function definition whitespace identification - fundef = dedent( - """ + fundef = dedent(""" def test_whitespace(*args, **kwargs): x = 1 if x > 0: print(True) - """ - ) + """) token_list = _get_token_list(fundef) for _, value, _, prefix in token_list: if value == "test_whitespace": @@ -122,12 +120,10 @@ def test_tokenize_multiline_III(self): ] def test_identifier_contains_unicode(self): - fundef = dedent( - """ + fundef = dedent(""" def 我あφ(): pass - """ - ) + """) token_list = _get_token_list(fundef) unicode_token = token_list[1] assert unicode_token[0] == NAME @@ -237,13 +233,11 @@ def test_error_string(self): assert endmarker.string == "" def test_indent_error_recovery(self): - code = dedent( - """\ + code = dedent("""\ str( from x import a def - """ - ) + """) lst = _get_token_list(code) expected = [ # `str(` @@ -268,13 +262,11 @@ def test_indent_error_recovery(self): assert [t.type for t in lst] == expected def test_error_token_after_dedent(self): - code = dedent( - """\ + code = dedent("""\ class C: pass $foo - """ - ) + """) lst = _get_token_list(code) expected = [ NAME, @@ -298,23 +290,17 @@ def test_brackets_no_indentation(self): There used to be an issue that the parentheses counting would go below zero. This should not happen. """ - code = dedent( - """\ + code = dedent("""\ } { } - """ - ) + """) lst = _get_token_list(code) assert [t.type for t in lst] == [OP, NEWLINE, OP, OP, NEWLINE, ENDMARKER] def test_form_feed(self): - error_token, endmarker = _get_token_list( - dedent( - '''\ - \f"""''' - ) - ) + error_token, endmarker = _get_token_list(dedent('''\ + \f"""''')) assert error_token.prefix == "\f" assert error_token.string == '"""' assert endmarker.prefix == "" diff --git a/libcst/_parser/tests/test_parse_errors.py b/libcst/_parser/tests/test_parse_errors.py index 7697893df..6e7d04216 100644 --- a/libcst/_parser/tests/test_parse_errors.py +++ b/libcst/_parser/tests/test_parse_errors.py @@ -19,150 +19,126 @@ class ParseErrorsTest(UnitTest): # _wrapped_tokenize raises these exceptions "wrapped_tokenize__invalid_token": ( lambda: cst.parse_module("'"), - dedent( - """ + dedent(""" Syntax Error @ 1:1. "'" is not a valid token. ' ^ - """ - ).strip(), + """).strip(), ), "wrapped_tokenize__expected_dedent": ( lambda: cst.parse_module("if False:\n pass\n pass"), - dedent( - """ + dedent(""" Syntax Error @ 3:1. Inconsistent indentation. Expected a dedent. pass ^ - """ - ).strip(), + """).strip(), ), "wrapped_tokenize__mismatched_braces": ( lambda: cst.parse_module("abcd)"), - dedent( - """ + dedent(""" Syntax Error @ 1:5. Encountered a closing brace without a matching opening brace. abcd) ^ - """ - ).strip(), + """).strip(), ), # _base_parser raises these exceptions "base_parser__unexpected_indent": ( lambda: cst.parse_module(" abcd"), - dedent( - """ + dedent(""" Syntax Error @ 1:5. Incomplete input. Unexpectedly encountered an indent. abcd ^ - """ - ).strip(), + """).strip(), ), "base_parser__unexpected_dedent": ( lambda: cst.parse_module("if False:\n (el for el\n"), - dedent( - """ + dedent(""" Syntax Error @ 3:1. Incomplete input. Encountered a dedent, but expected 'in'. (el for el ^ - """ - ).strip(), + """).strip(), ), "base_parser__multiple_possibilities": ( lambda: cst.parse_module("try: pass"), - dedent( - """ + dedent(""" Syntax Error @ 2:1. Incomplete input. Encountered end of file (EOF), but expected 'except', or 'finally'. try: pass ^ - """ - ).strip(), + """).strip(), ), # conversion functions raise these exceptions. # `_base_parser` is responsible for attaching location information. "convert_nonterminal__dict_unpacking": ( lambda: cst.parse_expression("{**el for el in []}"), - dedent( - """ + dedent(""" Syntax Error @ 1:19. dict unpacking cannot be used in dict comprehension {**el for el in []} ^ - """ - ).strip(), + """).strip(), ), "convert_nonterminal__arglist_non_default_after_default": ( lambda: cst.parse_statement("def fn(first=None, second): ..."), - dedent( - """ + dedent(""" Syntax Error @ 1:26. Cannot have a non-default argument following a default argument. def fn(first=None, second): ... ^ - """ - ).strip(), + """).strip(), ), "convert_nonterminal__arglist_trailing_param_star_without_comma": ( lambda: cst.parse_statement("def fn(abc, *): ..."), - dedent( - """ + dedent(""" Syntax Error @ 1:14. Named (keyword) arguments must follow a bare *. def fn(abc, *): ... ^ - """ - ).strip(), + """).strip(), ), "convert_nonterminal__arglist_trailing_param_star_with_comma": ( lambda: cst.parse_statement("def fn(abc, *,): ..."), - dedent( - """ + dedent(""" Syntax Error @ 1:15. Named (keyword) arguments must follow a bare *. def fn(abc, *,): ... ^ - """ - ).strip(), + """).strip(), ), "convert_nonterminal__class_arg_positional_after_keyword": ( lambda: cst.parse_statement("class Cls(first=None, second): ..."), - dedent( - """ + dedent(""" Syntax Error @ 2:1. Positional argument follows keyword argument. class Cls(first=None, second): ... ^ - """ - ).strip(), + """).strip(), ), "convert_nonterminal__class_arg_positional_expansion_after_keyword": ( lambda: cst.parse_statement("class Cls(first=None, *second): ..."), - dedent( - """ + dedent(""" Syntax Error @ 2:1. Positional argument follows keyword argument. class Cls(first=None, *second): ... ^ - """ - ).strip(), + """).strip(), ), } ) diff --git a/libcst/_typed_visitor.py b/libcst/_typed_visitor.py index 8816f619e..70b60e7de 100644 --- a/libcst/_typed_visitor.py +++ b/libcst/_typed_visitor.py @@ -12,7 +12,6 @@ from libcst._removal_sentinel import RemovalSentinel from libcst._typed_visitor_base import mark_no_op - if TYPE_CHECKING: from libcst._nodes.expression import ( # noqa: F401 Annotation, @@ -2551,19 +2550,19 @@ def visit_IfExp(self, node: "IfExp") -> Optional[bool]: pass @mark_no_op - def visit_IfExp_test(self, node: "IfExp") -> None: + def visit_IfExp_body(self, node: "IfExp") -> None: pass @mark_no_op - def leave_IfExp_test(self, node: "IfExp") -> None: + def leave_IfExp_body(self, node: "IfExp") -> None: pass @mark_no_op - def visit_IfExp_body(self, node: "IfExp") -> None: + def visit_IfExp_test(self, node: "IfExp") -> None: pass @mark_no_op - def leave_IfExp_body(self, node: "IfExp") -> None: + def leave_IfExp_test(self, node: "IfExp") -> None: pass @mark_no_op diff --git a/libcst/_typed_visitor_base.py b/libcst/_typed_visitor_base.py index de751a158..5a218347e 100644 --- a/libcst/_typed_visitor_base.py +++ b/libcst/_typed_visitor_base.py @@ -5,7 +5,6 @@ from typing import Any, Callable, cast, TypeVar - # pyre-fixme[24]: Generic type `Callable` expects 2 type parameters. F = TypeVar("F", bound=Callable) diff --git a/libcst/codemod/commands/add_trailing_commas.py b/libcst/codemod/commands/add_trailing_commas.py index 2f33a4bd7..219829eb7 100644 --- a/libcst/codemod/commands/add_trailing_commas.py +++ b/libcst/codemod/commands/add_trailing_commas.py @@ -10,7 +10,6 @@ import libcst as cst from libcst.codemod import CodemodContext, VisitorBasedCodemodCommand - presets_per_formatter: Dict[str, Dict[str, int]] = { "black": { "parameter_count": 1, @@ -24,8 +23,7 @@ class AddTrailingCommas(VisitorBasedCodemodCommand): - DESCRIPTION: str = textwrap.dedent( - """ + DESCRIPTION: str = textwrap.dedent(""" Codemod that adds trailing commas to arguments in function headers and function calls. @@ -38,8 +36,7 @@ class AddTrailingCommas(VisitorBasedCodemodCommand): Applying this codemod (and then an autoformatter) may make it easier to read function definitions and calls - """ - ) + """) def __init__( self, diff --git a/libcst/codemod/commands/tests/test_fix_pyre_directives.py b/libcst/codemod/commands/tests/test_fix_pyre_directives.py index 4707073a9..c807a1c5e 100644 --- a/libcst/codemod/commands/tests/test_fix_pyre_directives.py +++ b/libcst/codemod/commands/tests/test_fix_pyre_directives.py @@ -14,9 +14,7 @@ def test_no_need_to_fix_simple(self) -> None: """ Tests that a pyre-strict inside the module header doesn't get touched. """ - after = ( - before - ) = """ + after = before = """ # pyre-strict from typing import List @@ -29,9 +27,7 @@ def test_no_need_to_fix_complex_bottom(self) -> None: """ Tests that a pyre-strict inside the module header doesn't get touched. """ - after = ( - before - ) = """ + after = before = """ # This is some header comment. # # pyre-strict @@ -46,9 +42,7 @@ def test_no_need_to_fix_complex_top(self) -> None: """ Tests that a pyre-strict inside the module header doesn't get touched. """ - after = ( - before - ) = """ + after = before = """ # pyre-strict # # This is some header comment. diff --git a/libcst/codemod/visitors/_apply_type_annotations.py b/libcst/codemod/visitors/_apply_type_annotations.py index 593474204..0ae82127b 100644 --- a/libcst/codemod/visitors/_apply_type_annotations.py +++ b/libcst/codemod/visitors/_apply_type_annotations.py @@ -9,7 +9,6 @@ import libcst as cst import libcst.matchers as m - from libcst.codemod._context import CodemodContext from libcst.codemod._visitor import ContextAwareTransformer from libcst.codemod.visitors._add_imports import AddImportsVisitor @@ -19,7 +18,6 @@ from libcst.helpers import get_full_name_for_node from libcst.metadata import PositionProvider, QualifiedNameProvider - NameOrAttribute = Union[cst.Name, cst.Attribute] NAME_OR_ATTRIBUTE = (cst.Name, cst.Attribute) # Union type for *args and **args diff --git a/libcst/codemod/visitors/tests/test_gather_comments.py b/libcst/codemod/visitors/tests/test_gather_comments.py index 725118424..3ac59ad13 100644 --- a/libcst/codemod/visitors/tests/test_gather_comments.py +++ b/libcst/codemod/visitors/tests/test_gather_comments.py @@ -20,25 +20,21 @@ def gather_comments(self, code: str) -> GatherCommentsVisitor: return instance def test_no_comments(self) -> None: - visitor = self.gather_comments( - """ + visitor = self.gather_comments(""" def foo() -> None: pass - """ - ) + """) self.assertEqual(visitor.comments, {}) def test_noqa_comments(self) -> None: - visitor = self.gather_comments( - """ + visitor = self.gather_comments(""" import a.b.c # noqa import d # somethingelse # noqa def foo() -> None: pass - """ - ) + """) self.assertEqual(visitor.comments.keys(), {1, 4}) self.assertTrue(isinstance(visitor.comments[1], Comment)) self.assertEqual(visitor.comments[1].value, "# noqa") diff --git a/libcst/codemod/visitors/tests/test_gather_string_annotation_names.py b/libcst/codemod/visitors/tests/test_gather_string_annotation_names.py index d3c622a3d..c0d699073 100644 --- a/libcst/codemod/visitors/tests/test_gather_string_annotation_names.py +++ b/libcst/codemod/visitors/tests/test_gather_string_annotation_names.py @@ -18,44 +18,35 @@ def gather_names(self, code: str) -> GatherNamesFromStringAnnotationsVisitor: return instance def test_no_annotations(self) -> None: - visitor = self.gather_names( - """ + visitor = self.gather_names(""" def foo() -> None: pass - """ - ) + """) self.assertEqual(visitor.names, set()) def test_simple_string_annotations(self) -> None: - visitor = self.gather_names( - """ + visitor = self.gather_names(""" def foo() -> "None": pass - """ - ) + """) self.assertEqual(visitor.names, {"None"}) def test_concatenated_string_annotations(self) -> None: - visitor = self.gather_names( - """ + visitor = self.gather_names(""" def foo() -> "No" "ne": pass - """ - ) + """) self.assertEqual(visitor.names, {"None"}) def test_typevars(self) -> None: - visitor = self.gather_names( - """ + visitor = self.gather_names(""" from typing import TypeVar as SneakyBastard V = SneakyBastard("V", bound="int") - """ - ) + """) self.assertEqual(visitor.names, {"V", "int"}) def test_complex(self) -> None: - visitor = self.gather_names( - """ + visitor = self.gather_names(""" from typing import TypeVar, TYPE_CHECKING if TYPE_CHECKING: from a import Container, Item @@ -64,30 +55,25 @@ def foo(a: "A") -> "Item": A = TypeVar("A", bound="Container[Item]") class X: var: "ThisIsExpensiveToImport" # noqa - """ - ) + """) self.assertEqual( visitor.names, {"A", "Item", "Container", "ThisIsExpensiveToImport"} ) def test_dotted_names(self) -> None: - visitor = self.gather_names( - """ + visitor = self.gather_names(""" a: "api.http_exceptions.HttpException" - """ - ) + """) self.assertEqual( visitor.names, {"api", "api.http_exceptions", "api.http_exceptions.HttpException"}, ) def test_literals(self) -> None: - visitor = self.gather_names( - """ + visitor = self.gather_names(""" from typing import Literal a: Literal["in"] b: list[Literal["1x"]] c: Literal["Any"] - """ - ) + """) self.assertEqual(visitor.names, set()) diff --git a/libcst/codemod/visitors/tests/test_gather_unused_imports.py b/libcst/codemod/visitors/tests/test_gather_unused_imports.py index e6e0d9bbb..fbce49824 100644 --- a/libcst/codemod/visitors/tests/test_gather_unused_imports.py +++ b/libcst/codemod/visitors/tests/test_gather_unused_imports.py @@ -23,36 +23,29 @@ def gather_imports(self, code: str) -> Set[str]: } def test_no_imports(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" foo = 1 - """ - ) + """) self.assertEqual(imports, set()) def test_dotted_imports(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" import a.b.c, d import x.y a.b(d) - """ - ) + """) self.assertEqual(imports, {"x.y"}) def test_alias(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" from bar import baz as baz_alias import bar as bar_alias bar_alias() - """ - ) + """) self.assertEqual(imports, {"baz_alias"}) def test_import_complex(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" import bar import baz, qux import a.b @@ -65,13 +58,11 @@ def foo() -> None: c.d(qux) x.u j() - """ - ) + """) self.assertEqual(imports, {"bar", "baz", "a.b", "g"}) def test_import_from_complex(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" from bar import qux, quux from a.b import c from d.e import f @@ -82,22 +73,18 @@ def test_import_from_complex(self) -> None: def foo() -> None: f(qux) k() - """ - ) + """) self.assertEqual(imports, {"quux", "c", "o"}) def test_exports(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" import a __all__ = ["a"] - """ - ) + """) self.assertEqual(imports, set()) def test_string_annotation(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" from a import b from c import d import m, n.blah @@ -105,24 +92,19 @@ def test_string_annotation(self) -> None: bar: List["d"] quux: List["m.blah"] alma: List["n.blah"] - """ - ) + """) self.assertEqual(imports, set()) def test_typevars(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" from typing import TypeVar as Sneaky from a import b t = Sneaky("t", bound="b") - """ - ) + """) self.assertEqual(imports, set()) def test_future(self) -> None: - imports = self.gather_imports( - """ + imports = self.gather_imports(""" from __future__ import cool_feature - """ - ) + """) self.assertEqual(imports, set()) diff --git a/libcst/display/graphviz.py b/libcst/display/graphviz.py index e6b5b7481..305ee6fdf 100644 --- a/libcst/display/graphviz.py +++ b/libcst/display/graphviz.py @@ -11,7 +11,6 @@ from libcst import CSTNode from libcst.helpers import filter_node_fields - _syntax_style = ', color="#777777", fillcolor="#eeeeee"' _value_style = ', color="#3e99ed", fillcolor="#b8d9f8"' @@ -145,8 +144,7 @@ def dump_graphviz( ``show_defaults``. """ - graphviz_settings = textwrap.dedent( - r""" + graphviz_settings = textwrap.dedent(r""" layout=dot; rankdir=TB; splines=line; @@ -170,10 +168,7 @@ def dump_graphviz( fontsize=12, penwidth=2, ]; - """[ - 1: - ] - ) + """[1:]) return "\n".join( ["digraph {", graphviz_settings] diff --git a/libcst/display/tests/test_dump_graphviz.py b/libcst/display/tests/test_dump_graphviz.py index 17ce231f0..1fc34af71 100644 --- a/libcst/display/tests/test_dump_graphviz.py +++ b/libcst/display/tests/test_dump_graphviz.py @@ -19,16 +19,12 @@ class CSTDumpGraphvizTest(UnitTest): """Check dump_graphviz contains CST nodes.""" - source_code: str = dedent( - r""" + source_code: str = dedent(r""" def foo(a: str) -> None: pass ; pass return - """[ - 1: - ] - ) + """[1:]) cst: Module @classmethod diff --git a/libcst/helpers/tests/test_node_fields.py b/libcst/helpers/tests/test_node_fields.py index 61d5ec21f..a8aa53fbe 100644 --- a/libcst/helpers/tests/test_node_fields.py +++ b/libcst/helpers/tests/test_node_fields.py @@ -17,7 +17,6 @@ Semicolon, SimpleStatementLine, ) - from libcst.helpers import ( get_node_fields, is_default_node_field, diff --git a/libcst/helpers/tests/test_template.py b/libcst/helpers/tests/test_template.py index cef82dde5..1c87cb796 100644 --- a/libcst/helpers/tests/test_template.py +++ b/libcst/helpers/tests/test_template.py @@ -33,27 +33,23 @@ def code(self, node: cst.CSTNode) -> str: def test_simple_module(self) -> None: module = parse_template_module( - self.dedent( - """ + self.dedent(""" from {module} import {obj} def foo() -> {obj}: return {obj}() - """ - ), + """), module=cst.Name("foo"), obj=cst.Name("Bar"), ) self.assertEqual( module.code, - self.dedent( - """ + self.dedent(""" from foo import Bar def foo() -> Bar: return Bar() - """ - ), + """), ) def test_simple_statement(self) -> None: diff --git a/libcst/matchers/__init__.py b/libcst/matchers/__init__.py index 2857fee1b..e71998c79 100644 --- a/libcst/matchers/__init__.py +++ b/libcst/matchers/__init__.py @@ -10,7 +10,6 @@ import libcst as cst from libcst.matchers._decorators import call_if_inside, call_if_not_inside, leave, visit - from libcst.matchers._matcher_base import ( AbstractBaseMatcherNodeMeta, AllOf, @@ -6218,13 +6217,13 @@ class If(BaseCompoundStatement, BaseStatement, BaseMatcherNode): @dataclass(frozen=True, eq=False, unsafe_hash=False) class IfExp(BaseExpression, BaseMatcherNode): - test: Union[ + body: Union[ BaseExpressionMatchType, DoNotCareSentinel, OneOf[BaseExpressionMatchType], AllOf[BaseExpressionMatchType], ] = DoNotCare() - body: Union[ + test: Union[ BaseExpressionMatchType, DoNotCareSentinel, OneOf[BaseExpressionMatchType], diff --git a/libcst/matchers/_return_types.py b/libcst/matchers/_return_types.py index 2f0500885..759894929 100644 --- a/libcst/matchers/_return_types.py +++ b/libcst/matchers/_return_types.py @@ -75,7 +75,6 @@ Yield, ) from libcst._nodes.module import Module - from libcst._nodes.op import ( Add, AddAssign, @@ -206,7 +205,6 @@ ) from libcst._removal_sentinel import RemovalSentinel - TYPED_FUNCTION_RETURN_MAPPING: TypingDict[Type[CSTNode], object] = { Add: BaseBinaryOp, AddAssign: BaseAugOp, diff --git a/libcst/matchers/tests/test_decorators.py b/libcst/matchers/tests/test_decorators.py index 8b28657c3..15f69ecd8 100644 --- a/libcst/matchers/tests/test_decorators.py +++ b/libcst/matchers/tests/test_decorators.py @@ -47,8 +47,7 @@ def leave_SimpleString( return updated_node # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -57,8 +56,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -82,8 +80,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self.func_visits.append(node.name.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -92,8 +89,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -118,8 +114,7 @@ def leave_SimpleString(self, original_node: cst.SimpleString) -> None: self.leaves.append(original_node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -128,8 +123,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -153,8 +147,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self.func_visits.append(node.name.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -163,8 +156,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -185,16 +177,14 @@ def visit_SimpleString(self, node: cst.SimpleString) -> None: self.visits.append(node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" def foo() -> None: return "foo" class A: def foo(self) -> None: return "baz" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -214,16 +204,14 @@ def visit_SimpleString(self, node: cst.SimpleString) -> None: self.visits.append(node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" def foo() -> None: return "foo" class A: def foo(self) -> None: return "baz" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -250,8 +238,7 @@ def leave_SimpleString( return updated_node # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -260,8 +247,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -285,8 +271,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self.func_visits.append(node.name.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -295,8 +280,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -321,8 +305,7 @@ def leave_SimpleString(self, original_node: cst.SimpleString) -> None: self.leaves.append(original_node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -331,8 +314,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -356,8 +338,7 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> None: self.func_visits.append(node.name.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -366,8 +347,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -397,8 +377,7 @@ def leave_function( return updated_node # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -410,8 +389,7 @@ def bar() -> None: def baz() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -436,8 +414,7 @@ def leave_function(self, original_node: cst.FunctionDef) -> None: self.leaves.append(original_node.name.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -449,8 +426,7 @@ def bar() -> None: def baz() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -480,8 +456,7 @@ def leave_function( return updated_node # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -493,8 +468,7 @@ def bar() -> None: def baz() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -521,8 +495,7 @@ def leave_function(self, original_node: cst.FunctionDef) -> None: self.leaves.append(original_node.name.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -534,8 +507,7 @@ def bar() -> None: def baz() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -575,8 +547,7 @@ def leave_function2( return updated_node # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -588,8 +559,7 @@ def bar() -> None: def baz() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -622,8 +592,7 @@ def leave_function2(self, original_node: cst.FunctionDef) -> None: self.leaves.add(original_node.name.value + "2") # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -635,8 +604,7 @@ def bar() -> None: def baz() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -679,8 +647,7 @@ def leave_string2( return updated_node # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -692,8 +659,7 @@ def bar() -> None: def baz() -> None: return "foobarbaz" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -732,8 +698,7 @@ def leave_string2(self, original_node: cst.SimpleString) -> None: self.leaves.add(literal_eval(original_node.value) + "2") # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -745,8 +710,7 @@ def bar() -> None: def baz() -> None: return "foobarbaz" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -786,8 +750,7 @@ def leave_SimpleString( ) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -799,12 +762,10 @@ def bar() -> None: def baz() -> None: return "foobarbaz" - """ - ) + """) visitor = TestVisitor() actual = module.visit(visitor) - expected = fixture( - """ + expected = fixture(""" a = "foo" b = "bar" @@ -816,8 +777,7 @@ def bar() -> None: def baz() -> None: return "foobarbaz" - """ - ) + """) self.assertTrue(expected.deep_equals(actual)) def test_call_if_inside_visitor_attribute(self) -> None: @@ -837,8 +797,7 @@ def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: self.leaves.append(node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -847,8 +806,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -873,8 +831,7 @@ def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: self.leaves.append(node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -883,8 +840,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -909,8 +865,7 @@ def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: self.leaves.append(node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -919,8 +874,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -945,8 +899,7 @@ def leave_SimpleString_lpar(self, node: cst.SimpleString) -> None: self.leaves.append(node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -955,8 +908,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -978,8 +930,7 @@ def visit_SimpleString(self, node: cst.SimpleString) -> None: self.visits.append(node.value) # Parse a module and verify we visited correctly. - module = fixture( - """ + module = fixture(""" a = "foo" b = "bar" @@ -988,8 +939,7 @@ def foo() -> None: def bar() -> None: return "foobar" - """ - ) + """) visitor = TestVisitor() module.visit(visitor) diff --git a/libcst/matchers/tests/test_matchers_with_metadata.py b/libcst/matchers/tests/test_matchers_with_metadata.py index 63530c37a..21fcc77c0 100644 --- a/libcst/matchers/tests/test_matchers_with_metadata.py +++ b/libcst/matchers/tests/test_matchers_with_metadata.py @@ -498,8 +498,7 @@ def visit_Name(self, node: cst.Name) -> None: ): self.match_names.add(node.value) - module = self._make_fixture( - """ + module = self._make_fixture(""" a = 1 + 2 b = 3 + 4 + d + e def foo() -> str: @@ -509,8 +508,7 @@ def bar() -> int: return b del foo del bar - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -539,8 +537,7 @@ def visit_Name(self, node: cst.Name) -> None: ): self.match_names.add(node.value) - module = self._make_fixture( - """ + module = self._make_fixture(""" a = 1 + 2 b = 3 + 4 + d + e def foo() -> str: @@ -550,8 +547,7 @@ def bar() -> int: return b del foo del bar - """ - ) + """) visitor = TestTransformer() module.visit(visitor) @@ -579,8 +575,7 @@ def _visit_assignments(self, node: cst.Name) -> None: # Only match name nodes that are being assigned to. self.match_names.add(node.value) - module = self._make_fixture( - """ + module = self._make_fixture(""" a = 1 + 2 b = 3 + 4 + d + e def foo() -> str: @@ -590,8 +585,7 @@ def bar() -> int: return b del foo del bar - """ - ) + """) visitor = TestVisitor() module.visit(visitor) @@ -619,8 +613,7 @@ def _visit_assignments(self, node: cst.Name) -> None: # Only match name nodes that are being assigned to. self.match_names.add(node.value) - module = self._make_fixture( - """ + module = self._make_fixture(""" a = 1 + 2 b = 3 + 4 + d + e def foo() -> str: @@ -630,8 +623,7 @@ def bar() -> int: return b del foo del bar - """ - ) + """) visitor = TestTransformer() module.visit(visitor) diff --git a/libcst/metadata/accessor_provider.py b/libcst/metadata/accessor_provider.py index 5d4f22e42..4563984bc 100644 --- a/libcst/metadata/accessor_provider.py +++ b/libcst/metadata/accessor_provider.py @@ -7,7 +7,6 @@ import dataclasses import libcst as cst - from libcst.metadata.base_provider import VisitorMetadataProvider diff --git a/libcst/metadata/tests/test_accessor_provider.py b/libcst/metadata/tests/test_accessor_provider.py index 6ccfad5ee..148ee9a73 100644 --- a/libcst/metadata/tests/test_accessor_provider.py +++ b/libcst/metadata/tests/test_accessor_provider.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. import dataclasses - from textwrap import dedent import libcst as cst diff --git a/libcst/metadata/tests/test_name_provider.py b/libcst/metadata/tests/test_name_provider.py index fbd3631af..f11432d2e 100644 --- a/libcst/metadata/tests/test_name_provider.py +++ b/libcst/metadata/tests/test_name_provider.py @@ -67,22 +67,18 @@ def get_fully_qualified_names(file_path: str, module_str: str) -> Set[QualifiedN class QualifiedNameProviderTest(UnitTest): def test_imports(self) -> None: - qnames = get_qualified_names( - """ + qnames = get_qualified_names(""" from a.b import c as d d - """ - ) + """) self.assertEqual({"a.b.c"}, {qname.name for qname in qnames}) for qname in qnames: self.assertEqual(qname.source, QualifiedNameSource.IMPORT, msg=f"{qname}") def test_builtins(self) -> None: - qnames = get_qualified_names( - """ + qnames = get_qualified_names(""" int(None) - """ - ) + """) self.assertEqual( {"builtins.int", "builtins.None"}, {qname.name for qname in qnames} ) @@ -90,19 +86,16 @@ def test_builtins(self) -> None: self.assertEqual(qname.source, QualifiedNameSource.BUILTIN, msg=f"{qname}") def test_locals(self) -> None: - qnames = get_qualified_names( - """ + qnames = get_qualified_names(""" class X: a: "X" - """ - ) + """) self.assertEqual({"X", "X.a"}, {qname.name for qname in qnames}) for qname in qnames: self.assertEqual(qname.source, QualifiedNameSource.LOCAL, msg=f"{qname}") def test_simple_qualified_names(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" from a.b import c class Cls: def f(self) -> "c": @@ -112,8 +105,7 @@ def f(self) -> "c": def g(): pass g() - """ - ) + """) cls = ensure_type(m.body[1], cst.ClassDef) f = ensure_type(cls.body.body[0], cst.FunctionDef) self.assertEqual( @@ -159,8 +151,7 @@ def g(): ) def test_nested_qualified_names(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" class A: def f1(self): def f2(): @@ -177,8 +168,7 @@ class C: pass C() f5() - """ - ) + """) cls_a = ensure_type(m.body[0], cst.ClassDef) self.assertEqual(names[cls_a], {QualifiedName("A", QualifiedNameSource.LOCAL)}) @@ -218,15 +208,13 @@ class C: ) def test_multiple_assignments(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" if 1: from a import b as c elif 2: from d import e as c c() - """ - ) + """) call = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr ).value @@ -239,8 +227,7 @@ def test_multiple_assignments(self) -> None: ) def test_comprehension(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" class C: def fn(self) -> None: [[k for k in i] for i in [j for j in range(10)]] @@ -250,8 +237,7 @@ def fn(self) -> None: # so j has qualified name "C.fn...j". # ListComp k is evaluated inside ListComp i. # so k has qualified name "C.fn....k". - """ - ) + """) cls_def = ensure_type(m.body[0], cst.ClassDef) fn_def = ensure_type(cls_def.body.body[0], cst.FunctionDef) outer_comp = ensure_type( @@ -320,12 +306,10 @@ def visit_Call(self, node: cst.Call) -> Optional[bool]: MetadataWrapper(cst.parse_module("import a;a.b.c()")).visit(TestVisitor(self)) def test_name_in_attribute(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" obj = object() obj.eval - """ - ) + """) attr = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr @@ -340,13 +324,11 @@ def test_name_in_attribute(self) -> None: self.assertEqual(names[eval], set()) def test_repeated_values_in_qualified_name(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" import a class Foo: bar: a.aa.aaa - """ - ) + """) foo = ensure_type(m.body[1], cst.ClassDef) bar = ensure_type( ensure_type( @@ -364,8 +346,7 @@ class Foo: ) def test_multiple_qualified_names(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" if False: def f(): pass elif False: @@ -375,8 +356,7 @@ def f(): pass import a.b as f f() - """ - ) + """) if_ = ensure_type(m.body[0], cst.If) first_f = ensure_type(if_.body.body[0], cst.FunctionDef) second_f_alias = ensure_type( @@ -431,16 +411,14 @@ def f(): pass ) def test_shadowed_assignments(self) -> None: - m, names = get_qualified_name_metadata_provider( - """ + m, names = get_qualified_name_metadata_provider(""" from lib import a,b,c a = a class Test: b = b def func(): c = c - """ - ) + """) # pyre-fixme[53]: Captured variable `names` is not annotated. def test_name(node: cst.CSTNode, qnames: Set[QualifiedName]) -> None: @@ -568,7 +546,7 @@ def test_with_full_repo_manager(self) -> None: mgr = FullRepoManager(root, [file_path_str], [FullyQualifiedNameProvider]) wrapper = mgr.get_metadata_wrapper_for_path(file_path_str) fqnames = wrapper.resolve(FullyQualifiedNameProvider) - (mod, names) = next(iter(fqnames.items())) + mod, names = next(iter(fqnames.items())) self.assertIsInstance(mod, cst.Module) self.assertEqual( names, {QualifiedName(name="pkg.mod", source=QualifiedNameSource.LOCAL)} diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index a367de39a..045508425 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -50,25 +50,21 @@ def get_scope_metadata_provider( class ScopeProviderTest(UnitTest): def test_not_in_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" pass - """ - ) + """) global_scope = scopes[m] self.assertEqual(global_scope["not_in_scope"], set()) def test_accesses(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" foo = 'toplevel' fn1(foo) fn2(foo) def fn_def(): foo = 'shadow' fn3(foo) - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) global_foo_assignments = list(scope_of_module["foo"]) @@ -118,12 +114,10 @@ def fn_def(): wrapper.visit(DependentVisitor()) def test_fstring_accesses(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" from a import b f"{b}" "hello" - """ - ) + """) global_scope = scopes[m] self.assertIsInstance(global_scope, GlobalScope) global_accesses = list(global_scope.accesses) @@ -138,12 +132,10 @@ def test_fstring_accesses(self) -> None: @data_provider((("any",), ("True",), ("Exception",), ("__name__",))) def test_builtins(self, builtin: str) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def fn(): pass - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertEqual(len(scope_of_module[builtin]), 1) @@ -162,14 +154,12 @@ def fn(): self.assertEqual(len(scope_of_func_statement["something_not_a_builtin"]), 0) def test_import(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import foo.bar import fizz.buzz as fizzbuzz import a.b.c import d.e.f as g - """ - ) + """) scope_of_module = scopes[m] import_0 = cst.ensure_type( @@ -221,12 +211,10 @@ def test_import(self) -> None: ) def test_dotted_import_access(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import a.b.c, x.y a.b.c(x.z) - """ - ) + """) scope_of_module = scopes[m] first_statement = ensure_type(m.body[1], cst.SimpleStatementLine) call = ensure_type( @@ -253,12 +241,10 @@ def test_dotted_import_access(self) -> None: self.assertEqual(scope_of_module.accesses["x.y"], set()) def test_dotted_import_access_reference_by_node(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import a.b.c a.b.c() - """ - ) + """) scope_of_module = scopes[m] first_statement = ensure_type(m.body[1], cst.SimpleStatementLine) call = ensure_type( @@ -271,15 +257,13 @@ def test_dotted_import_access_reference_by_node(self) -> None: self.assertEqual(a_b_c_access.node, call.func) def test_decorator_access_reference_by_node(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import decorator @decorator def f(): pass - """ - ) + """) scope_of_module = scopes[m] function_def = ensure_type(m.body[1], cst.FunctionDef) decorator = function_def.decorators[0] @@ -292,12 +276,10 @@ def f(): self.assertEqual(scope_of_module.accesses[decorator], {decorator_access}) def test_dotted_import_with_call_access(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import os.path os.path.join("A", "B").lower() - """ - ) + """) scope_of_module = scopes[m] first_statement = ensure_type(m.body[1], cst.SimpleStatementLine) attr = ensure_type( @@ -327,13 +309,11 @@ def test_dotted_import_with_call_access(self) -> None: self.assertEqual(os_path_join_access.node, attr) def test_import_from(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" from foo.bar import a, b as b_renamed from . import c from .foo import d - """ - ) + """) scope_of_module = scopes[m] import_from = cst.ensure_type( @@ -391,13 +371,11 @@ def test_import_from(self) -> None: ) def test_function_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" global_var = None def foo(arg, **kwargs): local_var = 5 - """ - ) + """) scope_of_module = scopes[m] func_def = ensure_type(m.body[1], cst.FunctionDef) self.assertEqual(scopes[func_def], scopes[func_def.name]) @@ -414,16 +392,14 @@ def foo(arg, **kwargs): self.assertTrue("local_var" in scope_of_func) def test_class_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" global_var = None @cls_attr class Cls(cls_attr, kwarg=cls_attr): cls_attr = 5 def f(): pass - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) cls_assignments = list(scope_of_module["Cls"]) @@ -451,16 +427,14 @@ def f(): self.assertTrue("cls_attr" not in scope_of_func) def test_comprehension_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" iterator = None condition = None [elt for target in iterator if condition] {elt for target in iterator if condition} {elt: target for target in iterator if condition} (elt for target in iterator if condition) - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) @@ -525,11 +499,9 @@ def test_comprehension_scope(self) -> None: self.assertTrue("target" in scope_of_generator_expr) def test_nested_comprehension_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" [y for x in iterator for y in x] - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) @@ -666,13 +638,11 @@ def inner_f(): ) def test_local_scope_shadowing_with_functions(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def f(): def f(): f = ... - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertTrue("f" in scope_of_module) @@ -692,13 +662,11 @@ def f(): self.assertEqual(cast(Assignment, inner_f_assignment).node, inner_f) def test_func_param_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" @decorator def f(x: T=1, *vararg, y: T=2, z, **kwarg) -> RET: pass - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertTrue("f" in scope_of_module) @@ -749,11 +717,9 @@ def f(x: T=1, *vararg, y: T=2, z, **kwarg) -> RET: self.assertEqual(cast(Assignment, list(scope_of_f["kwarg"])[0]).node, kwarg) def test_lambda_param_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" lambda x=1, *vararg, y=2, z, **kwarg:x - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) @@ -806,14 +772,12 @@ def test_except_handler(self) -> None: See https://docs.python.org/3.4/reference/compound_stmts.html#except We don't create a new block for except body because we don't handle del in our Scope Analysis. """ - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" try: ... except Exception as ex: ... - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertTrue("ex" in scope_of_module) @@ -825,12 +789,10 @@ def test_except_handler(self) -> None: ) def test_with_asname(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" with open(file_name) as f: ... - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertTrue("f" in scope_of_module) @@ -842,8 +804,7 @@ def test_with_asname(self) -> None: ) def test_get_qualified_names_for(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" from a.b import c class Cls: def f(self) -> "c": @@ -853,8 +814,7 @@ def f(self) -> "c": def g(): pass g() - """ - ) + """) cls = ensure_type(m.body[1], cst.ClassDef) f = ensure_type(cls.body.body[0], cst.FunctionDef) scope_of_module = scopes[m] @@ -925,8 +885,7 @@ def g(): ) def test_get_qualified_names_for_nested_cases(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" class A: def f1(self): def f2(): @@ -943,8 +902,7 @@ class C: pass C() f5() - """ - ) + """) cls_a = ensure_type(m.body[0], cst.ClassDef) func_f1 = ensure_type(cls_a.body.body[0], cst.FunctionDef) scope_of_cls_a = scopes[func_f1] @@ -989,12 +947,10 @@ class C: ) def test_get_qualified_names_for_the_same_prefix(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" from a import b, bc bc() - """ - ) + """) call = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr @@ -1008,12 +964,10 @@ def test_get_qualified_names_for_the_same_prefix(self) -> None: ) def test_get_qualified_names_for_dotted_imports(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import a.b.c a(a.b.d) - """ - ) + """) call = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr @@ -1050,15 +1004,13 @@ def test_get_qualified_names_for_dotted_imports(self) -> None: ) def test_multiple_assignments(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" if 1: from a import b as c elif 2: from d import e as c c() - """ - ) + """) call = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Expr ).value @@ -1080,8 +1032,7 @@ def test_multiple_assignments(self) -> None: ) def test_assignments_and_accesses(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" a = 1 def f(): a = 2 @@ -1089,8 +1040,7 @@ def f(): def g(): b = a a - """ - ) + """) a_outer_assign = ( ensure_type( ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.Assign @@ -1176,8 +1126,7 @@ def g(): self.assertEqual(len(set(scopes.values())), 3) def test_annotation_access(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" from typing import Literal, NewType, Optional, TypeVar, Callable, cast from a import A, B, C, D, D2, E, E2, F, G, G2, H, I, J, K, K2, L, M def x(a: A): @@ -1197,8 +1146,7 @@ class Test(Generic[J]): pass castedK = cast("K", "K2") castedL = cast("L", M) - """ - ) + """) imp = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom ) @@ -1321,13 +1269,11 @@ class Test(Generic[J]): references = list(assignment.references) def test_insane_annotation_access(self) -> None: - m, scopes = get_scope_metadata_provider( - r""" + m, scopes = get_scope_metadata_provider(r""" from typing import TypeVar, Optional from a import G TypeVar("G2", bound="Optional[\"G\"]") - """ - ) + """) imp = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.ImportFrom ) @@ -1345,13 +1291,11 @@ def test_insane_annotation_access(self) -> None: self.assertEqual(list(assignment.references)[0].node, bound) def test_dotted_annotation_access(self) -> None: - m, scopes = get_scope_metadata_provider( - r""" + m, scopes = get_scope_metadata_provider(r""" from typing import TypeVar import a.G TypeVar("G2", bound="a.G") - """ - ) + """) imp = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Import ) @@ -1369,15 +1313,13 @@ def test_dotted_annotation_access(self) -> None: self.assertEqual(list(assignment.references)[0].node, bound) def test_node_of_scopes(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def f1(): target() class C: attr = target() - """ - ) + """) f1 = ensure_type(m.body[0], cst.FunctionDef) target_call = ensure_type( ensure_type(f1.body.body[0], cst.SimpleStatementLine).body[0], cst.Expr @@ -1394,16 +1336,14 @@ class C: self.assertEqual(cast(ClassScope, c_scope).node, c) def test_with_statement(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import unittest.mock with unittest.mock.patch("something") as obj: obj.f1() unittest.mock - """ - ) + """) import_ = ensure_type(m.body[0], cst.SimpleStatementLine).body[0] assignments = scopes[import_]["unittest"] self.assertEqual(len(assignments), 1) @@ -1427,15 +1367,13 @@ def test_with_statement(self) -> None: ) def test_del_context_names(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import a dic = {} del dic del dic["key"] del a.b - """ - ) + """) dic = ensure_type( ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign @@ -1530,12 +1468,10 @@ def test_self(self) -> None: get_scope_metadata_provider(f.read()) def test_get_qualified_names_for_is_read_only(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" import a import b - """ - ) + """) a = m.body[0] scope = scopes[a] assignments_before = list(scope.assignments) @@ -1576,15 +1512,13 @@ def test_gen_dotted_names(self) -> None: self.assertEqual(names, {"a.b.c", "a.b", "a"}) def test_ordering(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" from a import b class X: x = b b = b y = b - """ - ) + """) global_scope = scopes[m] import_stmt = ensure_type( ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom @@ -1625,15 +1559,13 @@ class X: self.assertIn(y.value, [access.node for access in class_accesses]) def test_ordering_between_scopes(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def f(a): print(a) print(b) a = 1 b = 1 - """ - ) + """) f = cst.ensure_type(m.body[0], cst.FunctionDef) a_param = f.params.params[0].name a_param_assignment = list(scopes[a_param]["a"])[0] @@ -1679,15 +1611,13 @@ def f(a): self.assertEqual(b_global_refs[0].node, second_print.args[0].value) def test_ordering_comprehension(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def f(a): [a for a in [] for b in a] [b for a in [] for b in a] [a for a in [] for a in []] a = 1 - """ - ) + """) f = cst.ensure_type(m.body[0], cst.FunctionDef) a_param = f.params.params[0].name a_param_assignment = list(scopes[a_param]["a"])[0] @@ -1756,13 +1686,11 @@ def f(a): self.assertEqual(a_global_refs, []) def test_ordering_comprehension_confusing(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def f(a): [a for a in a] a = 1 - """ - ) + """) f = cst.ensure_type(m.body[0], cst.FunctionDef) a_param = f.params.params[0].name a_param_assignment = list(scopes[a_param]["a"])[0] @@ -1781,8 +1709,7 @@ def f(a): self.assertEqual(list(a_comp_assignment.references)[0].node, comp.elt) def test_for_scope_ordering(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def f(): for x in []: x @@ -1790,8 +1717,7 @@ class X: def f(): for x in []: x - """ - ) + """) for scope in scopes.values(): for acc in scope.accesses: self.assertEqual( @@ -1804,12 +1730,10 @@ def f(): ) def test_no_out_of_order_references_in_global_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" x = y y = 1 - """ - ) + """) for scope in scopes.values(): for acc in scope.accesses: self.assertEqual( @@ -1824,13 +1748,11 @@ def test_no_out_of_order_references_in_global_scope(self) -> None: def test_walrus_accesses(self) -> None: if sys.version_info < (3, 8): self.skipTest("This python version doesn't support :=") - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" if x := y: y = 1 x - """ - ) + """) for scope in scopes.values(): for acc in scope.accesses: self.assertEqual( @@ -1925,13 +1847,11 @@ def test_parse_string_annotations( parse_mock.assert_has_calls(calls) def test_builtin_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" a = pow(1, 2) def foo(): b = pow(2, 3) - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertEqual(len(scope_of_module["pow"]), 1) @@ -1957,16 +1877,14 @@ def foo(): self.assertEqual(len(builtin_pow_accesses), 2) def test_override_builtin_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def pow(x, y): return x ** y a = pow(1, 2) def foo(): b = pow(2, 3) - """ - ) + """) scope_of_module = scopes[m] self.assertIsInstance(scope_of_module, GlobalScope) self.assertEqual(len(scope_of_module["pow"]), 1) @@ -1992,13 +1910,11 @@ def foo(): self.assertEqual(len(global_pow_accesses), 2) def test_annotation_access_in_typevar_bound(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" from typing import TypeVar class Test: var: TypeVar("T", bound="Test") - """ - ) + """) imp = ensure_type( ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.ImportFrom ) @@ -2011,12 +1927,10 @@ class Test: def test_prefix_match(self) -> None: """Verify that a name doesn't overmatch on prefix""" - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" def something(): ... - """ - ) + """) scope = scopes[m] self.assertEqual( scope.get_qualified_names_for(cst.Name("something")), @@ -2028,12 +1942,10 @@ def something(): ) def test_type_alias_scope(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" type A = C lol: A - """ - ) + """) alias = ensure_type( ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias ) @@ -2049,13 +1961,11 @@ def test_type_alias_scope(self) -> None: self.assertIsInstance(scopes[alias.value], AnnotationScope) def test_type_alias_param(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" B = int type A[T: B] = T lol: T - """ - ) + """) alias = ensure_type( ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.TypeAlias ) @@ -2079,14 +1989,12 @@ def test_type_alias_param(self) -> None: ) def test_type_alias_tuple_and_paramspec(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" type A[*T] = T lol: T type A[**T] = T lol: T - """ - ) + """) alias_tuple = ensure_type( ensure_type(m.body[0], cst.SimpleStatementLine).body[0], cst.TypeAlias ) @@ -2106,13 +2014,11 @@ def test_type_alias_tuple_and_paramspec(self) -> None: self.assertEqual(t_refs[0].node, alias_paramspec.value) def test_class_type_params(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" class W[T]: def f() -> T: pass def g[T]() -> T: pass - """ - ) + """) cls = ensure_type(m.body[0], cst.ClassDef) cls_scope = scopes[cls.body.body[0]] self.assertEqual(len(t_assignments_in_cls := list(cls_scope["T"])), 1) @@ -2140,12 +2046,10 @@ def g[T]() -> T: pass self.assertEqual(t_refs_in_g[0].node, g.returns.annotation) def test_nested_class_type_params(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" class Outer: class Nested[T: Outer]: pass - """ - ) + """) outer = ensure_type(m.body[0], cst.ClassDef) outer_refs = list(list(scopes[outer]["Outer"])[0].references) self.assertEqual(len(outer_refs), 1) @@ -2157,8 +2061,7 @@ class Nested[T: Outer]: pass ) def test_annotation_refers_to_nested_class(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" class Outer: class Nested: pass @@ -2167,8 +2070,7 @@ class Nested: def meth1[T: Nested](self): pass def meth2[T](self, arg: Nested): pass - """ - ) + """) outer = ensure_type(m.body[0], cst.ClassDef) nested = ensure_type(outer.body.body[0], cst.ClassDef) alias = ensure_type( @@ -2216,13 +2118,11 @@ def meth2[T](self, arg: Nested): pass ) def test_body_isnt_subject_to_special_annotation_rule(self) -> None: - m, scopes = get_scope_metadata_provider( - """ + m, scopes = get_scope_metadata_provider(""" class Outer: class Inner: pass def f[T: Inner](self): Inner - """ - ) + """) outer = ensure_type(m.body[0], cst.ClassDef) # note: this is different from global scope outer_scope = scopes[outer.body.body[0]] diff --git a/libcst/tests/__main__.py b/libcst/tests/__main__.py index df28d1a6d..edc3cbd03 100644 --- a/libcst/tests/__main__.py +++ b/libcst/tests/__main__.py @@ -5,6 +5,5 @@ from unittest import main - if __name__ == "__main__": main(module=None, verbosity=2) diff --git a/libcst/tests/test_add_slots.py b/libcst/tests/test_add_slots.py index e354f60b6..5dd1d30d4 100644 --- a/libcst/tests/test_add_slots.py +++ b/libcst/tests/test_add_slots.py @@ -8,7 +8,6 @@ from typing import ClassVar from libcst._add_slots import add_slots - from libcst.testing.utils import UnitTest diff --git a/libcst/tests/test_exceptions.py b/libcst/tests/test_exceptions.py index f54f1da64..2f4358d4c 100644 --- a/libcst/tests/test_exceptions.py +++ b/libcst/tests/test_exceptions.py @@ -18,29 +18,25 @@ class ExceptionsTest(UnitTest): cst.ParserSyntaxError( "some message", lines=["abcd"], raw_line=1, raw_column=0 ), - dedent( - """ + dedent(""" Syntax Error @ 1:1. some message abcd ^ - """ - ).strip(), + """).strip(), ), "tab_expansion": ( cst.ParserSyntaxError( "some message", lines=["\tabcd\r\n"], raw_line=1, raw_column=2 ), - dedent( - """ + dedent(""" Syntax Error @ 1:10. some message abcd ^ - """ - ).strip(), + """).strip(), ), "shows_last_line_with_text": ( cst.ParserSyntaxError( @@ -49,15 +45,13 @@ class ExceptionsTest(UnitTest): raw_line=5, raw_column=0, ), - dedent( - """ + dedent(""" Syntax Error @ 5:1. some message efgh ^ - """ - ).strip(), + """).strip(), ), "empty_file": ( cst.ParserSyntaxError( diff --git a/libcst/tests/test_roundtrip.py b/libcst/tests/test_roundtrip.py index 96d1e5075..2a5b424eb 100644 --- a/libcst/tests/test_roundtrip.py +++ b/libcst/tests/test_roundtrip.py @@ -9,7 +9,6 @@ from libcst import CSTTransformer, parse_module - fixtures: Path = Path(__file__).parent.parent.parent / "native/libcst/tests/fixtures"