mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Refactors codegen to write position information to the `__metadata__` fields in nodes keyed by `BasicPositionProvider` and `SyntacticPositionProvider` as defined in position_metadata.py. This commit also updates `deep_equals` to ignore dataclass fields that are marked `compare=False` to avoid comparing metadata when doing equality checks.
190 lines
6.9 KiB
Python
190 lines
6.9 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
# pyre-strict
|
|
from typing import Tuple, cast
|
|
|
|
import libcst.nodes as cst
|
|
from libcst.metadata.position_provider import SyntacticPositionProvider
|
|
from libcst.nodes._internal import CodePosition
|
|
from libcst.nodes.tests.base import CSTNodeTest
|
|
from libcst.parser import parse_module
|
|
from libcst.testing.utils import data_provider
|
|
|
|
|
|
class ModuleTest(CSTNodeTest):
|
|
@data_provider(
|
|
(
|
|
# simplest possible program
|
|
(cst.Module((cst.SimpleStatementLine((cst.Pass(),)),)), "pass\n"),
|
|
# test default_newline
|
|
(
|
|
cst.Module(
|
|
(cst.SimpleStatementLine((cst.Pass(),)),), default_newline="\r"
|
|
),
|
|
"pass\r",
|
|
),
|
|
# test header/footer
|
|
(
|
|
cst.Module(
|
|
(cst.SimpleStatementLine((cst.Pass(),)),),
|
|
header=(cst.EmptyLine(comment=cst.Comment("# header")),),
|
|
footer=(cst.EmptyLine(comment=cst.Comment("# footer")),),
|
|
),
|
|
"# header\npass\n# footer\n",
|
|
),
|
|
# test has_trailing_newline
|
|
(
|
|
cst.Module(
|
|
(cst.SimpleStatementLine((cst.Pass(),)),),
|
|
has_trailing_newline=False,
|
|
),
|
|
"pass",
|
|
),
|
|
# an empty file
|
|
(cst.Module((), has_trailing_newline=False), ""),
|
|
# a file with only comments
|
|
(
|
|
cst.Module(
|
|
(),
|
|
header=(
|
|
cst.EmptyLine(comment=cst.Comment("# nothing to see here")),
|
|
),
|
|
),
|
|
"# nothing to see here\n",
|
|
),
|
|
# TODO: test default_indent
|
|
)
|
|
)
|
|
def test_code_and_bytes_properties(self, module: cst.Module, expected: str) -> None:
|
|
self.assertEqual(module.code, expected)
|
|
self.assertEqual(module.bytes, expected.encode("utf-8"))
|
|
|
|
@data_provider(
|
|
(
|
|
(cst.Module(()), cst.Newline(), "\n"),
|
|
(cst.Module((), default_newline="\r\n"), cst.Newline(), "\r\n"),
|
|
# has_trailing_newline has no effect on code_for_node
|
|
(cst.Module((), has_trailing_newline=False), cst.Newline(), "\n"),
|
|
# TODO: test default_indent
|
|
)
|
|
)
|
|
def test_code_for_node(
|
|
self, module: cst.Module, node: cst.CSTNode, expected: str
|
|
) -> None:
|
|
self.assertEqual(module.code_for_node(node), expected)
|
|
|
|
@data_provider(
|
|
{
|
|
"empty_program": {
|
|
"code": "",
|
|
"expected": cst.Module([], has_trailing_newline=False),
|
|
},
|
|
"empty_program_with_newline": {
|
|
"code": "\n",
|
|
"expected": cst.Module([], has_trailing_newline=True),
|
|
},
|
|
"empty_program_with_comments": {
|
|
"code": "# some comment\n",
|
|
"expected": cst.Module(
|
|
[], header=[cst.EmptyLine(comment=cst.Comment("# some comment"))]
|
|
),
|
|
},
|
|
"simple_pass": {
|
|
"code": "pass\n",
|
|
"expected": cst.Module([cst.SimpleStatementLine([cst.Pass()])]),
|
|
},
|
|
"simple_pass_with_header_footer": {
|
|
"code": "# header\npass # trailing\n# footer\n",
|
|
"expected": cst.Module(
|
|
[
|
|
cst.SimpleStatementLine(
|
|
[cst.Pass()],
|
|
trailing_whitespace=cst.TrailingWhitespace(
|
|
whitespace=cst.SimpleWhitespace(" "),
|
|
comment=cst.Comment("# trailing"),
|
|
),
|
|
)
|
|
],
|
|
header=[cst.EmptyLine(comment=cst.Comment("# header"))],
|
|
footer=[cst.EmptyLine(comment=cst.Comment("# footer"))],
|
|
),
|
|
},
|
|
}
|
|
)
|
|
def test_parser(self, *, code: str, expected: cst.Module) -> None:
|
|
self.assertEqual(parse_module(code), expected)
|
|
|
|
@data_provider(
|
|
{
|
|
"empty": {"code": "", "expected": CodePosition((1, 0), (1, 0))},
|
|
"empty_with_newline": {
|
|
"code": "\n",
|
|
"expected": CodePosition((1, 0), (2, 0)),
|
|
},
|
|
"empty_program_with_comments": {
|
|
"code": "# 2345",
|
|
"expected": CodePosition((1, 0), (2, 0)),
|
|
},
|
|
"simple_pass": {"code": "pass\n", "expected": CodePosition((1, 0), (2, 0))},
|
|
"simple_pass_with_header_footer": {
|
|
"code": "# header\npass # trailing\n# footer\n",
|
|
"expected": CodePosition((1, 0), (4, 0)),
|
|
},
|
|
}
|
|
)
|
|
def test_module_position(self, *, code: str, expected: CodePosition) -> None:
|
|
module = parse_module(code)
|
|
module.code
|
|
|
|
self.assertEqual(module.__metadata__[SyntacticPositionProvider], expected)
|
|
|
|
def cmp_position(
|
|
self, node: cst.CSTNode, start: Tuple[int, int], end: Tuple[int, int]
|
|
) -> None:
|
|
self.assertEqual(
|
|
node.__metadata__[SyntacticPositionProvider], CodePosition(start, end)
|
|
)
|
|
|
|
def test_function_position(self) -> None:
|
|
module = parse_module("def foo():\n pass")
|
|
module.code
|
|
|
|
fn = cast(cst.FunctionDef, module.body[0])
|
|
stmt = cast(cst.SimpleStatementLine, fn.body.body[0])
|
|
pass_stmt = cast(cst.Pass, stmt.body[0])
|
|
self.cmp_position(stmt, (2, 0), (3, 0))
|
|
self.cmp_position(pass_stmt, (2, 4), (2, 8))
|
|
|
|
def test_nested_indent_position(self) -> None:
|
|
module = parse_module(
|
|
"if True:\n if False:\n x = 1\nelse:\n return"
|
|
)
|
|
module.code
|
|
|
|
outer_if = cast(cst.If, module.body[0])
|
|
inner_if = cast(cst.If, outer_if.body.body[0])
|
|
assign = cast(cst.SimpleStatementLine, inner_if.body.body[0]).body[0]
|
|
|
|
outer_else = cast(cst.Else, outer_if.orelse)
|
|
return_stmt = cast(cst.SimpleStatementLine, outer_else.body.body[0]).body[0]
|
|
|
|
self.cmp_position(outer_if, (1, 0), (6, 0))
|
|
self.cmp_position(inner_if, (2, 0), (4, 0))
|
|
self.cmp_position(assign, (3, 8), (3, 13))
|
|
self.cmp_position(outer_else, (4, 0), (6, 0))
|
|
self.cmp_position(return_stmt, (5, 4), (5, 10))
|
|
|
|
def test_multiline_string_position(self) -> None:
|
|
module = parse_module('"abc"\\\n"def"')
|
|
module.code
|
|
|
|
stmt = cast(cst.SimpleStatementLine, module.body[0])
|
|
expr = cast(cst.Expr, stmt.body[0])
|
|
string = expr.value
|
|
|
|
self.cmp_position(stmt, (1, 0), (3, 0))
|
|
self.cmp_position(expr, (1, 0), (2, 5))
|
|
self.cmp_position(string, (1, 0), (2, 5))
|