gh-108751: Add copy.replace() function (GH-108752)

It creates a modified copy of an object by calling the object's
__replace__() method.

It is a generalization of dataclasses.replace(), named tuple's _replace()
method and replace() methods in various classes, and supports all these
stdlib classes.
This commit is contained in:
Serhiy Storchaka 2023-09-06 23:55:42 +03:00 committed by GitHub
parent 9f0c0a46f0
commit 6f3c138dfa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 311 additions and 68 deletions

View file

@ -4,7 +4,7 @@ import copy
import copyreg
import weakref
import abc
from operator import le, lt, ge, gt, eq, ne
from operator import le, lt, ge, gt, eq, ne, attrgetter
import unittest
from test import support
@ -899,7 +899,71 @@ class TestCopy(unittest.TestCase):
g.b()
class TestReplace(unittest.TestCase):
def test_unsupported(self):
self.assertRaises(TypeError, copy.replace, 1)
self.assertRaises(TypeError, copy.replace, [])
self.assertRaises(TypeError, copy.replace, {})
def f(): pass
self.assertRaises(TypeError, copy.replace, f)
class A: pass
self.assertRaises(TypeError, copy.replace, A)
self.assertRaises(TypeError, copy.replace, A())
def test_replace_method(self):
class A:
def __new__(cls, x, y=0):
self = object.__new__(cls)
self.x = x
self.y = y
return self
def __init__(self, *args, **kwargs):
self.z = self.x + self.y
def __replace__(self, **changes):
x = changes.get('x', self.x)
y = changes.get('y', self.y)
return type(self)(x, y)
attrs = attrgetter('x', 'y', 'z')
a = A(11, 22)
self.assertEqual(attrs(copy.replace(a)), (11, 22, 33))
self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23))
self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13))
self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3))
def test_namedtuple(self):
from collections import namedtuple
Point = namedtuple('Point', 'x y', defaults=(0,))
p = Point(11, 22)
self.assertEqual(copy.replace(p), (11, 22))
self.assertEqual(copy.replace(p, x=1), (1, 22))
self.assertEqual(copy.replace(p, y=2), (11, 2))
self.assertEqual(copy.replace(p, x=1, y=2), (1, 2))
with self.assertRaisesRegex(ValueError, 'unexpected field name'):
copy.replace(p, x=1, error=2)
def test_dataclass(self):
from dataclasses import dataclass
@dataclass
class C:
x: int
y: int = 0
attrs = attrgetter('x', 'y')
c = C(11, 22)
self.assertEqual(attrs(copy.replace(c)), (11, 22))
self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22))
self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2))
self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2))
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
copy.replace(c, x=1, error=2)
def global_foo(x, y): return x+y
if __name__ == "__main__":
unittest.main()