LibCST/libcst/helpers/_template.py
zaicruvoir1rominet ca1f81f049
Avoid raising bare Exception (#1168)
* Keep old exception messages (avoid breaking-changes for users relying on exception messages)

* Move ``get_expected_str`` out of _exceptions.py, where it does not belong, to its own file in _parser/_parsing_check.py
2025-06-07 01:53:44 -07:00

462 lines
19 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from typing import Dict, Mapping, Optional, Set, Union
import libcst as cst
from libcst.helpers.common import ensure_type
TEMPLATE_PREFIX: str = "__LIBCST_MANGLED_NAME_"
TEMPLATE_SUFFIX: str = "_EMAN_DELGNAM_TSCBIL__"
ValidReplacementType = Union[
cst.BaseExpression,
cst.Annotation,
cst.AssignTarget,
cst.Param,
cst.Parameters,
cst.Arg,
cst.BaseStatement,
cst.BaseSmallStatement,
cst.BaseSuite,
cst.BaseSlice,
cst.SubscriptElement,
cst.Decorator,
]
def mangled_name(var: str) -> str:
return f"{TEMPLATE_PREFIX}{var}{TEMPLATE_SUFFIX}"
def unmangled_name(var: str) -> Optional[str]:
if TEMPLATE_PREFIX in var and TEMPLATE_SUFFIX in var:
prefix, name_and_suffix = var.split(TEMPLATE_PREFIX, 1)
name, suffix = name_and_suffix.split(TEMPLATE_SUFFIX, 1)
if not prefix and not suffix:
return name
# This is not a valid mangled name
return None
def mangle_template(template: str, template_vars: Set[str]) -> str:
if TEMPLATE_PREFIX in template or TEMPLATE_SUFFIX in template:
raise ValueError("Cannot parse a template containing reserved strings")
for var in template_vars:
original = f"{{{var}}}"
if original not in template:
raise ValueError(
f'Template string is missing a reference to "{var}" referred to in kwargs'
)
template = template.replace(original, mangled_name(var))
return template
class TemplateTransformer(cst.CSTTransformer):
def __init__(
self, template_replacements: Mapping[str, ValidReplacementType]
) -> None:
self.simple_replacements: Dict[str, cst.BaseExpression] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.BaseExpression)
}
self.annotation_replacements: Dict[str, cst.Annotation] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.Annotation)
}
self.assignment_replacements: Dict[str, cst.AssignTarget] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.AssignTarget)
}
self.param_replacements: Dict[str, cst.Param] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.Param)
}
self.parameters_replacements: Dict[str, cst.Parameters] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.Parameters)
}
self.arg_replacements: Dict[str, cst.Arg] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.Arg)
}
self.small_statement_replacements: Dict[str, cst.BaseSmallStatement] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.BaseSmallStatement)
}
self.statement_replacements: Dict[str, cst.BaseStatement] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.BaseStatement)
}
self.suite_replacements: Dict[str, cst.BaseSuite] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.BaseSuite)
}
self.subscript_element_replacements: Dict[str, cst.SubscriptElement] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.SubscriptElement)
}
self.subscript_index_replacements: Dict[str, cst.BaseSlice] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.BaseSlice)
}
self.decorator_replacements: Dict[str, cst.Decorator] = {
name: value
for name, value in template_replacements.items()
if isinstance(value, cst.Decorator)
}
# Figure out if there are any variables that we can't support
# inserting into templates.
supported_vars = {
*[name for name in self.simple_replacements],
*[name for name in self.annotation_replacements],
*[name for name in self.assignment_replacements],
*[name for name in self.param_replacements],
*[name for name in self.parameters_replacements],
*[name for name in self.arg_replacements],
*[name for name in self.small_statement_replacements],
*[name for name in self.statement_replacements],
*[name for name in self.suite_replacements],
*[name for name in self.subscript_element_replacements],
*[name for name in self.subscript_index_replacements],
*[name for name in self.decorator_replacements],
}
unsupported_vars = {
name for name in template_replacements if name not in supported_vars
}
if unsupported_vars:
raise ValueError(
f'Template replacement for "{next(iter(unsupported_vars))}" is unsupported'
)
def leave_Name(
self, original_node: cst.Name, updated_node: cst.Name
) -> cst.BaseExpression:
var_name = unmangled_name(updated_node.value)
if var_name is None or var_name not in self.simple_replacements:
# This is not a valid name, don't modify it
return updated_node
return self.simple_replacements[var_name].deep_clone()
def leave_Annotation(
self,
original_node: cst.Annotation,
updated_node: cst.Annotation,
) -> cst.Annotation:
# We can't use matchers here due to circular imports
annotation = updated_node.annotation
if isinstance(annotation, cst.Name):
var_name = unmangled_name(annotation.value)
if var_name in self.annotation_replacements:
return self.annotation_replacements[var_name].deep_clone()
return updated_node
def leave_AssignTarget(
self,
original_node: cst.AssignTarget,
updated_node: cst.AssignTarget,
) -> cst.AssignTarget:
# We can't use matchers here due to circular imports
target = updated_node.target
if isinstance(target, cst.Name):
var_name = unmangled_name(target.value)
if var_name in self.assignment_replacements:
return self.assignment_replacements[var_name].deep_clone()
return updated_node
def leave_Param(
self,
original_node: cst.Param,
updated_node: cst.Param,
) -> cst.Param:
var_name = unmangled_name(updated_node.name.value)
if var_name in self.param_replacements:
return self.param_replacements[var_name].deep_clone()
return updated_node
def leave_Parameters(
self,
original_node: cst.Parameters,
updated_node: cst.Parameters,
) -> cst.Parameters:
# A very special case for when we use a template variable for all
# function parameters.
if (
len(updated_node.params) == 1
and updated_node.star_arg == cst.MaybeSentinel.DEFAULT
and len(updated_node.kwonly_params) == 0
and updated_node.star_kwarg is None
and len(updated_node.posonly_params) == 0
and updated_node.posonly_ind == cst.MaybeSentinel.DEFAULT
):
# This parameters node has only one argument, which is possibly
# a replacement.
var_name = unmangled_name(updated_node.params[0].name.value)
if var_name in self.parameters_replacements:
return self.parameters_replacements[var_name].deep_clone()
return updated_node
def leave_Arg(self, original_node: cst.Arg, updated_node: cst.Arg) -> cst.Arg:
# We can't use matchers here due to circular imports
arg = updated_node.value
if isinstance(arg, cst.Name):
var_name = unmangled_name(arg.value)
if var_name in self.arg_replacements:
return self.arg_replacements[var_name].deep_clone()
return updated_node
def leave_SimpleStatementLine(
self,
original_node: cst.SimpleStatementLine,
updated_node: cst.SimpleStatementLine,
) -> cst.BaseStatement:
# We can't use matchers here due to circular imports. We take advantage of
# the fact that a name on a single line will be parsed as an Expr node
# contained in a SimpleStatementLine, so we check for these and see if they
# should be expanded template-wise to a statement of some type.
if len(updated_node.body) == 1:
body_node = updated_node.body[0]
if isinstance(body_node, cst.Expr):
name_node = body_node.value
if isinstance(name_node, cst.Name):
var_name = unmangled_name(name_node.value)
if var_name in self.statement_replacements:
return self.statement_replacements[var_name].deep_clone()
return updated_node
def leave_Expr(
self,
original_node: cst.Expr,
updated_node: cst.Expr,
) -> cst.BaseSmallStatement:
# We can't use matchers here due to circular imports. We do a similar trick
# to the above stanza handling SimpleStatementLine to support templates
# which are trying to substitute a BaseSmallStatement.
name_node = updated_node.value
if isinstance(name_node, cst.Name):
var_name = unmangled_name(name_node.value)
if var_name in self.small_statement_replacements:
return self.small_statement_replacements[var_name].deep_clone()
return updated_node
def leave_SimpleStatementSuite(
self,
original_node: cst.SimpleStatementSuite,
updated_node: cst.SimpleStatementSuite,
) -> cst.BaseSuite:
# We can't use matchers here due to circular imports. We take advantage of
# the fact that a name in a simple suite will be parsed as an Expr node
# contained in a SimpleStatementSuite, so we check for these and see if they
# should be expanded template-wise to a base suite of some type.
if len(updated_node.body) == 1:
body_node = updated_node.body[0]
if isinstance(body_node, cst.Expr):
name_node = body_node.value
if isinstance(name_node, cst.Name):
var_name = unmangled_name(name_node.value)
if var_name in self.suite_replacements:
return self.suite_replacements[var_name].deep_clone()
return updated_node
def leave_IndentedBlock(
self,
original_node: cst.IndentedBlock,
updated_node: cst.IndentedBlock,
) -> cst.BaseSuite:
# We can't use matchers here due to circular imports. We take advantage of
# the fact that a name in an indented block will be parsed as an Expr node
# contained in a SimpleStatementLine, so we check for these and see if they
# should be expanded template-wise to a base suite of some type.
if len(updated_node.body) == 1:
statement_node = updated_node.body[0]
if (
isinstance(statement_node, cst.SimpleStatementLine)
and len(statement_node.body) == 1
):
body_node = statement_node.body[0]
if isinstance(body_node, cst.Expr):
name_node = body_node.value
if isinstance(name_node, cst.Name):
var_name = unmangled_name(name_node.value)
if var_name in self.suite_replacements:
return self.suite_replacements[var_name].deep_clone()
return updated_node
def leave_Index(
self,
original_node: cst.Index,
updated_node: cst.Index,
) -> cst.BaseSlice:
# We can't use matchers here due to circular imports
expr = updated_node.value
if isinstance(expr, cst.Name):
var_name = unmangled_name(expr.value)
if var_name in self.subscript_index_replacements:
return self.subscript_index_replacements[var_name].deep_clone()
return updated_node
def leave_SubscriptElement(
self,
original_node: cst.SubscriptElement,
updated_node: cst.SubscriptElement,
) -> cst.SubscriptElement:
# We can't use matchers here due to circular imports. We use the trick
# similar to above stanzas where a template replacement variable will
# always show up as a certain type (in this case an Index inside of a
# SubscriptElement) in order to successfully replace subscript elements
# in templates.
index = updated_node.slice
if isinstance(index, cst.Index):
expr = index.value
if isinstance(expr, cst.Name):
var_name = unmangled_name(expr.value)
if var_name in self.subscript_element_replacements:
return self.subscript_element_replacements[var_name].deep_clone()
return updated_node
def leave_Decorator(
self, original_node: cst.Decorator, updated_node: cst.Decorator
) -> cst.Decorator:
# We can't use matchers here due to circular imports
decorator = updated_node.decorator
if isinstance(decorator, cst.Name):
var_name = unmangled_name(decorator.value)
if var_name in self.decorator_replacements:
return self.decorator_replacements[var_name].deep_clone()
return updated_node
class TemplateChecker(cst.CSTVisitor):
def __init__(self, template_vars: Set[str]) -> None:
self.template_vars = template_vars
def visit_Name(self, node: cst.Name) -> None:
for var in self.template_vars:
if node.value == mangled_name(var):
raise ValueError(f'Template variable "{var}" was not replaced properly')
def unmangle_nodes(
tree: cst.CSTNode,
template_replacements: Mapping[str, ValidReplacementType],
) -> cst.CSTNode:
unmangler = TemplateTransformer(template_replacements)
return ensure_type(tree.visit(unmangler), cst.CSTNode)
_DEFAULT_PARTIAL_PARSER_CONFIG: cst.PartialParserConfig = cst.PartialParserConfig()
def parse_template_module(
template: str,
config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG,
**template_replacements: ValidReplacementType,
) -> cst.Module:
"""
Accepts an entire python module template, including all leading and trailing
whitespace. Any :class:`~libcst.CSTNode` provided as a keyword argument to
this function will be inserted into the template at the appropriate location
similar to an f-string expansion. For example::
module = parse_template_module("from {mod} import Foo\\n", mod=Name("bar"))
The above code will parse to a module containing a single
:class:`~libcst.FromImport` statement, referencing module ``bar`` and importing
object ``Foo`` from it. Remember that if you are parsing a template as part
of a substitution inside a transform, its considered
:ref:`best practice <libcst-config_best_practice>` to pass in a ``config``
from the current module under transformation.
Note that unlike :func:`~libcst.parse_module`, this function does not support
bytes as an input. This is due to the fact that it is processed as a template
before parsing as a module.
"""
source = mangle_template(template, {name for name in template_replacements})
module = cst.parse_module(source, config)
new_module = ensure_type(unmangle_nodes(module, template_replacements), cst.Module)
new_module.visit(TemplateChecker({name for name in template_replacements}))
return new_module
def parse_template_statement(
template: str,
config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG,
**template_replacements: ValidReplacementType,
) -> Union[cst.SimpleStatementLine, cst.BaseCompoundStatement]:
"""
Accepts a statement template followed by a trailing newline. If a trailing
newline is not provided, one will be added. Any :class:`~libcst.CSTNode`
provided as a keyword argument to this function will be inserted into the
template at the appropriate location similar to an f-string expansion. For
example::
statement = parse_template_statement("assert x > 0, {msg}", msg=SimpleString('"Uh oh!"'))
The above code will parse to an assert statement checking that some variable
``x`` is greater than zero, or providing the assert message ``"Uh oh!"``.
Remember that if you are parsing a template as part of a substitution inside
a transform, its considered :ref:`best practice <libcst-config_best_practice>`
to pass in a ``config`` from the current module under transformation.
"""
source = mangle_template(template, {name for name in template_replacements})
statement = cst.parse_statement(source, config)
new_statement = unmangle_nodes(statement, template_replacements)
if not isinstance(
new_statement, (cst.SimpleStatementLine, cst.BaseCompoundStatement)
):
raise TypeError(
f"Expected a statement but got a {new_statement.__class__.__qualname__}!"
)
new_statement.visit(TemplateChecker({name for name in template_replacements}))
return new_statement
def parse_template_expression(
template: str,
config: cst.PartialParserConfig = _DEFAULT_PARTIAL_PARSER_CONFIG,
**template_replacements: ValidReplacementType,
) -> cst.BaseExpression:
"""
Accepts an expression template on a single line. Leading and trailing whitespace
is not valid (theres nowhere to store it on the expression node). Any
:class:`~libcst.CSTNode` provided as a keyword argument to this function will
be inserted into the template at the appropriate location similar to an
f-string expansion. For example::
expression = parse_template_expression("x + {foo}", foo=Name("y")))
The above code will parse to a :class:`~libcst.BinaryOperation` expression
adding two names (``x`` and ``y``) together.
Remember that if you are parsing a template as part of a substitution inside
a transform, its considered :ref:`best practice <libcst-config_best_practice>`
to pass in a ``config`` from the current module under transformation.
"""
source = mangle_template(template, {name for name in template_replacements})
expression = cst.parse_expression(source, config)
new_expression = ensure_type(
unmangle_nodes(expression, template_replacements), cst.BaseExpression
)
new_expression.visit(TemplateChecker({name for name in template_replacements}))
return new_expression