mirror of
https://github.com/python/cpython.git
synced 2025-11-02 19:12:55 +00:00
Issue 14814: Ensure ordering semantics across all 3 entity types in ipaddress are consistent and well-defined
This commit is contained in:
parent
9a9c28ce7a
commit
3008ec070f
3 changed files with 169 additions and 128 deletions
136
Lib/ipaddress.py
136
Lib/ipaddress.py
|
|
@ -12,7 +12,7 @@ __version__ = '1.0'
|
|||
|
||||
|
||||
import struct
|
||||
|
||||
import functools
|
||||
|
||||
IPV4LENGTH = 32
|
||||
IPV6LENGTH = 128
|
||||
|
|
@ -405,7 +405,38 @@ def get_mixed_type_key(obj):
|
|||
return NotImplemented
|
||||
|
||||
|
||||
class _IPAddressBase:
|
||||
class _TotalOrderingMixin:
|
||||
# Helper that derives the other comparison operations from
|
||||
# __lt__ and __eq__
|
||||
def __eq__(self, other):
|
||||
raise NotImplementedError
|
||||
def __ne__(self, other):
|
||||
equal = self.__eq__(other)
|
||||
if equal is NotImplemented:
|
||||
return NotImplemented
|
||||
return not equal
|
||||
def __lt__(self, other):
|
||||
raise NotImplementedError
|
||||
def __le__(self, other):
|
||||
less = self.__lt__(other)
|
||||
if less is NotImplemented or not less:
|
||||
return self.__eq__(other)
|
||||
return less
|
||||
def __gt__(self, other):
|
||||
less = self.__lt__(other)
|
||||
if less is NotImplemented:
|
||||
return NotImplemented
|
||||
equal = self.__eq__(other)
|
||||
if equal is NotImplemented:
|
||||
return NotImplemented
|
||||
return not (less or equal)
|
||||
def __ge__(self, other):
|
||||
less = self.__lt__(other)
|
||||
if less is NotImplemented:
|
||||
return NotImplemented
|
||||
return not less
|
||||
|
||||
class _IPAddressBase(_TotalOrderingMixin):
|
||||
|
||||
"""The mother class."""
|
||||
|
||||
|
|
@ -465,7 +496,6 @@ class _IPAddressBase:
|
|||
prefixlen = self._prefixlen
|
||||
return self._string_from_ip_int(self._ip_int_from_prefix(prefixlen))
|
||||
|
||||
|
||||
class _BaseAddress(_IPAddressBase):
|
||||
|
||||
"""A generic IP object.
|
||||
|
|
@ -493,24 +523,6 @@ class _BaseAddress(_IPAddressBase):
|
|||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
eq = self.__eq__(other)
|
||||
if eq is NotImplemented:
|
||||
return NotImplemented
|
||||
return not eq
|
||||
|
||||
def __le__(self, other):
|
||||
gt = self.__gt__(other)
|
||||
if gt is NotImplemented:
|
||||
return NotImplemented
|
||||
return not gt
|
||||
|
||||
def __ge__(self, other):
|
||||
lt = self.__lt__(other)
|
||||
if lt is NotImplemented:
|
||||
return NotImplemented
|
||||
return not lt
|
||||
|
||||
def __lt__(self, other):
|
||||
if self._version != other._version:
|
||||
raise TypeError('%s and %s are not of the same version' % (
|
||||
|
|
@ -522,17 +534,6 @@ class _BaseAddress(_IPAddressBase):
|
|||
return self._ip < other._ip
|
||||
return False
|
||||
|
||||
def __gt__(self, other):
|
||||
if self._version != other._version:
|
||||
raise TypeError('%s and %s are not of the same version' % (
|
||||
self, other))
|
||||
if not isinstance(other, _BaseAddress):
|
||||
raise TypeError('%s and %s are not of the same type' % (
|
||||
self, other))
|
||||
if self._ip != other._ip:
|
||||
return self._ip > other._ip
|
||||
return False
|
||||
|
||||
# Shorthand for Integer addition and subtraction. This is not
|
||||
# meant to ever support addition/subtraction of addresses.
|
||||
def __add__(self, other):
|
||||
|
|
@ -625,31 +626,6 @@ class _BaseNetwork(_IPAddressBase):
|
|||
return self.netmask < other.netmask
|
||||
return False
|
||||
|
||||
def __gt__(self, other):
|
||||
if self._version != other._version:
|
||||
raise TypeError('%s and %s are not of the same version' % (
|
||||
self, other))
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
raise TypeError('%s and %s are not of the same type' % (
|
||||
self, other))
|
||||
if self.network_address != other.network_address:
|
||||
return self.network_address > other.network_address
|
||||
if self.netmask != other.netmask:
|
||||
return self.netmask > other.netmask
|
||||
return False
|
||||
|
||||
def __le__(self, other):
|
||||
gt = self.__gt__(other)
|
||||
if gt is NotImplemented:
|
||||
return NotImplemented
|
||||
return not gt
|
||||
|
||||
def __ge__(self, other):
|
||||
lt = self.__lt__(other)
|
||||
if lt is NotImplemented:
|
||||
return NotImplemented
|
||||
return not lt
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
return (self._version == other._version and
|
||||
|
|
@ -658,12 +634,6 @@ class _BaseNetwork(_IPAddressBase):
|
|||
except AttributeError:
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
eq = self.__eq__(other)
|
||||
if eq is NotImplemented:
|
||||
return NotImplemented
|
||||
return not eq
|
||||
|
||||
def __hash__(self):
|
||||
return hash(int(self.network_address) ^ int(self.netmask))
|
||||
|
||||
|
|
@ -1292,11 +1262,27 @@ class IPv4Interface(IPv4Address):
|
|||
self.network.prefixlen)
|
||||
|
||||
def __eq__(self, other):
|
||||
address_equal = IPv4Address.__eq__(self, other)
|
||||
if not address_equal or address_equal is NotImplemented:
|
||||
return address_equal
|
||||
try:
|
||||
return (IPv4Address.__eq__(self, other) and
|
||||
self.network == other.network)
|
||||
return self.network == other.network
|
||||
except AttributeError:
|
||||
# An interface with an associated network is NOT the
|
||||
# same as an unassociated address. That's why the hash
|
||||
# takes the extra info into account.
|
||||
return False
|
||||
|
||||
def __lt__(self, other):
|
||||
address_less = IPv4Address.__lt__(self, other)
|
||||
if address_less is NotImplemented:
|
||||
return NotImplemented
|
||||
try:
|
||||
return self.network < other.network
|
||||
except AttributeError:
|
||||
# We *do* allow addresses and interfaces to be sorted. The
|
||||
# unassociated address is considered less than all interfaces.
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return self._ip ^ self._prefixlen ^ int(self.network.network_address)
|
||||
|
|
@ -1928,11 +1914,27 @@ class IPv6Interface(IPv6Address):
|
|||
self.network.prefixlen)
|
||||
|
||||
def __eq__(self, other):
|
||||
address_equal = IPv6Address.__eq__(self, other)
|
||||
if not address_equal or address_equal is NotImplemented:
|
||||
return address_equal
|
||||
try:
|
||||
return (IPv6Address.__eq__(self, other) and
|
||||
self.network == other.network)
|
||||
return self.network == other.network
|
||||
except AttributeError:
|
||||
# An interface with an associated network is NOT the
|
||||
# same as an unassociated address. That's why the hash
|
||||
# takes the extra info into account.
|
||||
return False
|
||||
|
||||
def __lt__(self, other):
|
||||
address_less = IPv6Address.__lt__(self, other)
|
||||
if address_less is NotImplemented:
|
||||
return NotImplemented
|
||||
try:
|
||||
return self.network < other.network
|
||||
except AttributeError:
|
||||
# We *do* allow addresses and interfaces to be sorted. The
|
||||
# unassociated address is considered less than all interfaces.
|
||||
return False
|
||||
|
||||
def __hash__(self):
|
||||
return self._ip ^ self._prefixlen ^ int(self.network.network_address)
|
||||
|
|
|
|||
|
|
@ -415,6 +415,93 @@ class FactoryFunctionErrors(ErrorReporting):
|
|||
self.assertFactoryError(ipaddress.ip_network, "network")
|
||||
|
||||
|
||||
class ComparisonTests(unittest.TestCase):
|
||||
|
||||
v4addr = ipaddress.IPv4Address(1)
|
||||
v4net = ipaddress.IPv4Network(1)
|
||||
v4intf = ipaddress.IPv4Interface(1)
|
||||
v6addr = ipaddress.IPv6Address(1)
|
||||
v6net = ipaddress.IPv6Network(1)
|
||||
v6intf = ipaddress.IPv6Interface(1)
|
||||
|
||||
v4_addresses = [v4addr, v4intf]
|
||||
v4_objects = v4_addresses + [v4net]
|
||||
v6_addresses = [v6addr, v6intf]
|
||||
v6_objects = v6_addresses + [v6net]
|
||||
objects = v4_objects + v6_objects
|
||||
|
||||
def test_foreign_type_equality(self):
|
||||
# __eq__ should never raise TypeError directly
|
||||
other = object()
|
||||
for obj in self.objects:
|
||||
self.assertNotEqual(obj, other)
|
||||
self.assertFalse(obj == other)
|
||||
self.assertEqual(obj.__eq__(other), NotImplemented)
|
||||
self.assertEqual(obj.__ne__(other), NotImplemented)
|
||||
|
||||
def test_mixed_type_equality(self):
|
||||
# Ensure none of the internal objects accidentally
|
||||
# expose the right set of attributes to become "equal"
|
||||
for lhs in self.objects:
|
||||
for rhs in self.objects:
|
||||
if lhs is rhs:
|
||||
continue
|
||||
self.assertNotEqual(lhs, rhs)
|
||||
|
||||
def test_containment(self):
|
||||
for obj in self.v4_addresses:
|
||||
self.assertIn(obj, self.v4net)
|
||||
for obj in self.v6_addresses:
|
||||
self.assertIn(obj, self.v6net)
|
||||
for obj in self.v4_objects + [self.v6net]:
|
||||
self.assertNotIn(obj, self.v6net)
|
||||
for obj in self.v6_objects + [self.v4net]:
|
||||
self.assertNotIn(obj, self.v4net)
|
||||
|
||||
def test_mixed_type_ordering(self):
|
||||
for lhs in self.objects:
|
||||
for rhs in self.objects:
|
||||
if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)):
|
||||
continue
|
||||
self.assertRaises(TypeError, lambda: lhs < rhs)
|
||||
self.assertRaises(TypeError, lambda: lhs > rhs)
|
||||
self.assertRaises(TypeError, lambda: lhs <= rhs)
|
||||
self.assertRaises(TypeError, lambda: lhs >= rhs)
|
||||
|
||||
def test_mixed_type_key(self):
|
||||
# with get_mixed_type_key, you can sort addresses and network.
|
||||
v4_ordered = [self.v4addr, self.v4net, self.v4intf]
|
||||
v6_ordered = [self.v6addr, self.v6net, self.v6intf]
|
||||
self.assertEqual(v4_ordered,
|
||||
sorted(self.v4_objects,
|
||||
key=ipaddress.get_mixed_type_key))
|
||||
self.assertEqual(v6_ordered,
|
||||
sorted(self.v6_objects,
|
||||
key=ipaddress.get_mixed_type_key))
|
||||
self.assertEqual(v4_ordered + v6_ordered,
|
||||
sorted(self.objects,
|
||||
key=ipaddress.get_mixed_type_key))
|
||||
self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object))
|
||||
|
||||
def test_incompatible_versions(self):
|
||||
# These should always raise TypeError
|
||||
v4addr = ipaddress.ip_address('1.1.1.1')
|
||||
v4net = ipaddress.ip_network('1.1.1.1')
|
||||
v6addr = ipaddress.ip_address('::1')
|
||||
v6net = ipaddress.ip_address('::1')
|
||||
|
||||
self.assertRaises(TypeError, v4addr.__lt__, v6addr)
|
||||
self.assertRaises(TypeError, v4addr.__gt__, v6addr)
|
||||
self.assertRaises(TypeError, v4net.__lt__, v6net)
|
||||
self.assertRaises(TypeError, v4net.__gt__, v6net)
|
||||
|
||||
self.assertRaises(TypeError, v6addr.__lt__, v4addr)
|
||||
self.assertRaises(TypeError, v6addr.__gt__, v4addr)
|
||||
self.assertRaises(TypeError, v6net.__lt__, v4net)
|
||||
self.assertRaises(TypeError, v6net.__gt__, v4net)
|
||||
|
||||
|
||||
|
||||
class IpaddrUnitTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
|
|
@ -495,67 +582,6 @@ class IpaddrUnitTest(unittest.TestCase):
|
|||
self.assertEqual(str(self.ipv6_network.hostmask),
|
||||
'::ffff:ffff:ffff:ffff')
|
||||
|
||||
def testEqualityChecks(self):
|
||||
# __eq__ should never raise TypeError directly
|
||||
other = object()
|
||||
def assertEqualityNotImplemented(instance):
|
||||
self.assertEqual(instance.__eq__(other), NotImplemented)
|
||||
self.assertEqual(instance.__ne__(other), NotImplemented)
|
||||
self.assertFalse(instance == other)
|
||||
self.assertTrue(instance != other)
|
||||
|
||||
assertEqualityNotImplemented(self.ipv4_address)
|
||||
assertEqualityNotImplemented(self.ipv4_network)
|
||||
assertEqualityNotImplemented(self.ipv4_interface)
|
||||
assertEqualityNotImplemented(self.ipv6_address)
|
||||
assertEqualityNotImplemented(self.ipv6_network)
|
||||
assertEqualityNotImplemented(self.ipv6_interface)
|
||||
|
||||
def testBadVersionComparison(self):
|
||||
# These should always raise TypeError
|
||||
v4addr = ipaddress.ip_address('1.1.1.1')
|
||||
v4net = ipaddress.ip_network('1.1.1.1')
|
||||
v6addr = ipaddress.ip_address('::1')
|
||||
v6net = ipaddress.ip_address('::1')
|
||||
|
||||
self.assertRaises(TypeError, v4addr.__lt__, v6addr)
|
||||
self.assertRaises(TypeError, v4addr.__gt__, v6addr)
|
||||
self.assertRaises(TypeError, v4net.__lt__, v6net)
|
||||
self.assertRaises(TypeError, v4net.__gt__, v6net)
|
||||
|
||||
self.assertRaises(TypeError, v6addr.__lt__, v4addr)
|
||||
self.assertRaises(TypeError, v6addr.__gt__, v4addr)
|
||||
self.assertRaises(TypeError, v6net.__lt__, v4net)
|
||||
self.assertRaises(TypeError, v6net.__gt__, v4net)
|
||||
|
||||
def testMixedTypeComparison(self):
|
||||
v4addr = ipaddress.ip_address('1.1.1.1')
|
||||
v4net = ipaddress.ip_network('1.1.1.1/32')
|
||||
v6addr = ipaddress.ip_address('::1')
|
||||
v6net = ipaddress.ip_network('::1/128')
|
||||
|
||||
self.assertFalse(v4net.__contains__(v6net))
|
||||
self.assertFalse(v6net.__contains__(v4net))
|
||||
|
||||
self.assertRaises(TypeError, lambda: v4addr < v4net)
|
||||
self.assertRaises(TypeError, lambda: v4addr > v4net)
|
||||
self.assertRaises(TypeError, lambda: v4net < v4addr)
|
||||
self.assertRaises(TypeError, lambda: v4net > v4addr)
|
||||
|
||||
self.assertRaises(TypeError, lambda: v6addr < v6net)
|
||||
self.assertRaises(TypeError, lambda: v6addr > v6net)
|
||||
self.assertRaises(TypeError, lambda: v6net < v6addr)
|
||||
self.assertRaises(TypeError, lambda: v6net > v6addr)
|
||||
|
||||
# with get_mixed_type_key, you can sort addresses and network.
|
||||
self.assertEqual([v4addr, v4net],
|
||||
sorted([v4net, v4addr],
|
||||
key=ipaddress.get_mixed_type_key))
|
||||
self.assertEqual([v6addr, v6net],
|
||||
sorted([v6net, v6addr],
|
||||
key=ipaddress.get_mixed_type_key))
|
||||
self.assertEqual(NotImplemented, ipaddress.get_mixed_type_key(object))
|
||||
|
||||
def testIpFromInt(self):
|
||||
self.assertEqual(self.ipv4_interface._ip,
|
||||
ipaddress.IPv4Interface(16909060)._ip)
|
||||
|
|
@ -1049,6 +1075,16 @@ class IpaddrUnitTest(unittest.TestCase):
|
|||
self.assertTrue(ipaddress.ip_address('::1') <=
|
||||
ipaddress.ip_address('::2'))
|
||||
|
||||
def testInterfaceComparison(self):
|
||||
self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
|
||||
ipaddress.ip_interface('1.1.1.1'))
|
||||
self.assertTrue(ipaddress.ip_interface('1.1.1.1') <=
|
||||
ipaddress.ip_interface('1.1.1.2'))
|
||||
self.assertTrue(ipaddress.ip_interface('::1') <=
|
||||
ipaddress.ip_interface('::1'))
|
||||
self.assertTrue(ipaddress.ip_interface('::1') <=
|
||||
ipaddress.ip_interface('::2'))
|
||||
|
||||
def testNetworkComparison(self):
|
||||
# ip1 and ip2 have the same network address
|
||||
ip1 = ipaddress.IPv4Network('1.1.1.0/24')
|
||||
|
|
|
|||
|
|
@ -23,6 +23,9 @@ Core and Builtins
|
|||
Library
|
||||
-------
|
||||
|
||||
- Issue #14814: implement more consistent ordering and sorting behaviour
|
||||
for ipaddress objects
|
||||
|
||||
- Issue #14814: ipaddress network objects correctly return NotImplemented
|
||||
when compared to arbitrary objects instead of raising TypeError
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue