Calculate syntactic position for statement nodes (1)

This commit covers most simple statements and if statements.
This commit is contained in:
Ray Zeng 2019-07-25 20:36:55 -07:00 committed by Benjamin Woodruff
parent 3fb60b9706
commit 89fb7fe524
9 changed files with 405 additions and 345 deletions

View file

@ -121,9 +121,11 @@ class Del(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("del")
self.whitespace_after_del._codegen(state)
self.target._codegen(state)
with state.record_syntactic_position(self):
state.add_token("del")
self.whitespace_after_del._codegen(state)
self.target._codegen(state)
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
if default_semicolon:
@ -145,7 +147,9 @@ class Pass(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("pass")
with state.record_syntactic_position(self):
state.add_token("pass")
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
if default_semicolon:
@ -167,7 +171,9 @@ class Break(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("break")
with state.record_syntactic_position(self):
state.add_token("break")
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
if default_semicolon:
@ -189,7 +195,9 @@ class Continue(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("continue")
with state.record_syntactic_position(self):
state.add_token("continue")
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
if default_semicolon:
@ -235,19 +243,17 @@ class Return(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
value = self.value
state.add_token("return")
whitespace_after_return = self.whitespace_after_return
if isinstance(whitespace_after_return, MaybeSentinel):
with state.record_syntactic_position(self):
state.add_token("return")
whitespace_after_return = self.whitespace_after_return
value = self.value
if isinstance(whitespace_after_return, MaybeSentinel):
if value is not None:
state.add_token(" ")
else:
whitespace_after_return._codegen(state)
if value is not None:
state.add_token(" ")
else:
whitespace_after_return._codegen(state)
if value is not None:
value._codegen(state)
value._codegen(state)
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
@ -279,7 +285,9 @@ class Expr(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
self.value._codegen(state)
with state.record_syntactic_position(self):
self.value._codegen(state)
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
if default_semicolon:
@ -316,12 +324,15 @@ class _BaseSimpleStatement(CSTNode, ABC):
body = self.body
if body:
laststmt = len(body) - 1
for idx, stmt in enumerate(body):
stmt._codegen(state, default_semicolon=(idx != laststmt))
with state.record_syntactic_position(self, end_node=body[laststmt]):
for idx, stmt in enumerate(body):
stmt._codegen(state, default_semicolon=(idx != laststmt))
else:
# Empty simple statement blocks are not syntactically valid in Python
# unless they contain a 'pass' statement, so add one here.
state.add_token("pass")
with state.record_syntactic_position(self):
state.add_token("pass")
self.trailing_whitespace._codegen(state)
@ -428,10 +439,12 @@ class Else(CSTNode):
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
state.add_token("else")
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
with state.record_syntactic_position(self, end_node=self.body):
state.add_token("else")
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
class BaseCompoundStatement(CSTNode, ABC):
@ -491,18 +504,21 @@ class If(BaseCompoundStatement):
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
state.add_token("elif" if is_elif else "if")
self.whitespace_before_test._codegen(state)
self.test._codegen(state)
self.whitespace_after_test._codegen(state)
state.add_token(":")
self.body._codegen(state)
orelse = self.orelse
if orelse is not None:
if isinstance(orelse, If): # special-case elif
orelse._codegen(state, is_elif=True)
else: # is an Else clause
orelse._codegen(state)
end_node = self.body if self.orelse is None else self.orelse
with state.record_syntactic_position(self, end_node=end_node):
state.add_token("elif" if is_elif else "if")
self.whitespace_before_test._codegen(state)
self.test._codegen(state)
self.whitespace_after_test._codegen(state)
state.add_token(":")
self.body._codegen(state)
orelse = self.orelse
if orelse is not None:
if isinstance(orelse, If): # special-case elif
orelse._codegen(state, is_elif=True)
else: # is an Else clause
orelse._codegen(state)
@add_slots
@ -566,16 +582,20 @@ class IndentedBlock(BaseSuite):
state.increase_indent(state.default_indent if indent is None else indent)
if self.body:
for stmt in self.body:
# IndentedBlock is responsible for adjusting the current indentation level,
# but its children are responsible for actually adding that indentation to
# the token list.
stmt._codegen(state)
with state.record_syntactic_position(
self, start_node=self.body[0], end_node=self.body[-1]
):
for stmt in self.body:
# IndentedBlock is responsible for adjusting the current indentation level,
# but its children are responsible for actually adding that indentation to
# the token list.
stmt._codegen(state)
else:
# Empty indented blocks are not syntactically valid in Python unless
# they contain a 'pass' statement, so add one here.
state.add_indent_tokens()
state.add_token("pass")
with state.record_syntactic_position(self):
state.add_token("pass")
state.add_token(state.default_newline)
for f in self.footer:

View file

@ -222,7 +222,10 @@ class DummyIndentedBlock(cst.CSTNode):
def _codegen_impl(self, state: CodegenState) -> None:
state.increase_indent(self.value)
self.child._codegen(state)
with state.record_syntactic_position(
self, start_node=self.child, end_node=self.child
):
self.child._codegen(state)
state.decrease_indent()
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "DummyIndentedBlock":

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable, Optional
from typing import Any
import libcst as cst
from libcst import parse_statement
@ -16,9 +16,14 @@ from libcst.testing.utils import data_provider
class DelTest(CSTNodeTest):
@data_provider(
(
(cst.SimpleStatementLine([cst.Del(cst.Name("abc"))]), "del abc\n"),
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine([cst.Del(cst.Name("abc"))]),
"code": "del abc\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 7)),
},
{
"node": cst.SimpleStatementLine(
[
cst.Del(
cst.Name("abc"),
@ -26,10 +31,12 @@ class DelTest(CSTNodeTest):
)
]
),
"del abc\n",
),
(
cst.SimpleStatementLine(
"code": "del abc\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 9)),
},
{
"node": cst.SimpleStatementLine(
[
cst.Del(
cst.Name(
@ -39,32 +46,32 @@ class DelTest(CSTNodeTest):
)
]
),
"del(abc)\n",
),
(
cst.SimpleStatementLine(
"code": "del(abc)\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 8)),
},
{
"node": cst.SimpleStatementLine(
[cst.Del(cst.Name("abc"), semicolon=cst.Semicolon())]
),
"del abc;\n",
),
"code": "del abc;\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 7)),
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
self.validate_node(node, code, parse_statement, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
(
lambda: cst.Del(
{
"get_node": lambda: cst.Del(
cst.Name("abc"), whitespace_after_del=cst.SimpleWhitespace("")
),
"Must have at least one space after 'del'.",
),
"expected_re": "Must have at least one space after 'del'.",
},
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Optional
from typing import Any
import libcst as cst
from libcst._nodes._internal import CodeRange
@ -15,17 +15,20 @@ from libcst.testing.utils import data_provider
class ElseTest(CSTNodeTest):
@data_provider(
(
(cst.Else(cst.SimpleStatementSuite((cst.Pass(),))), "else: pass\n"),
(
cst.Else(
{
"node": cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
"code": "else: pass\n",
"expected_position": CodeRange.create((1, 0), (1, 10)),
},
{
"node": cst.Else(
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_before_colon=cst.SimpleWhitespace(" "),
),
"else : pass\n",
),
"code": "else : pass\n",
"expected_position": CodeRange.create((1, 0), (1, 12)),
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
self.validate_node(node, code, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable, Optional
from typing import Any
import libcst as cst
from libcst import parse_statement
@ -17,26 +17,29 @@ class IfTest(CSTNodeTest):
@data_provider(
(
# Simple if without elif or else
(
cst.If(
# pyre-fixme[6]: Incompatible parameter type
{
"node": cst.If(
cst.Name("conditional"), cst.SimpleStatementSuite((cst.Pass(),))
),
"if conditional: pass\n",
parse_statement,
),
"code": "if conditional: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 20)),
},
# else clause
(
cst.If(
{
"node": cst.If(
cst.Name("conditional"),
cst.SimpleStatementSuite((cst.Pass(),)),
orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
),
"if conditional: pass\nelse: pass\n",
parse_statement,
),
"code": "if conditional: pass\nelse: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (2, 10)),
},
# elif clause
(
cst.If(
{
"node": cst.If(
cst.Name("conditional"),
cst.SimpleStatementSuite((cst.Pass(),)),
orelse=cst.If(
@ -45,12 +48,13 @@ class IfTest(CSTNodeTest):
orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
),
),
"if conditional: pass\nelif other_conditional: pass\nelse: pass\n",
parse_statement,
),
"code": "if conditional: pass\nelif other_conditional: pass\nelse: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (3, 10)),
},
# indentation
(
DummyIndentedBlock(
{
"node": DummyIndentedBlock(
" ",
cst.If(
cst.Name("conditional"),
@ -58,36 +62,39 @@ class IfTest(CSTNodeTest):
orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
),
),
" if conditional: pass\n else: pass\n",
None,
),
"code": " if conditional: pass\n else: pass\n",
"parser": None,
"expected_position": CodeRange.create((1, 4), (2, 14)),
},
# with an indented body
(
DummyIndentedBlock(
{
"node": DummyIndentedBlock(
" ",
cst.If(
cst.Name("conditional"),
cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)),
),
),
" if conditional:\n pass\n",
None,
),
"code": " if conditional:\n pass\n",
"parser": None,
"expected_position": CodeRange.create((1, 4), (2, 12)),
},
# leading_lines
(
cst.If(
{
"node": cst.If(
cst.Name("conditional"),
cst.SimpleStatementSuite((cst.Pass(),)),
leading_lines=(
cst.EmptyLine(comment=cst.Comment("# leading comment")),
),
),
"# leading comment\nif conditional: pass\n",
parse_statement,
),
"code": "# leading comment\nif conditional: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((2, 0), (2, 20)),
},
# whitespace before/after test and else
(
cst.If(
{
"node": cst.If(
cst.Name("conditional"),
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_before_test=cst.SimpleWhitespace(" "),
@ -97,12 +104,13 @@ class IfTest(CSTNodeTest):
whitespace_before_colon=cst.SimpleWhitespace(" "),
),
),
"if conditional : pass\nelse : pass\n",
parse_statement,
),
"code": "if conditional : pass\nelse : pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (2, 11)),
},
# empty lines between if/elif/else clauses, not captured by the suite.
(
cst.If(
{
"node": cst.If(
cst.Name("test_a"),
cst.SimpleStatementSuite((cst.Pass(),)),
orelse=cst.If(
@ -115,16 +123,11 @@ class IfTest(CSTNodeTest):
),
),
),
"if test_a: pass\n\nelif test_b: pass\n\nelse: pass\n",
parse_statement,
),
"code": "if test_a: pass\n\nelif test_b: pass\n\nelse: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (5, 10)),
},
)
)
def test_valid(
self,
node: cst.CSTNode,
code: str,
parser: Optional[Callable[[str], cst.CSTNode]],
position: Optional[CodeRange] = None,
) -> None:
self.validate_node(node, code, parser, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)

View file

@ -158,7 +158,7 @@ class ModuleTest(CSTNodeTest):
fn = cast(cst.FunctionDef, module.body[0])
stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
pass_stmt = cast(cst.Pass, stmt.body[0])
self.cmp_position(stmt, (2, 0), (3, 0))
self.cmp_position(stmt, (2, 4), (2, 8))
self.cmp_position(pass_stmt, (2, 4), (2, 8))
def test_nested_indent_position(self) -> None:
@ -174,10 +174,10 @@ class ModuleTest(CSTNodeTest):
outer_else = cast(cst.Else, outer_if.orelse)
return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0]
self.cmp_position(outer_if, (1, 0), (6, 0))
self.cmp_position(inner_if, (2, 0), (4, 0))
self.cmp_position(outer_if, (1, 0), (5, 10))
self.cmp_position(inner_if, (2, 4), (3, 13))
self.cmp_position(assign, (3, 8), (3, 13))
self.cmp_position(outer_else, (4, 0), (6, 0))
self.cmp_position(outer_else, (4, 0), (5, 10))
self.cmp_position(return_stmt, (5, 4), (5, 10))
def test_multiline_string_position(self) -> None:
@ -188,6 +188,6 @@ class ModuleTest(CSTNodeTest):
expr = cast(cst.Expr, stmt.body[0])
string = expr.value
self.cmp_position(stmt, (1, 0), (3, 0))
self.cmp_position(stmt, (1, 0), (2, 5))
self.cmp_position(expr, (1, 0), (2, 5))
self.cmp_position(string, (1, 0), (2, 5))

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable, Optional
from typing import Any
import libcst as cst
from libcst import parse_statement
@ -16,42 +16,47 @@ from libcst.testing.utils import data_provider
class ReturnCreateTest(CSTNodeTest):
@data_provider(
(
(cst.SimpleStatementLine([cst.Return()]), "return\n"),
(cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]), "return abc\n"),
{
"node": cst.SimpleStatementLine([cst.Return()]),
"code": "return\n",
"expected_position": CodeRange.create((1, 0), (1, 6)),
},
{
"node": cst.SimpleStatementLine([cst.Return(cst.Name("abc"))]),
"code": "return abc\n",
"expected_position": CodeRange.create((1, 0), (1, 10)),
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
self.validate_node(node, code, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
(
lambda: cst.Return(
{
"get_node": lambda: cst.Return(
cst.Name("abc"), whitespace_after_return=cst.SimpleWhitespace("")
),
"Must have at least one space after 'return'.",
),
"expected_re": "Must have at least one space after 'return'.",
},
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)
class ReturnParseTest(CSTNodeTest):
@data_provider(
(
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
[cst.Return(whitespace_after_return=cst.SimpleWhitespace(""))]
),
"return\n",
),
(
cst.SimpleStatementLine(
"code": "return\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
[
cst.Return(
cst.Name("abc"),
@ -59,10 +64,11 @@ class ReturnParseTest(CSTNodeTest):
)
]
),
"return abc\n",
),
(
cst.SimpleStatementLine(
"code": "return abc\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
[
cst.Return(
cst.Name("abc"),
@ -70,10 +76,11 @@ class ReturnParseTest(CSTNodeTest):
)
]
),
"return abc\n",
),
(
cst.SimpleStatementLine(
"code": "return abc\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
[
cst.Return(
cst.Name(
@ -83,10 +90,11 @@ class ReturnParseTest(CSTNodeTest):
)
]
),
"return(abc)\n",
),
(
cst.SimpleStatementLine(
"code": "return(abc)\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
[
cst.Return(
cst.Name("abc"),
@ -95,11 +103,10 @@ class ReturnParseTest(CSTNodeTest):
)
]
),
"return abc;\n",
),
"code": "return abc;\n",
"parser": parse_statement,
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
self.validate_node(node, code, parse_statement, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable, Optional
from typing import Any
import libcst as cst
from libcst import parse_statement
@ -17,18 +17,23 @@ class SimpleStatementTest(CSTNodeTest):
@data_provider(
(
# a single-element SimpleStatementLine
(cst.SimpleStatementLine((cst.Pass(),)), "pass\n", parse_statement),
# pyre-fixme[6]: Incompatible parameter type
{
"node": cst.SimpleStatementLine((cst.Pass(),)),
"code": "pass\n",
"parser": parse_statement,
},
# a multi-element SimpleStatementLine
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(cst.Pass(semicolon=cst.Semicolon()), cst.Continue())
),
"pass;continue\n",
parse_statement,
),
"code": "pass;continue\n",
"parser": parse_statement,
},
# a multi-element SimpleStatementLine with whitespace
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(
cst.Pass(
semicolon=cst.Semicolon(
@ -39,67 +44,70 @@ class SimpleStatementTest(CSTNodeTest):
cst.Continue(),
)
),
"pass ; continue\n",
parse_statement,
),
"code": "pass ; continue\n",
"parser": parse_statement,
},
# A more complicated SimpleStatementLine
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(
cst.Pass(semicolon=cst.Semicolon()),
cst.Continue(semicolon=cst.Semicolon()),
cst.Break(),
)
),
"pass;continue;break\n",
parse_statement,
),
"code": "pass;continue;break\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 19)),
},
# a multi-element SimpleStatementLine, inferred semicolons
(
cst.SimpleStatementLine((cst.Pass(), cst.Continue(), cst.Break())),
"pass; continue; break\n",
None, # No test for parsing, since we are using sentinels.
),
{
"node": cst.SimpleStatementLine(
(cst.Pass(), cst.Continue(), cst.Break())
),
"code": "pass; continue; break\n",
"parser": None, # No test for parsing, since we are using sentinels.
},
# some expression statements
(
cst.SimpleStatementLine((cst.Expr(cst.Name("None")),)),
"None\n",
parse_statement,
),
(
cst.SimpleStatementLine((cst.Expr(cst.Name("True")),)),
"True\n",
parse_statement,
),
(
cst.SimpleStatementLine((cst.Expr(cst.Name("False")),)),
"False\n",
parse_statement,
),
(
cst.SimpleStatementLine((cst.Expr(cst.Ellipses()),)),
"...\n",
parse_statement,
),
{
"node": cst.SimpleStatementLine((cst.Expr(cst.Name("None")),)),
"code": "None\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine((cst.Expr(cst.Name("True")),)),
"code": "True\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine((cst.Expr(cst.Name("False")),)),
"code": "False\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine((cst.Expr(cst.Ellipses()),)),
"code": "...\n",
"parser": parse_statement,
},
# Test some numbers
(
cst.SimpleStatementLine((cst.Expr(cst.Integer("5")),)),
"5\n",
parse_statement,
),
(
cst.SimpleStatementLine((cst.Expr(cst.Float("5.5")),)),
"5.5\n",
parse_statement,
),
(
cst.SimpleStatementLine((cst.Expr(cst.Imaginary("5j")),)),
"5j\n",
parse_statement,
),
{
"node": cst.SimpleStatementLine((cst.Expr(cst.Integer("5")),)),
"code": "5\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine((cst.Expr(cst.Float("5.5")),)),
"code": "5.5\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine((cst.Expr(cst.Imaginary("5j")),)),
"code": "5j\n",
"parser": parse_statement,
},
# Test some numbers with parens
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.Integer(
@ -108,11 +116,12 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
"(5)\n",
parse_statement,
),
(
cst.SimpleStatementLine(
"code": "(5)\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 3)),
},
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.Float(
@ -121,11 +130,11 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
"(5.5)\n",
parse_statement,
),
(
cst.SimpleStatementLine(
"code": "(5.5)\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.Imaginary(
@ -134,17 +143,17 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
"(5j)\n",
parse_statement,
),
"code": "(5j)\n",
"parser": parse_statement,
},
# Test some strings
(
cst.SimpleStatementLine((cst.Expr(cst.SimpleString('"abc"')),)),
'"abc"\n',
parse_statement,
),
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine((cst.Expr(cst.SimpleString('"abc"')),)),
"code": '"abc"\n',
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.ConcatenatedString(
@ -153,11 +162,11 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
'"abc""def"\n',
parse_statement,
),
(
cst.SimpleStatementLine(
"code": '"abc""def"\n',
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.ConcatenatedString(
@ -172,12 +181,13 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
'"abc" "def" "ghi"\n',
parse_statement,
),
"code": '"abc" "def" "ghi"\n',
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 17)),
},
# Test parenthesis rules
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.Ellipses(
@ -186,12 +196,12 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
"(...)\n",
parse_statement,
),
"code": "(...)\n",
"parser": parse_statement,
},
# Test parenthesis with whitespace ownership
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.Ellipses(
@ -209,11 +219,11 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
"( ... )\n",
parse_statement,
),
(
cst.SimpleStatementLine(
"code": "( ... )\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.Ellipses(
@ -243,12 +253,13 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
"( ( ( ... ) ) )\n",
parse_statement,
),
"code": "( ( ( ... ) ) )\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 21)),
},
# Test parenthesis rules with expressions
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(
cst.Expr(
cst.Ellipses(
@ -282,33 +293,36 @@ class SimpleStatementTest(CSTNodeTest):
),
)
),
"(\n# Wow, a comment!\n ...\n)\n",
parse_statement,
),
"code": "(\n# Wow, a comment!\n ...\n)\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (4, 1)),
},
# test trailing whitespace
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(cst.Pass(),),
trailing_whitespace=cst.TrailingWhitespace(
whitespace=cst.SimpleWhitespace(" "),
comment=cst.Comment("# trailing comment"),
),
),
"pass # trailing comment\n",
parse_statement,
),
"code": "pass # trailing comment\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (1, 4)),
},
# test leading comment
(
cst.SimpleStatementLine(
{
"node": cst.SimpleStatementLine(
(cst.Pass(),),
leading_lines=(cst.EmptyLine(comment=cst.Comment("# comment")),),
),
"# comment\npass\n",
parse_statement,
),
"code": "# comment\npass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((2, 0), (2, 4)),
},
# test indentation
(
DummyIndentedBlock(
{
"node": DummyIndentedBlock(
" ",
cst.SimpleStatementLine(
(cst.Pass(),),
@ -317,25 +331,23 @@ class SimpleStatementTest(CSTNodeTest):
),
),
),
" # comment\n pass\n",
None,
),
"code": " # comment\n pass\n",
"expected_position": CodeRange.create((2, 4), (2, 8)),
},
# test suite variant
(cst.SimpleStatementSuite((cst.Pass(),)), " pass\n", None),
(
cst.SimpleStatementSuite(
{
"node": cst.SimpleStatementSuite((cst.Pass(),)),
"code": " pass\n",
"expected_position": CodeRange.create((1, 1), (1, 5)),
},
{
"node": cst.SimpleStatementSuite(
(cst.Pass(),), leading_whitespace=cst.SimpleWhitespace("")
),
"pass\n",
None,
),
"code": "pass\n",
"expected_position": CodeRange.create((1, 0), (1, 4)),
},
)
)
def test_valid(
self,
node: cst.CSTNode,
code: str,
parser: Optional[Callable[[str], cst.CSTNode]],
position: Optional[CodeRange] = None,
) -> None:
self.validate_node(node, code, parser, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Optional
from typing import Any
import libcst as cst
from libcst._nodes._internal import CodeRange
@ -15,63 +15,68 @@ from libcst.testing.utils import data_provider
class SmallStatementTest(CSTNodeTest):
@data_provider(
(
(cst.Pass(), "pass"),
(cst.Pass(semicolon=cst.Semicolon()), "pass;"),
(
cst.Pass(
# pyre-fixme[6]: Incompatible parameter type
{"node": cst.Pass(), "code": "pass"},
{"node": cst.Pass(semicolon=cst.Semicolon()), "code": "pass;"},
{
"node": cst.Pass(
semicolon=cst.Semicolon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
)
),
"pass ; ",
),
(cst.Continue(), "continue"),
(cst.Continue(semicolon=cst.Semicolon()), "continue;"),
(
cst.Continue(
"code": "pass ; ",
"expected_position": CodeRange.create((1, 0), (1, 4)),
},
{"node": cst.Continue(), "code": "continue"},
{"node": cst.Continue(semicolon=cst.Semicolon()), "code": "continue;"},
{
"node": cst.Continue(
semicolon=cst.Semicolon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
)
),
"continue ; ",
),
(cst.Break(), "break"),
(cst.Break(semicolon=cst.Semicolon()), "break;"),
(
cst.Break(
"code": "continue ; ",
"expected_position": CodeRange.create((1, 0), (1, 8)),
},
{"node": cst.Break(), "code": "break"},
{"node": cst.Break(semicolon=cst.Semicolon()), "code": "break;"},
{
"node": cst.Break(
semicolon=cst.Semicolon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
)
),
"break ; ",
),
(
cst.Expr(cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y"))),
"x + y",
),
(
cst.Expr(
"code": "break ; ",
"expected_position": CodeRange.create((1, 0), (1, 5)),
},
{
"node": cst.Expr(
cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y"))
),
"code": "x + y",
},
{
"node": cst.Expr(
cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")),
semicolon=cst.Semicolon(),
),
"x + y;",
),
(
cst.Expr(
"code": "x + y;",
},
{
"node": cst.Expr(
cst.BinaryOperation(cst.Name("x"), cst.Add(), cst.Name("y")),
semicolon=cst.Semicolon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
),
),
"x + y ; ",
),
"code": "x + y ; ",
"expected_position": CodeRange.create((1, 0), (1, 5)),
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
self.validate_node(node, code, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)