bpo-28002: Roundtrip f-strings with ast.unparse better (#19612)

By attempting to avoid backslashes in f-string expressions.
We also now proactively raise errors for some backslashes we can't
avoid while unparsing FormattedValues

Co-authored-by: hauntsaninja <>
Co-authored-by: Shantanu <hauntsaninja@users.noreply.github.com>
Co-authored-by: Batuhan Taskaya <isidentical@gmail.com>
This commit is contained in:
Shantanu 2020-11-20 13:16:42 -08:00 committed by GitHub
parent 9fc319dc03
commit a993e901eb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 115 additions and 37 deletions

View file

@ -662,17 +662,23 @@ class _Precedence(IntEnum):
except ValueError:
return self
_SINGLE_QUOTES = ("'", '"')
_MULTI_QUOTES = ('"""', "'''")
_ALL_QUOTES = (*_SINGLE_QUOTES, *_MULTI_QUOTES)
class _Unparser(NodeVisitor):
"""Methods in this class recursively traverse an AST and
output source code for the abstract syntax; original formatting
is disregarded."""
def __init__(self):
def __init__(self, *, _avoid_backslashes=False):
self._source = []
self._buffer = []
self._precedences = {}
self._type_ignores = {}
self._indent = 0
self._avoid_backslashes = _avoid_backslashes
def interleave(self, inter, f, seq):
"""Call f on each item in seq, calling inter() in between."""
@ -1067,15 +1073,85 @@ class _Unparser(NodeVisitor):
with self.block(extra=self.get_type_comment(node)):
self.traverse(node.body)
def _str_literal_helper(
self, string, *, quote_types=_ALL_QUOTES, escape_special_whitespace=False
):
"""Helper for writing string literals, minimizing escapes.
Returns the tuple (string literal to write, possible quote types).
"""
def escape_char(c):
# \n and \t are non-printable, but we only escape them if
# escape_special_whitespace is True
if not escape_special_whitespace and c in "\n\t":
return c
# Always escape backslashes and other non-printable characters
if c == "\\" or not c.isprintable():
return c.encode("unicode_escape").decode("ascii")
return c
escaped_string = "".join(map(escape_char, string))
possible_quotes = quote_types
if "\n" in escaped_string:
possible_quotes = [q for q in possible_quotes if q in _MULTI_QUOTES]
possible_quotes = [q for q in possible_quotes if q not in escaped_string]
if not possible_quotes:
# If there aren't any possible_quotes, fallback to using repr
# on the original string. Try to use a quote from quote_types,
# e.g., so that we use triple quotes for docstrings.
string = repr(string)
quote = next((q for q in quote_types if string[0] in q), string[0])
return string[1:-1], [quote]
if escaped_string:
# Sort so that we prefer '''"''' over """\""""
possible_quotes.sort(key=lambda q: q[0] == escaped_string[-1])
# If we're using triple quotes and we'd need to escape a final
# quote, escape it
if possible_quotes[0][0] == escaped_string[-1]:
assert len(possible_quotes[0]) == 3
escaped_string = escaped_string[:-1] + "\\" + escaped_string[-1]
return escaped_string, possible_quotes
def _write_str_avoiding_backslashes(self, string, *, quote_types=_ALL_QUOTES):
"""Write string literal value with a best effort attempt to avoid backslashes."""
string, quote_types = self._str_literal_helper(string, quote_types=quote_types)
quote_type = quote_types[0]
self.write(f"{quote_type}{string}{quote_type}")
def visit_JoinedStr(self, node):
self.write("f")
self._fstring_JoinedStr(node, self.buffer_writer)
self.write(repr(self.buffer))
if self._avoid_backslashes:
self._fstring_JoinedStr(node, self.buffer_writer)
self._write_str_avoiding_backslashes(self.buffer)
return
# If we don't need to avoid backslashes globally (i.e., we only need
# to avoid them inside FormattedValues), it's cosmetically preferred
# to use escaped whitespace. That is, it's preferred to use backslashes
# for cases like: f"{x}\n". To accomplish this, we keep track of what
# in our buffer corresponds to FormattedValues and what corresponds to
# Constant parts of the f-string, and allow escapes accordingly.
buffer = []
for value in node.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, self.buffer_writer)
buffer.append((self.buffer, isinstance(value, Constant)))
new_buffer = []
quote_types = _ALL_QUOTES
for value, is_constant in buffer:
# Repeatedly narrow down the list of possible quote_types
value, quote_types = self._str_literal_helper(
value, quote_types=quote_types,
escape_special_whitespace=is_constant
)
new_buffer.append(value)
value = "".join(new_buffer)
quote_type = quote_types[0]
self.write(f"{quote_type}{value}{quote_type}")
def visit_FormattedValue(self, node):
self.write("f")
self._fstring_FormattedValue(node, self.buffer_writer)
self.write(repr(self.buffer))
self._write_str_avoiding_backslashes(self.buffer)
def _fstring_JoinedStr(self, node, write):
for value in node.values:
@ -1090,11 +1166,13 @@ class _Unparser(NodeVisitor):
def _fstring_FormattedValue(self, node, write):
write("{")
unparser = type(self)()
unparser = type(self)(_avoid_backslashes=True)
unparser.set_precedence(_Precedence.TEST.next(), node.value)
expr = unparser.visit(node.value)
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
if "\\" in expr:
raise ValueError("Unable to avoid backslash in f-string expression part")
write(expr)
if node.conversion != -1:
conversion = chr(node.conversion)
@ -1111,33 +1189,17 @@ class _Unparser(NodeVisitor):
self.write(node.id)
def _write_docstring(self, node):
def esc_char(c):
if c in ("\n", "\t"):
# In the AST form, we don't know the author's intentation
# about how this should be displayed. We'll only escape
# \n and \t, because they are more likely to be unescaped
# in the source
return c
return c.encode('unicode_escape').decode('ascii')
self.fill()
if node.kind == "u":
self.write("u")
value = node.value
if value:
# Preserve quotes in the docstring by escaping them
value = "".join(map(esc_char, value))
if value[-1] == '"':
value = value.replace('"', '\\"', -1)
value = value.replace('"""', '""\\"')
self.write(f'"""{value}"""')
self._write_str_avoiding_backslashes(node.value, quote_types=_MULTI_QUOTES)
def _write_constant(self, value):
if isinstance(value, (float, complex)):
# Substitute overflowing decimal literal for AST infinities.
self.write(repr(value).replace("inf", _INFSTR))
elif self._avoid_backslashes and isinstance(value, str):
self._write_str_avoiding_backslashes(value)
else:
self.write(repr(value))