bpo-43417: Better buffer handling for ast.unparse (GH-24772)

This commit is contained in:
Batuhan Taskaya 2021-05-09 02:32:04 +03:00 committed by GitHub
parent a0bd9e9c11
commit 3d98ececda
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 87 additions and 66 deletions

View file

@ -678,7 +678,6 @@ class _Unparser(NodeVisitor):
def __init__(self, *, _avoid_backslashes=False):
self._source = []
self._buffer = []
self._precedences = {}
self._type_ignores = {}
self._indent = 0
@ -721,14 +720,15 @@ class _Unparser(NodeVisitor):
"""Append a piece of text"""
self._source.append(text)
def buffer_writer(self, text):
self._buffer.append(text)
@contextmanager
def buffered(self, buffer = None):
if buffer is None:
buffer = []
@property
def buffer(self):
value = "".join(self._buffer)
self._buffer.clear()
return value
original_source = self._source
self._source = buffer
yield buffer
self._source = original_source
@contextmanager
def block(self, *, extra = None):
@ -1127,9 +1127,9 @@ class _Unparser(NodeVisitor):
def visit_JoinedStr(self, node):
self.write("f")
if self._avoid_backslashes:
self._fstring_JoinedStr(node, self.buffer_writer)
self._write_str_avoiding_backslashes(self.buffer)
return
with self.buffered() as buffer:
self._write_fstring_inner(node)
return self._write_str_avoiding_backslashes("".join(buffer))
# If we don't need to avoid backslashes globally (i.e., we only need
# to avoid them inside FormattedValues), it's cosmetically preferred
@ -1137,60 +1137,62 @@ class _Unparser(NodeVisitor):
# for cases like: f"{x}\n". To accomplish this, we keep track of what
# in our buffer corresponds to FormattedValues and what corresponds to
# Constant parts of the f-string, and allow escapes accordingly.
buffer = []
fstring_parts = []
for value in node.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, self.buffer_writer)
buffer.append((self.buffer, isinstance(value, Constant)))
new_buffer = []
quote_types = _ALL_QUOTES
for value, is_constant in buffer:
# Repeatedly narrow down the list of possible quote_types
value, quote_types = self._str_literal_helper(
value, quote_types=quote_types,
escape_special_whitespace=is_constant
with self.buffered() as buffer:
self._write_fstring_inner(value)
fstring_parts.append(
("".join(buffer), isinstance(value, Constant))
)
new_buffer.append(value)
value = "".join(new_buffer)
new_fstring_parts = []
quote_types = list(_ALL_QUOTES)
for value, is_constant in fstring_parts:
value, quote_types = self._str_literal_helper(
value,
quote_types=quote_types,
escape_special_whitespace=is_constant,
)
new_fstring_parts.append(value)
value = "".join(new_fstring_parts)
quote_type = quote_types[0]
self.write(f"{quote_type}{value}{quote_type}")
def _write_fstring_inner(self, node):
if isinstance(node, JoinedStr):
# for both the f-string itself, and format_spec
for value in node.values:
self._write_fstring_inner(value)
elif isinstance(node, Constant) and isinstance(node.value, str):
value = node.value.replace("{", "{{").replace("}", "}}")
self.write(value)
elif isinstance(node, FormattedValue):
self.visit_FormattedValue(node)
else:
raise ValueError(f"Unexpected node inside JoinedStr, {node!r}")
def visit_FormattedValue(self, node):
self.write("f")
self._fstring_FormattedValue(node, self.buffer_writer)
self._write_str_avoiding_backslashes(self.buffer)
def unparse_inner(inner):
unparser = type(self)(_avoid_backslashes=True)
unparser.set_precedence(_Precedence.TEST.next(), inner)
return unparser.visit(inner)
def _fstring_JoinedStr(self, node, write):
for value in node.values:
meth = getattr(self, "_fstring_" + type(value).__name__)
meth(value, write)
def _fstring_Constant(self, node, write):
if not isinstance(node.value, str):
raise ValueError("Constants inside JoinedStr should be a string.")
value = node.value.replace("{", "{{").replace("}", "}}")
write(value)
def _fstring_FormattedValue(self, node, write):
write("{")
unparser = type(self)(_avoid_backslashes=True)
unparser.set_precedence(_Precedence.TEST.next(), node.value)
expr = unparser.visit(node.value)
if expr.startswith("{"):
write(" ") # Separate pair of opening brackets as "{ {"
if "\\" in expr:
raise ValueError("Unable to avoid backslash in f-string expression part")
write(expr)
if node.conversion != -1:
conversion = chr(node.conversion)
if conversion not in "sra":
raise ValueError("Unknown f-string conversion.")
write(f"!{conversion}")
if node.format_spec:
write(":")
meth = getattr(self, "_fstring_" + type(node.format_spec).__name__)
meth(node.format_spec, write)
write("}")
with self.delimit("{", "}"):
expr = unparse_inner(node.value)
if "\\" in expr:
raise ValueError(
"Unable to avoid backslash in f-string expression part"
)
if expr.startswith("{"):
# Separate pair of opening brackets as "{ {"
self.write(" ")
self.write(expr)
if node.conversion != -1:
self.write(f"!{chr(node.conversion)}")
if node.format_spec:
self.write(":")
self._write_fstring_inner(node.format_spec)
def visit_Name(self, node):
self.write(node.id)