Add helper functions for common ways of filtering nodes (#1137)

* Make the nodes fields filtering process - from libcst.tool - public, so that other libraries may provide their own custom representation of LibCST graphs.

* Create functions to access & filter CST-node fields (with appropriate docstrings & tests), in libcst.helpers

* Add new CST-node fields functions to helpers documentation.
This commit is contained in:
zaicruvoir1rominet 2024-05-13 11:20:47 +02:00 committed by GitHub
parent 6783244eab
commit efc53af608
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 482 additions and 80 deletions

View file

@ -32,3 +32,18 @@ Functions that assist in traversing an existing LibCST tree.
.. autofunction:: libcst.helpers.get_full_name_for_node
.. autofunction:: libcst.helpers.get_full_name_for_node_or_raise
.. autofunction:: libcst.helpers.ensure_type
Node fields filtering Helpers
-----------------------------
Function that assist when handling CST nodes' fields.
.. autofunction:: libcst.helpers.filter_node_fields
And lower level functions:
.. autofunction:: libcst.helpers.get_node_fields
.. autofunction:: libcst.helpers.is_whitespace_node_field
.. autofunction:: libcst.helpers.is_syntax_node_field
.. autofunction:: libcst.helpers.is_default_node_field
.. autofunction:: libcst.helpers.get_field_default_value

View file

@ -25,6 +25,14 @@ from libcst.helpers.module import (
insert_header_comments,
ModuleNameAndPackage,
)
from libcst.helpers.node_fields import (
filter_node_fields,
get_field_default_value,
get_node_fields,
is_default_node_field,
is_syntax_node_field,
is_whitespace_node_field,
)
__all__ = [
"calculate_module_and_package",
@ -42,4 +50,10 @@ __all__ = [
"parse_template_statement",
"parse_template_expression",
"ModuleNameAndPackage",
"get_node_fields",
"get_field_default_value",
"is_whitespace_node_field",
"is_syntax_node_field",
"is_default_node_field",
"filter_node_fields",
]

View file

@ -0,0 +1,128 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from __future__ import annotations
import dataclasses
from typing import TYPE_CHECKING
from libcst import IndentedBlock, Module
from libcst._nodes.deep_equals import deep_equals
if TYPE_CHECKING:
from typing import Sequence
from libcst import CSTNode
def get_node_fields(node: CSTNode) -> Sequence[dataclasses.Field[CSTNode]]:
"""
Returns the sequence of a given CST-node's fields.
"""
return dataclasses.fields(node)
def is_whitespace_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
"""
Returns True if a given CST-node's field is a whitespace-related field
(whitespace, indent, header, footer, etc.).
"""
if "whitespace" in field.name:
return True
if "leading_lines" in field.name:
return True
if "lines_after_decorators" in field.name:
return True
if isinstance(node, (IndentedBlock, Module)) and field.name in [
"header",
"footer",
]:
return True
if isinstance(node, IndentedBlock) and field.name == "indent":
return True
return False
def is_syntax_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
"""
Returns True if a given CST-node's field is a syntax-related field
(colon, semicolon, dot, encoding, etc.).
"""
if isinstance(node, Module) and field.name in [
"encoding",
"default_indent",
"default_newline",
"has_trailing_newline",
]:
return True
type_str = repr(field.type)
if (
"Sentinel" in type_str
and field.name not in ["star_arg", "star", "posonly_ind"]
and "whitespace" not in field.name
):
# This is a value that can optionally be specified, so its
# definitely syntax.
return True
for name in ["Semicolon", "Colon", "Comma", "Dot", "AssignEqual"]:
# These are all nodes that exist for separation syntax
if name in type_str:
return True
return False
def get_field_default_value(field: dataclasses.Field[CSTNode]) -> object:
"""
Returns the default value of a CST-node's field.
"""
if field.default_factory is not dataclasses.MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[object]]` is not a function.
return field.default_factory()
return field.default
def is_default_node_field(node: CSTNode, field: dataclasses.Field[CSTNode]) -> bool:
"""
Returns True if a given CST-node's field has its default value.
"""
return deep_equals(getattr(node, field.name), get_field_default_value(field))
def filter_node_fields(
node: CSTNode,
*,
show_defaults: bool,
show_syntax: bool,
show_whitespace: bool,
) -> Sequence[dataclasses.Field[CSTNode]]:
"""
Returns a filtered sequence of a CST-node's fields.
Setting ``show_whitespace`` to ``False`` will filter whitespace fields.
Setting ``show_defaults`` to ``False`` will filter fields if their value is equal to
the default value ; while respecting the value of ``show_whitespace``.
Setting ``show_syntax`` to ``False`` will filter syntax fields ; while respecting
the value of ``show_whitespace`` & ``show_defaults``.
"""
fields: Sequence[dataclasses.Field[CSTNode]] = dataclasses.fields(node)
# Hide all fields prefixed with "_"
fields = [f for f in fields if f.name[0] != "_"]
# Filter whitespace nodes if needed
if not show_whitespace:
fields = [f for f in fields if not is_whitespace_node_field(node, f)]
# Filter values which aren't changed from their defaults
if not show_defaults:
fields = [f for f in fields if not is_default_node_field(node, f)]
# Filter out values which aren't interesting if needed
if not show_syntax:
fields = [f for f in fields if not is_syntax_node_field(node, f)]
return fields

View file

@ -0,0 +1,314 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from unittest import TestCase
from libcst import (
Annotation,
CSTNode,
FunctionDef,
IndentedBlock,
Module,
Param,
parse_module,
Pass,
Semicolon,
SimpleStatementLine,
)
from libcst.helpers import (
get_node_fields,
is_default_node_field,
is_syntax_node_field,
is_whitespace_node_field,
)
class _NodeFieldsTest(TestCase):
"""Node fields related tests."""
module: Module
annotation: Annotation
param: Param
_pass: Pass
semicolon: Semicolon
statement: SimpleStatementLine
indent: IndentedBlock
function: FunctionDef
@classmethod
def setUpClass(cls) -> None:
"""Parse a simple CST and references interesting nodes."""
cls.module = parse_module(
"def foo(a: str) -> None:\n pass ; pass\n return\n"
)
# /!\ Direct access to nodes
# This is done for test purposes on a known CST
# -> For "real code", use visitors to do this "the correct way"
# pyre-ignore[8]: direct access for tests
cls.function = cls.module.body[0]
cls.param = cls.function.params.params[0]
# pyre-ignore[8]: direct access for tests
cls.annotation = cls.param.annotation
# pyre-ignore[8]: direct access for tests
cls.indent = cls.function.body
# pyre-ignore[8]: direct access for tests
cls.statement = cls.indent.body[0]
# pyre-ignore[8]: direct access for tests
cls._pass = cls.statement.body[0]
# pyre-ignore[8]: direct access for tests
cls.semicolon = cls.statement.body[0].semicolon
def test__cst_correctness(self) -> None:
"""Test that the CST is correctly parsed."""
self.assertIsInstance(self.module, Module)
self.assertIsInstance(self.annotation, Annotation)
self.assertIsInstance(self.param, Param)
self.assertIsInstance(self._pass, Pass)
self.assertIsInstance(self.semicolon, Semicolon)
self.assertIsInstance(self.statement, SimpleStatementLine)
self.assertIsInstance(self.indent, IndentedBlock)
self.assertIsInstance(self.function, FunctionDef)
class IsWhitespaceNodeFieldTest(_NodeFieldsTest):
"""``is_whitespace_node_field`` tests."""
def _check_fields(self, is_filtered_field: dict[str, bool], node: CSTNode) -> None:
fields = get_node_fields(node)
self.assertEqual(len(is_filtered_field), len(fields))
for field in fields:
self.assertEqual(
is_filtered_field[field.name],
is_whitespace_node_field(node, field),
f"Node ``{node.__class__.__qualname__}`` field '{field.name}' "
f"{'should have' if is_filtered_field[field.name] else 'should not have'} "
"been filtered by ``is_whitespace_node_field``",
)
def test_module(self) -> None:
"""Check if a CST Module node is correctly filtered."""
is_filtered_field = {
"body": False,
"header": True,
"footer": True,
"encoding": False,
"default_indent": False,
"default_newline": False,
"has_trailing_newline": False,
}
self._check_fields(is_filtered_field, self.module)
def test_annotation(self) -> None:
"""Check if a CST Annotation node is correctly filtered."""
is_filtered_field = {
"annotation": False,
"whitespace_before_indicator": True,
"whitespace_after_indicator": True,
}
self._check_fields(is_filtered_field, self.annotation)
def test_param(self) -> None:
"""Check if a CST Param node is correctly filtered."""
is_filtered_field = {
"name": False,
"annotation": False,
"equal": False,
"default": False,
"comma": False,
"star": False,
"whitespace_after_star": True,
"whitespace_after_param": True,
}
self._check_fields(is_filtered_field, self.param)
def test_semicolon(self) -> None:
"""Check if a CST Semicolon node is correctly filtered."""
is_filtered_field = {
"whitespace_before": True,
"whitespace_after": True,
}
self._check_fields(is_filtered_field, self.semicolon)
def test_statement(self) -> None:
"""Check if a CST SimpleStatementLine node is correctly filtered."""
is_filtered_field = {
"body": False,
"leading_lines": True,
"trailing_whitespace": True,
}
self._check_fields(is_filtered_field, self.statement)
def test_indent(self) -> None:
"""Check if a CST IndentedBlock node is correctly filtered."""
is_filtered_field = {
"body": False,
"header": True,
"indent": True,
"footer": True,
}
self._check_fields(is_filtered_field, self.indent)
def test_function(self) -> None:
"""Check if a CST FunctionDef node is correctly filtered."""
is_filtered_field = {
"name": False,
"params": False,
"body": False,
"decorators": False,
"returns": False,
"asynchronous": False,
"leading_lines": True,
"lines_after_decorators": True,
"whitespace_after_def": True,
"whitespace_after_name": True,
"whitespace_before_params": True,
"whitespace_before_colon": True,
"type_parameters": False,
"whitespace_after_type_parameters": True,
}
self._check_fields(is_filtered_field, self.function)
class IsSyntaxNodeFieldTest(_NodeFieldsTest):
"""``is_syntax_node_field`` tests."""
def _check_fields(self, is_filtered_field: dict[str, bool], node: CSTNode) -> None:
fields = get_node_fields(node)
self.assertEqual(len(is_filtered_field), len(fields))
for field in fields:
self.assertEqual(
is_filtered_field[field.name],
is_syntax_node_field(node, field),
f"Node ``{node.__class__.__qualname__}`` field '{field.name}' "
f"{'should have' if is_filtered_field[field.name] else 'should not have'} "
"been filtered by ``is_syntax_node_field``",
)
def test_module(self) -> None:
"""Check if a CST Module node is correctly filtered."""
is_filtered_field = {
"body": False,
"header": False,
"footer": False,
"encoding": True,
"default_indent": True,
"default_newline": True,
"has_trailing_newline": True,
}
self._check_fields(is_filtered_field, self.module)
def test_param(self) -> None:
"""Check if a CST Param node is correctly filtered."""
is_filtered_field = {
"name": False,
"annotation": False,
"equal": True,
"default": False,
"comma": True,
"star": False,
"whitespace_after_star": False,
"whitespace_after_param": False,
}
self._check_fields(is_filtered_field, self.param)
def test_pass(self) -> None:
"""Check if a CST Pass node is correctly filtered."""
is_filtered_field = {
"semicolon": True,
}
self._check_fields(is_filtered_field, self._pass)
class IsDefaultNodeFieldTest(_NodeFieldsTest):
"""``is_default_node_field`` tests."""
def _check_fields(self, is_filtered_field: dict[str, bool], node: CSTNode) -> None:
fields = get_node_fields(node)
self.assertEqual(len(is_filtered_field), len(fields))
for field in fields:
self.assertEqual(
is_filtered_field[field.name],
is_default_node_field(node, field),
f"Node ``{node.__class__.__qualname__}`` field '{field.name}' "
f"{'should have' if is_filtered_field[field.name] else 'should not have'} "
"been filtered by ``is_default_node_field``",
)
def test_module(self) -> None:
"""Check if a CST Module node is correctly filtered."""
is_filtered_field = {
"body": False,
"header": True,
"footer": True,
"encoding": True,
"default_indent": True,
"default_newline": True,
"has_trailing_newline": True,
}
self._check_fields(is_filtered_field, self.module)
def test_annotation(self) -> None:
"""Check if a CST Annotation node is correctly filtered."""
is_filtered_field = {
"annotation": False,
"whitespace_before_indicator": False,
"whitespace_after_indicator": True,
}
self._check_fields(is_filtered_field, self.annotation)
def test_param(self) -> None:
"""Check if a CST Param node is correctly filtered."""
is_filtered_field = {
"name": False,
"annotation": False,
"equal": True,
"default": True,
"comma": True,
"star": False,
"whitespace_after_star": True,
"whitespace_after_param": True,
}
self._check_fields(is_filtered_field, self.param)
def test_statement(self) -> None:
"""Check if a CST SimpleStatementLine node is correctly filtered."""
is_filtered_field = {
"body": False,
"leading_lines": True,
"trailing_whitespace": True,
}
self._check_fields(is_filtered_field, self.statement)
def test_indent(self) -> None:
"""Check if a CST IndentedBlock node is correctly filtered."""
is_filtered_field = {
"body": False,
"header": True,
"indent": True,
"footer": True,
}
self._check_fields(is_filtered_field, self.indent)
def test_function(self) -> None:
"""Check if a CST FunctionDef node is correctly filtered."""
is_filtered_field = {
"name": False,
"params": False,
"body": False,
"decorators": True,
"returns": False,
"asynchronous": True,
"leading_lines": True,
"lines_after_decorators": True,
"whitespace_after_def": True,
"whitespace_after_name": True,
"whitespace_before_params": True,
"whitespace_before_colon": True,
"type_parameters": True,
"whitespace_after_type_parameters": True,
}
self._check_fields(is_filtered_field, self.function)

View file

@ -22,15 +22,7 @@ from typing import Any, Callable, Dict, List, Sequence, Tuple, Type
import yaml
from libcst import (
CSTNode,
IndentedBlock,
LIBCST_VERSION,
Module,
parse_module,
PartialParserConfig,
)
from libcst._nodes.deep_equals import deep_equals
from libcst import CSTNode, LIBCST_VERSION, parse_module, PartialParserConfig
from libcst._parser.parso.utils import parse_version_string
from libcst.codemod import (
CodemodCommand,
@ -40,6 +32,7 @@ from libcst.codemod import (
gather_files,
parallel_exec_transform_with_prettyprint,
)
from libcst.helpers import filter_node_fields
_DEFAULT_INDENT: str = " "
@ -54,76 +47,14 @@ def _node_repr_recursive( # noqa: C901
) -> List[str]:
if isinstance(node, CSTNode):
# This is a CSTNode, we must pretty-print it.
fields: Sequence["dataclasses.Field[CSTNode]"] = filter_node_fields(
node=node,
show_defaults=show_defaults,
show_syntax=show_syntax,
show_whitespace=show_whitespace,
)
tokens: List[str] = [node.__class__.__name__]
fields: Sequence["dataclasses.Field[object]"] = dataclasses.fields(node)
# Hide all fields prefixed with "_"
fields = [f for f in fields if f.name[0] != "_"]
# Filter whitespace nodes if needed
if not show_whitespace:
def _is_whitespace(field: "dataclasses.Field[object]") -> bool:
if "whitespace" in field.name:
return True
if "leading_lines" in field.name:
return True
if "lines_after_decorators" in field.name:
return True
if isinstance(node, (IndentedBlock, Module)) and field.name in [
"header",
"footer",
]:
return True
if isinstance(node, IndentedBlock) and field.name == "indent":
return True
return False
fields = [f for f in fields if not _is_whitespace(f)]
# Filter values which aren't changed from their defaults
if not show_defaults:
def _get_default(fld: "dataclasses.Field[object]") -> object:
if fld.default_factory is not dataclasses.MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[object]]` is not a function.
return fld.default_factory()
return fld.default
fields = [
f
for f in fields
if not deep_equals(getattr(node, f.name), _get_default(f))
]
# Filter out values which aren't interesting if needed
if not show_syntax:
def _is_syntax(field: "dataclasses.Field[object]") -> bool:
if isinstance(node, Module) and field.name in [
"encoding",
"default_indent",
"default_newline",
"has_trailing_newline",
]:
return True
type_str = repr(field.type)
if (
"Sentinel" in type_str
and field.name not in ["star_arg", "star", "posonly_ind"]
and "whitespace" not in field.name
):
# This is a value that can optionally be specified, so its
# definitely syntax.
return True
for name in ["Semicolon", "Colon", "Comma", "Dot", "AssignEqual"]:
# These are all nodes that exist for separation syntax
if name in type_str:
return True
return False
fields = [f for f in fields if not _is_syntax(f)]
if len(fields) == 0:
tokens.append("()")
@ -204,12 +135,12 @@ def dump(
from the default contruction of the node while also hiding whitespace and
syntax fields.
Setting ``show_default`` to ``True`` will add fields regardless if their
Setting ``show_defaults`` to ``True`` will add fields regardless if their
value is different from the default value.
Setting ``show_whitespace`` will add whitespace fields and setting
``show_syntax`` will add syntax fields while respecting the value of
``show_default``.
``show_defaults``.
When all keyword args are set to true, the output of this function is
indentical to the __repr__ method of the node.