gh-106727: Make inspect.getsource smarter for class for same name definitions (#106815)

This commit is contained in:
Tian Gao 2023-07-18 15:20:31 -08:00 committed by GitHub
parent 505eede38d
commit 663854d73b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 71 additions and 12 deletions

View file

@ -1034,9 +1034,13 @@ class ClassFoundException(Exception):
class _ClassFinder(ast.NodeVisitor): class _ClassFinder(ast.NodeVisitor):
def __init__(self, qualname): def __init__(self, cls, tree, lines, qualname):
self.stack = [] self.stack = []
self.cls = cls
self.tree = tree
self.lines = lines
self.qualname = qualname self.qualname = qualname
self.lineno_found = []
def visit_FunctionDef(self, node): def visit_FunctionDef(self, node):
self.stack.append(node.name) self.stack.append(node.name)
@ -1057,11 +1061,48 @@ class _ClassFinder(ast.NodeVisitor):
line_number = node.lineno line_number = node.lineno
# decrement by one since lines starts with indexing by zero # decrement by one since lines starts with indexing by zero
line_number -= 1 self.lineno_found.append((line_number - 1, node.end_lineno))
raise ClassFoundException(line_number)
self.generic_visit(node) self.generic_visit(node)
self.stack.pop() self.stack.pop()
def get_lineno(self):
self.visit(self.tree)
lineno_found_number = len(self.lineno_found)
if lineno_found_number == 0:
raise OSError('could not find class definition')
elif lineno_found_number == 1:
return self.lineno_found[0][0]
else:
# We have multiple candidates for the class definition.
# Now we have to guess.
# First, let's see if there are any method definitions
for member in self.cls.__dict__.values():
if isinstance(member, types.FunctionType):
for lineno, end_lineno in self.lineno_found:
if lineno <= member.__code__.co_firstlineno <= end_lineno:
return lineno
class_strings = [(''.join(self.lines[lineno: end_lineno]), lineno)
for lineno, end_lineno in self.lineno_found]
# Maybe the class has a docstring and it's unique?
if self.cls.__doc__:
ret = None
for candidate, lineno in class_strings:
if self.cls.__doc__.strip() in candidate:
if ret is None:
ret = lineno
else:
break
else:
if ret is not None:
return ret
# We are out of ideas, just return the last one found, which is
# slightly better than previous ones
return self.lineno_found[-1][0]
def findsource(object): def findsource(object):
"""Return the entire source file and starting line number for an object. """Return the entire source file and starting line number for an object.
@ -1098,14 +1139,8 @@ def findsource(object):
qualname = object.__qualname__ qualname = object.__qualname__
source = ''.join(lines) source = ''.join(lines)
tree = ast.parse(source) tree = ast.parse(source)
class_finder = _ClassFinder(qualname) class_finder = _ClassFinder(object, tree, lines, qualname)
try: return lines, class_finder.get_lineno()
class_finder.visit(tree)
except ClassFoundException as e:
line_number = e.args[0]
return lines, line_number
else:
raise OSError('could not find class definition')
if ismethod(object): if ismethod(object):
object = object.__func__ object = object.__func__

View file

@ -290,3 +290,23 @@ post_line_parenthesized_lambda1 = (lambda: ()
nested_lambda = ( nested_lambda = (
lambda right: [].map( lambda right: [].map(
lambda length: ())) lambda length: ()))
# line 294
if True:
class cls296:
def f():
pass
else:
class cls296:
def g():
pass
# line 304
if False:
class cls310:
def f():
pass
else:
class cls310:
def g():
pass

View file

@ -949,7 +949,6 @@ class TestBuggyCases(GetSourceBase):
self.assertSourceEqual(mod2.cls196.cls200, 198, 201) self.assertSourceEqual(mod2.cls196.cls200, 198, 201)
def test_class_inside_conditional(self): def test_class_inside_conditional(self):
self.assertSourceEqual(mod2.cls238, 238, 240)
self.assertSourceEqual(mod2.cls238.cls239, 239, 240) self.assertSourceEqual(mod2.cls238.cls239, 239, 240)
def test_multiple_children_classes(self): def test_multiple_children_classes(self):
@ -975,6 +974,10 @@ class TestBuggyCases(GetSourceBase):
self.assertSourceEqual(mod2.cls226, 231, 235) self.assertSourceEqual(mod2.cls226, 231, 235)
self.assertSourceEqual(asyncio.run(mod2.cls226().func232()), 233, 234) self.assertSourceEqual(asyncio.run(mod2.cls226().func232()), 233, 234)
def test_class_definition_same_name_diff_methods(self):
self.assertSourceEqual(mod2.cls296, 296, 298)
self.assertSourceEqual(mod2.cls310, 310, 312)
class TestNoEOL(GetSourceBase): class TestNoEOL(GetSourceBase):
def setUp(self): def setUp(self):
self.tempdir = TESTFN + '_dir' self.tempdir = TESTFN + '_dir'

View file

@ -0,0 +1 @@
Make :func:`inspect.getsource` smarter for class for same name definitions