mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
197 lines
7.2 KiB
Python
197 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
#
|
|
# This source code is licensed under the MIT license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from textwrap import dedent
|
|
from typing import Union
|
|
|
|
import libcst as cst
|
|
from libcst._removal_sentinel import RemovalSentinel
|
|
from libcst._types import CSTNodeT
|
|
from libcst._visitors import CSTTransformer
|
|
from libcst.testing.utils import data_provider, none_throws, UnitTest
|
|
|
|
_EMPTY_SIMPLE_WHITESPACE = cst.SimpleWhitespace("")
|
|
|
|
|
|
class _TestVisitor(CSTTransformer):
|
|
def __init__(self, test: UnitTest) -> None:
|
|
self.counter = 0
|
|
self.test = test
|
|
|
|
def assert_counter(self, expected: int) -> None:
|
|
self.test.assertEqual(self.counter, expected)
|
|
self.counter += 1
|
|
|
|
def on_visit(self, node: cst.CSTNode) -> bool:
|
|
if isinstance(node, cst.Module):
|
|
self.assert_counter(0)
|
|
elif isinstance(node, cst.SimpleStatementLine):
|
|
self.assert_counter(1)
|
|
elif isinstance(node, cst.Pass):
|
|
self.assert_counter(2)
|
|
elif isinstance(node, cst.Newline):
|
|
self.assert_counter(4)
|
|
return True
|
|
|
|
def on_leave(
|
|
self, original_node: CSTNodeT, updated_node: CSTNodeT
|
|
) -> Union[CSTNodeT, RemovalSentinel]:
|
|
self.test.assertTrue(original_node.deep_equals(updated_node))
|
|
# Don't allow type checkers to accidentally refine our return type.
|
|
return_node = updated_node
|
|
if isinstance(updated_node, cst.Pass):
|
|
self.assert_counter(3)
|
|
elif isinstance(updated_node, cst.Newline):
|
|
self.assert_counter(5)
|
|
elif isinstance(updated_node, cst.SimpleStatementLine):
|
|
self.assert_counter(6)
|
|
elif isinstance(updated_node, cst.Module):
|
|
self.assert_counter(7)
|
|
return return_node
|
|
|
|
|
|
class CSTNodeTest(UnitTest):
|
|
def test_with_changes(self) -> None:
|
|
initial = cst.TrailingWhitespace(
|
|
whitespace=cst.SimpleWhitespace(" \\\n "),
|
|
comment=cst.Comment("# initial"),
|
|
newline=cst.Newline("\r\n"),
|
|
)
|
|
changed = initial.with_changes(comment=cst.Comment("# new comment"))
|
|
|
|
# see that we have the updated fields
|
|
self.assertEqual(none_throws(changed.comment).value, "# new comment")
|
|
# and that the old fields are still there
|
|
self.assertEqual(changed.whitespace.value, " \\\n ")
|
|
self.assertEqual(changed.newline.value, "\r\n")
|
|
|
|
# ensure no mutation actually happened
|
|
self.assertEqual(none_throws(initial.comment).value, "# initial")
|
|
|
|
def test_default_eq(self) -> None:
|
|
sw1 = cst.SimpleWhitespace("")
|
|
sw2 = cst.SimpleWhitespace("")
|
|
self.assertNotEqual(sw1, sw2)
|
|
self.assertEqual(sw1, sw1)
|
|
self.assertEqual(sw2, sw2)
|
|
self.assertTrue(sw1.deep_equals(sw2))
|
|
self.assertTrue(sw2.deep_equals(sw1))
|
|
|
|
def test_hash(self) -> None:
|
|
sw1 = cst.SimpleWhitespace("")
|
|
sw2 = cst.SimpleWhitespace("")
|
|
self.assertNotEqual(hash(sw1), hash(sw2))
|
|
self.assertEqual(hash(sw1), hash(sw1))
|
|
self.assertEqual(hash(sw2), hash(sw2))
|
|
|
|
@data_provider(
|
|
{
|
|
"simple": (cst.SimpleWhitespace(""), cst.SimpleWhitespace("")),
|
|
"identity": (_EMPTY_SIMPLE_WHITESPACE, _EMPTY_SIMPLE_WHITESPACE),
|
|
"nested": (
|
|
cst.EmptyLine(whitespace=cst.SimpleWhitespace("")),
|
|
cst.EmptyLine(whitespace=cst.SimpleWhitespace("")),
|
|
),
|
|
"tuple_versus_list": (
|
|
cst.SimpleStatementLine(body=[cst.Pass()]),
|
|
cst.SimpleStatementLine(body=(cst.Pass(),)),
|
|
),
|
|
}
|
|
)
|
|
def test_deep_equals_success(self, a: cst.CSTNode, b: cst.CSTNode) -> None:
|
|
self.assertTrue(a.deep_equals(b))
|
|
|
|
@data_provider(
|
|
{
|
|
"simple": (cst.SimpleWhitespace(" "), cst.SimpleWhitespace(" ")),
|
|
"nested": (
|
|
cst.EmptyLine(whitespace=cst.SimpleWhitespace(" ")),
|
|
cst.EmptyLine(whitespace=cst.SimpleWhitespace(" ")),
|
|
),
|
|
"list": (
|
|
cst.SimpleStatementLine(body=[cst.Pass(semicolon=cst.Semicolon())]),
|
|
cst.SimpleStatementLine(body=[cst.Pass(semicolon=cst.Semicolon())] * 2),
|
|
),
|
|
}
|
|
)
|
|
def test_deep_equals_fails(self, a: cst.CSTNode, b: cst.CSTNode) -> None:
|
|
self.assertFalse(a.deep_equals(b))
|
|
|
|
def test_repr(self) -> None:
|
|
self.assertEqual(
|
|
repr(
|
|
cst.SimpleStatementLine(
|
|
body=[cst.Pass()],
|
|
# tuple with multiple items
|
|
leading_lines=(
|
|
cst.EmptyLine(
|
|
indent=True,
|
|
whitespace=cst.SimpleWhitespace(""),
|
|
comment=None,
|
|
newline=cst.Newline(),
|
|
),
|
|
cst.EmptyLine(
|
|
indent=True,
|
|
whitespace=cst.SimpleWhitespace(""),
|
|
comment=None,
|
|
newline=cst.Newline(),
|
|
),
|
|
),
|
|
trailing_whitespace=cst.TrailingWhitespace(
|
|
whitespace=cst.SimpleWhitespace(" "),
|
|
comment=cst.Comment("# comment"),
|
|
newline=cst.Newline(),
|
|
),
|
|
)
|
|
),
|
|
dedent(
|
|
"""
|
|
SimpleStatementLine(
|
|
body=[
|
|
Pass(
|
|
semicolon=MaybeSentinel.DEFAULT,
|
|
),
|
|
],
|
|
leading_lines=[
|
|
EmptyLine(
|
|
indent=True,
|
|
whitespace=SimpleWhitespace(
|
|
value='',
|
|
),
|
|
comment=None,
|
|
newline=Newline(
|
|
value=None,
|
|
),
|
|
),
|
|
EmptyLine(
|
|
indent=True,
|
|
whitespace=SimpleWhitespace(
|
|
value='',
|
|
),
|
|
comment=None,
|
|
newline=Newline(
|
|
value=None,
|
|
),
|
|
),
|
|
],
|
|
trailing_whitespace=TrailingWhitespace(
|
|
whitespace=SimpleWhitespace(
|
|
value=' ',
|
|
),
|
|
comment=Comment(
|
|
value='# comment',
|
|
),
|
|
newline=Newline(
|
|
value=None,
|
|
),
|
|
),
|
|
)
|
|
"""
|
|
).strip(),
|
|
)
|
|
|
|
def test_visit(self) -> None:
|
|
tree = cst.Module((cst.SimpleStatementLine((cst.Pass(),)),))
|
|
tree.visit(_TestVisitor(self))
|