Several trivial refactors (#770)

* Enumeration members are singletons. Copying on them would be no-op

* Avoid generating unnecessary `pass` statement

* Several trivial refactor

* Avoid building unnecessary intermediate lists, which are mere slight waste of time and space

* Remove unused import, an overlook from commit 8e6bf9e9

* `collections.abc.Mapping.get()` defaults to return `None` when key doesn't exist

* Just use unittest's `assertRaises` to specify expected exception types, instead of catching every possible `Exception`s, which could suppress legitimate errors and hide bugs

* We know for sure that the body of `CSTTypedTransformerFunctions` won't be empty, so don't bother with complex formal completeness
This commit is contained in:
MapleCCC 2022-09-14 21:33:45 +08:00 committed by GitHub
parent 667c713b38
commit 973895a6c0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 19 additions and 33 deletions

View file

@ -26,7 +26,7 @@ _DEFAULT_PARTIAL_PARSER_CONFIG: PartialParserConfig = PartialParserConfig()
def is_native() -> bool:
typ = os.environ.get("LIBCST_PARSER_TYPE", None)
typ = os.environ.get("LIBCST_PARSER_TYPE")
return typ == "native"

View file

@ -180,6 +180,5 @@ class ParseErrorsTest(UnitTest):
def test_native_fallible_into_py(self) -> None:
with patch("libcst._nodes.expression.Name._validate") as await_validate:
await_validate.side_effect = CSTValidationError("validate is broken")
with self.assertRaises(Exception) as e:
with self.assertRaises((SyntaxError, cst.ParserSyntaxError)):
cst.parse_module("foo")
self.assertIsInstance(e.exception, (SyntaxError, cst.ParserSyntaxError))

View file

@ -6155,8 +6155,6 @@ class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):
class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
pass
@mark_no_op
def leave_Add(self, original_node: "Add", updated_node: "Add") -> "BaseBinaryOp":
return updated_node

View file

@ -3,10 +3,8 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar
from typing import Any, Callable, cast, TypeVar
if TYPE_CHECKING:
from libcst._typed_visitor import CSTTypedBaseFunctions # noqa: F401
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
F = TypeVar("F", bound=Callable)

View file

@ -7,12 +7,12 @@ import inspect
from collections import defaultdict
from collections.abc import Sequence as ABCSequence
from dataclasses import dataclass, fields, replace
from typing import Dict, Generator, List, Mapping, Sequence, Set, Type, Union
from typing import Dict, Iterator, List, Mapping, Sequence, Set, Type, Union
import libcst as cst
def _get_bases() -> Generator[Type[cst.CSTNode], None, None]:
def _get_bases() -> Iterator[Type[cst.CSTNode]]:
"""
Get all base classes that are subclasses of CSTNode but not an actual
node itself. This allows us to keep our types sane by refering to the
@ -27,11 +27,11 @@ def _get_bases() -> Generator[Type[cst.CSTNode], None, None]:
typeclasses: Sequence[Type[cst.CSTNode]] = sorted(
list(_get_bases()), key=lambda base: base.__name__
_get_bases(), key=lambda base: base.__name__
)
def _get_nodes() -> Generator[Type[cst.CSTNode], None, None]:
def _get_nodes() -> Iterator[Type[cst.CSTNode]]:
"""
Grab all CSTNodes that are not a superclass. Basically, anything that a
person might use to generate a tree.
@ -53,7 +53,7 @@ def _get_nodes() -> Generator[Type[cst.CSTNode], None, None]:
all_libcst_nodes: Sequence[Type[cst.CSTNode]] = sorted(
list(_get_nodes()), key=lambda node: node.__name__
_get_nodes(), key=lambda node: node.__name__
)
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
for node in all_libcst_nodes:

View file

@ -547,7 +547,7 @@ for node in all_libcst_nodes:
# Make sure to add an __all__ for flake8 and compatibility with "from libcst.matchers import *"
generated_code.append(f"__all__ = {repr(sorted(list(all_exports)))}")
generated_code.append(f"__all__ = {repr(sorted(all_exports))}")
if __name__ == "__main__":

View file

@ -29,7 +29,7 @@ generated_code.append("")
generated_code.append("")
for module, objects in imports.items():
generated_code.append(f"from {module} import (")
generated_code.append(f" {', '.join(sorted(list(objects)))}")
generated_code.append(f" {', '.join(sorted(objects))}")
generated_code.append(")")
# Generate the base visit_ methods

View file

@ -32,7 +32,7 @@ generated_code.append("")
generated_code.append("if TYPE_CHECKING:")
for module, objects in imports.items():
generated_code.append(f" from {module} import ( # noqa: F401")
generated_code.append(f" {', '.join(sorted(list(objects)))}")
generated_code.append(f" {', '.join(sorted(objects))}")
generated_code.append(" )")
@ -87,7 +87,6 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):")
generated_code.append(" pass")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
name = node.__name__
if name.startswith("Base"):
@ -111,6 +110,7 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
)
generated_code.append(" return updated_node")
if __name__ == "__main__":
# Output the code
print("\n".join(generated_code))

View file

@ -557,7 +557,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
# Ensure that we have no duplicates, otherwise we might get race conditions
# on write.
files = sorted(list({os.path.abspath(f) for f in files}))
files = sorted({os.path.abspath(f) for f in files})
total = len(files)
progress = Progress(enabled=not hide_progress, total=total)

View file

@ -30,12 +30,10 @@ def call_if_inside(
"""
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
if not hasattr(original, VISIT_POSITIVE_MATCHER_ATTR):
setattr(original, VISIT_POSITIVE_MATCHER_ATTR, [])
setattr(
original,
VISIT_POSITIVE_MATCHER_ATTR,
[*getattr(original, VISIT_POSITIVE_MATCHER_ATTR), matcher],
[*getattr(original, VISIT_POSITIVE_MATCHER_ATTR, []), matcher],
)
return original
@ -57,12 +55,10 @@ def call_if_not_inside(
"""
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
if not hasattr(original, VISIT_NEGATIVE_MATCHER_ATTR):
setattr(original, VISIT_NEGATIVE_MATCHER_ATTR, [])
setattr(
original,
VISIT_NEGATIVE_MATCHER_ATTR,
[*getattr(original, VISIT_NEGATIVE_MATCHER_ATTR), matcher],
[*getattr(original, VISIT_NEGATIVE_MATCHER_ATTR, []), matcher],
)
return original
@ -88,12 +84,10 @@ def visit(matcher: BaseMatcherNode) -> Callable[[_CSTVisitFuncT], _CSTVisitFuncT
"""
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
if not hasattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR):
setattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR, [])
setattr(
original,
CONSTRUCTED_VISIT_MATCHER_ATTR,
[*getattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR), matcher],
[*getattr(original, CONSTRUCTED_VISIT_MATCHER_ATTR, []), matcher],
)
return original
@ -116,12 +110,10 @@ def leave(matcher: BaseMatcherNode) -> Callable[[_CSTVisitFuncT], _CSTVisitFuncT
"""
def inner(original: _CSTVisitFuncT) -> _CSTVisitFuncT:
if not hasattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR):
setattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR, [])
setattr(
original,
CONSTRUCTED_LEAVE_MATCHER_ATTR,
[*getattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR), matcher],
[*getattr(original, CONSTRUCTED_LEAVE_MATCHER_ATTR, []), matcher],
)
return original

View file

@ -4,7 +4,6 @@
# LICENSE file in the root directory of this source tree.
import collections.abc
import copy
import inspect
import re
from abc import ABCMeta
@ -1831,7 +1830,7 @@ class _ReplaceTransformer(libcst.CSTTransformer):
if inspect.isfunction(replacement):
self.replacement = replacement
elif isinstance(replacement, (MaybeSentinel, RemovalSentinel)):
self.replacement = lambda node, matches: copy.deepcopy(replacement)
self.replacement = lambda node, matches: replacement
else:
# pyre-ignore We know this is a CSTNode.
self.replacement = lambda node, matches: replacement.deep_clone()
@ -1946,7 +1945,7 @@ def replace(
"""
if isinstance(tree, (RemovalSentinel, MaybeSentinel)):
# We can't do any replacements on this, so return the tree exactly.
return copy.deepcopy(tree)
return tree
if isinstance(matcher, (AtLeastN, AtMostN)):
# We can't match this, since these matchers are forbidden at top level.
# These are not subclasses of BaseMatcherNode, but in the case that the