Calculate syntactic position for statement nodes (2)

Calculates positions for try, except, finally and import statements
This commit is contained in:
Ray Zeng 2019-07-25 20:40:41 -07:00 committed by Benjamin Woodruff
parent 89fb7fe524
commit 093acc994b
3 changed files with 406 additions and 380 deletions

View file

@ -696,17 +696,19 @@ class ExceptHandler(CSTNode):
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
state.add_token("except")
self.whitespace_after_except._codegen(state)
typenode = self.type
if typenode is not None:
typenode._codegen(state)
namenode = self.name
if namenode is not None:
namenode._codegen(state)
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
with state.record_syntactic_position(self, end_node=self.body):
state.add_token("except")
self.whitespace_after_except._codegen(state)
typenode = self.type
if typenode is not None:
typenode._codegen(state)
namenode = self.name
if namenode is not None:
namenode._codegen(state)
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
@add_slots
@ -733,10 +735,12 @@ class Finally(CSTNode):
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
state.add_token("finally")
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
with state.record_syntactic_position(self, end_node=self.body):
state.add_token("finally")
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
@add_slots
@ -789,18 +793,25 @@ class Try(BaseCompoundStatement):
for ll in self.leading_lines:
ll._codegen(state)
state.add_indent_tokens()
state.add_token("try")
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
for handler in self.handlers:
handler._codegen(state)
end_node = self.body
if len(self.handlers) > 0:
end_node = self.handlers[-1]
orelse = self.orelse
if orelse is not None:
orelse._codegen(state)
end_node = end_node if orelse is None else orelse
finalbody = self.finalbody
if finalbody is not None:
finalbody._codegen(state)
end_node = end_node if finalbody is None else finalbody
with state.record_syntactic_position(self, end_node=end_node):
state.add_token("try")
self.whitespace_before_colon._codegen(state)
state.add_token(":")
self.body._codegen(state)
for handler in self.handlers:
handler._codegen(state)
if orelse is not None:
orelse._codegen(state)
if finalbody is not None:
finalbody._codegen(state)
@dataclass(frozen=True)
@ -834,10 +845,12 @@ class ImportAlias(CSTNode):
)
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
self.name._codegen(state)
asname = self.asname
if asname is not None:
asname._codegen(state)
with state.record_syntactic_position(self):
self.name._codegen(state)
asname = self.asname
if asname is not None:
asname._codegen(state)
comma = self.comma
if comma is MaybeSentinel.DEFAULT and default_comma:
state.add_token(", ")
@ -884,11 +897,13 @@ class Import(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("import")
self.whitespace_after_import._codegen(state)
lastname = len(self.names) - 1
for i, name in enumerate(self.names):
name._codegen(state, default_comma=(i != lastname))
with state.record_syntactic_position(self):
state.add_token("import")
self.whitespace_after_import._codegen(state)
lastname = len(self.names) - 1
for i, name in enumerate(self.names):
name._codegen(state, default_comma=(i != lastname))
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
if default_semicolon:
@ -991,28 +1006,32 @@ class ImportFrom(BaseSmallStatement):
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
state.add_token("from")
self.whitespace_after_from._codegen(state)
for dot in self.relative:
dot._codegen(state)
module = self.module
if module is not None:
module._codegen(state)
self.whitespace_before_import._codegen(state)
state.add_token("import")
self.whitespace_after_import._codegen(state)
lpar = self.lpar
if lpar is not None:
lpar._codegen(state)
if isinstance(self.names, Sequence):
lastname = len(self.names) - 1
for i, name in enumerate(self.names):
name._codegen(state, default_comma=(i != lastname))
if isinstance(self.names, ImportStar):
self.names._codegen(state)
rpar = self.rpar
if rpar is not None:
rpar._codegen(state)
end_node = self.names[-1] if isinstance(self.names, Sequence) else self.names
end_node = end_node if self.rpar is None else self.rpar
with state.record_syntactic_position(self, end_node=end_node):
state.add_token("from")
self.whitespace_after_from._codegen(state)
for dot in self.relative:
dot._codegen(state)
module = self.module
if module is not None:
module._codegen(state)
self.whitespace_before_import._codegen(state)
state.add_token("import")
self.whitespace_after_import._codegen(state)
lpar = self.lpar
if lpar is not None:
lpar._codegen(state)
if isinstance(self.names, Sequence):
lastname = len(self.names) - 1
for i, name in enumerate(self.names):
name._codegen(state, default_comma=(i != lastname))
if isinstance(self.names, ImportStar):
self.names._codegen(state)
rpar = self.rpar
if rpar is not None:
rpar._codegen(state)
semicolon = self.semicolon
if isinstance(semicolon, MaybeSentinel):
if default_semicolon:

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable, Optional
from typing import Any
import libcst as cst
from libcst import parse_statement
@ -18,30 +18,34 @@ class ImportCreateTest(CSTNodeTest):
@data_provider(
(
# Simple import statement
(cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "import foo"),
(
cst.Import(
# pyre-fixme[6]: Incompatible parameter type
{
"node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
"code": "import foo",
},
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar"))
),
)
),
"import foo.bar",
),
(
cst.Import(
"code": "import foo.bar",
},
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar"))
),
)
),
"import foo.bar",
),
"code": "import foo.bar",
},
# Comma-separated list of imports
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar"))
@ -51,11 +55,12 @@ class ImportCreateTest(CSTNodeTest):
),
)
),
"import foo.bar, foo.baz",
),
"code": "import foo.bar, foo.baz",
"expected_position": CodeRange.create((1, 0), (1, 23)),
},
# Import with an alias
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -63,11 +68,11 @@ class ImportCreateTest(CSTNodeTest):
),
)
),
"import foo.bar as baz",
),
"code": "import foo.bar as baz",
},
# Import with an alias, comma separated
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -79,11 +84,11 @@ class ImportCreateTest(CSTNodeTest):
),
)
),
"import foo.bar as baz, foo.baz as bar",
),
"code": "import foo.bar as baz, foo.baz as bar",
},
# Combine for fun and profit
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -100,11 +105,11 @@ class ImportCreateTest(CSTNodeTest):
),
)
),
"import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
),
"code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
},
# Verify whitespace works everywhere.
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(
@ -136,40 +141,42 @@ class ImportCreateTest(CSTNodeTest):
),
whitespace_after_import=cst.SimpleWhitespace(" "),
),
"import foo . bar as baz , unittest as ut",
),
"code": "import foo . bar as baz , unittest as ut",
"expected_position": CodeRange.create((1, 0), (1, 46)),
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
self.validate_node(node, code, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
(lambda: cst.Import(names=()), "at least one ImportAlias"),
(
lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)),
"empty name identifier",
),
(
lambda: cst.Import(
{
"get_node": lambda: cst.Import(names=()),
"expected_re": "at least one ImportAlias",
},
{
"get_node": lambda: cst.Import(names=(cst.ImportAlias(cst.Name("")),)),
"expected_re": "empty name identifier",
},
{
"get_node": lambda: cst.Import(
names=(
cst.ImportAlias(cst.Attribute(cst.Name(""), cst.Name("bla"))),
)
),
"empty name identifier",
),
(
lambda: cst.Import(
"expected_re": "empty name identifier",
},
{
"get_node": lambda: cst.Import(
names=(
cst.ImportAlias(cst.Attribute(cst.Name("bla"), cst.Name(""))),
)
),
"empty name identifier",
),
(
lambda: cst.Import(
"expected_re": "empty name identifier",
},
{
"get_node": lambda: cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -177,10 +184,10 @@ class ImportCreateTest(CSTNodeTest):
),
)
),
"trailing comma",
),
(
lambda: cst.Import(
"expected_re": "trailing comma",
},
{
"get_node": lambda: cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar"))
@ -188,44 +195,45 @@ class ImportCreateTest(CSTNodeTest):
),
whitespace_after_import=cst.SimpleWhitespace(""),
),
"at least one space",
),
"expected_re": "at least one space",
},
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)
class ImportParseTest(CSTNodeTest):
@data_provider(
(
# Simple import statement
(cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)), "import foo"),
(
cst.Import(
{
"node": cst.Import(names=(cst.ImportAlias(cst.Name("foo")),)),
"code": "import foo",
},
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar"))
),
)
),
"import foo.bar",
),
(
cst.Import(
"code": "import foo.bar",
},
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar"))
),
)
),
"import foo.bar",
),
"code": "import foo.bar",
},
# Comma-separated list of imports
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -236,11 +244,11 @@ class ImportParseTest(CSTNodeTest):
),
)
),
"import foo.bar, foo.baz",
),
"code": "import foo.bar, foo.baz",
},
# Import with an alias
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -248,11 +256,11 @@ class ImportParseTest(CSTNodeTest):
),
)
),
"import foo.bar as baz",
),
"code": "import foo.bar as baz",
},
# Import with an alias, comma separated
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -265,11 +273,11 @@ class ImportParseTest(CSTNodeTest):
),
)
),
"import foo.bar as baz, foo.baz as bar",
),
"code": "import foo.bar as baz, foo.baz as bar",
},
# Combine for fun and profit
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(cst.Name("foo"), cst.Name("bar")),
@ -289,11 +297,11 @@ class ImportParseTest(CSTNodeTest):
),
)
),
"import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
),
"code": "import foo.bar as baz, insta.gram, foo.baz, unittest as ut",
},
# Verify whitespace works everywhere.
(
cst.Import(
{
"node": cst.Import(
names=(
cst.ImportAlias(
cst.Attribute(
@ -325,20 +333,16 @@ class ImportParseTest(CSTNodeTest):
),
whitespace_after_import=cst.SimpleWhitespace(" "),
),
"import foo . bar as baz , unittest as ut",
),
"code": "import foo . bar as baz , unittest as ut",
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(
node,
code,
lambda code: ensure_type(
parser=lambda code: ensure_type(
parse_statement(code), cst.SimpleStatementLine
).body[0],
expected_position=position,
**kwargs,
)
@ -346,15 +350,16 @@ class ImportFromCreateTest(CSTNodeTest):
@data_provider(
(
# Simple from import statement
(
cst.ImportFrom(
# pyre-fixme[6]: Incompatible parameter type
{
"node": cst.ImportFrom(
module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),)
),
"from foo import bar",
),
"code": "from foo import bar",
},
# From import statement with alias
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(
@ -362,64 +367,66 @@ class ImportFromCreateTest(CSTNodeTest):
),
),
),
"from foo import bar as baz",
),
"code": "from foo import bar as baz",
},
# Multiple imports
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(cst.Name("bar")),
cst.ImportAlias(cst.Name("baz")),
),
),
"from foo import bar, baz",
),
"code": "from foo import bar, baz",
},
# Trailing comma
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(cst.Name("bar"), comma=cst.Comma()),
cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
),
),
"from foo import bar,baz,",
),
"code": "from foo import bar,baz,",
"expected_position": CodeRange.create((1, 0), (1, 23)),
},
# Star import statement
(
cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
"from foo import *",
),
{
"node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
"code": "from foo import *",
"expected_position": CodeRange.create((1, 0), (1, 17)),
},
# Simple relative import statement
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
relative=(cst.Dot(),),
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
),
"from .foo import bar",
),
(
cst.ImportFrom(
"code": "from .foo import bar",
},
{
"node": cst.ImportFrom(
relative=(cst.Dot(), cst.Dot()),
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
),
"from ..foo import bar",
),
"code": "from ..foo import bar",
},
# Relative only import
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
relative=(cst.Dot(), cst.Dot()),
module=None,
names=(cst.ImportAlias(cst.Name("bar")),),
),
"from .. import bar",
),
"code": "from .. import bar",
},
# Parenthesis
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
lpar=cst.LeftParen(),
names=(
@ -429,11 +436,12 @@ class ImportFromCreateTest(CSTNodeTest):
),
rpar=cst.RightParen(),
),
"from foo import (bar as baz)",
),
"code": "from foo import (bar as baz)",
"expected_position": CodeRange.create((1, 0), (1, 28)),
},
# Verify whitespace works everywhere.
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
relative=(
cst.Dot(
whitespace_before=cst.SimpleWhitespace(" "),
@ -473,102 +481,99 @@ class ImportFromCreateTest(CSTNodeTest):
whitespace_before_import=cst.SimpleWhitespace(" "),
whitespace_after_import=cst.SimpleWhitespace(" "),
),
"from . . foo import ( bar as baz , unittest as ut )",
),
"code": "from . . foo import ( bar as baz , unittest as ut )",
"expected_position": CodeRange.create((1, 0), (1, 61)),
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
self.validate_node(node, code, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
(
lambda: cst.ImportFrom(
{
"get_node": lambda: cst.ImportFrom(
module=None, names=(cst.ImportAlias(cst.Name("bar")),)
),
"Must have a module specified",
),
(
lambda: cst.ImportFrom(module=cst.Name("foo"), names=()),
"at least one ImportAlias",
),
(
lambda: cst.ImportFrom(
"expected_re": "Must have a module specified",
},
{
"get_node": lambda: cst.ImportFrom(module=cst.Name("foo"), names=()),
"expected_re": "at least one ImportAlias",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
lpar=cst.LeftParen(),
),
"left paren without right paren",
),
(
lambda: cst.ImportFrom(
"expected_re": "left paren without right paren",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
rpar=cst.RightParen(),
),
"right paren without left paren",
),
(
lambda: cst.ImportFrom(
"expected_re": "right paren without left paren",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"), names=cst.ImportStar(), lpar=cst.LeftParen()
),
"cannot have parens",
),
(
lambda: cst.ImportFrom(
"expected_re": "cannot have parens",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"),
names=cst.ImportStar(),
rpar=cst.RightParen(),
),
"cannot have parens",
),
(
lambda: cst.ImportFrom(
"expected_re": "cannot have parens",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
whitespace_after_from=cst.SimpleWhitespace(""),
),
"one space after from",
),
(
lambda: cst.ImportFrom(
"expected_re": "one space after from",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
whitespace_before_import=cst.SimpleWhitespace(""),
),
"one space before import",
),
(
lambda: cst.ImportFrom(
"expected_re": "one space before import",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
whitespace_after_import=cst.SimpleWhitespace(""),
),
"one space after import",
),
"expected_re": "one space after import",
},
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)
class ImportFromParseTest(CSTNodeTest):
@data_provider(
(
# Simple from import statement
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"), names=(cst.ImportAlias(cst.Name("bar")),)
),
"from foo import bar",
),
"code": "from foo import bar",
},
# From import statement with alias
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(
@ -576,11 +581,11 @@ class ImportFromParseTest(CSTNodeTest):
),
),
),
"from foo import bar as baz",
),
"code": "from foo import bar as baz",
},
# Multiple imports
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(
@ -590,11 +595,11 @@ class ImportFromParseTest(CSTNodeTest):
cst.ImportAlias(cst.Name("baz")),
),
),
"from foo import bar, baz",
),
"code": "from foo import bar, baz",
},
# Trailing comma
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(
@ -604,42 +609,42 @@ class ImportFromParseTest(CSTNodeTest):
cst.ImportAlias(cst.Name("baz"), comma=cst.Comma()),
),
),
"from foo import bar, baz,",
),
"code": "from foo import bar, baz,",
},
# Star import statement
(
cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
"from foo import *",
),
{
"node": cst.ImportFrom(module=cst.Name("foo"), names=cst.ImportStar()),
"code": "from foo import *",
},
# Simple relative import statement
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
relative=(cst.Dot(),),
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
),
"from .foo import bar",
),
(
cst.ImportFrom(
"code": "from .foo import bar",
},
{
"node": cst.ImportFrom(
relative=(cst.Dot(), cst.Dot()),
module=cst.Name("foo"),
names=(cst.ImportAlias(cst.Name("bar")),),
),
"from ..foo import bar",
),
"code": "from ..foo import bar",
},
# Relative only import
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
relative=(cst.Dot(), cst.Dot()),
module=None,
names=(cst.ImportAlias(cst.Name("bar")),),
),
"from .. import bar",
),
"code": "from .. import bar",
},
# Parenthesis
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
module=cst.Name("foo"),
lpar=cst.LeftParen(),
names=(
@ -649,11 +654,11 @@ class ImportFromParseTest(CSTNodeTest):
),
rpar=cst.RightParen(),
),
"from foo import (bar as baz)",
),
"code": "from foo import (bar as baz)",
},
# Verify whitespace works everywhere.
(
cst.ImportFrom(
{
"node": cst.ImportFrom(
relative=(
cst.Dot(
whitespace_before=cst.SimpleWhitespace(""),
@ -693,18 +698,14 @@ class ImportFromParseTest(CSTNodeTest):
whitespace_before_import=cst.SimpleWhitespace(" "),
whitespace_after_import=cst.SimpleWhitespace(" "),
),
"from . . foo import ( bar as baz , unittest as ut )",
),
"code": "from . . foo import ( bar as baz , unittest as ut )",
},
)
)
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(
node,
code,
lambda code: ensure_type(
parser=lambda code: ensure_type(
parse_statement(code), cst.SimpleStatementLine
).body[0],
expected_position=position,
**kwargs,
)

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Callable, Optional
from typing import Any
import libcst as cst
from libcst import parse_statement
@ -17,8 +17,9 @@ class TryTest(CSTNodeTest):
@data_provider(
(
# Simple try/except block
(
cst.Try(
# pyre-fixme[6]: Incompatible parameter type
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -27,12 +28,13 @@ class TryTest(CSTNodeTest):
),
),
),
"try: pass\nexcept: pass\n",
parse_statement,
),
"code": "try: pass\nexcept: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (2, 12)),
},
# Try/except with a class
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -41,12 +43,12 @@ class TryTest(CSTNodeTest):
),
),
),
"try: pass\nexcept Exception: pass\n",
parse_statement,
),
"code": "try: pass\nexcept Exception: pass\n",
"parser": parse_statement,
},
# Try/except with a named class
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -56,12 +58,13 @@ class TryTest(CSTNodeTest):
),
),
),
"try: pass\nexcept Exception as exc: pass\n",
parse_statement,
),
"code": "try: pass\nexcept Exception as exc: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (2, 29)),
},
# Try/except with multiple clauses
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -80,24 +83,26 @@ class TryTest(CSTNodeTest):
),
),
),
"try: pass\n"
"code": "try: pass\n"
+ "except TypeError as e: pass\n"
+ "except KeyError as e: pass\n"
+ "except: pass\n",
parse_statement,
),
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (4, 12)),
},
# Simple try/finally block
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
),
"try: pass\nfinally: pass\n",
parse_statement,
),
"code": "try: pass\nfinally: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (2, 13)),
},
# Simple try/except/finally block
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -107,12 +112,13 @@ class TryTest(CSTNodeTest):
),
finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
),
"try: pass\nexcept: pass\nfinally: pass\n",
parse_statement,
),
"code": "try: pass\nexcept: pass\nfinally: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (3, 13)),
},
# Simple try/except/else block
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -122,12 +128,13 @@ class TryTest(CSTNodeTest):
),
orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
),
"try: pass\nexcept: pass\nelse: pass\n",
parse_statement,
),
"code": "try: pass\nexcept: pass\nelse: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (3, 10)),
},
# Simple try/except/else block/finally
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -138,12 +145,13 @@ class TryTest(CSTNodeTest):
orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
),
"try: pass\nexcept: pass\nelse: pass\nfinally: pass\n",
parse_statement,
),
"code": "try: pass\nexcept: pass\nelse: pass\nfinally: pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (4, 13)),
},
# Verify whitespace in various locations
(
cst.Try(
{
"node": cst.Try(
leading_lines=(cst.EmptyLine(comment=cst.Comment("# 1")),),
body=cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
@ -172,12 +180,13 @@ class TryTest(CSTNodeTest):
),
whitespace_before_colon=cst.SimpleWhitespace(" "),
),
"# 1\ntry : pass\n# 2\nexcept TypeError as e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n",
parse_statement,
),
"code": "# 1\ntry : pass\n# 2\nexcept TypeError as e : pass\n# 3\nelse : pass\n# 4\nfinally : pass\n",
"parser": parse_statement,
"expected_position": CodeRange.create((2, 0), (8, 14)),
},
# Please don't write code like this
(
cst.Try(
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=(
cst.ExceptHandler(
@ -198,17 +207,18 @@ class TryTest(CSTNodeTest):
orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
),
"try: pass\n"
"code": "try: pass\n"
+ "except TypeError as e: pass\n"
+ "except KeyError as e: pass\n"
+ "except: pass\n"
+ "else: pass\n"
+ "finally: pass\n",
parse_statement,
),
"parser": parse_statement,
"expected_position": CodeRange.create((1, 0), (6, 13)),
},
# Verify indentation
(
DummyIndentedBlock(
{
"node": DummyIndentedBlock(
" ",
cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
@ -232,17 +242,17 @@ class TryTest(CSTNodeTest):
finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
),
),
" try: pass\n"
"code": " try: pass\n"
+ " except TypeError as e: pass\n"
+ " except KeyError as e: pass\n"
+ " except: pass\n"
+ " else: pass\n"
+ " finally: pass\n",
None,
),
"parser": None,
},
# Verify indentation in bodies
(
DummyIndentedBlock(
{
"node": DummyIndentedBlock(
" ",
cst.Try(
cst.IndentedBlock((cst.SimpleStatementLine((cst.Pass(),)),)),
@ -262,7 +272,7 @@ class TryTest(CSTNodeTest):
),
),
),
" try:\n"
"code": " try:\n"
+ " pass\n"
+ " except:\n"
+ " pass\n"
@ -270,64 +280,60 @@ class TryTest(CSTNodeTest):
+ " pass\n"
+ " finally:\n"
+ " pass\n",
None,
),
"parser": None,
},
)
)
def test_valid(
self,
node: cst.CSTNode,
code: str,
parser: Optional[Callable[[str], cst.CSTNode]],
position: Optional[CodeRange] = None,
) -> None:
self.validate_node(node, code, parser, expected_position=position)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
(lambda: cst.AsName(cst.Name("")), "empty name identifier"),
(
lambda: cst.AsName(
# pyre-fixme[6]: Incompatible parameter type
{
"get_node": lambda: cst.AsName(cst.Name("")),
"expected_re": "empty name identifier",
},
{
"get_node": lambda: cst.AsName(
cst.Name("bla"), whitespace_after_as=cst.SimpleWhitespace("")
),
"between 'as'",
),
(
lambda: cst.AsName(
"expected_re": "between 'as'",
},
{
"get_node": lambda: cst.AsName(
cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("")
),
"before 'as'",
),
(
lambda: cst.ExceptHandler(
"expected_re": "before 'as'",
},
{
"get_node": lambda: cst.ExceptHandler(
cst.SimpleStatementSuite((cst.Pass(),)),
name=cst.AsName(cst.Name("bla")),
),
"name for an empty type",
),
(
lambda: cst.ExceptHandler(
"expected_re": "name for an empty type",
},
{
"get_node": lambda: cst.ExceptHandler(
cst.SimpleStatementSuite((cst.Pass(),)),
type=cst.Name("TypeError"),
whitespace_after_except=cst.SimpleWhitespace(""),
),
"at least one space after except",
),
(
lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))),
"at least one ExceptHandler or Finally",
),
(
lambda: cst.Try(
"expected_re": "at least one space after except",
},
{
"get_node": lambda: cst.Try(cst.SimpleStatementSuite((cst.Pass(),))),
"expected_re": "at least one ExceptHandler or Finally",
},
{
"get_node": lambda: cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
orelse=cst.Else(cst.SimpleStatementSuite((cst.Pass(),))),
finalbody=cst.Finally(cst.SimpleStatementSuite((cst.Pass(),))),
),
"at least one ExceptHandler in order to have an Else",
),
"expected_re": "at least one ExceptHandler in order to have an Else",
},
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)