mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Implement an ensure_type helper
We already have several places in unit tests that need this, as well as a spot in an existing codemod that could benefit from it. So, implement ensure_type which can be used to refine the type of a node to an exact node type. I purposefully left the node type as 'object' instead of libcst.CSTNode so that it could be used with RemovalSentinel or MaybeSentinel without needing to import these.
This commit is contained in:
parent
ffd9e4dd21
commit
60763fc5dc
7 changed files with 52 additions and 20 deletions
|
|
@ -4,6 +4,8 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# pyre-strict
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import libcst.nodes as libcst
|
||||
|
||||
|
||||
|
|
@ -14,3 +16,14 @@ def get_fully_qualified_name(node: libcst.BaseExpression) -> str:
|
|||
return get_fully_qualified_name(node.value) + "." + node.attr.value
|
||||
else:
|
||||
raise Exception(f"Invalid node type {type(node)}!")
|
||||
|
||||
|
||||
_CSTNodeT = TypeVar("_CSTNodeT", bound=libcst.CSTNode)
|
||||
|
||||
|
||||
def ensure_type(node: object, nodetype: Type[_CSTNodeT]) -> _CSTNodeT:
|
||||
if not isinstance(node, nodetype):
|
||||
raise Exception(
|
||||
f"Expected a {nodetype.__name__} bot got a {node.__class__.__name__}!"
|
||||
)
|
||||
return node
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@
|
|||
from typing import Any
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.helpers import ensure_type
|
||||
from libcst.nodes.tests.base import CSTNodeTest
|
||||
from libcst.parser import parse_statement
|
||||
from libcst.testing.utils import data_provider
|
||||
|
|
@ -87,6 +88,12 @@ class AssertConstructionTest(CSTNodeTest):
|
|||
self.assert_invalid(**kwargs)
|
||||
|
||||
|
||||
def _assert_parser(code: str) -> cst.Assert:
|
||||
return ensure_type(
|
||||
ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0], cst.Assert
|
||||
)
|
||||
|
||||
|
||||
class AssertParsingTest(CSTNodeTest):
|
||||
@data_provider(
|
||||
(
|
||||
|
|
@ -94,8 +101,7 @@ class AssertParsingTest(CSTNodeTest):
|
|||
{
|
||||
"node": cst.Assert(cst.Name("True")),
|
||||
"code": "assert True",
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
"parser": (lambda code: parse_statement(code).body[0]),
|
||||
"parser": _assert_parser,
|
||||
"expected_position": None,
|
||||
},
|
||||
# Assert with message
|
||||
|
|
@ -106,8 +112,7 @@ class AssertParsingTest(CSTNodeTest):
|
|||
comma=cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
|
||||
),
|
||||
"code": 'assert True, "Value should be true"',
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
"parser": (lambda code: parse_statement(code).body[0]),
|
||||
"parser": _assert_parser,
|
||||
"expected_position": None,
|
||||
},
|
||||
# Whitespace oddities test
|
||||
|
|
@ -117,8 +122,7 @@ class AssertParsingTest(CSTNodeTest):
|
|||
whitespace_after_assert=cst.SimpleWhitespace(""),
|
||||
),
|
||||
"code": "assert(True)",
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
"parser": (lambda code: parse_statement(code).body[0]),
|
||||
"parser": _assert_parser,
|
||||
"expected_position": None,
|
||||
},
|
||||
# Whitespace rendering test
|
||||
|
|
@ -133,8 +137,7 @@ class AssertParsingTest(CSTNodeTest):
|
|||
msg=cst.SimpleString('"Value should be true"'),
|
||||
),
|
||||
"code": 'assert True , "Value should be true"',
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
"parser": (lambda code: parse_statement(code).body[0]),
|
||||
"parser": _assert_parser,
|
||||
"expected_position": None,
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Callable, Optional
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.helpers import ensure_type
|
||||
from libcst.nodes._internal import CodeRange
|
||||
from libcst.nodes.tests.base import CSTNodeTest
|
||||
from libcst.parser import parse_statement
|
||||
|
|
@ -131,7 +132,8 @@ class GlobalParsingTest(CSTNodeTest):
|
|||
self.validate_node(
|
||||
node,
|
||||
code,
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
lambda code: parse_statement(code).body[0],
|
||||
lambda code: ensure_type(
|
||||
parse_statement(code), cst.SimpleStatementLine
|
||||
).body[0],
|
||||
expected_position=position,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Callable, Optional
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.helpers import ensure_type
|
||||
from libcst.nodes._internal import CodeRange
|
||||
from libcst.nodes.tests.base import CSTNodeTest
|
||||
from libcst.parser import parse_statement
|
||||
|
|
@ -334,8 +335,9 @@ class ImportParseTest(CSTNodeTest):
|
|||
self.validate_node(
|
||||
node,
|
||||
code,
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
lambda code: parse_statement(code).body[0],
|
||||
lambda code: ensure_type(
|
||||
parse_statement(code), cst.SimpleStatementLine
|
||||
).body[0],
|
||||
expected_position=position,
|
||||
)
|
||||
|
||||
|
|
@ -701,7 +703,8 @@ class ImportFromParseTest(CSTNodeTest):
|
|||
self.validate_node(
|
||||
node,
|
||||
code,
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
lambda code: parse_statement(code).body[0],
|
||||
lambda code: ensure_type(
|
||||
parse_statement(code), cst.SimpleStatementLine
|
||||
).body[0],
|
||||
expected_position=position,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Callable, Optional
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.helpers import ensure_type
|
||||
from libcst.nodes._internal import CodeRange
|
||||
from libcst.nodes.tests.base import CSTNodeTest
|
||||
from libcst.parser import parse_statement
|
||||
|
|
@ -133,7 +134,8 @@ class NonlocalParsingTest(CSTNodeTest):
|
|||
self.validate_node(
|
||||
node,
|
||||
code,
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
lambda code: parse_statement(code).body[0],
|
||||
lambda code: ensure_type(
|
||||
parse_statement(code), cst.SimpleStatementLine
|
||||
).body[0],
|
||||
expected_position=position,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Callable, Optional
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.helpers import ensure_type
|
||||
from libcst.nodes._internal import CodeRange
|
||||
from libcst.nodes.tests.base import CSTNodeTest
|
||||
from libcst.parser import parse_statement
|
||||
|
|
@ -195,7 +196,8 @@ class RaiseParsingTest(CSTNodeTest):
|
|||
self.validate_node(
|
||||
node,
|
||||
code,
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
lambda code: parse_statement(code).body[0],
|
||||
lambda code: ensure_type(
|
||||
parse_statement(code), cst.SimpleStatementLine
|
||||
).body[0],
|
||||
expected_position=position,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
from typing import Callable, Optional
|
||||
|
||||
import libcst.nodes as cst
|
||||
from libcst.helpers import ensure_type
|
||||
from libcst.nodes._internal import CodeRange
|
||||
from libcst.nodes.tests.base import CSTNodeTest
|
||||
from libcst.parser import parse_statement
|
||||
|
|
@ -216,5 +217,11 @@ class YieldParsingTest(CSTNodeTest):
|
|||
def test_valid(
|
||||
self, node: cst.CSTNode, code: str, position: Optional[CodeRange] = None
|
||||
) -> None:
|
||||
# pyre-fixme[16]: `BaseSuite` has no attribute `__getitem__`.
|
||||
self.validate_node(node, code, lambda code: parse_statement(code).body[0].value)
|
||||
self.validate_node(
|
||||
node,
|
||||
code,
|
||||
lambda code: ensure_type(
|
||||
ensure_type(parse_statement(code), cst.SimpleStatementLine).body[0],
|
||||
cst.Expr,
|
||||
).value,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue