mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
Several trivial refactors (#770)
* Enumeration members are singletons. Copying on them would be no-op
* Avoid generating unnecessary `pass` statement
* Several trivial refactor
* Avoid building unnecessary intermediate lists, which are mere slight waste of time and space
* Remove unused import, an overlook from commit 8e6bf9e9
* `collections.abc.Mapping.get()` defaults to return `None` when key doesn't exist
* Just use unittest's `assertRaises` to specify expected exception types, instead of catching every possible `Exception`s, which could suppress legitimate errors and hide bugs
* We know for sure that the body of `CSTTypedTransformerFunctions` won't be empty, so don't bother with complex formal completeness
This commit is contained in:
parent
667c713b38
commit
973895a6c0
11 changed files with 19 additions and 33 deletions
|
|
@ -26,7 +26,7 @@ _DEFAULT_PARTIAL_PARSER_CONFIG: PartialParserConfig = PartialParserConfig()
|
|||
|
||||
|
||||
def is_native() -> bool:
|
||||
typ = os.environ.get("LIBCST_PARSER_TYPE", None)
|
||||
typ = os.environ.get("LIBCST_PARSER_TYPE")
|
||||
return typ == "native"
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -180,6 +180,5 @@ class ParseErrorsTest(UnitTest):
|
|||
def test_native_fallible_into_py(self) -> None:
|
||||
with patch("libcst._nodes.expression.Name._validate") as await_validate:
|
||||
await_validate.side_effect = CSTValidationError("validate is broken")
|
||||
with self.assertRaises(Exception) as e:
|
||||
with self.assertRaises((SyntaxError, cst.ParserSyntaxError)):
|
||||
cst.parse_module("foo")
|
||||
self.assertIsInstance(e.exception, (SyntaxError, cst.ParserSyntaxError))
|
||||
|
|
|
|||
|
|
@ -6155,8 +6155,6 @@ class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):
|
|||
|
||||
|
||||
class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
|
||||
pass
|
||||
|
||||
@mark_no_op
|
||||
def leave_Add(self, original_node: "Add", updated_node: "Add") -> "BaseBinaryOp":
|
||||
return updated_node
|
||||
|
|
|
|||
|
|
@ -3,10 +3,8 @@
|
|||
# This source code is licensed under the MIT license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar
|
||||
from typing import Any, Callable, cast, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from libcst._typed_visitor import CSTTypedBaseFunctions # noqa: F401
|
||||
|
||||
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ import inspect
|
|||
from collections import defaultdict
|
||||
from collections.abc import Sequence as ABCSequence
|
||||
from dataclasses import dataclass, fields, replace
|
||||
from typing import Dict, Generator, List, Mapping, Sequence, Set, Type, Union
|
||||
from typing import Dict, Iterator, List, Mapping, Sequence, Set, Type, Union
|
||||
|
||||
import libcst as cst
|
||||
|
||||
|
||||
def _get_bases() -> Generator[Type[cst.CSTNode], None, None]:
|
||||
def _get_bases() -> Iterator[Type[cst.CSTNode]]:
|
||||
"""
|
||||
Get all base classes that are subclasses of CSTNode but not an actual
|
||||
node itself. This allows us to keep our types sane by refering to the
|
||||
|
|
@ -27,11 +27,11 @@ def _get_bases() -> Generator[Type[cst.CSTNode], None, None]:
|
|||
|
||||
|
||||
typeclasses: Sequence[Type[cst.CSTNode]] = sorted(
|
||||
list(_get_bases()), key=lambda base: base.__name__
|
||||
_get_bases(), key=lambda base: base.__name__
|
||||
)
|
||||
|
||||
|
||||
def _get_nodes() -> Generator[Type[cst.CSTNode], None, None]:
|
||||
def _get_nodes() -> Iterator[Type[cst.CSTNode]]:
|
||||
"""
|
||||
Grab all CSTNodes that are not a superclass. Basically, anything that a
|
||||
person might use to generate a tree.
|
||||
|
|
@ -53,7 +53,7 @@ def _get_nodes() -> Generator[Type[cst.CSTNode], None, None]:
|
|||
|
||||
|
||||
all_libcst_nodes: Sequence[Type[cst.CSTNode]] = sorted(
|
||||
list(_get_nodes()), key=lambda node: node.__name__
|
||||
_get_nodes(), key=lambda node: node.__name__
|
||||
)
|
||||
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
|
||||
for node in all_libcst_nodes:
|
||||
|
|
|
|||
|
|
@ -547,7 +547,7 @@ for node in all_libcst_nodes:
|
|||
|
||||
|
||||
# Make sure to add an __all__ for flake8 and compatibility with "from libcst.matchers import *"
|
||||
generated_code.append(f"__all__ = {repr(sorted(list(all_exports)))}")
|
||||
generated_code.append(f"__all__ = {repr(sorted(all_exports))}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ generated_code.append("")
|
|||
generated_code.append("")
|
||||
for module, objects in imports.items():
|
||||
generated_code.append(f"from {module} import (")
|
||||
generated_code.append(f" {', '.join(sorted(list(objects)))}")
|
||||
generated_code.append(f" {', '.join(sorted(objects))}")
|
||||
generated_code.append(")")
|
||||
|
||||
# Generate the base visit_ methods
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ generated_code.append("")
|
|||
generated_code.append("if TYPE_CHECKING:")
|
||||
for module, objects in imports.items():
|
||||
generated_code.append(f" from {module} import ( # noqa: F401")
|
||||
generated_code.append(f" {', '.join(sorted(list(objects)))}")
|
||||
generated_code.append(f" {', '.join(sorted(objects))}")
|
||||
generated_code.append(" )")
|
||||
|
||||
|
||||
|
|
@ -87,7 +87,6 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
|
|||
generated_code.append("")
|
||||
generated_code.append("")
|
||||
generated_code.append("class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):")
|
||||
generated_code.append(" pass")
|
||||
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
|
||||
name = node.__name__
|
||||
if name.startswith("Base"):
|
||||
|
|
@ -111,6 +110,7 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
|
|||
)
|
||||
generated_code.append(" return updated_node")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Output the code
|
||||
print("\n".join(generated_code))
|
||||
|
|
|
|||
|
|
@ -557,7 +557,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
|
|||
|
||||
# Ensure that we have no duplicates, otherwise we might get race conditions
|
||||
# on write.
|
||||
files = sorted(list({os.path.abspath(f) for f in files}))
|
||||
files = sorted({os.path.abspath(f) for f in files})
|
||||
total = len(files)
|
||||
progress = Progress(enabled=not hide_progress, total=total)
|
||||
|
||||
|
|
|
|||
|
|
@ -30,12 +30,10 @@ def call_if_inside(
|
|||
"""
|
||||
|
||||
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
|
||||
if not hasattr(original, VISIT_POSITIVE_MATCHER_ATTR):
|
||||
setattr(original, VISIT_POSITIVE_MATCHER_ATTR, [])
|
||||
setattr(
|
||||
original,
|
||||
VISIT_POSITIVE_MATCHER_ATTR,
|
||||
[*getattr(original, VISIT_POSITIVE_MATCHER_ATTR), matcher],
|
||||
[*getattr(original, VISIT_POSITIVE_MATCHER_ATTR, []), matcher],
|
||||
)
|
||||
return original
|
||||
|
||||
|
|
@ -57,12 +55,10 @@ def call_if_not_inside(
|
|||
"""
|
||||
|
||||
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
|
||||
if not hasattr(original, VISIT_NEGATIVE_MATCHER_ATTR):
|
||||
setattr(original, VISIT_NEGATIVE_MATCHER_ATTR, [])
|
||||
setattr(
|
||||
original,
|
||||
VISIT_NEGATIVE_MATCHER_ATTR,
|
||||
[*getattr(original, VISIT_NEGATIVE_MATCHER_ATTR), matcher],
|
||||
[*getattr(original, VISIT_NEGATIVE_MATCHER_ATTR, []), matcher],
|
||||
)
|
||||
return original
|
||||
|
||||
|
|
@ -88,12 +84,10 @@ def visit(matcher: BaseMatcherNode) -> Callable[[_CSTVisitFuncT], _CSTVisitFuncT
|
|||
"""
|
||||
|
||||
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
|
||||
if not hasattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR):
|
||||
setattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR, [])
|
||||
setattr(
|
||||
original,
|
||||
CONSTRUCTED_VISIT_MATCHER_ATTR,
|
||||
[*getattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR), matcher],
|
||||
[*getattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR, []), matcher],
|
||||
)
|
||||
return original
|
||||
|
||||
|
|
@ -116,12 +110,10 @@ def leave(matcher: BaseMatcherNode) -> Callable[[_CSTVisitFuncT], _CSTVisitFuncT
|
|||
"""
|
||||
|
||||
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
|
||||
if not hasattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR):
|
||||
setattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR, [])
|
||||
setattr(
|
||||
original,
|
||||
CONSTRUCTED_LEAVE_MATCHER_ATTR,
|
||||
[*getattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR), matcher],
|
||||
[*getattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR, []), matcher],
|
||||
)
|
||||
return original
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@
|
|||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import collections.abc
|
||||
import copy
|
||||
import inspect
|
||||
import re
|
||||
from abc import ABCMeta
|
||||
|
|
@ -1831,7 +1830,7 @@ class _ReplaceTransformer(libcst.CSTTransformer):
|
|||
if inspect.isfunction(replacement):
|
||||
self.replacement = replacement
|
||||
elif isinstance(replacement, (MaybeSentinel, RemovalSentinel)):
|
||||
self.replacement = lambda node, matches: copy.deepcopy(replacement)
|
||||
self.replacement = lambda node, matches: replacement
|
||||
else:
|
||||
# pyre-ignore We know this is a CSTNode.
|
||||
self.replacement = lambda node, matches: replacement.deep_clone()
|
||||
|
|
@ -1946,7 +1945,7 @@ def replace(
|
|||
"""
|
||||
if isinstance(tree, (RemovalSentinel, MaybeSentinel)):
|
||||
# We can't do any replacements on this, so return the tree exactly.
|
||||
return copy.deepcopy(tree)
|
||||
return tree
|
||||
if isinstance(matcher, (AtLeastN, AtMostN)):
|
||||
# We can't match this, since these matchers are forbidden at top level.
|
||||
# These are not subclasses of BaseMatcherNode, but in the case that the
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue