gh-119793: Add optional length-checking to map() (GH-120471)

Co-authored-by: Bénédikt Tran <10796600+picnixz@users.noreply.github.com>
Co-authored-by: Pieter Eendebak <pieter.eendebak@gmail.com>
Co-authored-by: Erlend E. Aasland <erlend.aasland@protonmail.com>
Co-authored-by: Raymond Hettinger <rhettinger@users.noreply.github.com>
This commit is contained in:
Nice Zombies 2024-11-04 15:00:19 +01:00 committed by GitHub
parent bfc1d2504c
commit 3032fcd90e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 210 additions and 17 deletions

View file

@ -148,6 +148,9 @@ def filter_char(arg):
def map_char(arg):
return chr(ord(arg)+1)
def pack(*args):
return args
class BuiltinTest(unittest.TestCase):
# Helper to check picklability
def check_iter_pickle(self, it, seq, proto):
@ -1269,6 +1272,108 @@ class BuiltinTest(unittest.TestCase):
m2 = map(map_char, "Is this the real life?")
self.check_iter_pickle(m1, list(m2), proto)
# strict map tests based on strict zip tests
def test_map_pickle_strict(self):
a = (1, 2, 3)
b = (4, 5, 6)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
self.check_iter_pickle(m1, t, proto)
def test_map_pickle_strict_fail(self):
a = (1, 2, 3)
b = (4, 5, 6, 7)
t = [(1, 4), (2, 5), (3, 6)]
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
m1 = map(pack, a, b, strict=True)
m2 = pickle.loads(pickle.dumps(m1, proto))
self.assertEqual(self.iter_error(m1, ValueError), t)
self.assertEqual(self.iter_error(m2, ValueError), t)
def test_map_strict(self):
self.assertEqual(tuple(map(pack, (1, 2, 3), 'abc', strict=True)),
((1, 'a'), (2, 'b'), (3, 'c')))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2, 3, 4), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), 'abc', strict=True))
self.assertRaises(ValueError, tuple,
map(pack, (1, 2), (1, 2), 'abc', strict=True))
def test_map_strict_iterators(self):
x = iter(range(5))
y = [0]
z = iter(range(5))
self.assertRaises(ValueError, list,
(map(pack, x, y, z, strict=True)))
self.assertEqual(next(x), 2)
self.assertEqual(next(z), 1)
def test_map_strict_error_handling(self):
class Error(Exception):
pass
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise Error
return self.size
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), Error)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), Error)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), Error)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), Error)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_map_strict_error_handling_stopiteration(self):
class Iter:
def __init__(self, size):
self.size = size
def __iter__(self):
return self
def __next__(self):
self.size -= 1
if self.size < 0:
raise StopIteration
return self.size
l1 = self.iter_error(map(pack, "AB", Iter(1), strict=True), ValueError)
self.assertEqual(l1, [("A", 0)])
l2 = self.iter_error(map(pack, "AB", Iter(2), "A", strict=True), ValueError)
self.assertEqual(l2, [("A", 1, "A")])
l3 = self.iter_error(map(pack, "AB", Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l3, [("A", 1, "A"), ("B", 0, "B")])
l4 = self.iter_error(map(pack, "AB", Iter(3), strict=True), ValueError)
self.assertEqual(l4, [("A", 2), ("B", 1)])
l5 = self.iter_error(map(pack, Iter(1), "AB", strict=True), ValueError)
self.assertEqual(l5, [(0, "A")])
l6 = self.iter_error(map(pack, Iter(2), "A", strict=True), ValueError)
self.assertEqual(l6, [(1, "A")])
l7 = self.iter_error(map(pack, Iter(2), "ABC", strict=True), ValueError)
self.assertEqual(l7, [(1, "A"), (0, "B")])
l8 = self.iter_error(map(pack, Iter(3), "AB", strict=True), ValueError)
self.assertEqual(l8, [(2, "A"), (1, "B")])
def test_max(self):
self.assertEqual(max('123123'), '3')
self.assertEqual(max(1, 2, 3), 3)