mirror of
https://github.com/python/cpython.git
synced 2025-10-14 02:43:49 +00:00
Improve diff for assertCountEqual() to actually show the differing counts.
New output looks like this: Traceback (most recent call last): File "test.py", line 5, in test_ce self.assertCountEqual('abracadabra xx', 'simsalabim xx') AssertionError: Element counts were not equal: Expected 5, got 2: 'a' Expected 2, got 1: 'b' Expected 0, got 2: 'i' Expected 0, got 2: 'm' Expected 0, got 1: 'l' Expected 0, got 2: 's' Expected 1, got 0: 'c' Expected 1, got 0: 'd' Expected 2, got 0: 'r'
This commit is contained in:
parent
fca8beed4a
commit
93e233d6e5
3 changed files with 71 additions and 21 deletions
|
@ -10,7 +10,8 @@ import collections
|
||||||
|
|
||||||
from . import result
|
from . import result
|
||||||
from .util import (strclass, safe_repr, sorted_list_difference,
|
from .util import (strclass, safe_repr, sorted_list_difference,
|
||||||
unorderable_list_difference)
|
unorderable_list_difference, _count_diff_all_purpose,
|
||||||
|
_count_diff_hashable)
|
||||||
|
|
||||||
__unittest = True
|
__unittest = True
|
||||||
|
|
||||||
|
@ -1022,23 +1023,22 @@ class TestCase(object):
|
||||||
expected = collections.Counter(expected_seq)
|
expected = collections.Counter(expected_seq)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# Handle case with unhashable elements
|
# Handle case with unhashable elements
|
||||||
missing, unexpected = unorderable_list_difference(expected_seq, actual_seq)
|
differences = _count_diff_all_purpose(expected_seq, actual_seq)
|
||||||
else:
|
else:
|
||||||
if actual == expected:
|
if actual == expected:
|
||||||
return
|
return
|
||||||
missing = list(expected - actual)
|
differences = _count_diff_hashable(expected_seq, actual_seq)
|
||||||
unexpected = list(actual - expected)
|
|
||||||
|
|
||||||
errors = []
|
if differences:
|
||||||
if missing:
|
standardMsg = 'Element counts were not equal:\n'
|
||||||
errors.append('Expected, but missing:\n %s' %
|
lines = []
|
||||||
safe_repr(missing))
|
for act, exp, elem in differences:
|
||||||
if unexpected:
|
line = 'Expected %d, got %d: %r' % (exp, act, elem)
|
||||||
errors.append('Unexpected, but present:\n %s' %
|
lines.append(line)
|
||||||
safe_repr(unexpected))
|
diffMsg = '\n'.join(lines)
|
||||||
if errors:
|
standardMsg = self._truncateMessage(standardMsg, diffMsg)
|
||||||
standardMsg = '\n'.join(errors)
|
msg = self._formatMessage(msg, standardMsg)
|
||||||
self.fail(self._formatMessage(msg, standardMsg))
|
self.fail(msg)
|
||||||
|
|
||||||
def assertMultiLineEqual(self, first, second, msg=None):
|
def assertMultiLineEqual(self, first, second, msg=None):
|
||||||
"""Assert that two multi-line strings are equal."""
|
"""Assert that two multi-line strings are equal."""
|
||||||
|
|
|
@ -229,12 +229,6 @@ class TestLongMessage(unittest.TestCase):
|
||||||
"^Missing: 'key'$",
|
"^Missing: 'key'$",
|
||||||
"^Missing: 'key' : oops$"])
|
"^Missing: 'key' : oops$"])
|
||||||
|
|
||||||
def testassertCountEqual(self):
|
|
||||||
self.assertMessages('assertCountEqual', ([], [None]),
|
|
||||||
[r"\[None\]$", "^oops$",
|
|
||||||
r"\[None\]$",
|
|
||||||
r"\[None\] : oops$"])
|
|
||||||
|
|
||||||
def testAssertMultiLineEqual(self):
|
def testAssertMultiLineEqual(self):
|
||||||
self.assertMessages('assertMultiLineEqual', ("", "foo"),
|
self.assertMessages('assertMultiLineEqual', ("", "foo"),
|
||||||
[r"\+ foo$", "^oops$",
|
[r"\+ foo$", "^oops$",
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
"""Various utility functions."""
|
"""Various utility functions."""
|
||||||
|
|
||||||
|
from collections import namedtuple, Counter
|
||||||
|
|
||||||
__unittest = True
|
__unittest = True
|
||||||
|
|
||||||
_MAX_LENGTH = 80
|
_MAX_LENGTH = 80
|
||||||
|
@ -12,7 +14,6 @@ def safe_repr(obj, short=False):
|
||||||
return result
|
return result
|
||||||
return result[:_MAX_LENGTH] + ' [truncated]...'
|
return result[:_MAX_LENGTH] + ' [truncated]...'
|
||||||
|
|
||||||
|
|
||||||
def strclass(cls):
|
def strclass(cls):
|
||||||
return "%s.%s" % (cls.__module__, cls.__name__)
|
return "%s.%s" % (cls.__module__, cls.__name__)
|
||||||
|
|
||||||
|
@ -77,3 +78,58 @@ def unorderable_list_difference(expected, actual):
|
||||||
def three_way_cmp(x, y):
|
def three_way_cmp(x, y):
|
||||||
"""Return -1 if x < y, 0 if x == y and 1 if x > y"""
|
"""Return -1 if x < y, 0 if x == y and 1 if x > y"""
|
||||||
return (x > y) - (x < y)
|
return (x > y) - (x < y)
|
||||||
|
|
||||||
|
_Mismatch = namedtuple('Mismatch', 'actual expected value')
|
||||||
|
|
||||||
|
def _count_diff_all_purpose(actual, expected):
|
||||||
|
'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
|
||||||
|
# elements need not be hashable
|
||||||
|
s, t = list(actual), list(expected)
|
||||||
|
m, n = len(s), len(t)
|
||||||
|
NULL = object()
|
||||||
|
result = []
|
||||||
|
for i, elem in enumerate(s):
|
||||||
|
if elem is NULL:
|
||||||
|
continue
|
||||||
|
cnt_s = cnt_t = 0
|
||||||
|
for j in range(i, m):
|
||||||
|
if s[j] == elem:
|
||||||
|
cnt_s += 1
|
||||||
|
s[j] = NULL
|
||||||
|
for j, other_elem in enumerate(t):
|
||||||
|
if other_elem == elem:
|
||||||
|
cnt_t += 1
|
||||||
|
t[j] = NULL
|
||||||
|
if cnt_s != cnt_t:
|
||||||
|
diff = _Mismatch(cnt_s, cnt_t, elem)
|
||||||
|
result.append(diff)
|
||||||
|
|
||||||
|
for i, elem in enumerate(t):
|
||||||
|
if elem is NULL:
|
||||||
|
continue
|
||||||
|
cnt_t = 0
|
||||||
|
for j in range(i, n):
|
||||||
|
if t[j] == elem:
|
||||||
|
cnt_t += 1
|
||||||
|
t[j] = NULL
|
||||||
|
diff = _Mismatch(0, cnt_t, elem)
|
||||||
|
result.append(diff)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _count_diff_hashable(actual, expected):
|
||||||
|
'Returns list of (cnt_act, cnt_exp, elem) triples where the counts differ'
|
||||||
|
# elements must be hashable
|
||||||
|
s, t = Counter(actual), Counter(expected)
|
||||||
|
if s == t:
|
||||||
|
return []
|
||||||
|
result = []
|
||||||
|
for elem, cnt_s in s.items():
|
||||||
|
cnt_t = t[elem]
|
||||||
|
if cnt_s != cnt_t:
|
||||||
|
diff = _Mismatch(cnt_s, cnt_t, elem)
|
||||||
|
result.append(diff)
|
||||||
|
for elem, cnt_t in t.items():
|
||||||
|
if elem not in s:
|
||||||
|
diff = _Mismatch(0, cnt_t, elem)
|
||||||
|
result.append(diff)
|
||||||
|
return result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue