From 60763fc5dce8cdc00e2b56e55d48fc9dad43ea5a Mon Sep 17 00:00:00 2001 From: Jennifer Taylor Date: Tue, 23 Jul 2019 12:58:12 -0700 Subject: [PATCH] Implement an ensure_type helper We already have several places in unit tests that need this, as well as a spot in an existing codemod that could benefit from it. So, implement ensure_type which can be used to refine the type of a node to an exact node type. I purposefully left the node type as 'object' instead of libcst.CSTNode so that it could be used with RemovalSentinel or MaybeSentinel without needing to import these. --- libcst/helpers.py | 13 +++++++++++++ libcst/nodes/tests/test_assert.py | 19 +++++++++++-------- libcst/nodes/tests/test_global.py | 6 ++++-- libcst/nodes/tests/test_import.py | 11 +++++++---- libcst/nodes/tests/test_nonlocal.py | 6 ++++-- libcst/nodes/tests/test_raise.py | 6 ++++-- libcst/nodes/tests/test_yield.py | 11 +++++++++-- 7 files changed, 52 insertions(+), 20 deletions(-) diff --git a/libcst/helpers.py b/libcst/helpers.py index 6470e892..2d53ced5 100644 --- a/libcst/helpers.py +++ b/libcst/helpers.py @@ -4,6 +4,8 @@ # LICENSE file in the root directory of this source tree. # pyre-strict +from typing import Type, TypeVar + import libcst.nodes as libcst @@ -14,3 +16,14 @@ def get_fully_qualified_name(node: libcst.BaseExpression) -> str: return get_fully_qualified_name(node.value) + "." + node.attr.value else: raise Exception(f"Invalid node type {type(node)}!") + + +_CSTNodeT = TypeVar("_CSTNodeT", bound=libcst.CSTNode) + + +def ensure_type(node: object, nodetype: Type[_CSTNodeT]) -> _CSTNodeT: + if not isinstance(node, nodetype): + raise Exception( + f"Expected a {nodetype.__name__} bot got a {node.__class__.__name__}!" + ) + return node diff --git a/libcst/nodes/tests/test_assert.py b/libcst/nodes/tests/test_assert.py index aba845e1..561e0326 100644 --- a/libcst/nodes/tests/test_assert.py +++ b/libcst/nodes/tests/test_assert.py @@ -8,6 +8,7 @@ from typing import Any import libcst.nodes as cst +from libcst.helpers import ensure_type from libcst.nodes.tests.base import CSTNodeTest from libcst.parser import parse_statement from libcst.testing.utils import data_provider @@ -87,6 +88,12 @@ class AssertConstructionTest(CSTNodeTest): self.assert_invalid(**kwargs) +def _assert_parser(code: str) -> cst.Assert: + return ensure_type( + ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0], cst.Assert + ) + + class AssertParsingTest(CSTNodeTest): @data_provider( ( @@ -94,8 +101,7 @@ class AssertParsingTest(CSTNodeTest): { "node": cst.Assert(cst.Name("True")), "code": "assert True", - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - "parser": (lambda code: parse_statement(code).body[0]), + "parser": _assert_parser, "expected_position": None, }, # Assert with message @@ -106,8 +112,7 @@ class AssertParsingTest(CSTNodeTest): comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")), ), "code": 'assert True, "Value should be true"', - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - "parser": (lambda code: parse_statement(code).body[0]), + "parser": _assert_parser, "expected_position": None, }, # Whitespace oddities test @@ -117,8 +122,7 @@ class AssertParsingTest(CSTNodeTest): whitespace_after_assert=cst.SimpleWhitespace(""), ), "code": "assert(True)", - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - "parser": (lambda code: parse_statement(code).body[0]), + "parser": _assert_parser, "expected_position": None, }, # Whitespace rendering test @@ -133,8 +137,7 @@ class AssertParsingTest(CSTNodeTest): msg=cst.SimpleString('"Value should be true"'), ), "code": 'assert True , "Value should be true"', - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - "parser": (lambda code: parse_statement(code).body[0]), + "parser": _assert_parser, "expected_position": None, }, ) diff --git a/libcst/nodes/tests/test_global.py b/libcst/nodes/tests/test_global.py index 63664db3..1e9cd438 100644 --- a/libcst/nodes/tests/test_global.py +++ b/libcst/nodes/tests/test_global.py @@ -7,6 +7,7 @@ from typing import Callable, Optional import libcst.nodes as cst +from libcst.helpers import ensure_type from libcst.nodes._internal import CodeRange from libcst.nodes.tests.base import CSTNodeTest from libcst.parser import parse_statement @@ -131,7 +132,8 @@ class GlobalParsingTest(CSTNodeTest): self.validate_node( node, code, - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - lambda code: parse_statement(code).body[0], + lambda code: ensure_type( + parse_statement(code), cst.SimpleStatementLine + ).body[0], expected_position=position, ) diff --git a/libcst/nodes/tests/test_import.py b/libcst/nodes/tests/test_import.py index 4ae89272..11c3f9b0 100644 --- a/libcst/nodes/tests/test_import.py +++ b/libcst/nodes/tests/test_import.py @@ -7,6 +7,7 @@ from typing import Callable, Optional import libcst.nodes as cst +from libcst.helpers import ensure_type from libcst.nodes._internal import CodeRange from libcst.nodes.tests.base import CSTNodeTest from libcst.parser import parse_statement @@ -334,8 +335,9 @@ class ImportParseTest(CSTNodeTest): self.validate_node( node, code, - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - lambda code: parse_statement(code).body[0], + lambda code: ensure_type( + parse_statement(code), cst.SimpleStatementLine + ).body[0], expected_position=position, ) @@ -701,7 +703,8 @@ class ImportFromParseTest(CSTNodeTest): self.validate_node( node, code, - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - lambda code: parse_statement(code).body[0], + lambda code: ensure_type( + parse_statement(code), cst.SimpleStatementLine + ).body[0], expected_position=position, ) diff --git a/libcst/nodes/tests/test_nonlocal.py b/libcst/nodes/tests/test_nonlocal.py index 42ea1fea..f6fcc176 100644 --- a/libcst/nodes/tests/test_nonlocal.py +++ b/libcst/nodes/tests/test_nonlocal.py @@ -7,6 +7,7 @@ from typing import Callable, Optional import libcst.nodes as cst +from libcst.helpers import ensure_type from libcst.nodes._internal import CodeRange from libcst.nodes.tests.base import CSTNodeTest from libcst.parser import parse_statement @@ -133,7 +134,8 @@ class NonlocalParsingTest(CSTNodeTest): self.validate_node( node, code, - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - lambda code: parse_statement(code).body[0], + lambda code: ensure_type( + parse_statement(code), cst.SimpleStatementLine + ).body[0], expected_position=position, ) diff --git a/libcst/nodes/tests/test_raise.py b/libcst/nodes/tests/test_raise.py index 8a6ede16..12fbd99b 100644 --- a/libcst/nodes/tests/test_raise.py +++ b/libcst/nodes/tests/test_raise.py @@ -7,6 +7,7 @@ from typing import Callable, Optional import libcst.nodes as cst +from libcst.helpers import ensure_type from libcst.nodes._internal import CodeRange from libcst.nodes.tests.base import CSTNodeTest from libcst.parser import parse_statement @@ -195,7 +196,8 @@ class RaiseParsingTest(CSTNodeTest): self.validate_node( node, code, - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - lambda code: parse_statement(code).body[0], + lambda code: ensure_type( + parse_statement(code), cst.SimpleStatementLine + ).body[0], expected_position=position, ) diff --git a/libcst/nodes/tests/test_yield.py b/libcst/nodes/tests/test_yield.py index 1fc3d5d3..63c45ea8 100644 --- a/libcst/nodes/tests/test_yield.py +++ b/libcst/nodes/tests/test_yield.py @@ -7,6 +7,7 @@ from typing import Callable, Optional import libcst.nodes as cst +from libcst.helpers import ensure_type from libcst.nodes._internal import CodeRange from libcst.nodes.tests.base import CSTNodeTest from libcst.parser import parse_statement @@ -216,5 +217,11 @@ class YieldParsingTest(CSTNodeTest): def test_valid( self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None ) -> None: - # pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`. - self.validate_node(node, code, lambda code: parse_statement(code).body[0].value) + self.validate_node( + node, + code, + lambda code: ensure_type( + ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0], + cst.Expr, + ).value, + )