mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
gh-108751: Add copy.replace() function (GH-108752)
It creates a modified copy of an object by calling the object's __replace__() method. It is a generalization of dataclasses.replace(), named tuple's _replace() method and replace() methods in various classes, and supports all these stdlib classes.
This commit is contained in:
parent
9f0c0a46f0
commit
6f3c138dfa
19 changed files with 311 additions and 68 deletions
|
@ -4,7 +4,7 @@ import copy
|
|||
import copyreg
|
||||
import weakref
|
||||
import abc
|
||||
from operator import le, lt, ge, gt, eq, ne
|
||||
from operator import le, lt, ge, gt, eq, ne, attrgetter
|
||||
|
||||
import unittest
|
||||
from test import support
|
||||
|
@ -899,7 +899,71 @@ class TestCopy(unittest.TestCase):
|
|||
g.b()
|
||||
|
||||
|
||||
class TestReplace(unittest.TestCase):
|
||||
|
||||
def test_unsupported(self):
|
||||
self.assertRaises(TypeError, copy.replace, 1)
|
||||
self.assertRaises(TypeError, copy.replace, [])
|
||||
self.assertRaises(TypeError, copy.replace, {})
|
||||
def f(): pass
|
||||
self.assertRaises(TypeError, copy.replace, f)
|
||||
class A: pass
|
||||
self.assertRaises(TypeError, copy.replace, A)
|
||||
self.assertRaises(TypeError, copy.replace, A())
|
||||
|
||||
def test_replace_method(self):
|
||||
class A:
|
||||
def __new__(cls, x, y=0):
|
||||
self = object.__new__(cls)
|
||||
self.x = x
|
||||
self.y = y
|
||||
return self
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.z = self.x + self.y
|
||||
|
||||
def __replace__(self, **changes):
|
||||
x = changes.get('x', self.x)
|
||||
y = changes.get('y', self.y)
|
||||
return type(self)(x, y)
|
||||
|
||||
attrs = attrgetter('x', 'y', 'z')
|
||||
a = A(11, 22)
|
||||
self.assertEqual(attrs(copy.replace(a)), (11, 22, 33))
|
||||
self.assertEqual(attrs(copy.replace(a, x=1)), (1, 22, 23))
|
||||
self.assertEqual(attrs(copy.replace(a, y=2)), (11, 2, 13))
|
||||
self.assertEqual(attrs(copy.replace(a, x=1, y=2)), (1, 2, 3))
|
||||
|
||||
def test_namedtuple(self):
|
||||
from collections import namedtuple
|
||||
Point = namedtuple('Point', 'x y', defaults=(0,))
|
||||
p = Point(11, 22)
|
||||
self.assertEqual(copy.replace(p), (11, 22))
|
||||
self.assertEqual(copy.replace(p, x=1), (1, 22))
|
||||
self.assertEqual(copy.replace(p, y=2), (11, 2))
|
||||
self.assertEqual(copy.replace(p, x=1, y=2), (1, 2))
|
||||
with self.assertRaisesRegex(ValueError, 'unexpected field name'):
|
||||
copy.replace(p, x=1, error=2)
|
||||
|
||||
def test_dataclass(self):
|
||||
from dataclasses import dataclass
|
||||
@dataclass
|
||||
class C:
|
||||
x: int
|
||||
y: int = 0
|
||||
|
||||
attrs = attrgetter('x', 'y')
|
||||
c = C(11, 22)
|
||||
self.assertEqual(attrs(copy.replace(c)), (11, 22))
|
||||
self.assertEqual(attrs(copy.replace(c, x=1)), (1, 22))
|
||||
self.assertEqual(attrs(copy.replace(c, y=2)), (11, 2))
|
||||
self.assertEqual(attrs(copy.replace(c, x=1, y=2)), (1, 2))
|
||||
with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
|
||||
copy.replace(c, x=1, error=2)
|
||||
|
||||
|
||||
def global_foo(x, y): return x+y
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue