Implement an ensure_type helper

We already have several places in unit tests that need this, as well
as a spot in an existing codemod that could benefit from it. So,
implement ensure_type which can be used to refine the type of a node to
an exact node type. I purposefully left the node type as 'object'
instead of libcst.CSTNode so that it could be used with RemovalSentinel
or MaybeSentinel without needing to import these.
This commit is contained in:
Jennifer Taylor 2019-07-23 12:58:12 -07:00
parent ffd9e4dd21
commit 60763fc5dc
7 changed files with 52 additions and 20 deletions

View file

@ -4,6 +4,8 @@
# LICENSE file in the root directory of this source tree.
# pyre-strict
from typing import Type, TypeVar
import libcst.nodes as libcst
@ -14,3 +16,14 @@ def get_fully_qualified_name(node: libcst.BaseExpression) -> str:
return get_fully_qualified_name(node.value) + "." + node.attr.value
else:
raise Exception(f"Invalid node type {type(node)}!")
_CSTNodeT = TypeVar("_CSTNodeT", bound=libcst.CSTNode)
def ensure_type(node: object, nodetype: Type[_CSTNodeT]) -> _CSTNodeT:
if not isinstance(node, nodetype):
raise Exception(
f"Expected a {nodetype.__name__} bot got a {node.__class__.__name__}!"
)
return node

View file

@ -8,6 +8,7 @@
from typing import Any
import libcst.nodes as cst
from libcst.helpers import ensure_type
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_statement
from libcst.testing.utils import data_provider
@ -87,6 +88,12 @@ class AssertConstructionTest(CSTNodeTest):
self.assert_invalid(**kwargs)
def _assert_parser(code: str) -> cst.Assert:
return ensure_type(
ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0], cst.Assert
)
class AssertParsingTest(CSTNodeTest):
@data_provider(
(
@ -94,8 +101,7 @@ class AssertParsingTest(CSTNodeTest):
{
"node": cst.Assert(cst.Name("True")),
"code": "assert True",
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
"parser": (lambda code: parse_statement(code).body[0]),
"parser": _assert_parser,
"expected_position": None,
},
# Assert with message
@ -106,8 +112,7 @@ class AssertParsingTest(CSTNodeTest):
comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
"code": 'assert True, "Value should be true"',
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
"parser": (lambda code: parse_statement(code).body[0]),
"parser": _assert_parser,
"expected_position": None,
},
# Whitespace oddities test
@ -117,8 +122,7 @@ class AssertParsingTest(CSTNodeTest):
whitespace_after_assert=cst.SimpleWhitespace(""),
),
"code": "assert(True)",
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
"parser": (lambda code: parse_statement(code).body[0]),
"parser": _assert_parser,
"expected_position": None,
},
# Whitespace rendering test
@ -133,8 +137,7 @@ class AssertParsingTest(CSTNodeTest):
msg=cst.SimpleString('"Value should be true"'),
),
"code": 'assert True , "Value should be true"',
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
"parser": (lambda code: parse_statement(code).body[0]),
"parser": _assert_parser,
"expected_position": None,
},
)

View file

@ -7,6 +7,7 @@
from typing import Callable, Optional
import libcst.nodes as cst
from libcst.helpers import ensure_type
from libcst.nodes._internal import CodeRange
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_statement
@ -131,7 +132,8 @@ class GlobalParsingTest(CSTNodeTest):
self.validate_node(
node,
code,
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
lambda code: parse_statement(code).body[0],
lambda code: ensure_type(
parse_statement(code), cst.SimpleStatementLine
).body[0],
expected_position=position,
)

View file

@ -7,6 +7,7 @@
from typing import Callable, Optional
import libcst.nodes as cst
from libcst.helpers import ensure_type
from libcst.nodes._internal import CodeRange
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_statement
@ -334,8 +335,9 @@ class ImportParseTest(CSTNodeTest):
self.validate_node(
node,
code,
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
lambda code: parse_statement(code).body[0],
lambda code: ensure_type(
parse_statement(code), cst.SimpleStatementLine
).body[0],
expected_position=position,
)
@ -701,7 +703,8 @@ class ImportFromParseTest(CSTNodeTest):
self.validate_node(
node,
code,
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
lambda code: parse_statement(code).body[0],
lambda code: ensure_type(
parse_statement(code), cst.SimpleStatementLine
).body[0],
expected_position=position,
)

View file

@ -7,6 +7,7 @@
from typing import Callable, Optional
import libcst.nodes as cst
from libcst.helpers import ensure_type
from libcst.nodes._internal import CodeRange
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_statement
@ -133,7 +134,8 @@ class NonlocalParsingTest(CSTNodeTest):
self.validate_node(
node,
code,
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
lambda code: parse_statement(code).body[0],
lambda code: ensure_type(
parse_statement(code), cst.SimpleStatementLine
).body[0],
expected_position=position,
)

View file

@ -7,6 +7,7 @@
from typing import Callable, Optional
import libcst.nodes as cst
from libcst.helpers import ensure_type
from libcst.nodes._internal import CodeRange
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_statement
@ -195,7 +196,8 @@ class RaiseParsingTest(CSTNodeTest):
self.validate_node(
node,
code,
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
lambda code: parse_statement(code).body[0],
lambda code: ensure_type(
parse_statement(code), cst.SimpleStatementLine
).body[0],
expected_position=position,
)

View file

@ -7,6 +7,7 @@
from typing import Callable, Optional
import libcst.nodes as cst
from libcst.helpers import ensure_type
from libcst.nodes._internal import CodeRange
from libcst.nodes.tests.base import CSTNodeTest
from libcst.parser import parse_statement
@ -216,5 +217,11 @@ class YieldParsingTest(CSTNodeTest):
def test_valid(
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
) -> None:
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
self.validate_node(node, code, lambda code: parse_statement(code).body[0].value)
self.validate_node(
node,
code,
lambda code: ensure_type(
ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
)