mirror of
https://github.com/python/cpython.git
synced 2025-07-23 11:15:24 +00:00
merge 2to3 improvments
This commit is contained in:
parent
afcd5f36f0
commit
dd6a4edc45
18 changed files with 405 additions and 140 deletions
|
@ -33,6 +33,8 @@ class BaseFix(object):
|
||||||
explicit = False # Is this ignored by refactor.py -f all?
|
explicit = False # Is this ignored by refactor.py -f all?
|
||||||
run_order = 5 # Fixers will be sorted by run order before execution
|
run_order = 5 # Fixers will be sorted by run order before execution
|
||||||
# Lower numbers will be run first.
|
# Lower numbers will be run first.
|
||||||
|
_accept_type = None # [Advanced and not public] This tells RefactoringTool
|
||||||
|
# which node type to accept when there's not a pattern.
|
||||||
|
|
||||||
# Shortcut for access to Python grammar symbols
|
# Shortcut for access to Python grammar symbols
|
||||||
syms = pygram.python_symbols
|
syms = pygram.python_symbols
|
||||||
|
|
|
@ -12,7 +12,7 @@ Becomes:
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .. import fixer_base
|
from .. import fixer_base
|
||||||
from os.path import dirname, join, exists, pathsep
|
from os.path import dirname, join, exists, sep
|
||||||
from ..fixer_util import FromImport, syms, token
|
from ..fixer_util import FromImport, syms, token
|
||||||
|
|
||||||
|
|
||||||
|
@ -84,7 +84,7 @@ class FixImport(fixer_base.BaseFix):
|
||||||
# so can't be a relative import.
|
# so can't be a relative import.
|
||||||
if not exists(join(dirname(base_path), '__init__.py')):
|
if not exists(join(dirname(base_path), '__init__.py')):
|
||||||
return False
|
return False
|
||||||
for ext in ['.py', pathsep, '.pyc', '.so', '.sl', '.pyd']:
|
for ext in ['.py', sep, '.pyc', '.so', '.sl', '.pyd']:
|
||||||
if exists(base_path + ext):
|
if exists(base_path + ext):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
|
@ -84,8 +84,6 @@ def build_pattern(mapping=MAPPING):
|
||||||
|
|
||||||
class FixImports(fixer_base.BaseFix):
|
class FixImports(fixer_base.BaseFix):
|
||||||
|
|
||||||
order = "pre" # Pre-order tree traversal
|
|
||||||
|
|
||||||
# This is overridden in fix_imports2.
|
# This is overridden in fix_imports2.
|
||||||
mapping = MAPPING
|
mapping = MAPPING
|
||||||
|
|
||||||
|
|
|
@ -5,18 +5,15 @@
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .. import fixer_base
|
from lib2to3 import fixer_base
|
||||||
from ..fixer_util import Name, Number, is_probably_builtin
|
from lib2to3.fixer_util import is_probably_builtin
|
||||||
|
|
||||||
|
|
||||||
class FixLong(fixer_base.BaseFix):
|
class FixLong(fixer_base.BaseFix):
|
||||||
|
|
||||||
PATTERN = "'long'"
|
PATTERN = "'long'"
|
||||||
|
|
||||||
static_int = Name("int")
|
|
||||||
|
|
||||||
def transform(self, node, results):
|
def transform(self, node, results):
|
||||||
if is_probably_builtin(node):
|
if is_probably_builtin(node):
|
||||||
new = self.static_int.clone()
|
node.value = "int"
|
||||||
new.prefix = node.prefix
|
node.changed()
|
||||||
return new
|
|
||||||
|
|
|
@ -12,9 +12,11 @@ from .. import fixer_base
|
||||||
class FixNe(fixer_base.BaseFix):
|
class FixNe(fixer_base.BaseFix):
|
||||||
# This is so simple that we don't need the pattern compiler.
|
# This is so simple that we don't need the pattern compiler.
|
||||||
|
|
||||||
|
_accept_type = token.NOTEQUAL
|
||||||
|
|
||||||
def match(self, node):
|
def match(self, node):
|
||||||
# Override
|
# Override
|
||||||
return node.type == token.NOTEQUAL and node.value == "<>"
|
return node.value == "<>"
|
||||||
|
|
||||||
def transform(self, node, results):
|
def transform(self, node, results):
|
||||||
new = pytree.Leaf(token.NOTEQUAL, "!=", prefix=node.prefix)
|
new = pytree.Leaf(token.NOTEQUAL, "!=", prefix=node.prefix)
|
||||||
|
|
|
@ -12,10 +12,11 @@ from ..fixer_util import Number
|
||||||
class FixNumliterals(fixer_base.BaseFix):
|
class FixNumliterals(fixer_base.BaseFix):
|
||||||
# This is so simple that we don't need the pattern compiler.
|
# This is so simple that we don't need the pattern compiler.
|
||||||
|
|
||||||
|
_accept_type = token.NUMBER
|
||||||
|
|
||||||
def match(self, node):
|
def match(self, node):
|
||||||
# Override
|
# Override
|
||||||
return (node.type == token.NUMBER and
|
return (node.value.startswith("0") or node.value[-1] in "Ll")
|
||||||
(node.value.startswith("0") or node.value[-1] in "Ll"))
|
|
||||||
|
|
||||||
def transform(self, node, results):
|
def transform(self, node, results):
|
||||||
val = node.value
|
val = node.value
|
||||||
|
|
40
Lib/lib2to3/fixes/fix_operator.py
Normal file
40
Lib/lib2to3/fixes/fix_operator.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
"""Fixer for operator.{isCallable,sequenceIncludes}
|
||||||
|
|
||||||
|
operator.isCallable(obj) -> hasattr(obj, '__call__')
|
||||||
|
operator.sequenceIncludes(obj) -> operator.contains(obj)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Local imports
|
||||||
|
from .. import fixer_base
|
||||||
|
from ..fixer_util import Call, Name, String
|
||||||
|
|
||||||
|
class FixOperator(fixer_base.BaseFix):
|
||||||
|
|
||||||
|
methods = "method=('isCallable'|'sequenceIncludes')"
|
||||||
|
func = "'(' func=any ')'"
|
||||||
|
PATTERN = """
|
||||||
|
power< module='operator'
|
||||||
|
trailer< '.' {methods} > trailer< {func} > >
|
||||||
|
|
|
||||||
|
power< {methods} trailer< {func} > >
|
||||||
|
""".format(methods=methods, func=func)
|
||||||
|
|
||||||
|
def transform(self, node, results):
|
||||||
|
method = results["method"][0]
|
||||||
|
|
||||||
|
if method.value == "sequenceIncludes":
|
||||||
|
if "module" not in results:
|
||||||
|
# operator may not be in scope, so we can't make a change.
|
||||||
|
self.warning(node, "You should use operator.contains here.")
|
||||||
|
else:
|
||||||
|
method.value = "contains"
|
||||||
|
method.changed()
|
||||||
|
elif method.value == "isCallable":
|
||||||
|
if "module" not in results:
|
||||||
|
self.warning(node,
|
||||||
|
"You should use hasattr(%s, '__call__') here." %
|
||||||
|
results["func"].value)
|
||||||
|
else:
|
||||||
|
func = results["func"]
|
||||||
|
args = [func.clone(), String(", "), String("'__call__'")]
|
||||||
|
return Call(Name("hasattr"), args, prefix=node.prefix)
|
|
@ -26,20 +26,15 @@ parend_expr = patcomp.compile_pattern(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class FixPrint(fixer_base.ConditionalFix):
|
class FixPrint(fixer_base.BaseFix):
|
||||||
|
|
||||||
PATTERN = """
|
PATTERN = """
|
||||||
simple_stmt< any* bare='print' any* > | print_stmt
|
simple_stmt< any* bare='print' any* > | print_stmt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
skip_on = '__future__.print_function'
|
|
||||||
|
|
||||||
def transform(self, node, results):
|
def transform(self, node, results):
|
||||||
assert results
|
assert results
|
||||||
|
|
||||||
if self.should_skip(node):
|
|
||||||
return
|
|
||||||
|
|
||||||
bare_print = results.get("bare")
|
bare_print = results.get("bare")
|
||||||
|
|
||||||
if bare_print:
|
if bare_print:
|
||||||
|
|
|
@ -12,13 +12,13 @@ from ..fixer_util import Name, Comma, FromImport, Newline, attr_chain
|
||||||
MAPPING = {'urllib': [
|
MAPPING = {'urllib': [
|
||||||
('urllib.request',
|
('urllib.request',
|
||||||
['URLOpener', 'FancyURLOpener', 'urlretrieve',
|
['URLOpener', 'FancyURLOpener', 'urlretrieve',
|
||||||
'_urlopener', 'urlopen', 'urlcleanup']),
|
'_urlopener', 'urlopen', 'urlcleanup',
|
||||||
|
'pathname2url', 'url2pathname']),
|
||||||
('urllib.parse',
|
('urllib.parse',
|
||||||
['quote', 'quote_plus', 'unquote', 'unquote_plus',
|
['quote', 'quote_plus', 'unquote', 'unquote_plus',
|
||||||
'urlencode', 'pathname2url', 'url2pathname', 'splitattr',
|
'urlencode', 'splitattr', 'splithost', 'splitnport',
|
||||||
'splithost', 'splitnport', 'splitpasswd', 'splitport',
|
'splitpasswd', 'splitport', 'splitquery', 'splittag',
|
||||||
'splitquery', 'splittag', 'splittype', 'splituser',
|
'splittype', 'splituser', 'splitvalue', ]),
|
||||||
'splitvalue', ]),
|
|
||||||
('urllib.error',
|
('urllib.error',
|
||||||
['ContentTooShortError'])],
|
['ContentTooShortError'])],
|
||||||
'urllib2' : [
|
'urllib2' : [
|
||||||
|
|
|
@ -4,19 +4,31 @@ Main program for 2to3.
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
import difflib
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import optparse
|
import optparse
|
||||||
|
|
||||||
from . import refactor
|
from . import refactor
|
||||||
|
|
||||||
|
|
||||||
|
def diff_texts(a, b, filename):
|
||||||
|
"""Return a unified diff of two strings."""
|
||||||
|
a = a.splitlines()
|
||||||
|
b = b.splitlines()
|
||||||
|
return difflib.unified_diff(a, b, filename, filename,
|
||||||
|
"(original)", "(refactored)",
|
||||||
|
lineterm="")
|
||||||
|
|
||||||
|
|
||||||
class StdoutRefactoringTool(refactor.MultiprocessRefactoringTool):
|
class StdoutRefactoringTool(refactor.MultiprocessRefactoringTool):
|
||||||
"""
|
"""
|
||||||
Prints output to stdout.
|
Prints output to stdout.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, fixers, options, explicit, nobackups):
|
def __init__(self, fixers, options, explicit, nobackups, show_diffs):
|
||||||
self.nobackups = nobackups
|
self.nobackups = nobackups
|
||||||
|
self.show_diffs = show_diffs
|
||||||
super(StdoutRefactoringTool, self).__init__(fixers, options, explicit)
|
super(StdoutRefactoringTool, self).__init__(fixers, options, explicit)
|
||||||
|
|
||||||
def log_error(self, msg, *args, **kwargs):
|
def log_error(self, msg, *args, **kwargs):
|
||||||
|
@ -42,9 +54,17 @@ class StdoutRefactoringTool(refactor.MultiprocessRefactoringTool):
|
||||||
if not self.nobackups:
|
if not self.nobackups:
|
||||||
shutil.copymode(backup, filename)
|
shutil.copymode(backup, filename)
|
||||||
|
|
||||||
def print_output(self, lines):
|
def print_output(self, old, new, filename, equal):
|
||||||
for line in lines:
|
if equal:
|
||||||
print(line)
|
self.log_message("No changes to %s", filename)
|
||||||
|
else:
|
||||||
|
self.log_message("Refactored %s", filename)
|
||||||
|
if self.show_diffs:
|
||||||
|
for line in diff_texts(old, new, filename):
|
||||||
|
print(line)
|
||||||
|
|
||||||
|
def warn(msg):
|
||||||
|
print >> sys.stderr, "WARNING: %s" % (msg,)
|
||||||
|
|
||||||
|
|
||||||
def main(fixer_pkg, args=None):
|
def main(fixer_pkg, args=None):
|
||||||
|
@ -70,9 +90,12 @@ def main(fixer_pkg, args=None):
|
||||||
parser.add_option("-l", "--list-fixes", action="store_true",
|
parser.add_option("-l", "--list-fixes", action="store_true",
|
||||||
help="List available transformations (fixes/fix_*.py)")
|
help="List available transformations (fixes/fix_*.py)")
|
||||||
parser.add_option("-p", "--print-function", action="store_true",
|
parser.add_option("-p", "--print-function", action="store_true",
|
||||||
help="Modify the grammar so that print() is a function")
|
help="DEPRECATED Modify the grammar so that print() is "
|
||||||
|
"a function")
|
||||||
parser.add_option("-v", "--verbose", action="store_true",
|
parser.add_option("-v", "--verbose", action="store_true",
|
||||||
help="More verbose logging")
|
help="More verbose logging")
|
||||||
|
parser.add_option("--no-diffs", action="store_true",
|
||||||
|
help="Don't show diffs of the refactoring")
|
||||||
parser.add_option("-w", "--write", action="store_true",
|
parser.add_option("-w", "--write", action="store_true",
|
||||||
help="Write back modified files")
|
help="Write back modified files")
|
||||||
parser.add_option("-n", "--nobackups", action="store_true", default=False,
|
parser.add_option("-n", "--nobackups", action="store_true", default=False,
|
||||||
|
@ -81,6 +104,11 @@ def main(fixer_pkg, args=None):
|
||||||
# Parse command line arguments
|
# Parse command line arguments
|
||||||
refactor_stdin = False
|
refactor_stdin = False
|
||||||
options, args = parser.parse_args(args)
|
options, args = parser.parse_args(args)
|
||||||
|
if not options.write and options.no_diffs:
|
||||||
|
warn("not writing files and not printing diffs; that's not very useful")
|
||||||
|
if options.print_function:
|
||||||
|
warn("-p is deprecated; "
|
||||||
|
"detection of from __future__ import print_function is automatic")
|
||||||
if not options.write and options.nobackups:
|
if not options.write and options.nobackups:
|
||||||
parser.error("Can't use -n without -w")
|
parser.error("Can't use -n without -w")
|
||||||
if options.list_fixes:
|
if options.list_fixes:
|
||||||
|
@ -104,7 +132,6 @@ def main(fixer_pkg, args=None):
|
||||||
logging.basicConfig(format='%(name)s: %(message)s', level=level)
|
logging.basicConfig(format='%(name)s: %(message)s', level=level)
|
||||||
|
|
||||||
# Initialize the refactoring tool
|
# Initialize the refactoring tool
|
||||||
rt_opts = {"print_function" : options.print_function}
|
|
||||||
avail_fixes = set(refactor.get_fixers_from_package(fixer_pkg))
|
avail_fixes = set(refactor.get_fixers_from_package(fixer_pkg))
|
||||||
unwanted_fixes = set(fixer_pkg + ".fix_" + fix for fix in options.nofix)
|
unwanted_fixes = set(fixer_pkg + ".fix_" + fix for fix in options.nofix)
|
||||||
explicit = set()
|
explicit = set()
|
||||||
|
@ -119,8 +146,8 @@ def main(fixer_pkg, args=None):
|
||||||
else:
|
else:
|
||||||
requested = avail_fixes.union(explicit)
|
requested = avail_fixes.union(explicit)
|
||||||
fixer_names = requested.difference(unwanted_fixes)
|
fixer_names = requested.difference(unwanted_fixes)
|
||||||
rt = StdoutRefactoringTool(sorted(fixer_names), rt_opts, sorted(explicit),
|
rt = StdoutRefactoringTool(sorted(fixer_names), None, sorted(explicit),
|
||||||
options.nobackups)
|
options.nobackups, not options.no_diffs)
|
||||||
|
|
||||||
# Refactor all files and directories passed as arguments
|
# Refactor all files and directories passed as arguments
|
||||||
if not rt.errors:
|
if not rt.errors:
|
||||||
|
|
|
@ -14,7 +14,7 @@ __author__ = "Guido van Rossum <guido@python.org>"
|
||||||
import os
|
import os
|
||||||
|
|
||||||
# Fairly local imports
|
# Fairly local imports
|
||||||
from .pgen2 import driver, literals, token, tokenize, parse
|
from .pgen2 import driver, literals, token, tokenize, parse, grammar
|
||||||
|
|
||||||
# Really local imports
|
# Really local imports
|
||||||
from . import pytree
|
from . import pytree
|
||||||
|
@ -138,7 +138,7 @@ class PatternCompiler(object):
|
||||||
node = nodes[0]
|
node = nodes[0]
|
||||||
if node.type == token.STRING:
|
if node.type == token.STRING:
|
||||||
value = str(literals.evalString(node.value))
|
value = str(literals.evalString(node.value))
|
||||||
return pytree.LeafPattern(content=value)
|
return pytree.LeafPattern(_type_of_literal(value), value)
|
||||||
elif node.type == token.NAME:
|
elif node.type == token.NAME:
|
||||||
value = node.value
|
value = node.value
|
||||||
if value.isupper():
|
if value.isupper():
|
||||||
|
@ -179,6 +179,15 @@ TOKEN_MAP = {"NAME": token.NAME,
|
||||||
"TOKEN": None}
|
"TOKEN": None}
|
||||||
|
|
||||||
|
|
||||||
|
def _type_of_literal(value):
|
||||||
|
if value[0].isalpha():
|
||||||
|
return token.NAME
|
||||||
|
elif value in grammar.opmap:
|
||||||
|
return grammar.opmap[value]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def pattern_convert(grammar, raw_node_info):
|
def pattern_convert(grammar, raw_node_info):
|
||||||
"""Converts raw node information to a Node or Leaf instance."""
|
"""Converts raw node information to a Node or Leaf instance."""
|
||||||
type, value, context, children = raw_node_info
|
type, value, context, children = raw_node_info
|
||||||
|
|
|
@ -97,6 +97,19 @@ class Grammar(object):
|
||||||
f.close()
|
f.close()
|
||||||
self.__dict__.update(d)
|
self.__dict__.update(d)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
"""
|
||||||
|
Copy the grammar.
|
||||||
|
"""
|
||||||
|
new = self.__class__()
|
||||||
|
for dict_attr in ("symbol2number", "number2symbol", "dfas", "keywords",
|
||||||
|
"tokens", "symbol2label"):
|
||||||
|
setattr(new, dict_attr, getattr(self, dict_attr).copy())
|
||||||
|
new.labels = self.labels[:]
|
||||||
|
new.states = self.states[:]
|
||||||
|
new.start = self.start
|
||||||
|
return new
|
||||||
|
|
||||||
def report(self):
|
def report(self):
|
||||||
"""Dump the grammar tables to standard output, for debugging."""
|
"""Dump the grammar tables to standard output, for debugging."""
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
|
@ -28,4 +28,8 @@ class Symbols(object):
|
||||||
|
|
||||||
|
|
||||||
python_grammar = driver.load_grammar(_GRAMMAR_FILE)
|
python_grammar = driver.load_grammar(_GRAMMAR_FILE)
|
||||||
|
|
||||||
python_symbols = Symbols(python_grammar)
|
python_symbols = Symbols(python_grammar)
|
||||||
|
|
||||||
|
python_grammar_no_print_statement = python_grammar.copy()
|
||||||
|
del python_grammar_no_print_statement.keywords["print"]
|
||||||
|
|
|
@ -14,14 +14,15 @@ __author__ = "Guido van Rossum <guido@python.org>"
|
||||||
# Python imports
|
# Python imports
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import difflib
|
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
from collections import defaultdict
|
import collections
|
||||||
|
import io
|
||||||
|
import warnings
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .pgen2 import driver, tokenize
|
from .pgen2 import driver, tokenize, token
|
||||||
from . import pytree, pygram
|
from . import pytree, pygram
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,7 +38,12 @@ def get_all_fix_names(fixer_pkg, remove_prefix=True):
|
||||||
fix_names.append(name[:-3])
|
fix_names.append(name[:-3])
|
||||||
return fix_names
|
return fix_names
|
||||||
|
|
||||||
def get_head_types(pat):
|
|
||||||
|
class _EveryNode(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _get_head_types(pat):
|
||||||
""" Accepts a pytree Pattern Node and returns a set
|
""" Accepts a pytree Pattern Node and returns a set
|
||||||
of the pattern types which will match first. """
|
of the pattern types which will match first. """
|
||||||
|
|
||||||
|
@ -45,34 +51,50 @@ def get_head_types(pat):
|
||||||
# NodePatters must either have no type and no content
|
# NodePatters must either have no type and no content
|
||||||
# or a type and content -- so they don't get any farther
|
# or a type and content -- so they don't get any farther
|
||||||
# Always return leafs
|
# Always return leafs
|
||||||
|
if pat.type is None:
|
||||||
|
raise _EveryNode
|
||||||
return set([pat.type])
|
return set([pat.type])
|
||||||
|
|
||||||
if isinstance(pat, pytree.NegatedPattern):
|
if isinstance(pat, pytree.NegatedPattern):
|
||||||
if pat.content:
|
if pat.content:
|
||||||
return get_head_types(pat.content)
|
return _get_head_types(pat.content)
|
||||||
return set([None]) # Negated Patterns don't have a type
|
raise _EveryNode # Negated Patterns don't have a type
|
||||||
|
|
||||||
if isinstance(pat, pytree.WildcardPattern):
|
if isinstance(pat, pytree.WildcardPattern):
|
||||||
# Recurse on each node in content
|
# Recurse on each node in content
|
||||||
r = set()
|
r = set()
|
||||||
for p in pat.content:
|
for p in pat.content:
|
||||||
for x in p:
|
for x in p:
|
||||||
r.update(get_head_types(x))
|
r.update(_get_head_types(x))
|
||||||
return r
|
return r
|
||||||
|
|
||||||
raise Exception("Oh no! I don't understand pattern %s" %(pat))
|
raise Exception("Oh no! I don't understand pattern %s" %(pat))
|
||||||
|
|
||||||
def get_headnode_dict(fixer_list):
|
|
||||||
|
def _get_headnode_dict(fixer_list):
|
||||||
""" Accepts a list of fixers and returns a dictionary
|
""" Accepts a list of fixers and returns a dictionary
|
||||||
of head node type --> fixer list. """
|
of head node type --> fixer list. """
|
||||||
head_nodes = defaultdict(list)
|
head_nodes = collections.defaultdict(list)
|
||||||
|
every = []
|
||||||
for fixer in fixer_list:
|
for fixer in fixer_list:
|
||||||
if not fixer.pattern:
|
if fixer.pattern:
|
||||||
head_nodes[None].append(fixer)
|
try:
|
||||||
continue
|
heads = _get_head_types(fixer.pattern)
|
||||||
for t in get_head_types(fixer.pattern):
|
except _EveryNode:
|
||||||
head_nodes[t].append(fixer)
|
every.append(fixer)
|
||||||
return head_nodes
|
else:
|
||||||
|
for node_type in heads:
|
||||||
|
head_nodes[node_type].append(fixer)
|
||||||
|
else:
|
||||||
|
if fixer._accept_type is not None:
|
||||||
|
head_nodes[fixer._accept_type].append(fixer)
|
||||||
|
else:
|
||||||
|
every.append(fixer)
|
||||||
|
for node_type in chain(pygram.python_grammar.symbol2number.values(),
|
||||||
|
pygram.python_grammar.tokens):
|
||||||
|
head_nodes[node_type].extend(every)
|
||||||
|
return dict(head_nodes)
|
||||||
|
|
||||||
|
|
||||||
def get_fixers_from_package(pkg_name):
|
def get_fixers_from_package(pkg_name):
|
||||||
"""
|
"""
|
||||||
|
@ -101,13 +123,56 @@ else:
|
||||||
_to_system_newlines = _identity
|
_to_system_newlines = _identity
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_future_print(source):
|
||||||
|
have_docstring = False
|
||||||
|
gen = tokenize.generate_tokens(io.StringIO(source).readline)
|
||||||
|
def advance():
|
||||||
|
tok = next(gen)
|
||||||
|
return tok[0], tok[1]
|
||||||
|
ignore = frozenset((token.NEWLINE, tokenize.NL, token.COMMENT))
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
tp, value = advance()
|
||||||
|
if tp in ignore:
|
||||||
|
continue
|
||||||
|
elif tp == token.STRING:
|
||||||
|
if have_docstring:
|
||||||
|
break
|
||||||
|
have_docstring = True
|
||||||
|
elif tp == token.NAME:
|
||||||
|
if value == "from":
|
||||||
|
tp, value = advance()
|
||||||
|
if tp != token.NAME and value != "__future__":
|
||||||
|
break
|
||||||
|
tp, value = advance()
|
||||||
|
if tp != token.NAME and value != "import":
|
||||||
|
break
|
||||||
|
tp, value = advance()
|
||||||
|
if tp == token.OP and value == "(":
|
||||||
|
tp, value = advance()
|
||||||
|
while tp == token.NAME:
|
||||||
|
if value == "print_function":
|
||||||
|
return True
|
||||||
|
tp, value = advance()
|
||||||
|
if tp != token.OP and value != ",":
|
||||||
|
break
|
||||||
|
tp, value = advance()
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
class FixerError(Exception):
|
class FixerError(Exception):
|
||||||
"""A fixer could not be loaded."""
|
"""A fixer could not be loaded."""
|
||||||
|
|
||||||
|
|
||||||
class RefactoringTool(object):
|
class RefactoringTool(object):
|
||||||
|
|
||||||
_default_options = {"print_function": False}
|
_default_options = {}
|
||||||
|
|
||||||
CLASS_PREFIX = "Fix" # The prefix for fixer classes
|
CLASS_PREFIX = "Fix" # The prefix for fixer classes
|
||||||
FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
|
FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
|
||||||
|
@ -124,20 +189,21 @@ class RefactoringTool(object):
|
||||||
self.explicit = explicit or []
|
self.explicit = explicit or []
|
||||||
self.options = self._default_options.copy()
|
self.options = self._default_options.copy()
|
||||||
if options is not None:
|
if options is not None:
|
||||||
|
if "print_function" in options:
|
||||||
|
warnings.warn("the 'print_function' option is deprecated",
|
||||||
|
DeprecationWarning)
|
||||||
self.options.update(options)
|
self.options.update(options)
|
||||||
self.errors = []
|
self.errors = []
|
||||||
self.logger = logging.getLogger("RefactoringTool")
|
self.logger = logging.getLogger("RefactoringTool")
|
||||||
self.fixer_log = []
|
self.fixer_log = []
|
||||||
self.wrote = False
|
self.wrote = False
|
||||||
if self.options["print_function"]:
|
|
||||||
del pygram.python_grammar.keywords["print"]
|
|
||||||
self.driver = driver.Driver(pygram.python_grammar,
|
self.driver = driver.Driver(pygram.python_grammar,
|
||||||
convert=pytree.convert,
|
convert=pytree.convert,
|
||||||
logger=self.logger)
|
logger=self.logger)
|
||||||
self.pre_order, self.post_order = self.get_fixers()
|
self.pre_order, self.post_order = self.get_fixers()
|
||||||
|
|
||||||
self.pre_order_heads = get_headnode_dict(self.pre_order)
|
self.pre_order_heads = _get_headnode_dict(self.pre_order)
|
||||||
self.post_order_heads = get_headnode_dict(self.post_order)
|
self.post_order_heads = _get_headnode_dict(self.post_order)
|
||||||
|
|
||||||
self.files = [] # List of files that were or should be modified
|
self.files = [] # List of files that were or should be modified
|
||||||
|
|
||||||
|
@ -196,8 +262,9 @@ class RefactoringTool(object):
|
||||||
msg = msg % args
|
msg = msg % args
|
||||||
self.logger.debug(msg)
|
self.logger.debug(msg)
|
||||||
|
|
||||||
def print_output(self, lines):
|
def print_output(self, old_text, new_text, filename, equal):
|
||||||
"""Called with lines of output to give to the user."""
|
"""Called with the old version, new version, and filename of a
|
||||||
|
refactored file."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def refactor(self, items, write=False, doctests_only=False):
|
def refactor(self, items, write=False, doctests_only=False):
|
||||||
|
@ -220,7 +287,8 @@ class RefactoringTool(object):
|
||||||
dirnames.sort()
|
dirnames.sort()
|
||||||
filenames.sort()
|
filenames.sort()
|
||||||
for name in filenames:
|
for name in filenames:
|
||||||
if not name.startswith(".") and name.endswith("py"):
|
if not name.startswith(".") and \
|
||||||
|
os.path.splitext(name)[1].endswith("py"):
|
||||||
fullname = os.path.join(dirpath, name)
|
fullname = os.path.join(dirpath, name)
|
||||||
self.refactor_file(fullname, write, doctests_only)
|
self.refactor_file(fullname, write, doctests_only)
|
||||||
# Modify dirnames in-place to remove subdirs with leading dots
|
# Modify dirnames in-place to remove subdirs with leading dots
|
||||||
|
@ -276,12 +344,16 @@ class RefactoringTool(object):
|
||||||
An AST corresponding to the refactored input stream; None if
|
An AST corresponding to the refactored input stream; None if
|
||||||
there were errors during the parse.
|
there were errors during the parse.
|
||||||
"""
|
"""
|
||||||
|
if _detect_future_print(data):
|
||||||
|
self.driver.grammar = pygram.python_grammar_no_print_statement
|
||||||
try:
|
try:
|
||||||
tree = self.driver.parse_string(data)
|
tree = self.driver.parse_string(data)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
self.log_error("Can't parse %s: %s: %s",
|
self.log_error("Can't parse %s: %s: %s",
|
||||||
name, err.__class__.__name__, err)
|
name, err.__class__.__name__, err)
|
||||||
return
|
return
|
||||||
|
finally:
|
||||||
|
self.driver.grammar = pygram.python_grammar
|
||||||
self.log_debug("Refactoring %s", name)
|
self.log_debug("Refactoring %s", name)
|
||||||
self.refactor_tree(tree, name)
|
self.refactor_tree(tree, name)
|
||||||
return tree
|
return tree
|
||||||
|
@ -338,12 +410,11 @@ class RefactoringTool(object):
|
||||||
if not fixers:
|
if not fixers:
|
||||||
return
|
return
|
||||||
for node in traversal:
|
for node in traversal:
|
||||||
for fixer in fixers[node.type] + fixers[None]:
|
for fixer in fixers[node.type]:
|
||||||
results = fixer.match(node)
|
results = fixer.match(node)
|
||||||
if results:
|
if results:
|
||||||
new = fixer.transform(node, results)
|
new = fixer.transform(node, results)
|
||||||
if new is not None and (new != node or
|
if new is not None:
|
||||||
str(new) != str(node)):
|
|
||||||
node.replace(new)
|
node.replace(new)
|
||||||
node = new
|
node = new
|
||||||
|
|
||||||
|
@ -357,10 +428,11 @@ class RefactoringTool(object):
|
||||||
old_text = self._read_python_source(filename)[0]
|
old_text = self._read_python_source(filename)[0]
|
||||||
if old_text is None:
|
if old_text is None:
|
||||||
return
|
return
|
||||||
if old_text == new_text:
|
equal = old_text == new_text
|
||||||
|
self.print_output(old_text, new_text, filename, equal)
|
||||||
|
if equal:
|
||||||
self.log_debug("No changes to %s", filename)
|
self.log_debug("No changes to %s", filename)
|
||||||
return
|
return
|
||||||
self.print_output(diff_texts(old_text, new_text, filename))
|
|
||||||
if write:
|
if write:
|
||||||
self.write_file(new_text, filename, old_text, encoding)
|
self.write_file(new_text, filename, old_text, encoding)
|
||||||
else:
|
else:
|
||||||
|
@ -582,12 +654,3 @@ class MultiprocessRefactoringTool(RefactoringTool):
|
||||||
else:
|
else:
|
||||||
return super(MultiprocessRefactoringTool, self).refactor_file(
|
return super(MultiprocessRefactoringTool, self).refactor_file(
|
||||||
*args, **kwargs)
|
*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def diff_texts(a, b, filename):
|
|
||||||
"""Return a unified diff of two strings."""
|
|
||||||
a = a.splitlines()
|
|
||||||
b = b.splitlines()
|
|
||||||
return difflib.unified_diff(a, b, filename, filename,
|
|
||||||
"(original)", "(refactored)",
|
|
||||||
lineterm="")
|
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- coding: iso-8859-1 -*-
|
# -*- coding: utf-8 -*-
|
||||||
print u'ßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞ'
|
print u'ßàáâãäåæçèéêëìíîïðñòóôõöøùúûüýþÿÀÁÂÃÄÅÆÇÈÉÊËÌÍÎÏÐÑÒÓÔÕÖØÙÚÛÜÝÞ'
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
print '%s\t-> α(%2i):%s β(%s)'
|
||||||
|
|
|
@ -18,8 +18,6 @@ class FixerTestCase(support.TestCase):
|
||||||
def setUp(self, fix_list=None, fixer_pkg="lib2to3", options=None):
|
def setUp(self, fix_list=None, fixer_pkg="lib2to3", options=None):
|
||||||
if fix_list is None:
|
if fix_list is None:
|
||||||
fix_list = [self.fixer]
|
fix_list = [self.fixer]
|
||||||
if options is None:
|
|
||||||
options = {"print_function" : False}
|
|
||||||
self.refactor = support.get_refactorer(fixer_pkg, fix_list, options)
|
self.refactor = support.get_refactorer(fixer_pkg, fix_list, options)
|
||||||
self.fixer_log = []
|
self.fixer_log = []
|
||||||
self.filename = "<string>"
|
self.filename = "<string>"
|
||||||
|
@ -58,8 +56,7 @@ class FixerTestCase(support.TestCase):
|
||||||
def assert_runs_after(self, *names):
|
def assert_runs_after(self, *names):
|
||||||
fixes = [self.fixer]
|
fixes = [self.fixer]
|
||||||
fixes.extend(names)
|
fixes.extend(names)
|
||||||
options = {"print_function" : False}
|
r = support.get_refactorer("lib2to3", fixes)
|
||||||
r = support.get_refactorer("lib2to3", fixes, options)
|
|
||||||
(pre, post) = r.get_fixers()
|
(pre, post) = r.get_fixers()
|
||||||
n = "fix_" + self.fixer
|
n = "fix_" + self.fixer
|
||||||
if post and post[-1].__class__.__module__.endswith(n):
|
if post and post[-1].__class__.__module__.endswith(n):
|
||||||
|
@ -379,18 +376,15 @@ class Test_print(FixerTestCase):
|
||||||
self.unchanged(s)
|
self.unchanged(s)
|
||||||
|
|
||||||
def test_idempotency_print_as_function(self):
|
def test_idempotency_print_as_function(self):
|
||||||
print_stmt = pygram.python_grammar.keywords.pop("print")
|
self.refactor.driver.grammar = pygram.python_grammar_no_print_statement
|
||||||
try:
|
s = """print(1, 1+1, 1+1+1)"""
|
||||||
s = """print(1, 1+1, 1+1+1)"""
|
self.unchanged(s)
|
||||||
self.unchanged(s)
|
|
||||||
|
|
||||||
s = """print()"""
|
s = """print()"""
|
||||||
self.unchanged(s)
|
self.unchanged(s)
|
||||||
|
|
||||||
s = """print('')"""
|
s = """print('')"""
|
||||||
self.unchanged(s)
|
self.unchanged(s)
|
||||||
finally:
|
|
||||||
pygram.python_grammar.keywords["print"] = print_stmt
|
|
||||||
|
|
||||||
def test_1(self):
|
def test_1(self):
|
||||||
b = """print 1, 1+1, 1+1+1"""
|
b = """print 1, 1+1, 1+1+1"""
|
||||||
|
@ -462,31 +456,15 @@ class Test_print(FixerTestCase):
|
||||||
a = """print(file=sys.stderr)"""
|
a = """print(file=sys.stderr)"""
|
||||||
self.check(b, a)
|
self.check(b, a)
|
||||||
|
|
||||||
# With from __future__ import print_function
|
|
||||||
def test_with_future_print_function(self):
|
def test_with_future_print_function(self):
|
||||||
# XXX: These tests won't actually do anything until the parser
|
s = "from __future__ import print_function\n" \
|
||||||
# is fixed so it won't crash when it sees print(x=y).
|
"print('Hai!', end=' ')"
|
||||||
# When #2412 is fixed, the try/except block can be taken
|
self.unchanged(s)
|
||||||
# out and the tests can be run like normal.
|
|
||||||
# MvL: disable entirely for now, so that it doesn't print to stdout
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
s = "from __future__ import print_function\n"\
|
|
||||||
"print('Hai!', end=' ')"
|
|
||||||
self.unchanged(s)
|
|
||||||
|
|
||||||
b = "print 'Hello, world!'"
|
b = "print 'Hello, world!'"
|
||||||
a = "print('Hello, world!')"
|
a = "print('Hello, world!')"
|
||||||
self.check(b, a)
|
self.check(b, a)
|
||||||
|
|
||||||
s = "from __future__ import *\n"\
|
|
||||||
"print('Hai!', end=' ')"
|
|
||||||
self.unchanged(s)
|
|
||||||
except:
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
self.assertFalse(True, "#2421 has been fixed -- printing tests "\
|
|
||||||
"need to be updated!")
|
|
||||||
|
|
||||||
class Test_exec(FixerTestCase):
|
class Test_exec(FixerTestCase):
|
||||||
fixer = "exec"
|
fixer = "exec"
|
||||||
|
@ -1705,6 +1683,11 @@ class Test_imports_fixer_order(FixerTestCase, ImportsFixerTests):
|
||||||
for key in ('dbhash', 'dumbdbm', 'dbm', 'gdbm'):
|
for key in ('dbhash', 'dumbdbm', 'dbm', 'gdbm'):
|
||||||
self.modules[key] = mapping1[key]
|
self.modules[key] = mapping1[key]
|
||||||
|
|
||||||
|
def test_after_local_imports_refactoring(self):
|
||||||
|
for fix in ("imports", "imports2"):
|
||||||
|
self.fixer = fix
|
||||||
|
self.assert_runs_after("import")
|
||||||
|
|
||||||
|
|
||||||
class Test_urllib(FixerTestCase):
|
class Test_urllib(FixerTestCase):
|
||||||
fixer = "urllib"
|
fixer = "urllib"
|
||||||
|
@ -3504,6 +3487,7 @@ class Test_itertools_imports(FixerTestCase):
|
||||||
s = "from itertools import foo"
|
s = "from itertools import foo"
|
||||||
self.unchanged(s)
|
self.unchanged(s)
|
||||||
|
|
||||||
|
|
||||||
class Test_import(FixerTestCase):
|
class Test_import(FixerTestCase):
|
||||||
fixer = "import"
|
fixer = "import"
|
||||||
|
|
||||||
|
@ -3538,8 +3522,7 @@ class Test_import(FixerTestCase):
|
||||||
|
|
||||||
self.always_exists = False
|
self.always_exists = False
|
||||||
self.present_files = set(['__init__.py'])
|
self.present_files = set(['__init__.py'])
|
||||||
expected_extensions = ('.py', os.path.pathsep, '.pyc', '.so',
|
expected_extensions = ('.py', os.path.sep, '.pyc', '.so', '.sl', '.pyd')
|
||||||
'.sl', '.pyd')
|
|
||||||
names_to_test = (p("/spam/eggs.py"), "ni.py", p("../../shrubbery.py"))
|
names_to_test = (p("/spam/eggs.py"), "ni.py", p("../../shrubbery.py"))
|
||||||
|
|
||||||
for name in names_to_test:
|
for name in names_to_test:
|
||||||
|
@ -3569,6 +3552,13 @@ class Test_import(FixerTestCase):
|
||||||
self.present_files = set(["__init__.py", "bar.py"])
|
self.present_files = set(["__init__.py", "bar.py"])
|
||||||
self.check(b, a)
|
self.check(b, a)
|
||||||
|
|
||||||
|
def test_import_from_package(self):
|
||||||
|
b = "import bar"
|
||||||
|
a = "from . import bar"
|
||||||
|
self.always_exists = False
|
||||||
|
self.present_files = set(["__init__.py", "bar/"])
|
||||||
|
self.check(b, a)
|
||||||
|
|
||||||
def test_comments_and_indent(self):
|
def test_comments_and_indent(self):
|
||||||
b = "import bar # Foo"
|
b = "import bar # Foo"
|
||||||
a = "from . import bar # Foo"
|
a = "from . import bar # Foo"
|
||||||
|
@ -4095,3 +4085,26 @@ class Test_getcwdu(FixerTestCase):
|
||||||
b = """os.getcwdu ( )"""
|
b = """os.getcwdu ( )"""
|
||||||
a = """os.getcwd ( )"""
|
a = """os.getcwd ( )"""
|
||||||
self.check(b, a)
|
self.check(b, a)
|
||||||
|
|
||||||
|
|
||||||
|
class Test_operator(FixerTestCase):
|
||||||
|
|
||||||
|
fixer = "operator"
|
||||||
|
|
||||||
|
def test_operator_isCallable(self):
|
||||||
|
b = "operator.isCallable(x)"
|
||||||
|
a = "hasattr(x, '__call__')"
|
||||||
|
self.check(b, a)
|
||||||
|
|
||||||
|
def test_operator_sequenceIncludes(self):
|
||||||
|
b = "operator.sequenceIncludes(x, y)"
|
||||||
|
a = "operator.contains(x, y)"
|
||||||
|
self.check(b, a)
|
||||||
|
|
||||||
|
def test_bare_isCallable(self):
|
||||||
|
s = "isCallable(x)"
|
||||||
|
self.warns_unchanged(s, "You should use hasattr(x, '__call__') here.")
|
||||||
|
|
||||||
|
def test_bare_sequenceIncludes(self):
|
||||||
|
s = "sequenceIncludes(x, y)"
|
||||||
|
self.warns_unchanged(s, "You should use operator.contains here.")
|
||||||
|
|
|
@ -7,9 +7,12 @@ import os
|
||||||
import operator
|
import operator
|
||||||
import io
|
import io
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import shutil
|
||||||
import unittest
|
import unittest
|
||||||
|
import warnings
|
||||||
|
|
||||||
from lib2to3 import refactor, pygram, fixer_base
|
from lib2to3 import refactor, pygram, fixer_base
|
||||||
|
from lib2to3.pgen2 import token
|
||||||
|
|
||||||
from . import support
|
from . import support
|
||||||
|
|
||||||
|
@ -42,14 +45,11 @@ class TestRefactoringTool(unittest.TestCase):
|
||||||
return refactor.RefactoringTool(fixers, options, explicit)
|
return refactor.RefactoringTool(fixers, options, explicit)
|
||||||
|
|
||||||
def test_print_function_option(self):
|
def test_print_function_option(self):
|
||||||
gram = pygram.python_grammar
|
with warnings.catch_warnings(record=True) as w:
|
||||||
save = gram.keywords["print"]
|
refactor.RefactoringTool(_DEFAULT_FIXERS, {"print_function" : True})
|
||||||
try:
|
self.assertEqual(len(w), 1)
|
||||||
rt = self.rt({"print_function" : True})
|
msg, = w
|
||||||
self.assertRaises(KeyError, operator.itemgetter("print"),
|
self.assertTrue(msg.category is DeprecationWarning)
|
||||||
gram.keywords)
|
|
||||||
finally:
|
|
||||||
gram.keywords["print"] = save
|
|
||||||
|
|
||||||
def test_fixer_loading_helpers(self):
|
def test_fixer_loading_helpers(self):
|
||||||
contents = ["explicit", "first", "last", "parrot", "preorder"]
|
contents = ["explicit", "first", "last", "parrot", "preorder"]
|
||||||
|
@ -61,19 +61,63 @@ class TestRefactoringTool(unittest.TestCase):
|
||||||
self.assertEqual(full_names,
|
self.assertEqual(full_names,
|
||||||
["myfixes.fix_" + name for name in contents])
|
["myfixes.fix_" + name for name in contents])
|
||||||
|
|
||||||
|
def test_detect_future_print(self):
|
||||||
|
run = refactor._detect_future_print
|
||||||
|
self.assertFalse(run(""))
|
||||||
|
self.assertTrue(run("from __future__ import print_function"))
|
||||||
|
self.assertFalse(run("from __future__ import generators"))
|
||||||
|
self.assertFalse(run("from __future__ import generators, feature"))
|
||||||
|
input = "from __future__ import generators, print_function"
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
input ="from __future__ import print_function, generators"
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
input = "from __future__ import (print_function,)"
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
input = "from __future__ import (generators, print_function)"
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
input = "from __future__ import (generators, nested_scopes)"
|
||||||
|
self.assertFalse(run(input))
|
||||||
|
input = """from __future__ import generators
|
||||||
|
from __future__ import print_function"""
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
self.assertFalse(run("from"))
|
||||||
|
self.assertFalse(run("from 4"))
|
||||||
|
self.assertFalse(run("from x"))
|
||||||
|
self.assertFalse(run("from x 5"))
|
||||||
|
self.assertFalse(run("from x im"))
|
||||||
|
self.assertFalse(run("from x import"))
|
||||||
|
self.assertFalse(run("from x import 4"))
|
||||||
|
input = "'docstring'\nfrom __future__ import print_function"
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
input = "'docstring'\n'somng'\nfrom __future__ import print_function"
|
||||||
|
self.assertFalse(run(input))
|
||||||
|
input = "# comment\nfrom __future__ import print_function"
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
input = "# comment\n'doc'\nfrom __future__ import print_function"
|
||||||
|
self.assertTrue(run(input))
|
||||||
|
input = "class x: pass\nfrom __future__ import print_function"
|
||||||
|
self.assertFalse(run(input))
|
||||||
|
|
||||||
def test_get_headnode_dict(self):
|
def test_get_headnode_dict(self):
|
||||||
class NoneFix(fixer_base.BaseFix):
|
class NoneFix(fixer_base.BaseFix):
|
||||||
PATTERN = None
|
pass
|
||||||
|
|
||||||
class FileInputFix(fixer_base.BaseFix):
|
class FileInputFix(fixer_base.BaseFix):
|
||||||
PATTERN = "file_input< any * >"
|
PATTERN = "file_input< any * >"
|
||||||
|
|
||||||
|
class SimpleFix(fixer_base.BaseFix):
|
||||||
|
PATTERN = "'name'"
|
||||||
|
|
||||||
no_head = NoneFix({}, [])
|
no_head = NoneFix({}, [])
|
||||||
with_head = FileInputFix({}, [])
|
with_head = FileInputFix({}, [])
|
||||||
d = refactor.get_headnode_dict([no_head, with_head])
|
simple = SimpleFix({}, [])
|
||||||
expected = {None: [no_head],
|
d = refactor._get_headnode_dict([no_head, with_head, simple])
|
||||||
pygram.python_symbols.file_input : [with_head]}
|
top_fixes = d.pop(pygram.python_symbols.file_input)
|
||||||
self.assertEqual(d, expected)
|
self.assertEqual(top_fixes, [with_head, no_head])
|
||||||
|
name_fixes = d.pop(token.NAME)
|
||||||
|
self.assertEqual(name_fixes, [simple, no_head])
|
||||||
|
for fixes in d.values():
|
||||||
|
self.assertEqual(fixes, [no_head])
|
||||||
|
|
||||||
def test_fixer_loading(self):
|
def test_fixer_loading(self):
|
||||||
from myfixes.fix_first import FixFirst
|
from myfixes.fix_first import FixFirst
|
||||||
|
@ -106,10 +150,10 @@ class TestRefactoringTool(unittest.TestCase):
|
||||||
|
|
||||||
class MyRT(refactor.RefactoringTool):
|
class MyRT(refactor.RefactoringTool):
|
||||||
|
|
||||||
def print_output(self, lines):
|
def print_output(self, old_text, new_text, filename, equal):
|
||||||
diff_lines.extend(lines)
|
results.extend([old_text, new_text, filename, equal])
|
||||||
|
|
||||||
diff_lines = []
|
results = []
|
||||||
rt = MyRT(_DEFAULT_FIXERS)
|
rt = MyRT(_DEFAULT_FIXERS)
|
||||||
save = sys.stdin
|
save = sys.stdin
|
||||||
sys.stdin = io.StringIO("def parrot(): pass\n\n")
|
sys.stdin = io.StringIO("def parrot(): pass\n\n")
|
||||||
|
@ -117,12 +161,10 @@ class TestRefactoringTool(unittest.TestCase):
|
||||||
rt.refactor_stdin()
|
rt.refactor_stdin()
|
||||||
finally:
|
finally:
|
||||||
sys.stdin = save
|
sys.stdin = save
|
||||||
expected = """--- <stdin> (original)
|
expected = ["def parrot(): pass\n\n",
|
||||||
+++ <stdin> (refactored)
|
"def cheese(): pass\n\n",
|
||||||
@@ -1,2 +1,2 @@
|
"<stdin>", False]
|
||||||
-def parrot(): pass
|
self.assertEqual(results, expected)
|
||||||
+def cheese(): pass""".splitlines()
|
|
||||||
self.assertEqual(diff_lines[:-1], expected)
|
|
||||||
|
|
||||||
def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS):
|
def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS):
|
||||||
def read_file():
|
def read_file():
|
||||||
|
@ -145,6 +187,37 @@ class TestRefactoringTool(unittest.TestCase):
|
||||||
test_file = os.path.join(FIXER_DIR, "parrot_example.py")
|
test_file = os.path.join(FIXER_DIR, "parrot_example.py")
|
||||||
self.check_file_refactoring(test_file, _DEFAULT_FIXERS)
|
self.check_file_refactoring(test_file, _DEFAULT_FIXERS)
|
||||||
|
|
||||||
|
def test_refactor_dir(self):
|
||||||
|
def check(structure, expected):
|
||||||
|
def mock_refactor_file(self, f, *args):
|
||||||
|
got.append(f)
|
||||||
|
save_func = refactor.RefactoringTool.refactor_file
|
||||||
|
refactor.RefactoringTool.refactor_file = mock_refactor_file
|
||||||
|
rt = self.rt()
|
||||||
|
got = []
|
||||||
|
dir = tempfile.mkdtemp(prefix="2to3-test_refactor")
|
||||||
|
try:
|
||||||
|
os.mkdir(os.path.join(dir, "a_dir"))
|
||||||
|
for fn in structure:
|
||||||
|
open(os.path.join(dir, fn), "wb").close()
|
||||||
|
rt.refactor_dir(dir)
|
||||||
|
finally:
|
||||||
|
refactor.RefactoringTool.refactor_file = save_func
|
||||||
|
shutil.rmtree(dir)
|
||||||
|
self.assertEqual(got,
|
||||||
|
[os.path.join(dir, path) for path in expected])
|
||||||
|
check([], [])
|
||||||
|
tree = ["nothing",
|
||||||
|
"hi.py",
|
||||||
|
".dumb",
|
||||||
|
".after.py",
|
||||||
|
"sappy"]
|
||||||
|
expected = ["hi.py"]
|
||||||
|
check(tree, expected)
|
||||||
|
tree = ["hi.py",
|
||||||
|
"a_dir/stuff.py"]
|
||||||
|
check(tree, tree)
|
||||||
|
|
||||||
def test_file_encoding(self):
|
def test_file_encoding(self):
|
||||||
fn = os.path.join(TEST_DATA_DIR, "different_encoding.py")
|
fn = os.path.join(TEST_DATA_DIR, "different_encoding.py")
|
||||||
self.check_file_refactoring(fn)
|
self.check_file_refactoring(fn)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
""" Test suite for the code in fixes.util """
|
""" Test suite for the code in fixer_util """
|
||||||
|
|
||||||
# Testing imports
|
# Testing imports
|
||||||
from . import support
|
from . import support
|
||||||
|
@ -7,10 +7,10 @@ from . import support
|
||||||
import os.path
|
import os.path
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
from .. import pytree
|
from lib2to3.pytree import Node, Leaf
|
||||||
from .. import fixer_util
|
from lib2to3 import fixer_util
|
||||||
from ..fixer_util import Attr, Name
|
from lib2to3.fixer_util import Attr, Name, Call, Comma
|
||||||
|
from lib2to3.pgen2 import token
|
||||||
|
|
||||||
def parse(code, strip_levels=0):
|
def parse(code, strip_levels=0):
|
||||||
# The topmost node is file_input, which we don't care about.
|
# The topmost node is file_input, which we don't care about.
|
||||||
|
@ -24,7 +24,7 @@ def parse(code, strip_levels=0):
|
||||||
class MacroTestCase(support.TestCase):
|
class MacroTestCase(support.TestCase):
|
||||||
def assertStr(self, node, string):
|
def assertStr(self, node, string):
|
||||||
if isinstance(node, (tuple, list)):
|
if isinstance(node, (tuple, list)):
|
||||||
node = pytree.Node(fixer_util.syms.simple_stmt, node)
|
node = Node(fixer_util.syms.simple_stmt, node)
|
||||||
self.assertEqual(str(node), string)
|
self.assertEqual(str(node), string)
|
||||||
|
|
||||||
|
|
||||||
|
@ -78,6 +78,31 @@ class Test_Name(MacroTestCase):
|
||||||
self.assertStr(Name("a", prefix="b"), "ba")
|
self.assertStr(Name("a", prefix="b"), "ba")
|
||||||
|
|
||||||
|
|
||||||
|
class Test_Call(MacroTestCase):
|
||||||
|
def _Call(self, name, args=None, prefix=None):
|
||||||
|
"""Help the next test"""
|
||||||
|
children = []
|
||||||
|
if isinstance(args, list):
|
||||||
|
for arg in args:
|
||||||
|
children.append(arg)
|
||||||
|
children.append(Comma())
|
||||||
|
children.pop()
|
||||||
|
return Call(Name(name), children, prefix)
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
kids = [None,
|
||||||
|
[Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 2),
|
||||||
|
Leaf(token.NUMBER, 3)],
|
||||||
|
[Leaf(token.NUMBER, 1), Leaf(token.NUMBER, 3),
|
||||||
|
Leaf(token.NUMBER, 2), Leaf(token.NUMBER, 4)],
|
||||||
|
[Leaf(token.STRING, "b"), Leaf(token.STRING, "j", prefix=" ")]
|
||||||
|
]
|
||||||
|
self.assertStr(self._Call("A"), "A()")
|
||||||
|
self.assertStr(self._Call("b", kids[1]), "b(1,2,3)")
|
||||||
|
self.assertStr(self._Call("a.b().c", kids[2]), "a.b().c(1,3,2,4)")
|
||||||
|
self.assertStr(self._Call("d", kids[3], prefix=" "), " d(b, j)")
|
||||||
|
|
||||||
|
|
||||||
class Test_does_tree_import(support.TestCase):
|
class Test_does_tree_import(support.TestCase):
|
||||||
def _find_bind_rec(self, name, node):
|
def _find_bind_rec(self, name, node):
|
||||||
# Search a tree for a binding -- used to find the starting
|
# Search a tree for a binding -- used to find the starting
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue