From 89fb7fe524ebfd33fd85a54f5282d92d877cda9f Mon Sep 17 00:00:00 2001 From: Ray Zeng Date: Thu, 25 Jul 2019 20:36:55 -0700 Subject: [PATCH] Calculate syntactic position for statement nodes (1) This commit covers most simple statements and if statements. --- libcst/_nodes/_statement.py | 108 ++++--- libcst/_nodes/tests/base.py | 5 +- libcst/_nodes/tests/test_del.py | 59 ++-- libcst/_nodes/tests/test_else.py | 23 +- libcst/_nodes/tests/test_if.py | 101 +++---- libcst/_nodes/tests/test_module.py | 10 +- libcst/_nodes/tests/test_return.py | 85 +++--- libcst/_nodes/tests/test_simple_statement.py | 284 ++++++++++--------- libcst/_nodes/tests/test_small_statement.py | 75 ++--- 9 files changed, 405 insertions(+), 345 deletions(-) diff --git a/libcst/_nodes/_statement.py b/libcst/_nodes/_statement.py index 223edda5..343b6b3b 100644 --- a/libcst/_nodes/_statement.py +++ b/libcst/_nodes/_statement.py @@ -121,9 +121,11 @@ class Del(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - state.add_token("del") - self.whitespace_after_del._codegen(state) - self.target._codegen(state) + with state.record_syntactic_position(self): + state.add_token("del") + self.whitespace_after_del._codegen(state) + self.target._codegen(state) + semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): if default_semicolon: @@ -145,7 +147,9 @@ class Pass(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - state.add_token("pass") + with state.record_syntactic_position(self): + state.add_token("pass") + semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): if default_semicolon: @@ -167,7 +171,9 @@ class Break(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - state.add_token("break") + with state.record_syntactic_position(self): + state.add_token("break") + semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): if default_semicolon: @@ -189,7 +195,9 @@ class Continue(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - state.add_token("continue") + with state.record_syntactic_position(self): + state.add_token("continue") + semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): if default_semicolon: @@ -235,19 +243,17 @@ class Return(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - value = self.value - - state.add_token("return") - - whitespace_after_return = self.whitespace_after_return - if isinstance(whitespace_after_return, MaybeSentinel): + with state.record_syntactic_position(self): + state.add_token("return") + whitespace_after_return = self.whitespace_after_return + value = self.value + if isinstance(whitespace_after_return, MaybeSentinel): + if value is not None: + state.add_token(" ") + else: + whitespace_after_return._codegen(state) if value is not None: - state.add_token(" ") - else: - whitespace_after_return._codegen(state) - - if value is not None: - value._codegen(state) + value._codegen(state) semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): @@ -279,7 +285,9 @@ class Expr(BaseSmallStatement): def _codegen_impl( self, state: CodegenState, default_semicolon: bool = False ) -> None: - self.value._codegen(state) + with state.record_syntactic_position(self): + self.value._codegen(state) + semicolon = self.semicolon if isinstance(semicolon, MaybeSentinel): if default_semicolon: @@ -316,12 +324,15 @@ class _BaseSimpleStatement(CSTNode, ABC): body = self.body if body: laststmt = len(body) - 1 - for idx, stmt in enumerate(body): - stmt._codegen(state, default_semicolon=(idx != laststmt)) + with state.record_syntactic_position(self, end_node=body[laststmt]): + for idx, stmt in enumerate(body): + stmt._codegen(state, default_semicolon=(idx != laststmt)) else: # Empty simple statement blocks are not syntactically valid in Python # unless they contain a 'pass' statement, so add one here. - state.add_token("pass") + with state.record_syntactic_position(self): + state.add_token("pass") + self.trailing_whitespace._codegen(state) @@ -428,10 +439,12 @@ class Else(CSTNode): for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() - state.add_token("else") - self.whitespace_before_colon._codegen(state) - state.add_token(":") - self.body._codegen(state) + + with state.record_syntactic_position(self, end_node=self.body): + state.add_token("else") + self.whitespace_before_colon._codegen(state) + state.add_token(":") + self.body._codegen(state) class BaseCompoundStatement(CSTNode, ABC): @@ -491,18 +504,21 @@ class If(BaseCompoundStatement): for ll in self.leading_lines: ll._codegen(state) state.add_indent_tokens() - state.add_token("elif" if is_elif else "if") - self.whitespace_before_test._codegen(state) - self.test._codegen(state) - self.whitespace_after_test._codegen(state) - state.add_token(":") - self.body._codegen(state) - orelse = self.orelse - if orelse is not None: - if isinstance(orelse, If): # special-case elif - orelse._codegen(state, is_elif=True) - else: # is an Else clause - orelse._codegen(state) + + end_node = self.body if self.orelse is None else self.orelse + with state.record_syntactic_position(self, end_node=end_node): + state.add_token("elif" if is_elif else "if") + self.whitespace_before_test._codegen(state) + self.test._codegen(state) + self.whitespace_after_test._codegen(state) + state.add_token(":") + self.body._codegen(state) + orelse = self.orelse + if orelse is not None: + if isinstance(orelse, If): # special-case elif + orelse._codegen(state, is_elif=True) + else: # is an Else clause + orelse._codegen(state) @add_slots @@ -566,16 +582,20 @@ class IndentedBlock(BaseSuite): state.increase_indent(state.default_indent if indent is None else indent) if self.body: - for stmt in self.body: - # IndentedBlock is responsible for adjusting the current indentation level, - # but its children are responsible for actually adding that indentation to - # the token list. - stmt._codegen(state) + with state.record_syntactic_position( + self, start_node=self.body[0], end_node=self.body[-1] + ): + for stmt in self.body: + # IndentedBlock is responsible for adjusting the current indentation level, + # but its children are responsible for actually adding that indentation to + # the token list. + stmt._codegen(state) else: # Empty indented blocks are not syntactically valid in Python unless # they contain a 'pass' statement, so add one here. state.add_indent_tokens() - state.add_token("pass") + with state.record_syntactic_position(self): + state.add_token("pass") state.add_token(state.default_newline) for f in self.footer: diff --git a/libcst/_nodes/tests/base.py b/libcst/_nodes/tests/base.py index 269033b9..ac6184a8 100644 --- a/libcst/_nodes/tests/base.py +++ b/libcst/_nodes/tests/base.py @@ -222,7 +222,10 @@ class DummyIndentedBlock(cst.CSTNode): def _codegen_impl(self, state: CodegenState) -> None: state.increase_indent(self.value) - self.child._codegen(state) + with state.record_syntactic_position( + self, start_node=self.child, end_node=self.child + ): + self.child._codegen(state) state.decrease_indent() def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "DummyIndentedBlock": diff --git a/libcst/_nodes/tests/test_del.py b/libcst/_nodes/tests/test_del.py index d9c6dc3b..3a42a20a 100644 --- a/libcst/_nodes/tests/test_del.py +++ b/libcst/_nodes/tests/test_del.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Callable, Optional +from typing import Any import libcst as cst from libcst import parse_statement @@ -16,9 +16,14 @@ from libcst.testing.utils import data_provider class DelTest(CSTNodeTest): @data_provider( ( - (cst.SimpleStatementLine([cst.Del(cst.Name("abc"))]), "del abc\n"), - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine([cst.Del(cst.Name("abc"))]), + "code": "del abc\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 7)), + }, + { + "node": cst.SimpleStatementLine( [ cst.Del( cst.Name("abc"), @@ -26,10 +31,12 @@ class DelTest(CSTNodeTest): ) ] ), - "del abc\n", - ), - ( - cst.SimpleStatementLine( + "code": "del abc\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 9)), + }, + { + "node": cst.SimpleStatementLine( [ cst.Del( cst.Name( @@ -39,32 +46,32 @@ class DelTest(CSTNodeTest): ) ] ), - "del(abc)\n", - ), - ( - cst.SimpleStatementLine( + "code": "del(abc)\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 8)), + }, + { + "node": cst.SimpleStatementLine( [cst.Del(cst.Name("abc"), semicolon=cst.Semicolon())] ), - "del abc;\n", - ), + "code": "del abc;\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 7)), + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: - self.validate_node(node, code, parse_statement, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) @data_provider( ( - ( - lambda: cst.Del( + { + "get_node": lambda: cst.Del( cst.Name("abc"), whitespace_after_del=cst.SimpleWhitespace("") ), - "Must have at least one space after 'del'.", - ), + "expected_re": "Must have at least one space after 'del'.", + }, ) ) - def test_invalid( - self, get_node: Callable[[], cst.CSTNode], expected_re: str - ) -> None: - self.assert_invalid(get_node, expected_re) + def test_invalid(self, **kwargs: Any) -> None: + self.assert_invalid(**kwargs) diff --git a/libcst/_nodes/tests/test_else.py b/libcst/_nodes/tests/test_else.py index c25e3c8e..79b04cb3 100644 --- a/libcst/_nodes/tests/test_else.py +++ b/libcst/_nodes/tests/test_else.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Optional +from typing import Any import libcst as cst from libcst._nodes._internal import CodeRange @@ -15,17 +15,20 @@ from libcst.testing.utils import data_provider class ElseTest(CSTNodeTest): @data_provider( ( - (cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), "else: pass\n"), - ( - cst.Else( + { + "node": cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), + "code": "else: pass\n", + "expected_position": CodeRange.create((1, 0), (1, 10)), + }, + { + "node": cst.Else( cst.SimpleStatementSuite((cst.Pass(),)), whitespace_before_colon=cst.SimpleWhitespace(" "), ), - "else : pass\n", - ), + "code": "else : pass\n", + "expected_position": CodeRange.create((1, 0), (1, 12)), + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: - self.validate_node(node, code, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) diff --git a/libcst/_nodes/tests/test_if.py b/libcst/_nodes/tests/test_if.py index e94011ab..87fae5e6 100644 --- a/libcst/_nodes/tests/test_if.py +++ b/libcst/_nodes/tests/test_if.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Callable, Optional +from typing import Any import libcst as cst from libcst import parse_statement @@ -17,26 +17,29 @@ class IfTest(CSTNodeTest): @data_provider( ( # Simple if without elif or else - ( - cst.If( + # pyre-fixme[6]: Incompatible parameter type + { + "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(),)) ), - "if conditional: pass\n", - parse_statement, - ), + "code": "if conditional: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 20)), + }, # else clause - ( - cst.If( + { + "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(),)), orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), ), - "if conditional: pass\nelse: pass\n", - parse_statement, - ), + "code": "if conditional: pass\nelse: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (2, 10)), + }, # elif clause - ( - cst.If( + { + "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(),)), orelse=cst.If( @@ -45,12 +48,13 @@ class IfTest(CSTNodeTest): orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), ), ), - "if conditional: pass\nelif other_conditional: pass\nelse: pass\n", - parse_statement, - ), + "code": "if conditional: pass\nelif other_conditional: pass\nelse: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (3, 10)), + }, # indentation - ( - DummyIndentedBlock( + { + "node": DummyIndentedBlock( " ", cst.If( cst.Name("conditional"), @@ -58,36 +62,39 @@ class IfTest(CSTNodeTest): orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), ), ), - " if conditional: pass\n else: pass\n", - None, - ), + "code": " if conditional: pass\n else: pass\n", + "parser": None, + "expected_position": CodeRange.create((1, 4), (2, 14)), + }, # with an indented body - ( - DummyIndentedBlock( + { + "node": DummyIndentedBlock( " ", cst.If( cst.Name("conditional"), cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)), ), ), - " if conditional:\n pass\n", - None, - ), + "code": " if conditional:\n pass\n", + "parser": None, + "expected_position": CodeRange.create((1, 4), (2, 12)), + }, # leading_lines - ( - cst.If( + { + "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(),)), leading_lines=( cst.EmptyLine(comment=cst.Comment("# leading comment")), ), ), - "# leading comment\nif conditional: pass\n", - parse_statement, - ), + "code": "# leading comment\nif conditional: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((2, 0), (2, 20)), + }, # whitespace before/after test and else - ( - cst.If( + { + "node": cst.If( cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(),)), whitespace_before_test=cst.SimpleWhitespace(" "), @@ -97,12 +104,13 @@ class IfTest(CSTNodeTest): whitespace_before_colon=cst.SimpleWhitespace(" "), ), ), - "if conditional : pass\nelse : pass\n", - parse_statement, - ), + "code": "if conditional : pass\nelse : pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (2, 11)), + }, # empty lines between if/elif/else clauses, not captured by the suite. - ( - cst.If( + { + "node": cst.If( cst.Name("test_a"), cst.SimpleStatementSuite((cst.Pass(),)), orelse=cst.If( @@ -115,16 +123,11 @@ class IfTest(CSTNodeTest): ), ), ), - "if test_a: pass\n\nelif test_b: pass\n\nelse: pass\n", - parse_statement, - ), + "code": "if test_a: pass\n\nelif test_b: pass\n\nelse: pass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (5, 10)), + }, ) ) - def test_valid( - self, - node: cst.CSTNode, - code: str, - parser: Optional[Callable[[str], cst.CSTNode]], - position: Optional[CodeRange] = None, - ) -> None: - self.validate_node(node, code, parser, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) diff --git a/libcst/_nodes/tests/test_module.py b/libcst/_nodes/tests/test_module.py index d41c1bbf..2cbe407b 100644 --- a/libcst/_nodes/tests/test_module.py +++ b/libcst/_nodes/tests/test_module.py @@ -158,7 +158,7 @@ class ModuleTest(CSTNodeTest): fn = cast(cst.FunctionDef, module.body[0]) stmt = cast(cst.SimpleStatementLine, fn.body.body[0]) pass_stmt = cast(cst.Pass, stmt.body[0]) - self.cmp_position(stmt, (2, 0), (3, 0)) + self.cmp_position(stmt, (2, 4), (2, 8)) self.cmp_position(pass_stmt, (2, 4), (2, 8)) def test_nested_indent_position(self) -> None: @@ -174,10 +174,10 @@ class ModuleTest(CSTNodeTest): outer_else = cast(cst.Else, outer_if.orelse) return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0] - self.cmp_position(outer_if, (1, 0), (6, 0)) - self.cmp_position(inner_if, (2, 0), (4, 0)) + self.cmp_position(outer_if, (1, 0), (5, 10)) + self.cmp_position(inner_if, (2, 4), (3, 13)) self.cmp_position(assign, (3, 8), (3, 13)) - self.cmp_position(outer_else, (4, 0), (6, 0)) + self.cmp_position(outer_else, (4, 0), (5, 10)) self.cmp_position(return_stmt, (5, 4), (5, 10)) def test_multiline_string_position(self) -> None: @@ -188,6 +188,6 @@ class ModuleTest(CSTNodeTest): expr = cast(cst.Expr, stmt.body[0]) string = expr.value - self.cmp_position(stmt, (1, 0), (3, 0)) + self.cmp_position(stmt, (1, 0), (2, 5)) self.cmp_position(expr, (1, 0), (2, 5)) self.cmp_position(string, (1, 0), (2, 5)) diff --git a/libcst/_nodes/tests/test_return.py b/libcst/_nodes/tests/test_return.py index c9f4df69..d24248cd 100644 --- a/libcst/_nodes/tests/test_return.py +++ b/libcst/_nodes/tests/test_return.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Callable, Optional +from typing import Any import libcst as cst from libcst import parse_statement @@ -16,42 +16,47 @@ from libcst.testing.utils import data_provider class ReturnCreateTest(CSTNodeTest): @data_provider( ( - (cst.SimpleStatementLine([cst.Return()]), "return\n"), - (cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]), "return abc\n"), + { + "node": cst.SimpleStatementLine([cst.Return()]), + "code": "return\n", + "expected_position": CodeRange.create((1, 0), (1, 6)), + }, + { + "node": cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]), + "code": "return abc\n", + "expected_position": CodeRange.create((1, 0), (1, 10)), + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: - self.validate_node(node, code, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) @data_provider( ( - ( - lambda: cst.Return( + { + "get_node": lambda: cst.Return( cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace("") ), - "Must have at least one space after 'return'.", - ), + "expected_re": "Must have at least one space after 'return'.", + }, ) ) - def test_invalid( - self, get_node: Callable[[], cst.CSTNode], expected_re: str - ) -> None: - self.assert_invalid(get_node, expected_re) + def test_invalid(self, **kwargs: Any) -> None: + self.assert_invalid(**kwargs) class ReturnParseTest(CSTNodeTest): @data_provider( ( - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( [cst.Return(whitespace_after_return=cst.SimpleWhitespace(""))] ), - "return\n", - ), - ( - cst.SimpleStatementLine( + "code": "return\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( [ cst.Return( cst.Name("abc"), @@ -59,10 +64,11 @@ class ReturnParseTest(CSTNodeTest): ) ] ), - "return abc\n", - ), - ( - cst.SimpleStatementLine( + "code": "return abc\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( [ cst.Return( cst.Name("abc"), @@ -70,10 +76,11 @@ class ReturnParseTest(CSTNodeTest): ) ] ), - "return abc\n", - ), - ( - cst.SimpleStatementLine( + "code": "return abc\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( [ cst.Return( cst.Name( @@ -83,10 +90,11 @@ class ReturnParseTest(CSTNodeTest): ) ] ), - "return(abc)\n", - ), - ( - cst.SimpleStatementLine( + "code": "return(abc)\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( [ cst.Return( cst.Name("abc"), @@ -95,11 +103,10 @@ class ReturnParseTest(CSTNodeTest): ) ] ), - "return abc;\n", - ), + "code": "return abc;\n", + "parser": parse_statement, + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: - self.validate_node(node, code, parse_statement, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) diff --git a/libcst/_nodes/tests/test_simple_statement.py b/libcst/_nodes/tests/test_simple_statement.py index 6c4949a2..c255daba 100644 --- a/libcst/_nodes/tests/test_simple_statement.py +++ b/libcst/_nodes/tests/test_simple_statement.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Callable, Optional +from typing import Any import libcst as cst from libcst import parse_statement @@ -17,18 +17,23 @@ class SimpleStatementTest(CSTNodeTest): @data_provider( ( # a single-element SimpleStatementLine - (cst.SimpleStatementLine((cst.Pass(),)), "pass\n", parse_statement), + # pyre-fixme[6]: Incompatible parameter type + { + "node": cst.SimpleStatementLine((cst.Pass(),)), + "code": "pass\n", + "parser": parse_statement, + }, # a multi-element SimpleStatementLine - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( (cst.Pass(semicolon=cst.Semicolon()), cst.Continue()) ), - "pass;continue\n", - parse_statement, - ), + "code": "pass;continue\n", + "parser": parse_statement, + }, # a multi-element SimpleStatementLine with whitespace - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( ( cst.Pass( semicolon=cst.Semicolon( @@ -39,67 +44,70 @@ class SimpleStatementTest(CSTNodeTest): cst.Continue(), ) ), - "pass ; continue\n", - parse_statement, - ), + "code": "pass ; continue\n", + "parser": parse_statement, + }, # A more complicated SimpleStatementLine - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( ( cst.Pass(semicolon=cst.Semicolon()), cst.Continue(semicolon=cst.Semicolon()), cst.Break(), ) ), - "pass;continue;break\n", - parse_statement, - ), + "code": "pass;continue;break\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 19)), + }, # a multi-element SimpleStatementLine, inferred semicolons - ( - cst.SimpleStatementLine((cst.Pass(), cst.Continue(), cst.Break())), - "pass; continue; break\n", - None, # No test for parsing, since we are using sentinels. - ), + { + "node": cst.SimpleStatementLine( + (cst.Pass(), cst.Continue(), cst.Break()) + ), + "code": "pass; continue; break\n", + "parser": None, # No test for parsing, since we are using sentinels. + }, # some expression statements - ( - cst.SimpleStatementLine((cst.Expr(cst.Name("None")),)), - "None\n", - parse_statement, - ), - ( - cst.SimpleStatementLine((cst.Expr(cst.Name("True")),)), - "True\n", - parse_statement, - ), - ( - cst.SimpleStatementLine((cst.Expr(cst.Name("False")),)), - "False\n", - parse_statement, - ), - ( - cst.SimpleStatementLine((cst.Expr(cst.Ellipses()),)), - "...\n", - parse_statement, - ), + { + "node": cst.SimpleStatementLine((cst.Expr(cst.Name("None")),)), + "code": "None\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine((cst.Expr(cst.Name("True")),)), + "code": "True\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine((cst.Expr(cst.Name("False")),)), + "code": "False\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine((cst.Expr(cst.Ellipses()),)), + "code": "...\n", + "parser": parse_statement, + }, # Test some numbers - ( - cst.SimpleStatementLine((cst.Expr(cst.Integer("5")),)), - "5\n", - parse_statement, - ), - ( - cst.SimpleStatementLine((cst.Expr(cst.Float("5.5")),)), - "5.5\n", - parse_statement, - ), - ( - cst.SimpleStatementLine((cst.Expr(cst.Imaginary("5j")),)), - "5j\n", - parse_statement, - ), + { + "node": cst.SimpleStatementLine((cst.Expr(cst.Integer("5")),)), + "code": "5\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine((cst.Expr(cst.Float("5.5")),)), + "code": "5.5\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine((cst.Expr(cst.Imaginary("5j")),)), + "code": "5j\n", + "parser": parse_statement, + }, # Test some numbers with parens - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.Integer( @@ -108,11 +116,12 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - "(5)\n", - parse_statement, - ), - ( - cst.SimpleStatementLine( + "code": "(5)\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 3)), + }, + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.Float( @@ -121,11 +130,11 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - "(5.5)\n", - parse_statement, - ), - ( - cst.SimpleStatementLine( + "code": "(5.5)\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.Imaginary( @@ -134,17 +143,17 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - "(5j)\n", - parse_statement, - ), + "code": "(5j)\n", + "parser": parse_statement, + }, # Test some strings - ( - cst.SimpleStatementLine((cst.Expr(cst.SimpleString('"abc"')),)), - '"abc"\n', - parse_statement, - ), - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine((cst.Expr(cst.SimpleString('"abc"')),)), + "code": '"abc"\n', + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.ConcatenatedString( @@ -153,11 +162,11 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - '"abc""def"\n', - parse_statement, - ), - ( - cst.SimpleStatementLine( + "code": '"abc""def"\n', + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.ConcatenatedString( @@ -172,12 +181,13 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - '"abc" "def" "ghi"\n', - parse_statement, - ), + "code": '"abc" "def" "ghi"\n', + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 17)), + }, # Test parenthesis rules - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.Ellipses( @@ -186,12 +196,12 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - "(...)\n", - parse_statement, - ), + "code": "(...)\n", + "parser": parse_statement, + }, # Test parenthesis with whitespace ownership - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.Ellipses( @@ -209,11 +219,11 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - "( ... )\n", - parse_statement, - ), - ( - cst.SimpleStatementLine( + "code": "( ... )\n", + "parser": parse_statement, + }, + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.Ellipses( @@ -243,12 +253,13 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - "( ( ( ... ) ) )\n", - parse_statement, - ), + "code": "( ( ( ... ) ) )\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 21)), + }, # Test parenthesis rules with expressions - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( ( cst.Expr( cst.Ellipses( @@ -282,33 +293,36 @@ class SimpleStatementTest(CSTNodeTest): ), ) ), - "(\n# Wow, a comment!\n ...\n)\n", - parse_statement, - ), + "code": "(\n# Wow, a comment!\n ...\n)\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (4, 1)), + }, # test trailing whitespace - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( (cst.Pass(),), trailing_whitespace=cst.TrailingWhitespace( whitespace=cst.SimpleWhitespace(" "), comment=cst.Comment("# trailing comment"), ), ), - "pass # trailing comment\n", - parse_statement, - ), + "code": "pass # trailing comment\n", + "parser": parse_statement, + "expected_position": CodeRange.create((1, 0), (1, 4)), + }, # test leading comment - ( - cst.SimpleStatementLine( + { + "node": cst.SimpleStatementLine( (cst.Pass(),), leading_lines=(cst.EmptyLine(comment=cst.Comment("# comment")),), ), - "# comment\npass\n", - parse_statement, - ), + "code": "# comment\npass\n", + "parser": parse_statement, + "expected_position": CodeRange.create((2, 0), (2, 4)), + }, # test indentation - ( - DummyIndentedBlock( + { + "node": DummyIndentedBlock( " ", cst.SimpleStatementLine( (cst.Pass(),), @@ -317,25 +331,23 @@ class SimpleStatementTest(CSTNodeTest): ), ), ), - " # comment\n pass\n", - None, - ), + "code": " # comment\n pass\n", + "expected_position": CodeRange.create((2, 4), (2, 8)), + }, # test suite variant - (cst.SimpleStatementSuite((cst.Pass(),)), " pass\n", None), - ( - cst.SimpleStatementSuite( + { + "node": cst.SimpleStatementSuite((cst.Pass(),)), + "code": " pass\n", + "expected_position": CodeRange.create((1, 1), (1, 5)), + }, + { + "node": cst.SimpleStatementSuite( (cst.Pass(),), leading_whitespace=cst.SimpleWhitespace("") ), - "pass\n", - None, - ), + "code": "pass\n", + "expected_position": CodeRange.create((1, 0), (1, 4)), + }, ) ) - def test_valid( - self, - node: cst.CSTNode, - code: str, - parser: Optional[Callable[[str], cst.CSTNode]], - position: Optional[CodeRange] = None, - ) -> None: - self.validate_node(node, code, parser, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs) diff --git a/libcst/_nodes/tests/test_small_statement.py b/libcst/_nodes/tests/test_small_statement.py index 8ed938a7..6f97eea8 100644 --- a/libcst/_nodes/tests/test_small_statement.py +++ b/libcst/_nodes/tests/test_small_statement.py @@ -4,7 +4,7 @@ # LICENSE file in the root directory of this source tree. # pyre-strict -from typing import Optional +from typing import Any import libcst as cst from libcst._nodes._internal import CodeRange @@ -15,63 +15,68 @@ from libcst.testing.utils import data_provider class SmallStatementTest(CSTNodeTest): @data_provider( ( - (cst.Pass(), "pass"), - (cst.Pass(semicolon=cst.Semicolon()), "pass;"), - ( - cst.Pass( + # pyre-fixme[6]: Incompatible parameter type + {"node": cst.Pass(), "code": "pass"}, + {"node": cst.Pass(semicolon=cst.Semicolon()), "code": "pass;"}, + { + "node": cst.Pass( semicolon=cst.Semicolon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ) ), - "pass ; ", - ), - (cst.Continue(), "continue"), - (cst.Continue(semicolon=cst.Semicolon()), "continue;"), - ( - cst.Continue( + "code": "pass ; ", + "expected_position": CodeRange.create((1, 0), (1, 4)), + }, + {"node": cst.Continue(), "code": "continue"}, + {"node": cst.Continue(semicolon=cst.Semicolon()), "code": "continue;"}, + { + "node": cst.Continue( semicolon=cst.Semicolon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ) ), - "continue ; ", - ), - (cst.Break(), "break"), - (cst.Break(semicolon=cst.Semicolon()), "break;"), - ( - cst.Break( + "code": "continue ; ", + "expected_position": CodeRange.create((1, 0), (1, 8)), + }, + {"node": cst.Break(), "code": "break"}, + {"node": cst.Break(semicolon=cst.Semicolon()), "code": "break;"}, + { + "node": cst.Break( semicolon=cst.Semicolon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ) ), - "break ; ", - ), - ( - cst.Expr(cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y"))), - "x + y", - ), - ( - cst.Expr( + "code": "break ; ", + "expected_position": CodeRange.create((1, 0), (1, 5)), + }, + { + "node": cst.Expr( + cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")) + ), + "code": "x + y", + }, + { + "node": cst.Expr( cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")), semicolon=cst.Semicolon(), ), - "x + y;", - ), - ( - cst.Expr( + "code": "x + y;", + }, + { + "node": cst.Expr( cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")), semicolon=cst.Semicolon( whitespace_before=cst.SimpleWhitespace(" "), whitespace_after=cst.SimpleWhitespace(" "), ), ), - "x + y ; ", - ), + "code": "x + y ; ", + "expected_position": CodeRange.create((1, 0), (1, 5)), + }, ) ) - def test_valid( - self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None - ) -> None: - self.validate_node(node, code, expected_position=position) + def test_valid(self, **kwargs: Any) -> None: + self.validate_node(**kwargs)