This commit is contained in:
Bénédikt Tran 2025-12-23 14:12:58 +05:30 committed by GitHub
commit f13d19ef2c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 194 additions and 134 deletions

View file

@ -150,6 +150,12 @@ def v6_int_to_packed(address):
raise ValueError("Address negative or too large for IPv6")
def _check_ip_version(a, b):
if a.version != b.version:
# does this need to raise a ValueError?
raise TypeError(f"{a} and {b} are not of the same version")
def _split_optional_netmask(address):
"""Helper to split the netmask and raise AddressValueError if needed"""
addr = str(address).split('/')
@ -213,7 +219,7 @@ def summarize_address_range(first, last):
Raise:
TypeError:
If the first and last objects are not IP addresses.
If the first or last objects are not IP addresses.
If the first and last objects are not the same version.
ValueError:
If the last object is not greater than the first.
@ -223,9 +229,7 @@ def summarize_address_range(first, last):
if (not (isinstance(first, _BaseAddress) and
isinstance(last, _BaseAddress))):
raise TypeError('first and last must be IP addresses, not networks')
if first.version != last.version:
raise TypeError("%s and %s are not of the same version" % (
first, last))
_check_ip_version(first, last)
if first > last:
raise ValueError('last IP address must be greater than first')
@ -316,40 +320,39 @@ def collapse_addresses(addresses):
TypeError: If passed a list of mixed version objects.
"""
addrs = []
ips = []
nets = []
# split IP addresses and networks
# split IP addresses/interfaces and networks
for ip in addresses:
if isinstance(ip, _BaseAddress):
if ips and ips[-1].version != ip.version:
raise TypeError("%s and %s are not of the same version" % (
ip, ips[-1]))
ips.append(ip)
elif ip._prefixlen == ip.max_prefixlen:
if ips and ips[-1].version != ip.version:
raise TypeError("%s and %s are not of the same version" % (
ip, ips[-1]))
try:
ips.append(ip.ip)
except AttributeError:
ips.append(ip.network_address)
if ips:
_check_ip_version(ips[-1], ip)
if hasattr(ip, "ip") and isinstance(ip.ip, _BaseAddress):
ips.append(ip.ip) # interface IP address
else:
ips.append(ip)
elif isinstance(ip, _BaseNetwork):
if ip.prefixlen == ip.max_prefixlen:
if ips:
_check_ip_version(ips[-1], ip)
ips.append(ip.network_address) # network address
else:
if nets:
_check_ip_version(nets[-1], ip)
nets.append(ip)
else:
if nets and nets[-1].version != ip.version:
raise TypeError("%s and %s are not of the same version" % (
ip, nets[-1]))
nets.append(ip)
raise TypeError(f"{ip} is not an IP object")
# sort and dedup
ips = sorted(set(ips))
# find consecutive address ranges in the sorted sequence and summarize them
nets_from_range = []
if ips:
for first, last in _find_address_range(ips):
addrs.extend(summarize_address_range(first, last))
nets_from_range.extend(summarize_address_range(first, last))
return _collapse_addresses_internal(addrs + nets)
return _collapse_addresses_internal(nets_from_range + nets)
def get_mixed_type_key(obj):
@ -567,21 +570,15 @@ class _BaseAddress(_IPAddressBase):
return self._ip
def __eq__(self, other):
try:
return (self._ip == other._ip
and self.version == other.version)
except AttributeError:
if not isinstance(other, _BaseAddress):
return NotImplemented
return self._ip == other._ip and self.version == other.version
def __lt__(self, other):
if not isinstance(other, _BaseAddress):
return NotImplemented
if self.version != other.version:
raise TypeError('%s and %s are not of the same version' % (
self, other))
if self._ip != other._ip:
return self._ip < other._ip
return False
_check_ip_version(self, other)
return self._ip < other._ip
# Shorthand for Integer addition and subtraction. This is not
# meant to ever support addition/subtraction of addresses.
@ -708,9 +705,7 @@ class _BaseNetwork(_IPAddressBase):
def __lt__(self, other):
if not isinstance(other, _BaseNetwork):
return NotImplemented
if self.version != other.version:
raise TypeError('%s and %s are not of the same version' % (
self, other))
_check_ip_version(self, other)
if self.network_address != other.network_address:
return self.network_address < other.network_address
if self.netmask != other.netmask:
@ -718,30 +713,31 @@ class _BaseNetwork(_IPAddressBase):
return False
def __eq__(self, other):
try:
return (self.version == other.version and
self.network_address == other.network_address and
int(self.netmask) == int(other.netmask))
except AttributeError:
if not isinstance(other, _BaseNetwork):
return NotImplemented
return (self.version == other.version and
self.network_address == other.network_address and
int(self.netmask._ip) == int(other.netmask))
def __hash__(self):
return hash((int(self.network_address), int(self.netmask)))
def __contains__(self, other):
# always false if one is v4 and the other is v6.
if self.version != other.version:
return False
# dealing with another network.
if isinstance(other, _BaseNetwork):
# should __contains__ actually implement subnet_of()
# and supernet_of() instead?
return False
# dealing with another address
else:
# address
return other._ip & self.netmask._ip == self.network_address._ip
if isinstance(other, _BaseAddress):
return (
self.version == other.version
and (other._ip & self.netmask._ip) == self.network_address._ip
)
return NotImplemented
def overlaps(self, other):
"""Tell if self is partly contained in other."""
if not isinstance(other, _BaseNetwork):
raise TypeError(f"expecting a network object, not {type(other)}")
return self.network_address in other or (
self.broadcast_address in other or (
other.network_address in self or (
@ -821,13 +817,9 @@ class _BaseNetwork(_IPAddressBase):
ValueError: If other is not completely contained by self.
"""
if not self.version == other.version:
raise TypeError("%s and %s are not of the same version" % (
self, other))
if not isinstance(other, _BaseNetwork):
raise TypeError("%s is not a network object" % other)
raise TypeError(f"expecting a network object, not {type(other)}")
_check_ip_version(self, other)
if not other.subnet_of(self):
raise ValueError('%s not contained in %s' % (other, self))
if other == self:
@ -870,7 +862,7 @@ class _BaseNetwork(_IPAddressBase):
'HostA._ip < HostB._ip'
Args:
other: An IP object.
other: An IP network object.
Returns:
If the IP versions of self and other are the same, returns:
@ -892,10 +884,9 @@ class _BaseNetwork(_IPAddressBase):
TypeError if the IP versions are different.
"""
# does this need to raise a ValueError?
if self.version != other.version:
raise TypeError('%s and %s are not of the same type' % (
self, other))
if not isinstance(other, _BaseNetwork):
raise TypeError(f"expecting a network object, not {type(other)}")
_check_ip_version(self, other)
# self.version == other.version below here:
if self.network_address < other.network_address:
return -1
@ -1026,22 +1017,21 @@ class _BaseNetwork(_IPAddressBase):
@staticmethod
def _is_subnet_of(a, b):
try:
# Always false if one is v4 and the other is v6.
if a.version != b.version:
raise TypeError(f"{a} and {b} are not of the same version")
return (b.network_address <= a.network_address and
b.broadcast_address >= a.broadcast_address)
except AttributeError:
raise TypeError(f"Unable to test subnet containment "
f"between {a} and {b}")
# The caller must ensure that 'a' and 'b' are both networks.
_check_ip_version(a, b)
return (b.network_address <= a.network_address and
b.broadcast_address >= a.broadcast_address)
def subnet_of(self, other):
"""Return True if this network is a subnet of other."""
if not isinstance(other, _BaseNetwork):
raise TypeError(f"expecting a network object, not {type(other)}")
return self._is_subnet_of(self, other)
def supernet_of(self, other):
"""Return True if this network is a supernet of other."""
if not isinstance(other, _BaseNetwork):
raise TypeError(f"expecting a network object, not {type(other)}")
return self._is_subnet_of(other, self)
@property
@ -1429,28 +1419,27 @@ class IPv4Interface(IPv4Address):
self._prefixlen)
def __eq__(self, other):
if not isinstance(other, IPv4Interface):
if isinstance(other, IPv4Address):
# avoid falling back to IPv4Address.__eq__(other, self)
return False
return NotImplemented
# An interface with an associated network is NOT the
# same as an unassociated address. That's why the hash
# takes the extra info into account.
address_equal = IPv4Address.__eq__(self, other)
if address_equal is NotImplemented or not address_equal:
return address_equal
try:
return self.network == other.network
except AttributeError:
# An interface with an associated network is NOT the
# same as an unassociated address. That's why the hash
# takes the extra info into account.
return False
return address_equal and self.network == other.network
def __lt__(self, other):
# We *do* allow addresses and interfaces to be sorted. The
# unassociated address is considered less than all interfaces.
address_less = IPv4Address.__lt__(self, other)
if address_less is NotImplemented:
return NotImplemented
try:
return (self.network < other.network or
self.network == other.network and address_less)
except AttributeError:
# We *do* allow addresses and interfaces to be sorted. The
# unassociated address is considered less than all interfaces.
return False
if isinstance(other, IPv4Interface):
assert address_less is not NotImplemented
# compare interfaces by their network first
return (self.network < other.network
or (self.network == other.network and address_less))
return address_less
def __hash__(self):
return hash((self._ip, self._prefixlen, int(self.network.network_address)))
@ -2219,28 +2208,27 @@ class IPv6Interface(IPv6Address):
self._prefixlen)
def __eq__(self, other):
if not isinstance(other, IPv6Interface):
if isinstance(other, IPv6Address):
# avoid falling back to IPv6Address.__eq__(other, self)
return False
return NotImplemented
# An interface with an associated network is NOT the
# same as an unassociated address. That's why the hash
# takes the extra info into account.
address_equal = IPv6Address.__eq__(self, other)
if address_equal is NotImplemented or not address_equal:
return address_equal
try:
return self.network == other.network
except AttributeError:
# An interface with an associated network is NOT the
# same as an unassociated address. That's why the hash
# takes the extra info into account.
return False
return address_equal and self.network == other.network
def __lt__(self, other):
# We *do* allow addresses and interfaces to be sorted. The
# unassociated address is considered less than all interfaces.
address_less = IPv6Address.__lt__(self, other)
if address_less is NotImplemented:
return address_less
try:
return (self.network < other.network or
self.network == other.network and address_less)
except AttributeError:
# We *do* allow addresses and interfaces to be sorted. The
# unassociated address is considered less than all interfaces.
return False
if isinstance(other, IPv6Interface):
assert address_less is not NotImplemented
# compare interfaces by their network first
return (self.network < other.network
or (self.network == other.network and address_less))
return address_less
def __hash__(self):
return hash((self._ip, self._prefixlen, int(self.network.network_address)))

View file

@ -13,6 +13,7 @@ import pickle
import ipaddress
import weakref
from collections.abc import Iterator
from functools import total_ordering
from test.support import LARGEST, SMALLEST
@ -912,11 +913,22 @@ class ComparisonTests(unittest.TestCase):
v6intf_scoped = ipaddress.IPv6Interface('::1%scope')
v4_addresses = [v4addr, v4intf]
v4_objects = v4_addresses + [v4net]
v4_networks = [v4net]
v4_objects = v4_addresses + v4_networks
v6_addresses = [v6addr, v6intf]
v6_objects = v6_addresses + [v6net]
v6_networks = [v6net]
v6_objects = v6_addresses + v6_networks
v6_scoped_addresses = [v6addr_scoped, v6intf_scoped]
v6_scoped_objects = v6_scoped_addresses + [v6net_scoped]
v6_scoped_networks = [v6net_scoped]
v6_scoped_objects = v6_scoped_addresses + v6_scoped_networks
addresses = v4_addresses + v6_addresses
addresses_with_scoped = addresses + v6_scoped_addresses
networks = v4_networks + v6_networks
networks_with_scoped = networks + v6_scoped_networks
objects = v4_objects + v6_objects
objects_with_scoped = objects + v6_scoped_objects
@ -935,10 +947,14 @@ class ComparisonTests(unittest.TestCase):
# __eq__ should never raise TypeError directly
other = object()
for obj in self.objects_with_scoped:
self.assertNotEqual(obj, other)
self.assertFalse(obj == other)
self.assertEqual(obj.__eq__(other), NotImplemented)
self.assertEqual(obj.__ne__(other), NotImplemented)
with self.subTest(obj=obj):
self.assertNotEqual(obj, other)
self.assertFalse(obj == other)
self.assertIs(obj.__eq__(other), NotImplemented)
self.assertTrue(obj != other)
self.assertIs(obj.__ne__(other), NotImplemented)
def test_mixed_type_equality(self):
# Ensure none of the internal objects accidentally
@ -1006,30 +1022,54 @@ class ComparisonTests(unittest.TestCase):
for rhs in self.objects_with_scoped:
if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)):
continue
self.assertRaises(TypeError, lambda: lhs < rhs)
self.assertRaises(TypeError, lambda: lhs > rhs)
self.assertRaises(TypeError, lambda: lhs <= rhs)
self.assertRaises(TypeError, lambda: lhs >= rhs)
for dunder in ["__lt__", "__le__", "__ge__", "__gt__"]:
with self.subTest(dunder, lhs=lhs, rhs=rhs):
func = getattr(operator, dunder)
# dunders raise a TypeError or return NotImplemented
lhs_method = getattr(lhs, dunder)
try:
self.assertIs(lhs_method(rhs), NotImplemented)
except TypeError as exc:
self.assertIn("version", str(exc))
rhs_method = getattr(rhs, dunder)
try:
self.assertIs(rhs_method(lhs), NotImplemented)
except TypeError as exc:
self.assertIn("version", str(exc))
# Using the comparison operator directly must
# raise a TypeError, either because we returned
# NotImplemented or because of incompatible versions.
self.assertRaises(TypeError, func, lhs, rhs)
def test_foreign_type_ordering(self):
other = object()
for obj in self.objects_with_scoped:
with self.assertRaises(TypeError):
obj < other
with self.assertRaises(TypeError):
obj > other
with self.assertRaises(TypeError):
obj <= other
with self.assertRaises(TypeError):
obj >= other
self.assertTrue(obj < LARGEST)
self.assertFalse(obj > LARGEST)
self.assertTrue(obj <= LARGEST)
self.assertFalse(obj >= LARGEST)
self.assertFalse(obj < SMALLEST)
self.assertTrue(obj > SMALLEST)
self.assertFalse(obj <= SMALLEST)
self.assertTrue(obj >= SMALLEST)
with self.subTest(obj=obj):
for dunder in ["__lt__", "__le__", "__ge__", "__gt__"]:
with self.subTest(dunder):
via_meth = getattr(obj, dunder)
self.assertIs(via_meth(other), NotImplemented)
via_op = getattr(operator, dunder)
self.assertRaises(TypeError, via_op, obj, other)
self.assertIs(obj.__lt__(LARGEST), NotImplemented)
self.assertTrue(obj < LARGEST)
self.assertIs(obj.__le__(LARGEST), NotImplemented)
self.assertTrue(obj <= LARGEST)
self.assertIs(obj.__ge__(LARGEST), NotImplemented)
self.assertFalse(obj >= LARGEST)
self.assertIs(obj.__gt__(LARGEST), NotImplemented)
self.assertFalse(obj > LARGEST)
self.assertIs(obj.__lt__(SMALLEST), NotImplemented)
self.assertFalse(obj < SMALLEST)
self.assertIs(obj.__le__(SMALLEST), NotImplemented)
self.assertFalse(obj <= SMALLEST)
self.assertIs(obj.__ge__(SMALLEST), NotImplemented)
self.assertTrue(obj >= SMALLEST)
self.assertIs(obj.__gt__(SMALLEST), NotImplemented)
self.assertTrue(obj > SMALLEST)
def test_mixed_type_key(self):
# with get_mixed_type_key, you can sort addresses and network.
@ -1079,6 +1119,33 @@ class ComparisonTests(unittest.TestCase):
self.assertRaises(TypeError, v6net_scoped.__lt__, v4net)
self.assertRaises(TypeError, v6net_scoped.__gt__, v4net)
def test_object_compare_with_always_equal(self):
# Check that __eq__/__lt__ for IP objects work for non-IP
# objects that share the same attributes as IP objects.
class AlwaysEqual:
version = None
def __eq__(self, other):
return True
same_object = AlwaysEqual()
for obj in self.objects_with_scoped:
with self.subTest(obj=obj):
self.assertEqual(obj, same_object)
def test_object_compare_with_always_smallest(self):
@total_ordering
class Smallest:
version = None
def __lt__(self, other):
return True
smallest = Smallest()
for obj in self.objects_with_scoped:
with self.subTest(obj=obj):
# ensure that we dispatch to Smallest.__lt__ instead.
self.assertIs(obj.__lt__(smallest), NotImplemented)
self.assertLess(smallest, obj)
class IpaddrUnitTest(unittest.TestCase):

View file

@ -0,0 +1,5 @@
:mod:`ipaddress`: fix comparison operators :class:`~ipaddress.IPv4Address`,
:class:`~ipaddress.IPv6Address`, :class:`~ipaddress.IPv4Network`,
:class:`~ipaddress.IPv6Network`, :class:`~ipaddress.IPv4Interface`, and
:class:`~ipaddress.IPv6Interface` to avoid comparing instances of incorrect
types. Patch by Bénédikt Tran and yihong0618.