add get_full_name_for_expression helper

This commit is contained in:
Jimmy Lai 2020-01-09 15:43:29 -08:00 committed by jimmylai
parent 2fb0db33d1
commit dffb27b5b9
6 changed files with 85 additions and 30 deletions

14
docs/source/helpers.rst Normal file
View file

@ -0,0 +1,14 @@
=======
Helpers
=======
Helpers are higher level functions built for reducing recurring code boilerplate.
We add helpers as method of ``CSTNode`` or ``libcst.helpers`` package based on those principles:
- ``CSTNode`` method: simple, read-only and only require data of the direct children of a CSTNode.
- ``libcst.helpers``: node transforms or require recursively traversing the syntax tree.
libcst.helpers
--------------
.. autofunction:: libcst.helpers.module.insert_header_comments
.. autofunction:: libcst.helpers.expression.get_full_name_for_node

View file

@ -41,6 +41,7 @@ LibCST
metadata
matchers
codemods
helpers
experimental

View file

@ -9,6 +9,7 @@ from typing import Dict, List, Optional, Sequence, Set, Tuple, Union
import libcst
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareVisitor
from libcst.helpers.expression import get_full_name_for_node
class GatherImportsVisitor(ContextAwareVisitor):
@ -62,15 +63,12 @@ class GatherImportsVisitor(ContextAwareVisitor):
self.all_imports: List[Union[libcst.Import, libcst.ImportFrom]] = []
def _get_string_name(self, node: Optional[libcst.CSTNode]) -> str:
if node is None:
return ""
elif isinstance(node, libcst.Name):
return node.value
elif isinstance(node, libcst.Attribute):
return self._get_string_name(node.value) + "." + node.attr.value
else:
name = "" if node is None else get_full_name_for_node(node)
if name is None:
raise Exception(f"Invalid node type {type(node)}!")
return name
def visit_Import(self, node: libcst.Import) -> None:
# Track this import statement for later analysis.
self.all_imports.append(node)

View file

@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
from typing import Optional, Union
import libcst as cst
def get_full_name_for_node(node: Union[str, cst.CSTNode]) -> Optional[str]:
"""Return a dot concatenated full name for str, :class:`~libcst.Name`, :class:`~libcst.Attribute`.
:class:`~libcst.Call`, :class:`~libcst.Subscript`, :class:`~libcst.FunctionDef`, :class:`~libcst.ClassDef`.
Return ``None`` for not supported Node.
"""
if isinstance(node, cst.Name):
return node.value
elif isinstance(node, str):
return node
elif isinstance(node, cst.Attribute):
return f"{get_full_name_for_node(node.value)}.{node.attr.value}"
elif isinstance(node, cst.Call):
return get_full_name_for_node(node.func)
elif isinstance(node, cst.Subscript):
return get_full_name_for_node(node.value)
elif isinstance(node, (cst.FunctionDef, cst.ClassDef)):
return get_full_name_for_node(node.name)
return None

View file

@ -0,0 +1,30 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
# pyre-strict
from typing import Optional, Union
import libcst as cst
from libcst.helpers.expression import get_full_name_for_node
from libcst.testing.utils import UnitTest, data_provider
class ExpressionTest(UnitTest):
@data_provider(
(
("a string", "a string"),
(cst.Name("a_name"), "a_name"),
(cst.parse_expression("a.b.c"), "a.b.c"),
(cst.parse_expression("a.b()"), "a.b"),
(cst.parse_expression("a.b.c[i]"), "a.b.c"),
(cst.parse_statement("def fun(): pass"), "fun"),
(cst.parse_statement("class cls: pass"), "cls"),
(cst.parse_statement("(a.b()).c()"), None), # not a supported Node type
)
)
def test_get_full_name_for_expression(
self, input: Union[str, cst.CSTNode], output: Optional[str],
) -> None:
self.assertEqual(get_full_name_for_node(input), output)

View file

@ -27,6 +27,7 @@ from typing import (
import libcst as cst
from libcst._add_slots import add_slots
from libcst.helpers.expression import get_full_name_for_node
from libcst.metadata.base_provider import BatchableMetadataProvider
from libcst.metadata.expression_context_provider import (
ExpressionContext,
@ -189,22 +190,6 @@ class QualifiedName:
class _NameUtil:
@staticmethod
def get_full_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]:
if isinstance(node, cst.Name):
return node.value
elif isinstance(node, str):
return node
elif isinstance(node, cst.Attribute):
return f"{_NameUtil.get_full_name_for(node.value)}.{node.attr.value}"
elif isinstance(node, cst.Call):
return _NameUtil.get_full_name_for(node.func)
elif isinstance(node, cst.Subscript):
return _NameUtil.get_full_name_for(node.value)
elif isinstance(node, (cst.FunctionDef, cst.ClassDef)):
return _NameUtil.get_full_name_for(node.name)
return None
@staticmethod
def get_name_for(node: Union[str, cst.CSTNode]) -> Optional[str]:
"""A helper function to retrieve simple name str from a CSTNode or str"""
@ -230,11 +215,11 @@ class _NameUtil:
module_attr = assignment_node.module
if module_attr:
# TODO: for relative import, keep the relative Dot in the qualified name
module = _NameUtil.get_full_name_for(module_attr)
module = get_full_name_for_node(module_attr)
import_names = assignment_node.names
if not isinstance(import_names, cst.ImportStar):
for name in import_names:
real_name = _NameUtil.get_full_name_for(name.name)
real_name = get_full_name_for_node(name.name)
as_name = real_name
if name and name.asname:
name_asname = name.asname
@ -415,7 +400,7 @@ class Scope(abc.ABC):
resolve, e.g. ``List[Union[int, str]]``.
"""
results = set()
full_name = _NameUtil.get_full_name_for(node)
full_name = get_full_name_for_node(node)
if full_name is None:
return results
parts = full_name.split(".")
@ -647,9 +632,7 @@ class ScopeVisitor(cst.CSTVisitor):
self.scope.record_assignment(node.name.value, node)
self.provider.set_metadata(node.name, self.scope)
with self._new_scope(
FunctionScope, node, _NameUtil.get_full_name_for(node.name)
):
with self._new_scope(FunctionScope, node, get_full_name_for_node(node.name)):
node.params.visit(self)
node.body.visit(self)
@ -724,7 +707,7 @@ class ScopeVisitor(cst.CSTVisitor):
for keyword in node.keywords:
keyword.visit(self)
with self._new_scope(ClassScope, node, _NameUtil.get_full_name_for(node.name)):
with self._new_scope(ClassScope, node, get_full_name_for_node(node.name)):
for statement in node.body.body:
statement.visit(self)