diff --git a/libcst/codemod/commands/convert_precent_format_to_fstring.py b/libcst/codemod/commands/convert_precent_format_to_fstring.py new file mode 100644 index 00000000..2f2a9aa0 --- /dev/null +++ b/libcst/codemod/commands/convert_precent_format_to_fstring.py @@ -0,0 +1,126 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict +import re +from typing import Callable, cast + +import libcst as cst +import libcst.matchers as m +from libcst.codemod import VisitorBasedCodemodCommand + + +USE_FSTRING_SIMPLE_EXPRESSION_MAX_LENGTH = 30 + + +def _match_simple_string(node: cst.CSTNode) -> bool: + if isinstance(node, cst.SimpleString) and not node.prefix.lower().startswith("b"): + # SimpleString can be a bytes and fstring don't support bytes + return re.fullmatch("[^%]*(%s[^%]*)+", node.raw_value) is not None + return False + + +def _gen_match_simple_expression(module: cst.Module) -> Callable[[cst.CSTNode], bool]: + def _match_simple_expression(node: cst.CSTNode) -> bool: + # either each element in Tuple is simple expression or the entire expression is simple. + if ( + isinstance(node, cst.Tuple) + and all( + len(module.code_for_node(elm.value)) + < USE_FSTRING_SIMPLE_EXPRESSION_MAX_LENGTH + for elm in node.elements + ) + ) or len(module.code_for_node(node)) < USE_FSTRING_SIMPLE_EXPRESSION_MAX_LENGTH: + return True + return False + + return _match_simple_expression + + +class EscapeStringQuote(cst.CSTTransformer): + def __init__(self, quote: str) -> None: + self.quote = quote + super().__init__() + + def leave_SimpleString( + self, original_node: cst.SimpleString, updated_node: cst.SimpleString + ) -> cst.SimpleString: + if self.quote == original_node.quote: + for quo in ["'", '"', "'''", '"""']: + if quo != original_node.quote and quo not in original_node.raw_value: + escaped_string = cst.SimpleString( + original_node.prefix + quo + original_node.raw_value + quo + ) + if escaped_string.evaluated_value != original_node.evaluated_value: + raise Exception( + f"Failed to escape string:\n original:{original_node.value}\n escaped:{escaped_string.value}" + ) + else: + return escaped_string + raise Exception( + f"Cannot find a good quote for escaping the SimpleString: {original_node.value}" + ) + return original_node + + +class ConvertPercentFormatStringCommand(VisitorBasedCodemodCommand): + DESCRIPTION: str = "Converts simple % style string format to f-string." + + def leave_BinaryOperation( + self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation + ) -> cst.BaseExpression: + expr_key = "expr" + extracts = m.extract( + original_node, + m.BinaryOperation( + left=m.MatchIfTrue(_match_simple_string), + operator=m.Modulo(), + right=m.SaveMatchedNode( + m.MatchIfTrue(_gen_match_simple_expression(self.module)), expr_key, + ), + ), + ) + + if extracts: + expr = extracts[expr_key] + parts = [] + simple_string = cst.ensure_type(original_node.left, cst.SimpleString) + innards = simple_string.raw_value.replace("{", "{{").replace("}", "}}") + tokens = innards.split("%s") + token = tokens[0] + if len(token) > 0: + parts.append(cst.FormattedStringText(value=token)) + expressions = ( + [elm.value for elm in expr.elements] + if isinstance(expr, cst.Tuple) + else [expr] + ) + escape_transformer = EscapeStringQuote(simple_string.quote) + i = 1 + while i < len(tokens): + if i - 1 >= len(expressions): + # the %-string doesn't come with same number of elements in tuple + return original_node + try: + parts.append( + cst.FormattedStringExpression( + expression=cast( + cst.BaseExpression, + expressions[i - 1].visit(escape_transformer), + ) + ) + ) + except Exception: + return original_node + token = tokens[i] + if len(token) > 0: + parts.append(cst.FormattedStringText(value=token)) + i += 1 + start = f"f{simple_string.prefix}{simple_string.quote}" + return cst.FormattedString( + parts=parts, start=start, end=simple_string.quote + ) + + return original_node diff --git a/libcst/codemod/commands/tests/test_convert_percent_format_to_fstring.py b/libcst/codemod/commands/tests/test_convert_percent_format_to_fstring.py new file mode 100644 index 00000000..a9e5ff09 --- /dev/null +++ b/libcst/codemod/commands/tests/test_convert_percent_format_to_fstring.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# +# pyre-strict +from libcst.codemod import CodemodTest +from libcst.codemod.commands.convert_precent_format_to_fstring import ( + ConvertPercentFormatStringCommand, +) + + +class ConvertPercentFormatStringCommandTest(CodemodTest): + TRANSFORM = ConvertPercentFormatStringCommand + + def test_simple_cases(self) -> None: + self.assertCodemod('"a name: %s" % name', 'f"a name: {name}"') + self.assertCodemod( + '"an attribute %s ." % obj.attr', 'f"an attribute {obj.attr} ."' + ) + self.assertCodemod('r"raw string value=%s" % val', 'fr"raw string value={val}"') + self.assertCodemod( + '"The type of var: %s" % type(var)', 'f"The type of var: {type(var)}"' + ) + self.assertCodemod( + '"type of var: %s, value of var: %s" % (type(var), var)', + 'f"type of var: {type(var)}, value of var: {var}"', + ) + self.assertCodemod( + '"var1: %s, var2: %s, var3: %s, var4: %s" % (class_object.attribute, dict_lookup["some_key"], some_module.some_function(), var4)', + '''f"var1: {class_object.attribute}, var2: {dict_lookup['some_key']}, var3: {some_module.some_function()}, var4: {var4}"''', + ) + + def test_escaping(self) -> None: + self.assertCodemod('"%s" % "hi"', '''f"{'hi'}"''') # escape quote + self.assertCodemod('"{%s}" % val', 'f"{{{val}}}"') # escape curly bracket + self.assertCodemod('"{%s" % val', 'f"{{{val}"') # escape curly bracket + self.assertCodemod( + "'%s\" double quote is used' % var", "f'{var}\" double quote is used'" + ) # escape quote + self.assertCodemod( + '"a list: %s" % " ".join(var)', '''f"a list: {' '.join(var)}"''' + ) # escape quote + + def test_not_supported_case(self) -> None: + code = '"%s" % obj.this_is_a_very_long_expression(parameter)["a_very_long_key"]' + self.assertCodemod(code, code) + code = 'b"a type %s" % var' + self.assertCodemod(code, code)