gh-104400: pygettext: use an AST parser instead of a tokenizer (GH-104402)

This greatly simplifies the code and fixes many corner cases.
This commit is contained in:
Tomas R. 2025-02-11 12:51:42 +01:00 committed by GitHub
parent 1da412e574
commit 374abded07
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 174 additions and 251 deletions

View file

@ -15,26 +15,40 @@ msgstr ""
"Generated-By: pygettext.py 1.5\n" "Generated-By: pygettext.py 1.5\n"
#: docstrings.py:7 #: docstrings.py:1
#, docstring
msgid "Module docstring"
msgstr ""
#: docstrings.py:9
#, docstring #, docstring
msgid "" msgid ""
msgstr "" msgstr ""
#: docstrings.py:18 #: docstrings.py:15
#, docstring
msgid "docstring"
msgstr ""
#: docstrings.py:20
#, docstring #, docstring
msgid "" msgid ""
"multiline\n" "multiline\n"
" docstring\n" "docstring"
" "
msgstr "" msgstr ""
#: docstrings.py:25 #: docstrings.py:27
#, docstring #, docstring
msgid "docstring1" msgid "docstring1"
msgstr "" msgstr ""
#: docstrings.py:30 #: docstrings.py:38
#, docstring #, docstring
msgid "Hello, {}!" msgid "nested docstring"
msgstr ""
#: docstrings.py:43
#, docstring
msgid "nested class docstring"
msgstr "" msgstr ""

View file

@ -1,3 +1,5 @@
"""Module docstring"""
# Test docstring extraction # Test docstring extraction
from gettext import gettext as _ from gettext import gettext as _
@ -10,10 +12,10 @@ def test(x):
# Leading empty line # Leading empty line
def test2(x): def test2(x):
"""docstring""" # XXX This should be extracted but isn't. """docstring"""
# XXX Multiline docstrings should be cleaned with `inspect.cleandoc`. # Multiline docstrings are cleaned with `inspect.cleandoc`.
def test3(x): def test3(x):
"""multiline """multiline
docstring docstring
@ -27,15 +29,15 @@ def test4(x):
def test5(x): def test5(x):
"""Hello, {}!""".format("world!") # XXX This should not be extracted. """Hello, {}!""".format("world!") # This should not be extracted.
# Nested docstrings # Nested docstrings
def test6(x): def test6(x):
def inner(y): def inner(y):
"""nested docstring""" # XXX This should be extracted but isn't. """nested docstring"""
class Outer: class Outer:
class Inner: class Inner:
"nested class docstring" # XXX This should be extracted but isn't. "nested class docstring"

View file

@ -19,22 +19,22 @@ msgstr ""
msgid "" msgid ""
msgstr "" msgstr ""
#: messages.py:19 messages.py:20 #: messages.py:19 messages.py:20 messages.py:21
msgid "parentheses" msgid "parentheses"
msgstr "" msgstr ""
#: messages.py:23 #: messages.py:24
msgid "Hello, world!" msgid "Hello, world!"
msgstr "" msgstr ""
#: messages.py:26 #: messages.py:27
msgid "" msgid ""
"Hello,\n" "Hello,\n"
" multiline!\n" " multiline!\n"
msgstr "" msgstr ""
#: messages.py:46 messages.py:89 messages.py:90 messages.py:93 messages.py:94 #: messages.py:46 messages.py:89 messages.py:90 messages.py:93 messages.py:94
#: messages.py:99 #: messages.py:99 messages.py:100 messages.py:101
msgid "foo" msgid "foo"
msgid_plural "foos" msgid_plural "foos"
msgstr[0] "" msgstr[0] ""
@ -68,7 +68,7 @@ msgstr ""
msgid "set" msgid "set"
msgstr "" msgstr ""
#: messages.py:63 #: messages.py:62 messages.py:63
msgid "nested string" msgid "nested string"
msgstr "" msgstr ""
@ -76,6 +76,10 @@ msgstr ""
msgid "baz" msgid "baz"
msgstr "" msgstr ""
#: messages.py:71 messages.py:75
msgid "default value"
msgstr ""
#: messages.py:91 messages.py:92 messages.py:95 messages.py:96 #: messages.py:91 messages.py:92 messages.py:95 messages.py:96
msgctxt "context" msgctxt "context"
msgid "foo" msgid "foo"
@ -83,7 +87,13 @@ msgid_plural "foos"
msgstr[0] "" msgstr[0] ""
msgstr[1] "" msgstr[1] ""
#: messages.py:100 #: messages.py:102
msgid "domain foo" msgid "domain foo"
msgstr "" msgstr ""
#: messages.py:118 messages.py:119
msgid "world"
msgid_plural "worlds"
msgstr[0] ""
msgstr[1] ""

View file

@ -18,6 +18,7 @@ _("")
# Extra parentheses # Extra parentheses
(_("parentheses")) (_("parentheses"))
((_("parentheses"))) ((_("parentheses")))
_(("parentheses"))
# Multiline strings # Multiline strings
_("Hello, " _("Hello, "
@ -32,7 +33,6 @@ _()
_(None) _(None)
_(1) _(1)
_(False) _(False)
_(("invalid"))
_(["invalid"]) _(["invalid"])
_({"invalid"}) _({"invalid"})
_("string"[3]) _("string"[3])
@ -40,7 +40,7 @@ _("string"[:3])
_({"string": "foo"}) _({"string": "foo"})
# pygettext does not allow keyword arguments, but both xgettext and pybabel do # pygettext does not allow keyword arguments, but both xgettext and pybabel do
_(x="kwargs work!") _(x="kwargs are not allowed!")
# Unusual, but valid arguments # Unusual, but valid arguments
_("foo", "bar") _("foo", "bar")
@ -48,7 +48,7 @@ _("something", x="something else")
# .format() # .format()
_("Hello, {}!").format("world") # valid _("Hello, {}!").format("world") # valid
_("Hello, {}!".format("world")) # invalid, but xgettext and pybabel extract the first string _("Hello, {}!".format("world")) # invalid, but xgettext extracts the first string
# Nested structures # Nested structures
_("1"), _("2") _("1"), _("2")
@ -59,7 +59,7 @@ obj = {'a': _("A"), 'b': _("B")}
# Nested functions and classes # Nested functions and classes
def test(): def test():
_("nested string") # XXX This should be extracted but isn't. _("nested string")
[_("nested string")] [_("nested string")]
@ -68,11 +68,11 @@ class Foo:
return _("baz") return _("baz")
def bar(x=_('default value')): # XXX This should be extracted but isn't. def bar(x=_('default value')):
pass pass
def baz(x=[_('default value')]): # XXX This should be extracted but isn't. def baz(x=[_('default value')]):
pass pass
@ -97,6 +97,8 @@ dnpgettext("domain", "context", "foo", "foos", 1)
# Complex arguments # Complex arguments
ngettext("foo", "foos", 42 + (10 - 20)) ngettext("foo", "foos", 42 + (10 - 20))
ngettext("foo", "foos", *args)
ngettext("foo", "foos", **kwargs)
dgettext(["some", {"complex"}, ("argument",)], "domain foo") dgettext(["some", {"complex"}, ("argument",)], "domain foo")
# Invalid calls which are not extracted # Invalid calls which are not extracted
@ -108,3 +110,10 @@ dgettext('domain')
dngettext('domain', 'foo') dngettext('domain', 'foo')
dpgettext('domain', 'context') dpgettext('domain', 'context')
dnpgettext('domain', 'context', 'foo') dnpgettext('domain', 'context', 'foo')
dgettext(*args, 'foo')
dpgettext(*args, 'context', 'foo')
dnpgettext(*args, 'context', 'foo', 'foos')
# f-strings
f"Hello, {_('world')}!"
f"Hello, {ngettext('world', 'worlds', 3)}!"

View file

@ -87,7 +87,7 @@ class Test_pygettext(unittest.TestCase):
self.maxDiff = None self.maxDiff = None
self.assertEqual(normalize_POT_file(expected), normalize_POT_file(actual)) self.assertEqual(normalize_POT_file(expected), normalize_POT_file(actual))
def extract_from_str(self, module_content, *, args=(), strict=True): def extract_from_str(self, module_content, *, args=(), strict=True, with_stderr=False):
"""Return all msgids extracted from module_content.""" """Return all msgids extracted from module_content."""
filename = 'test.py' filename = 'test.py'
with temp_cwd(None): with temp_cwd(None):
@ -98,12 +98,18 @@ class Test_pygettext(unittest.TestCase):
self.assertEqual(res.err, b'') self.assertEqual(res.err, b'')
with open('messages.pot', encoding='utf-8') as fp: with open('messages.pot', encoding='utf-8') as fp:
data = fp.read() data = fp.read()
return self.get_msgids(data) msgids = self.get_msgids(data)
if not with_stderr:
return msgids
return msgids, res.err
def extract_docstrings_from_str(self, module_content): def extract_docstrings_from_str(self, module_content):
"""Return all docstrings extracted from module_content.""" """Return all docstrings extracted from module_content."""
return self.extract_from_str(module_content, args=('--docstrings',), strict=False) return self.extract_from_str(module_content, args=('--docstrings',), strict=False)
def get_stderr(self, module_content):
return self.extract_from_str(module_content, strict=False, with_stderr=True)[1]
def test_header(self): def test_header(self):
"""Make sure the required fields are in the header, according to: """Make sure the required fields are in the header, according to:
http://www.gnu.org/software/gettext/manual/gettext.html#Header-Entry http://www.gnu.org/software/gettext/manual/gettext.html#Header-Entry
@ -407,6 +413,24 @@ class Test_pygettext(unittest.TestCase):
self.assertIn(f'msgid "{text2}"', data) self.assertIn(f'msgid "{text2}"', data)
self.assertNotIn(text3, data) self.assertNotIn(text3, data)
def test_error_messages(self):
"""Test that pygettext outputs error messages to stderr."""
stderr = self.get_stderr(dedent('''\
_(1+2)
ngettext('foo')
dgettext(*args, 'foo')
'''))
# Normalize line endings on Windows
stderr = stderr.decode('utf-8').replace('\r', '')
self.assertEqual(
stderr,
"*** test.py:1: Expected a string constant for argument 1, got 1 + 2\n"
"*** test.py:2: Expected at least 2 positional argument(s) in gettext call, got 1\n"
"*** test.py:3: Variable positional arguments are not allowed in gettext calls\n"
)
def update_POT_snapshots(): def update_POT_snapshots():
for input_file in DATA_DIR.glob('*.py'): for input_file in DATA_DIR.glob('*.py'):

View file

@ -0,0 +1 @@
Fix several bugs in extraction by switching to an AST parser in :program:`pygettext`.

View file

@ -140,8 +140,6 @@ import importlib.util
import os import os
import sys import sys
import time import time
import tokenize
from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from operator import itemgetter from operator import itemgetter
@ -206,15 +204,6 @@ def escape_nonascii(s, encoding):
return ''.join(escapes[b] for b in s.encode(encoding)) return ''.join(escapes[b] for b in s.encode(encoding))
def is_literal_string(s):
return s[0] in '\'"' or (s[0] in 'rRuU' and s[1] in '\'"')
def safe_eval(s):
# unwrap quotes, safely
return eval(s, {'__builtins__':{}}, {})
def normalize(s, encoding): def normalize(s, encoding):
# This converts the various Python string types into a format that is # This converts the various Python string types into a format that is
# appropriate for .po files, namely much closer to C style. # appropriate for .po files, namely much closer to C style.
@ -296,11 +285,6 @@ DEFAULTKEYWORDS = {
} }
def matches_spec(message, spec):
"""Check if a message has all the keys defined by the keyword spec."""
return all(key in message for key in spec.values())
@dataclass(frozen=True) @dataclass(frozen=True)
class Location: class Location:
filename: str filename: str
@ -325,203 +309,91 @@ class Message:
self.is_docstring |= is_docstring self.is_docstring |= is_docstring
class TokenEater: class GettextVisitor(ast.NodeVisitor):
def __init__(self, options): def __init__(self, options):
self.__options = options super().__init__()
self.__messages = {} self.options = options
self.__state = self.__waiting self.filename = None
self.__data = defaultdict(str) self.messages = {}
self.__curr_arg = 0
self.__curr_keyword = None
self.__lineno = -1
self.__freshmodule = 1
self.__curfile = None
self.__enclosurecount = 0
def __call__(self, ttype, tstring, stup, etup, line): def visit_file(self, node, filename):
# dispatch self.filename = filename
## import token self.visit(node)
## print('ttype:', token.tok_name[ttype], 'tstring:', tstring,
## file=sys.stderr)
self.__state(ttype, tstring, stup[0])
@property def visit_Module(self, node):
def messages(self): self._extract_docstring(node)
return self.__messages self.generic_visit(node)
def __waiting(self, ttype, tstring, lineno): visit_ClassDef = visit_FunctionDef = visit_AsyncFunctionDef = visit_Module
opts = self.__options
# Do docstring extractions, if enabled
if opts.docstrings and not opts.nodocstrings.get(self.__curfile):
# module docstring?
if self.__freshmodule:
if ttype == tokenize.STRING and is_literal_string(tstring):
self.__addentry({'msgid': safe_eval(tstring)}, lineno, is_docstring=True)
self.__freshmodule = 0
return
if ttype in (tokenize.COMMENT, tokenize.NL, tokenize.ENCODING):
return
self.__freshmodule = 0
# class or func/method docstring?
if ttype == tokenize.NAME and tstring in ('class', 'def'):
self.__state = self.__suiteseen
return
if ttype == tokenize.NAME and tstring in ('class', 'def'):
self.__state = self.__ignorenext
return
if ttype == tokenize.NAME and tstring in opts.keywords:
self.__state = self.__keywordseen
self.__curr_keyword = tstring
return
if ttype == tokenize.STRING:
maybe_fstring = ast.parse(tstring, mode='eval').body
if not isinstance(maybe_fstring, ast.JoinedStr):
return
for value in filter(lambda node: isinstance(node, ast.FormattedValue),
maybe_fstring.values):
for call in filter(lambda node: isinstance(node, ast.Call),
ast.walk(value)):
func = call.func
if isinstance(func, ast.Name):
func_name = func.id
elif isinstance(func, ast.Attribute):
func_name = func.attr
else:
continue
if func_name not in opts.keywords: def visit_Call(self, node):
continue self._extract_message(node)
if len(call.args) != 1: self.generic_visit(node)
print((
'*** %(file)s:%(lineno)s: Seen unexpected amount of'
' positional arguments in gettext call: %(source_segment)s'
) % {
'source_segment': ast.get_source_segment(tstring, call) or tstring,
'file': self.__curfile,
'lineno': lineno
}, file=sys.stderr)
continue
if call.keywords:
print((
'*** %(file)s:%(lineno)s: Seen unexpected keyword arguments'
' in gettext call: %(source_segment)s'
) % {
'source_segment': ast.get_source_segment(tstring, call) or tstring,
'file': self.__curfile,
'lineno': lineno
}, file=sys.stderr)
continue
arg = call.args[0]
if not isinstance(arg, ast.Constant):
print((
'*** %(file)s:%(lineno)s: Seen unexpected argument type'
' in gettext call: %(source_segment)s'
) % {
'source_segment': ast.get_source_segment(tstring, call) or tstring,
'file': self.__curfile,
'lineno': lineno
}, file=sys.stderr)
continue
if isinstance(arg.value, str):
self.__curr_keyword = func_name
self.__addentry({'msgid': arg.value}, lineno)
def __suiteseen(self, ttype, tstring, lineno): def _extract_docstring(self, node):
# skip over any enclosure pairs until we see the colon if (not self.options.docstrings or
if ttype == tokenize.OP: self.options.nodocstrings.get(self.filename)):
if tstring == ':' and self.__enclosurecount == 0:
# we see a colon and we're not in an enclosure: end of def
self.__state = self.__suitedocstring
elif tstring in '([{':
self.__enclosurecount += 1
elif tstring in ')]}':
self.__enclosurecount -= 1
def __suitedocstring(self, ttype, tstring, lineno):
# ignore any intervening noise
if ttype == tokenize.STRING and is_literal_string(tstring):
self.__addentry({'msgid': safe_eval(tstring)}, lineno, is_docstring=True)
self.__state = self.__waiting
elif ttype not in (tokenize.NEWLINE, tokenize.INDENT,
tokenize.COMMENT):
# there was no class docstring
self.__state = self.__waiting
def __keywordseen(self, ttype, tstring, lineno):
if ttype == tokenize.OP and tstring == '(':
self.__data.clear()
self.__curr_arg = 0
self.__enclosurecount = 0
self.__lineno = lineno
self.__state = self.__openseen
else:
self.__state = self.__waiting
def __openseen(self, ttype, tstring, lineno):
spec = self.__options.keywords[self.__curr_keyword]
arg_type = spec.get(self.__curr_arg)
expect_string_literal = arg_type is not None
if ttype == tokenize.OP and self.__enclosurecount == 0:
if tstring == ')':
# We've seen the last of the translatable strings. Record the
# line number of the first line of the strings and update the list
# of messages seen. Reset state for the next batch. If there
# were no strings inside _(), then just ignore this entry.
if self.__data:
self.__addentry(self.__data)
self.__state = self.__waiting
return
elif tstring == ',':
# Advance to the next argument
self.__curr_arg += 1
return return
if expect_string_literal: docstring = ast.get_docstring(node)
if ttype == tokenize.STRING and is_literal_string(tstring): if docstring is not None:
self.__data[arg_type] += safe_eval(tstring) lineno = node.body[0].lineno # The first statement is the docstring
elif ttype not in (tokenize.COMMENT, tokenize.INDENT, tokenize.DEDENT, self._add_message(lineno, docstring, is_docstring=True)
tokenize.NEWLINE, tokenize.NL):
# We are inside an argument which is a translatable string and
# we encountered a token that is not a string. This is an error.
self.warn_unexpected_token(tstring)
self.__enclosurecount = 0
self.__state = self.__waiting
elif ttype == tokenize.OP:
if tstring in '([{':
self.__enclosurecount += 1
elif tstring in ')]}':
self.__enclosurecount -= 1
def __ignorenext(self, ttype, tstring, lineno): def _extract_message(self, node):
self.__state = self.__waiting func_name = self._get_func_name(node)
spec = self.options.keywords.get(func_name)
if spec is None:
return
def __addentry(self, msg, lineno=None, *, is_docstring=False): max_index = max(spec)
msgid = msg.get('msgid') has_var_positional = any(isinstance(arg, ast.Starred) for
if msgid in self.__options.toexclude: arg in node.args[:max_index+1])
if has_var_positional:
print(f'*** {self.filename}:{node.lineno}: Variable positional '
f'arguments are not allowed in gettext calls', file=sys.stderr)
return return
if not is_docstring:
spec = self.__options.keywords[self.__curr_keyword] if max_index >= len(node.args):
if not matches_spec(msg, spec): print(f'*** {self.filename}:{node.lineno}: Expected at least '
f'{max(spec) + 1} positional argument(s) in gettext call, '
f'got {len(node.args)}', file=sys.stderr)
return return
if lineno is None:
lineno = self.__lineno msg_data = {}
msgctxt = msg.get('msgctxt') for position, arg_type in spec.items():
msgid_plural = msg.get('msgid_plural') arg = node.args[position]
if not self._is_string_const(arg):
print(f'*** {self.filename}:{arg.lineno}: Expected a string '
f'constant for argument {position + 1}, '
f'got {ast.unparse(arg)}', file=sys.stderr)
return
msg_data[arg_type] = arg.value
lineno = node.lineno
self._add_message(lineno, **msg_data)
def _add_message(
self, lineno, msgid, msgid_plural=None, msgctxt=None, *,
is_docstring=False):
if msgid in self.options.toexclude:
return
key = self._key_for(msgid, msgctxt) key = self._key_for(msgid, msgctxt)
if key in self.__messages: message = self.messages.get(key)
self.__messages[key].add_location( if message:
self.__curfile, message.add_location(
self.filename,
lineno, lineno,
msgid_plural, msgid_plural,
is_docstring=is_docstring, is_docstring=is_docstring,
) )
else: else:
self.__messages[key] = Message( self.messages[key] = Message(
msgid=msgid, msgid=msgid,
msgid_plural=msgid_plural, msgid_plural=msgid_plural,
msgctxt=msgctxt, msgctxt=msgctxt,
locations={Location(self.__curfile, lineno)}, locations={Location(self.filename, lineno)},
is_docstring=is_docstring, is_docstring=is_docstring,
) )
@ -531,19 +403,17 @@ class TokenEater:
return (msgctxt, msgid) return (msgctxt, msgid)
return msgid return msgid
def warn_unexpected_token(self, token): def _get_func_name(self, node):
print(( match node.func:
'*** %(file)s:%(lineno)s: Seen unexpected token "%(token)s"' case ast.Name(id=id):
) % { return id
'token': token, case ast.Attribute(attr=attr):
'file': self.__curfile, return attr
'lineno': self.__lineno case _:
}, file=sys.stderr) return None
def set_filename(self, filename):
self.__curfile = filename
self.__freshmodule = 1
def _is_string_const(self, node):
return isinstance(node, ast.Constant) and isinstance(node.value, str)
def write_pot_file(messages, options, fp): def write_pot_file(messages, options, fp):
timestamp = time.strftime('%Y-%m-%d %H:%M%z') timestamp = time.strftime('%Y-%m-%d %H:%M%z')
@ -717,31 +587,24 @@ def main():
args = expanded args = expanded
# slurp through all the files # slurp through all the files
eater = TokenEater(options) visitor = GettextVisitor(options)
for filename in args: for filename in args:
if filename == '-': if filename == '-':
if options.verbose: if options.verbose:
print('Reading standard input') print('Reading standard input')
fp = sys.stdin.buffer source = sys.stdin.buffer.read()
closep = 0
else: else:
if options.verbose: if options.verbose:
print(f'Working on {filename}') print(f'Working on {filename}')
fp = open(filename, 'rb') with open(filename, 'rb') as fp:
closep = 1 source = fp.read()
try: try:
eater.set_filename(filename) module_tree = ast.parse(source)
try: except SyntaxError:
tokens = tokenize.tokenize(fp.readline) continue
for _token in tokens:
eater(*_token) visitor.visit_file(module_tree, filename)
except tokenize.TokenError as e:
print('%s: %s, line %d, column %d' % (
e.args[0], filename, e.args[1][0], e.args[1][1]),
file=sys.stderr)
finally:
if closep:
fp.close()
# write the output # write the output
if options.outfile == '-': if options.outfile == '-':
@ -753,7 +616,7 @@ def main():
fp = open(options.outfile, 'w') fp = open(options.outfile, 'w')
closep = 1 closep = 1
try: try:
write_pot_file(eater.messages, options, fp) write_pot_file(visitor.messages, options, fp)
finally: finally:
if closep: if closep:
fp.close() fp.close()