From 093acc994b5428e0c4c0f3e7dc9bf896fa5f6dc8 Mon Sep 17 00:00:00 2001 From: Ray Zeng Date: Thu, 25 Jul 2019 20:40:41 -0700 Subject: [PATCH] Calculate syntactic position for statement nodes (2) Calculates positions for try, except, finally and import statements --- libcst/_nodes/_statement.py | 131 +++++---- libcst/_nodes/tests/test_import.py | 453 +++++++++++++++-------------- libcst/_nodes/tests/test_try.py | 202 ++++++------- 3 files changed, 406 insertions(+), 380 deletions(-) diff --git a/libcst/_nodes/_statement.py b/libcst/_nodes/_statement.py index 343b6b3b..776afca8 100644 --- a/libcst/_nodes/_statement.py +++ b/libcst/_nodes/_statement.py @@ -696,17 +696,19 @@ class ExceptHandler(CSTNode): for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() - state.add_token("except") - self.whitespace_after_except._codegen(state) - typenode = self.type - if typenode is not None: - typenode._codegen(state) - namenode = self.name - if namenode is not None: - namenode._codegen(state) - self.whitespace_before_colon._codegen(state) - state.add_token(":") - self.body._codegen(state) + + with state.record_syntactic_position(self, end_node=self.body): + state.add_token("except") + self.whitespace_after_except._codegen(state) + typenode = self.type + if typenode is not None: + typenode._codegen(state) + namenode = self.name + if namenode is not None: + namenode._codegen(state) + self.whitespace_before_colon._codegen(state) + state.add_token(":") + self.body._codegen(state) @add_slots @@ -733,10 +735,12 @@ class Finally(CSTNode): for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() - state.add_token("finally") - self.whitespace_before_colon._codegen(state) - state.add_token(":") - self.body._codegen(state) + + with state.record_syntactic_position(self, end_node=self.body): + state.add_token("finally") + self.whitespace_before_colon._codegen(state) + state.add_token(":") + self.body._codegen(state) @add_slots @@ -789,18 +793,25 @@ class Try(BaseCompoundStatement): for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() - state.add_token("try") - self.whitespace_before_colon._codegen(state) - state.add_token(":") - self.body._codegen(state) - for handler in self.handlers: - handler._codegen(state) + + end_node = self.body + if len(self.handlers) > 0: + end_node = self.handlers[-1] orelse = self.orelse - if orelse is not None: - orelse._codegen(state) + end_node = end_node if orelse is None else orelse finalbody = self.finalbody - if finalbody is not None: - finalbody._codegen(state) + end_node = end_node if finalbody is None else finalbody + with state.record_syntactic_position(self, end_node=end_node): + state.add_token("try") + self.whitespace_before_colon._codegen(state) + state.add_token(":") + self.body._codegen(state) + for handler in self.handlers: + handler._codegen(state) + if orelse is not None: + orelse._codegen(state) + if finalbody is not None: + finalbody._codegen(state) @dataclass(frozen=True) @@ -834,10 +845,12 @@ class ImportAlias(CSTNode): ) def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None: - self.name._codegen(state) - asname = self.asname - if asname is not None: - asname._codegen(state) + with state.record_syntactic_position(self): + self.name._codegen(state) + asname = self.asname + if asname is not None: + asname._codegen(state) + comma = self.comma if comma is MaybeSentinel.DEFAULT and default_comma: state.add_token(", ") @@ -884,11 +897,13 @@ class Import(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - state.add_token("import") - self.whitespace_after_import._codegen(state) - lastname = len(self.names) - 1 - for i, name in enumerate(self.names): - name._codegen(state, default_comma=(i != lastname)) + with state.record_syntactic_position(self): + state.add_token("import") + self.whitespace_after_import._codegen(state) + lastname = len(self.names) - 1 + for i, name in enumerate(self.names): + name._codegen(state, default_comma=(i != lastname)) + semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): if default_semicolon: @@ -991,28 +1006,32 @@ class ImportFrom(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - state.add_token("from") - self.whitespace_after_from._codegen(state) - for dot in self.relative: - dot._codegen(state) - module = self.module - if module is not None: - module._codegen(state) - self.whitespace_before_import._codegen(state) - state.add_token("import") - self.whitespace_after_import._codegen(state) - lpar = self.lpar - if lpar is not None: - lpar._codegen(state) - if isinstance(self.names, Sequence): - lastname = len(self.names) - 1 - for i, name in enumerate(self.names): - name._codegen(state, default_comma=(i != lastname)) - if isinstance(self.names, ImportStar): - self.names._codegen(state) - rpar = self.rpar - if rpar is not None: - rpar._codegen(state) + end_node = self.names[-1] if isinstance(self.names, Sequence) else self.names + end_node = end_node if self.rpar is None else self.rpar + with state.record_syntactic_position(self, end_node=end_node): + state.add_token("from") + self.whitespace_after_from._codegen(state) + for dot in self.relative: + dot._codegen(state) + module = self.module + if module is not None: + module._codegen(state) + self.whitespace_before_import._codegen(state) + state.add_token("import") + self.whitespace_after_import._codegen(state) + lpar = self.lpar + if lpar is not None: + lpar._codegen(state) + if isinstance(self.names, Sequence): + lastname = len(self.names) - 1 + for i, name in enumerate(self.names): + name._codegen(state, default_comma=(i != lastname)) + if isinstance(self.names, ImportStar): + self.names._codegen(state) + rpar = self.rpar + if rpar is not None: + rpar._codegen(state) + semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): if default_semicolon: diff --git a/libcst/_nodes/tests/test_import.py b/libcst/_nodes/tests/test_import.py index 71e10b34..d0e52ad0 100644 --- a/libcst/_nodes/tests/test_import.py +++ b/libcst/_nodes/tests/test_import.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Callable, Optional +from typing import Any import libcst as cst from libcst import parse_statement @@ -18,30 +18,34 @@ class ImportCreateTest(CSTNodeTest): @data_provider( ( # Simple import statement - (cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "import foo"), - ( - cst.Import( + # pyre-fixme[6]: Incompatible parameter type + { + "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), + "code": "import foo", + }, + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), - "import foo.bar", - ), - ( - cst.Import( + "code": "import foo.bar", + }, + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), - "import foo.bar", - ), + "code": "import foo.bar", + }, # Comma-separated list of imports - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) @@ -51,11 +55,12 @@ class ImportCreateTest(CSTNodeTest): ), ) ), - "import foo.bar, foo.baz", - ), + "code": "import foo.bar, foo.baz", + "expected_position": CodeRange.create((1, 0), (1, 23)), + }, # Import with an alias - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -63,11 +68,11 @@ class ImportCreateTest(CSTNodeTest): ), ) ), - "import foo.bar as baz", - ), + "code": "import foo.bar as baz", + }, # Import with an alias, comma separated - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -79,11 +84,11 @@ class ImportCreateTest(CSTNodeTest): ), ) ), - "import foo.bar as baz, foo.baz as bar", - ), + "code": "import foo.bar as baz, foo.baz as bar", + }, # Combine for fun and profit - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -100,11 +105,11 @@ class ImportCreateTest(CSTNodeTest): ), ) ), - "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", - ), + "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", + }, # Verify whitespace works everywhere. - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( @@ -136,40 +141,42 @@ class ImportCreateTest(CSTNodeTest): ), whitespace_after_import=cst.SimpleWhitespace(" "), ), - "import foo . bar as baz , unittest as ut", - ), + "code": "import foo . bar as baz , unittest as ut", + "expected_position": CodeRange.create((1, 0), (1, 46)), + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: - self.validate_node(node, code, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) @data_provider( ( - (lambda: cst.Import(names=()), "at least one ImportAlias"), - ( - lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)), - "empty name identifier", - ), - ( - lambda: cst.Import( + { + "get_node": lambda: cst.Import(names=()), + "expected_re": "at least one ImportAlias", + }, + { + "get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)), + "expected_re": "empty name identifier", + }, + { + "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))), ) ), - "empty name identifier", - ), - ( - lambda: cst.Import( + "expected_re": "empty name identifier", + }, + { + "get_node": lambda: cst.Import( names=( cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))), ) ), - "empty name identifier", - ), - ( - lambda: cst.Import( + "expected_re": "empty name identifier", + }, + { + "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -177,10 +184,10 @@ class ImportCreateTest(CSTNodeTest): ), ) ), - "trailing comma", - ), - ( - lambda: cst.Import( + "expected_re": "trailing comma", + }, + { + "get_node": lambda: cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) @@ -188,44 +195,45 @@ class ImportCreateTest(CSTNodeTest): ), whitespace_after_import=cst.SimpleWhitespace(""), ), - "at least one space", - ), + "expected_re": "at least one space", + }, ) ) - def test_invalid( - self, get_node: Callable[[], cst.CSTNode], expected_re: str - ) -> None: - self.assert_invalid(get_node, expected_re) + def test_invalid(self, **kwargs: Any) -> None: + self.assert_invalid(**kwargs) class ImportParseTest(CSTNodeTest): @data_provider( ( # Simple import statement - (cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "import foo"), - ( - cst.Import( + { + "node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), + "code": "import foo", + }, + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), - "import foo.bar", - ), - ( - cst.Import( + "code": "import foo.bar", + }, + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")) ), ) ), - "import foo.bar", - ), + "code": "import foo.bar", + }, # Comma-separated list of imports - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -236,11 +244,11 @@ class ImportParseTest(CSTNodeTest): ), ) ), - "import foo.bar, foo.baz", - ), + "code": "import foo.bar, foo.baz", + }, # Import with an alias - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -248,11 +256,11 @@ class ImportParseTest(CSTNodeTest): ), ) ), - "import foo.bar as baz", - ), + "code": "import foo.bar as baz", + }, # Import with an alias, comma separated - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -265,11 +273,11 @@ class ImportParseTest(CSTNodeTest): ), ) ), - "import foo.bar as baz, foo.baz as bar", - ), + "code": "import foo.bar as baz, foo.baz as bar", + }, # Combine for fun and profit - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute(cst.Name("foo"), cst.Name("bar")), @@ -289,11 +297,11 @@ class ImportParseTest(CSTNodeTest): ), ) ), - "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", - ), + "code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut", + }, # Verify whitespace works everywhere. - ( - cst.Import( + { + "node": cst.Import( names=( cst.ImportAlias( cst.Attribute( @@ -325,20 +333,16 @@ class ImportParseTest(CSTNodeTest): ), whitespace_after_import=cst.SimpleWhitespace(" "), ), - "import foo . bar as baz , unittest as ut", - ), + "code": "import foo . bar as baz , unittest as ut", + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: + def test_valid(self, **kwargs: Any) -> None: self.validate_node( - node, - code, - lambda code: ensure_type( + parser=lambda code: ensure_type( parse_statement(code), cst.SimpleStatementLine ).body[0], - expected_position=position, + **kwargs, ) @@ -346,15 +350,16 @@ class ImportFromCreateTest(CSTNodeTest): @data_provider( ( # Simple from import statement - ( - cst.ImportFrom( + # pyre-fixme[6]: Incompatible parameter type + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) ), - "from foo import bar", - ), + "code": "from foo import bar", + }, # From import statement with alias - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( @@ -362,64 +367,66 @@ class ImportFromCreateTest(CSTNodeTest): ), ), ), - "from foo import bar as baz", - ), + "code": "from foo import bar as baz", + }, # Multiple imports - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias(cst.Name("bar")), cst.ImportAlias(cst.Name("baz")), ), ), - "from foo import bar, baz", - ), + "code": "from foo import bar, baz", + }, # Trailing comma - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()), cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), ), ), - "from foo import bar,baz,", - ), + "code": "from foo import bar,baz,", + "expected_position": CodeRange.create((1, 0), (1, 23)), + }, # Star import statement - ( - cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), - "from foo import *", - ), + { + "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), + "code": "from foo import *", + "expected_position": CodeRange.create((1, 0), (1, 17)), + }, # Simple relative import statement - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( relative=(cst.Dot(),), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), - "from .foo import bar", - ), - ( - cst.ImportFrom( + "code": "from .foo import bar", + }, + { + "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), - "from ..foo import bar", - ), + "code": "from ..foo import bar", + }, # Relative only import - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=None, names=(cst.ImportAlias(cst.Name("bar")),), ), - "from .. import bar", - ), + "code": "from .. import bar", + }, # Parenthesis - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), lpar=cst.LeftParen(), names=( @@ -429,11 +436,12 @@ class ImportFromCreateTest(CSTNodeTest): ), rpar=cst.RightParen(), ), - "from foo import (bar as baz)", - ), + "code": "from foo import (bar as baz)", + "expected_position": CodeRange.create((1, 0), (1, 28)), + }, # Verify whitespace works everywhere. - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( relative=( cst.Dot( whitespace_before=cst.SimpleWhitespace(" "), @@ -473,102 +481,99 @@ class ImportFromCreateTest(CSTNodeTest): whitespace_before_import=cst.SimpleWhitespace(" "), whitespace_after_import=cst.SimpleWhitespace(" "), ), - "from . . foo import ( bar as baz , unittest as ut )", - ), + "code": "from . . foo import ( bar as baz , unittest as ut )", + "expected_position": CodeRange.create((1, 0), (1, 61)), + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: - self.validate_node(node, code, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) @data_provider( ( - ( - lambda: cst.ImportFrom( + { + "get_node": lambda: cst.ImportFrom( module=None, names=(cst.ImportAlias(cst.Name("bar")),) ), - "Must have a module specified", - ), - ( - lambda: cst.ImportFrom(module=cst.Name("foo"), names=()), - "at least one ImportAlias", - ), - ( - lambda: cst.ImportFrom( + "expected_re": "Must have a module specified", + }, + { + "get_node": lambda: cst.ImportFrom(module=cst.Name("foo"), names=()), + "expected_re": "at least one ImportAlias", + }, + { + "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), lpar=cst.LeftParen(), ), - "left paren without right paren", - ), - ( - lambda: cst.ImportFrom( + "expected_re": "left paren without right paren", + }, + { + "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), rpar=cst.RightParen(), ), - "right paren without left paren", - ), - ( - lambda: cst.ImportFrom( + "expected_re": "right paren without left paren", + }, + { + "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=cst.ImportStar(), lpar=cst.LeftParen() ), - "cannot have parens", - ), - ( - lambda: cst.ImportFrom( + "expected_re": "cannot have parens", + }, + { + "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=cst.ImportStar(), rpar=cst.RightParen(), ), - "cannot have parens", - ), - ( - lambda: cst.ImportFrom( + "expected_re": "cannot have parens", + }, + { + "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_after_from=cst.SimpleWhitespace(""), ), - "one space after from", - ), - ( - lambda: cst.ImportFrom( + "expected_re": "one space after from", + }, + { + "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_before_import=cst.SimpleWhitespace(""), ), - "one space before import", - ), - ( - lambda: cst.ImportFrom( + "expected_re": "one space before import", + }, + { + "get_node": lambda: cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), whitespace_after_import=cst.SimpleWhitespace(""), ), - "one space after import", - ), + "expected_re": "one space after import", + }, ) ) - def test_invalid( - self, get_node: Callable[[], cst.CSTNode], expected_re: str - ) -> None: - self.assert_invalid(get_node, expected_re) + def test_invalid(self, **kwargs: Any) -> None: + self.assert_invalid(**kwargs) class ImportFromParseTest(CSTNodeTest): @data_provider( ( # Simple from import statement - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),) ), - "from foo import bar", - ), + "code": "from foo import bar", + }, # From import statement with alias - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( @@ -576,11 +581,11 @@ class ImportFromParseTest(CSTNodeTest): ), ), ), - "from foo import bar as baz", - ), + "code": "from foo import bar as baz", + }, # Multiple imports - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( @@ -590,11 +595,11 @@ class ImportFromParseTest(CSTNodeTest): cst.ImportAlias(cst.Name("baz")), ), ), - "from foo import bar, baz", - ), + "code": "from foo import bar, baz", + }, # Trailing comma - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), names=( cst.ImportAlias( @@ -604,42 +609,42 @@ class ImportFromParseTest(CSTNodeTest): cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()), ), ), - "from foo import bar, baz,", - ), + "code": "from foo import bar, baz,", + }, # Star import statement - ( - cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), - "from foo import *", - ), + { + "node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()), + "code": "from foo import *", + }, # Simple relative import statement - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( relative=(cst.Dot(),), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), - "from .foo import bar", - ), - ( - cst.ImportFrom( + "code": "from .foo import bar", + }, + { + "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),), ), - "from ..foo import bar", - ), + "code": "from ..foo import bar", + }, # Relative only import - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( relative=(cst.Dot(), cst.Dot()), module=None, names=(cst.ImportAlias(cst.Name("bar")),), ), - "from .. import bar", - ), + "code": "from .. import bar", + }, # Parenthesis - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( module=cst.Name("foo"), lpar=cst.LeftParen(), names=( @@ -649,11 +654,11 @@ class ImportFromParseTest(CSTNodeTest): ), rpar=cst.RightParen(), ), - "from foo import (bar as baz)", - ), + "code": "from foo import (bar as baz)", + }, # Verify whitespace works everywhere. - ( - cst.ImportFrom( + { + "node": cst.ImportFrom( relative=( cst.Dot( whitespace_before=cst.SimpleWhitespace(""), @@ -693,18 +698,14 @@ class ImportFromParseTest(CSTNodeTest): whitespace_before_import=cst.SimpleWhitespace(" "), whitespace_after_import=cst.SimpleWhitespace(" "), ), - "from . . foo import ( bar as baz , unittest as ut )", - ), + "code": "from . . foo import ( bar as baz , unittest as ut )", + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: + def test_valid(self, **kwargs: Any) -> None: self.validate_node( - node, - code, - lambda code: ensure_type( + parser=lambda code: ensure_type( parse_statement(code), cst.SimpleStatementLine ).body[0], - expected_position=position, + **kwargs, ) diff --git a/libcst/_nodes/tests/test_try.py b/libcst/_nodes/tests/test_try.py index a0342fce..be0809fa 100644 --- a/libcst/_nodes/tests/test_try.py +++ b/libcst/_nodes/tests/test_try.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Callable, Optional +from typing import Any import libcst as cst from libcst import parse_statement @@ -17,8 +17,9 @@ class TryTest(CSTNodeTest): @data_provider( ( # Simple try/except block - ( - cst.Try( + # pyre-fixme[6]: Incompatible parameter type + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -27,12 +28,13 @@ class TryTest(CSTNodeTest): ), ), ), - "try: pass\nexcept: pass\n", - parse_statement, - ), + "code": "try: pass\nexcept: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (2, 12)), + }, # Try/except with a class - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -41,12 +43,12 @@ class TryTest(CSTNodeTest): ), ), ), - "try: pass\nexcept Exception: pass\n", - parse_statement, - ), + "code": "try: pass\nexcept Exception: pass\n", + "parser": parse_statement, + }, # Try/except with a named class - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -56,12 +58,13 @@ class TryTest(CSTNodeTest): ), ), ), - "try: pass\nexcept Exception as exc: pass\n", - parse_statement, - ), + "code": "try: pass\nexcept Exception as exc: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (2, 29)), + }, # Try/except with multiple clauses - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -80,24 +83,26 @@ class TryTest(CSTNodeTest): ), ), ), - "try: pass\n" + "code": "try: pass\n" + "except TypeError as e: pass\n" + "except KeyError as e: pass\n" + "except: pass\n", - parse_statement, - ), + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (4, 12)), + }, # Simple try/finally block - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), - "try: pass\nfinally: pass\n", - parse_statement, - ), + "code": "try: pass\nfinally: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (2, 13)), + }, # Simple try/except/finally block - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -107,12 +112,13 @@ class TryTest(CSTNodeTest): ), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), - "try: pass\nexcept: pass\nfinally: pass\n", - parse_statement, - ), + "code": "try: pass\nexcept: pass\nfinally: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (3, 13)), + }, # Simple try/except/else block - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -122,12 +128,13 @@ class TryTest(CSTNodeTest): ), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), ), - "try: pass\nexcept: pass\nelse: pass\n", - parse_statement, - ), + "code": "try: pass\nexcept: pass\nelse: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (3, 10)), + }, # Simple try/except/else block/finally - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -138,12 +145,13 @@ class TryTest(CSTNodeTest): orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), - "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n", - parse_statement, - ), + "code": "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (4, 13)), + }, # Verify whitespace in various locations - ( - cst.Try( + { + "node": cst.Try( leading_lines=(cst.EmptyLine(comment=cst.Comment("# 1")),), body=cst.SimpleStatementSuite((cst.Pass(),)), handlers=( @@ -172,12 +180,13 @@ class TryTest(CSTNodeTest): ), whitespace_before_colon=cst.SimpleWhitespace(" "), ), - "# 1\ntry : pass\n# 2\nexcept TypeError as e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n", - parse_statement, - ), + "code": "# 1\ntry : pass\n# 2\nexcept TypeError as e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((2, 0), (8, 14)), + }, # Please don't write code like this - ( - cst.Try( + { + "node": cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), handlers=( cst.ExceptHandler( @@ -198,17 +207,18 @@ class TryTest(CSTNodeTest): orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), - "try: pass\n" + "code": "try: pass\n" + "except TypeError as e: pass\n" + "except KeyError as e: pass\n" + "except: pass\n" + "else: pass\n" + "finally: pass\n", - parse_statement, - ), + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (6, 13)), + }, # Verify indentation - ( - DummyIndentedBlock( + { + "node": DummyIndentedBlock( " ", cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), @@ -232,17 +242,17 @@ class TryTest(CSTNodeTest): finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), ), - " try: pass\n" + "code": " try: pass\n" + " except TypeError as e: pass\n" + " except KeyError as e: pass\n" + " except: pass\n" + " else: pass\n" + " finally: pass\n", - None, - ), + "parser": None, + }, # Verify indentation in bodies - ( - DummyIndentedBlock( + { + "node": DummyIndentedBlock( " ", cst.Try( cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), @@ -262,7 +272,7 @@ class TryTest(CSTNodeTest): ), ), ), - " try:\n" + "code": " try:\n" + " pass\n" + " except:\n" + " pass\n" @@ -270,64 +280,60 @@ class TryTest(CSTNodeTest): + " pass\n" + " finally:\n" + " pass\n", - None, - ), + "parser": None, + }, ) ) - def test_valid( - self, - node: cst.CSTNode, - code: str, - parser: Optional[Callable[[str], cst.CSTNode]], - position: Optional[CodeRange] = None, - ) -> None: - self.validate_node(node, code, parser, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) @data_provider( ( - (lambda: cst.AsName(cst.Name("")), "empty name identifier"), - ( - lambda: cst.AsName( + # pyre-fixme[6]: Incompatible parameter type + { + "get_node": lambda: cst.AsName(cst.Name("")), + "expected_re": "empty name identifier", + }, + { + "get_node": lambda: cst.AsName( cst.Name("bla"), whitespace_after_as=cst.SimpleWhitespace("") ), - "between 'as'", - ), - ( - lambda: cst.AsName( + "expected_re": "between 'as'", + }, + { + "get_node": lambda: cst.AsName( cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("") ), - "before 'as'", - ), - ( - lambda: cst.ExceptHandler( + "expected_re": "before 'as'", + }, + { + "get_node": lambda: cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), name=cst.AsName(cst.Name("bla")), ), - "name for an empty type", - ), - ( - lambda: cst.ExceptHandler( + "expected_re": "name for an empty type", + }, + { + "get_node": lambda: cst.ExceptHandler( cst.SimpleStatementSuite((cst.Pass(),)), type=cst.Name("TypeError"), whitespace_after_except=cst.SimpleWhitespace(""), ), - "at least one space after except", - ), - ( - lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))), - "at least one ExceptHandler or Finally", - ), - ( - lambda: cst.Try( + "expected_re": "at least one space after except", + }, + { + "get_node": lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))), + "expected_re": "at least one ExceptHandler or Finally", + }, + { + "get_node": lambda: cst.Try( cst.SimpleStatementSuite((cst.Pass(),)), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))), ), - "at least one ExceptHandler in order to have an Else", - ), + "expected_re": "at least one ExceptHandler in order to have an Else", + }, ) ) - def test_invalid( - self, get_node: Callable[[], cst.CSTNode], expected_re: str - ) -> None: - self.assert_invalid(get_node, expected_re) + def test_invalid(self, **kwargs: Any) -> None: + self.assert_invalid(**kwargs)