mirror of
https://github.com/python/cpython.git
synced 2025-12-23 09:19:18 +00:00
Merge 100dad9e2c into f9704f1d84
This commit is contained in:
commit
f13d19ef2c
3 changed files with 194 additions and 134 deletions
202
Lib/ipaddress.py
202
Lib/ipaddress.py
|
|
@ -150,6 +150,12 @@ def v6_int_to_packed(address):
|
|||
raise ValueError("Address negative or too large for IPv6")
|
||||
|
||||
|
||||
def _check_ip_version(a, b):
|
||||
if a.version != b.version:
|
||||
# does this need to raise a ValueError?
|
||||
raise TypeError(f"{a} and {b} are not of the same version")
|
||||
|
||||
|
||||
def _split_optional_netmask(address):
|
||||
"""Helper to split the netmask and raise AddressValueError if needed"""
|
||||
addr = str(address).split('/')
|
||||
|
|
@ -213,7 +219,7 @@ def summarize_address_range(first, last):
|
|||
|
||||
Raise:
|
||||
TypeError:
|
||||
If the first and last objects are not IP addresses.
|
||||
If the first or last objects are not IP addresses.
|
||||
If the first and last objects are not the same version.
|
||||
ValueError:
|
||||
If the last object is not greater than the first.
|
||||
|
|
@ -223,9 +229,7 @@ def summarize_address_range(first, last):
|
|||
if (not (isinstance(first, _BaseAddress) and
|
||||
isinstance(last, _BaseAddress))):
|
||||
raise TypeError('first and last must be IP addresses, not networks')
|
||||
if first.version != last.version:
|
||||
raise TypeError("%s and %s are not of the same version" % (
|
||||
first, last))
|
||||
_check_ip_version(first, last)
|
||||
if first > last:
|
||||
raise ValueError('last IP address must be greater than first')
|
||||
|
||||
|
|
@ -316,40 +320,39 @@ def collapse_addresses(addresses):
|
|||
TypeError: If passed a list of mixed version objects.
|
||||
|
||||
"""
|
||||
addrs = []
|
||||
ips = []
|
||||
nets = []
|
||||
|
||||
# split IP addresses and networks
|
||||
# split IP addresses/interfaces and networks
|
||||
for ip in addresses:
|
||||
if isinstance(ip, _BaseAddress):
|
||||
if ips and ips[-1].version != ip.version:
|
||||
raise TypeError("%s and %s are not of the same version" % (
|
||||
ip, ips[-1]))
|
||||
ips.append(ip)
|
||||
elif ip._prefixlen == ip.max_prefixlen:
|
||||
if ips and ips[-1].version != ip.version:
|
||||
raise TypeError("%s and %s are not of the same version" % (
|
||||
ip, ips[-1]))
|
||||
try:
|
||||
ips.append(ip.ip)
|
||||
except AttributeError:
|
||||
ips.append(ip.network_address)
|
||||
if ips:
|
||||
_check_ip_version(ips[-1], ip)
|
||||
if hasattr(ip, "ip") and isinstance(ip.ip, _BaseAddress):
|
||||
ips.append(ip.ip) # interface IP address
|
||||
else:
|
||||
ips.append(ip)
|
||||
elif isinstance(ip, _BaseNetwork):
|
||||
if ip.prefixlen == ip.max_prefixlen:
|
||||
if ips:
|
||||
_check_ip_version(ips[-1], ip)
|
||||
ips.append(ip.network_address) # network address
|
||||
else:
|
||||
if nets:
|
||||
_check_ip_version(nets[-1], ip)
|
||||
nets.append(ip)
|
||||
else:
|
||||
if nets and nets[-1].version != ip.version:
|
||||
raise TypeError("%s and %s are not of the same version" % (
|
||||
ip, nets[-1]))
|
||||
nets.append(ip)
|
||||
raise TypeError(f"{ip} is not an IP object")
|
||||
|
||||
# sort and dedup
|
||||
ips = sorted(set(ips))
|
||||
|
||||
# find consecutive address ranges in the sorted sequence and summarize them
|
||||
nets_from_range = []
|
||||
if ips:
|
||||
for first, last in _find_address_range(ips):
|
||||
addrs.extend(summarize_address_range(first, last))
|
||||
nets_from_range.extend(summarize_address_range(first, last))
|
||||
|
||||
return _collapse_addresses_internal(addrs + nets)
|
||||
return _collapse_addresses_internal(nets_from_range + nets)
|
||||
|
||||
|
||||
def get_mixed_type_key(obj):
|
||||
|
|
@ -567,21 +570,15 @@ class _BaseAddress(_IPAddressBase):
|
|||
return self._ip
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
return (self._ip == other._ip
|
||||
and self.version == other.version)
|
||||
except AttributeError:
|
||||
if not isinstance(other, _BaseAddress):
|
||||
return NotImplemented
|
||||
return self._ip == other._ip and self.version == other.version
|
||||
|
||||
def __lt__(self, other):
|
||||
if not isinstance(other, _BaseAddress):
|
||||
return NotImplemented
|
||||
if self.version != other.version:
|
||||
raise TypeError('%s and %s are not of the same version' % (
|
||||
self, other))
|
||||
if self._ip != other._ip:
|
||||
return self._ip < other._ip
|
||||
return False
|
||||
_check_ip_version(self, other)
|
||||
return self._ip < other._ip
|
||||
|
||||
# Shorthand for Integer addition and subtraction. This is not
|
||||
# meant to ever support addition/subtraction of addresses.
|
||||
|
|
@ -708,9 +705,7 @@ class _BaseNetwork(_IPAddressBase):
|
|||
def __lt__(self, other):
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
return NotImplemented
|
||||
if self.version != other.version:
|
||||
raise TypeError('%s and %s are not of the same version' % (
|
||||
self, other))
|
||||
_check_ip_version(self, other)
|
||||
if self.network_address != other.network_address:
|
||||
return self.network_address < other.network_address
|
||||
if self.netmask != other.netmask:
|
||||
|
|
@ -718,30 +713,31 @@ class _BaseNetwork(_IPAddressBase):
|
|||
return False
|
||||
|
||||
def __eq__(self, other):
|
||||
try:
|
||||
return (self.version == other.version and
|
||||
self.network_address == other.network_address and
|
||||
int(self.netmask) == int(other.netmask))
|
||||
except AttributeError:
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
return NotImplemented
|
||||
return (self.version == other.version and
|
||||
self.network_address == other.network_address and
|
||||
int(self.netmask._ip) == int(other.netmask))
|
||||
|
||||
def __hash__(self):
|
||||
return hash((int(self.network_address), int(self.netmask)))
|
||||
|
||||
def __contains__(self, other):
|
||||
# always false if one is v4 and the other is v6.
|
||||
if self.version != other.version:
|
||||
return False
|
||||
# dealing with another network.
|
||||
if isinstance(other, _BaseNetwork):
|
||||
# should __contains__ actually implement subnet_of()
|
||||
# and supernet_of() instead?
|
||||
return False
|
||||
# dealing with another address
|
||||
else:
|
||||
# address
|
||||
return other._ip & self.netmask._ip == self.network_address._ip
|
||||
if isinstance(other, _BaseAddress):
|
||||
return (
|
||||
self.version == other.version
|
||||
and (other._ip & self.netmask._ip) == self.network_address._ip
|
||||
)
|
||||
return NotImplemented
|
||||
|
||||
def overlaps(self, other):
|
||||
"""Tell if self is partly contained in other."""
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
raise TypeError(f"expecting a network object, not {type(other)}")
|
||||
return self.network_address in other or (
|
||||
self.broadcast_address in other or (
|
||||
other.network_address in self or (
|
||||
|
|
@ -821,13 +817,9 @@ class _BaseNetwork(_IPAddressBase):
|
|||
ValueError: If other is not completely contained by self.
|
||||
|
||||
"""
|
||||
if not self.version == other.version:
|
||||
raise TypeError("%s and %s are not of the same version" % (
|
||||
self, other))
|
||||
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
raise TypeError("%s is not a network object" % other)
|
||||
|
||||
raise TypeError(f"expecting a network object, not {type(other)}")
|
||||
_check_ip_version(self, other)
|
||||
if not other.subnet_of(self):
|
||||
raise ValueError('%s not contained in %s' % (other, self))
|
||||
if other == self:
|
||||
|
|
@ -870,7 +862,7 @@ class _BaseNetwork(_IPAddressBase):
|
|||
'HostA._ip < HostB._ip'
|
||||
|
||||
Args:
|
||||
other: An IP object.
|
||||
other: An IP network object.
|
||||
|
||||
Returns:
|
||||
If the IP versions of self and other are the same, returns:
|
||||
|
|
@ -892,10 +884,9 @@ class _BaseNetwork(_IPAddressBase):
|
|||
TypeError if the IP versions are different.
|
||||
|
||||
"""
|
||||
# does this need to raise a ValueError?
|
||||
if self.version != other.version:
|
||||
raise TypeError('%s and %s are not of the same type' % (
|
||||
self, other))
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
raise TypeError(f"expecting a network object, not {type(other)}")
|
||||
_check_ip_version(self, other)
|
||||
# self.version == other.version below here:
|
||||
if self.network_address < other.network_address:
|
||||
return -1
|
||||
|
|
@ -1026,22 +1017,21 @@ class _BaseNetwork(_IPAddressBase):
|
|||
|
||||
@staticmethod
|
||||
def _is_subnet_of(a, b):
|
||||
try:
|
||||
# Always false if one is v4 and the other is v6.
|
||||
if a.version != b.version:
|
||||
raise TypeError(f"{a} and {b} are not of the same version")
|
||||
return (b.network_address <= a.network_address and
|
||||
b.broadcast_address >= a.broadcast_address)
|
||||
except AttributeError:
|
||||
raise TypeError(f"Unable to test subnet containment "
|
||||
f"between {a} and {b}")
|
||||
# The caller must ensure that 'a' and 'b' are both networks.
|
||||
_check_ip_version(a, b)
|
||||
return (b.network_address <= a.network_address and
|
||||
b.broadcast_address >= a.broadcast_address)
|
||||
|
||||
def subnet_of(self, other):
|
||||
"""Return True if this network is a subnet of other."""
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
raise TypeError(f"expecting a network object, not {type(other)}")
|
||||
return self._is_subnet_of(self, other)
|
||||
|
||||
def supernet_of(self, other):
|
||||
"""Return True if this network is a supernet of other."""
|
||||
if not isinstance(other, _BaseNetwork):
|
||||
raise TypeError(f"expecting a network object, not {type(other)}")
|
||||
return self._is_subnet_of(other, self)
|
||||
|
||||
@property
|
||||
|
|
@ -1429,28 +1419,27 @@ class IPv4Interface(IPv4Address):
|
|||
self._prefixlen)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, IPv4Interface):
|
||||
if isinstance(other, IPv4Address):
|
||||
# avoid falling back to IPv4Address.__eq__(other, self)
|
||||
return False
|
||||
return NotImplemented
|
||||
# An interface with an associated network is NOT the
|
||||
# same as an unassociated address. That's why the hash
|
||||
# takes the extra info into account.
|
||||
address_equal = IPv4Address.__eq__(self, other)
|
||||
if address_equal is NotImplemented or not address_equal:
|
||||
return address_equal
|
||||
try:
|
||||
return self.network == other.network
|
||||
except AttributeError:
|
||||
# An interface with an associated network is NOT the
|
||||
# same as an unassociated address. That's why the hash
|
||||
# takes the extra info into account.
|
||||
return False
|
||||
return address_equal and self.network == other.network
|
||||
|
||||
def __lt__(self, other):
|
||||
# We *do* allow addresses and interfaces to be sorted. The
|
||||
# unassociated address is considered less than all interfaces.
|
||||
address_less = IPv4Address.__lt__(self, other)
|
||||
if address_less is NotImplemented:
|
||||
return NotImplemented
|
||||
try:
|
||||
return (self.network < other.network or
|
||||
self.network == other.network and address_less)
|
||||
except AttributeError:
|
||||
# We *do* allow addresses and interfaces to be sorted. The
|
||||
# unassociated address is considered less than all interfaces.
|
||||
return False
|
||||
if isinstance(other, IPv4Interface):
|
||||
assert address_less is not NotImplemented
|
||||
# compare interfaces by their network first
|
||||
return (self.network < other.network
|
||||
or (self.network == other.network and address_less))
|
||||
return address_less
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._ip, self._prefixlen, int(self.network.network_address)))
|
||||
|
|
@ -2219,28 +2208,27 @@ class IPv6Interface(IPv6Address):
|
|||
self._prefixlen)
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, IPv6Interface):
|
||||
if isinstance(other, IPv6Address):
|
||||
# avoid falling back to IPv6Address.__eq__(other, self)
|
||||
return False
|
||||
return NotImplemented
|
||||
# An interface with an associated network is NOT the
|
||||
# same as an unassociated address. That's why the hash
|
||||
# takes the extra info into account.
|
||||
address_equal = IPv6Address.__eq__(self, other)
|
||||
if address_equal is NotImplemented or not address_equal:
|
||||
return address_equal
|
||||
try:
|
||||
return self.network == other.network
|
||||
except AttributeError:
|
||||
# An interface with an associated network is NOT the
|
||||
# same as an unassociated address. That's why the hash
|
||||
# takes the extra info into account.
|
||||
return False
|
||||
return address_equal and self.network == other.network
|
||||
|
||||
def __lt__(self, other):
|
||||
# We *do* allow addresses and interfaces to be sorted. The
|
||||
# unassociated address is considered less than all interfaces.
|
||||
address_less = IPv6Address.__lt__(self, other)
|
||||
if address_less is NotImplemented:
|
||||
return address_less
|
||||
try:
|
||||
return (self.network < other.network or
|
||||
self.network == other.network and address_less)
|
||||
except AttributeError:
|
||||
# We *do* allow addresses and interfaces to be sorted. The
|
||||
# unassociated address is considered less than all interfaces.
|
||||
return False
|
||||
if isinstance(other, IPv6Interface):
|
||||
assert address_less is not NotImplemented
|
||||
# compare interfaces by their network first
|
||||
return (self.network < other.network
|
||||
or (self.network == other.network and address_less))
|
||||
return address_less
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self._ip, self._prefixlen, int(self.network.network_address)))
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import pickle
|
|||
import ipaddress
|
||||
import weakref
|
||||
from collections.abc import Iterator
|
||||
from functools import total_ordering
|
||||
from test.support import LARGEST, SMALLEST
|
||||
|
||||
|
||||
|
|
@ -912,11 +913,22 @@ class ComparisonTests(unittest.TestCase):
|
|||
v6intf_scoped = ipaddress.IPv6Interface('::1%scope')
|
||||
|
||||
v4_addresses = [v4addr, v4intf]
|
||||
v4_objects = v4_addresses + [v4net]
|
||||
v4_networks = [v4net]
|
||||
v4_objects = v4_addresses + v4_networks
|
||||
|
||||
v6_addresses = [v6addr, v6intf]
|
||||
v6_objects = v6_addresses + [v6net]
|
||||
v6_networks = [v6net]
|
||||
v6_objects = v6_addresses + v6_networks
|
||||
|
||||
v6_scoped_addresses = [v6addr_scoped, v6intf_scoped]
|
||||
v6_scoped_objects = v6_scoped_addresses + [v6net_scoped]
|
||||
v6_scoped_networks = [v6net_scoped]
|
||||
v6_scoped_objects = v6_scoped_addresses + v6_scoped_networks
|
||||
|
||||
addresses = v4_addresses + v6_addresses
|
||||
addresses_with_scoped = addresses + v6_scoped_addresses
|
||||
|
||||
networks = v4_networks + v6_networks
|
||||
networks_with_scoped = networks + v6_scoped_networks
|
||||
|
||||
objects = v4_objects + v6_objects
|
||||
objects_with_scoped = objects + v6_scoped_objects
|
||||
|
|
@ -935,10 +947,14 @@ class ComparisonTests(unittest.TestCase):
|
|||
# __eq__ should never raise TypeError directly
|
||||
other = object()
|
||||
for obj in self.objects_with_scoped:
|
||||
self.assertNotEqual(obj, other)
|
||||
self.assertFalse(obj == other)
|
||||
self.assertEqual(obj.__eq__(other), NotImplemented)
|
||||
self.assertEqual(obj.__ne__(other), NotImplemented)
|
||||
with self.subTest(obj=obj):
|
||||
self.assertNotEqual(obj, other)
|
||||
|
||||
self.assertFalse(obj == other)
|
||||
self.assertIs(obj.__eq__(other), NotImplemented)
|
||||
|
||||
self.assertTrue(obj != other)
|
||||
self.assertIs(obj.__ne__(other), NotImplemented)
|
||||
|
||||
def test_mixed_type_equality(self):
|
||||
# Ensure none of the internal objects accidentally
|
||||
|
|
@ -1006,30 +1022,54 @@ class ComparisonTests(unittest.TestCase):
|
|||
for rhs in self.objects_with_scoped:
|
||||
if isinstance(lhs, type(rhs)) or isinstance(rhs, type(lhs)):
|
||||
continue
|
||||
self.assertRaises(TypeError, lambda: lhs < rhs)
|
||||
self.assertRaises(TypeError, lambda: lhs > rhs)
|
||||
self.assertRaises(TypeError, lambda: lhs <= rhs)
|
||||
self.assertRaises(TypeError, lambda: lhs >= rhs)
|
||||
|
||||
for dunder in ["__lt__", "__le__", "__ge__", "__gt__"]:
|
||||
with self.subTest(dunder, lhs=lhs, rhs=rhs):
|
||||
func = getattr(operator, dunder)
|
||||
# dunders raise a TypeError or return NotImplemented
|
||||
lhs_method = getattr(lhs, dunder)
|
||||
try:
|
||||
self.assertIs(lhs_method(rhs), NotImplemented)
|
||||
except TypeError as exc:
|
||||
self.assertIn("version", str(exc))
|
||||
rhs_method = getattr(rhs, dunder)
|
||||
try:
|
||||
self.assertIs(rhs_method(lhs), NotImplemented)
|
||||
except TypeError as exc:
|
||||
self.assertIn("version", str(exc))
|
||||
# Using the comparison operator directly must
|
||||
# raise a TypeError, either because we returned
|
||||
# NotImplemented or because of incompatible versions.
|
||||
self.assertRaises(TypeError, func, lhs, rhs)
|
||||
|
||||
def test_foreign_type_ordering(self):
|
||||
other = object()
|
||||
for obj in self.objects_with_scoped:
|
||||
with self.assertRaises(TypeError):
|
||||
obj < other
|
||||
with self.assertRaises(TypeError):
|
||||
obj > other
|
||||
with self.assertRaises(TypeError):
|
||||
obj <= other
|
||||
with self.assertRaises(TypeError):
|
||||
obj >= other
|
||||
self.assertTrue(obj < LARGEST)
|
||||
self.assertFalse(obj > LARGEST)
|
||||
self.assertTrue(obj <= LARGEST)
|
||||
self.assertFalse(obj >= LARGEST)
|
||||
self.assertFalse(obj < SMALLEST)
|
||||
self.assertTrue(obj > SMALLEST)
|
||||
self.assertFalse(obj <= SMALLEST)
|
||||
self.assertTrue(obj >= SMALLEST)
|
||||
with self.subTest(obj=obj):
|
||||
for dunder in ["__lt__", "__le__", "__ge__", "__gt__"]:
|
||||
with self.subTest(dunder):
|
||||
via_meth = getattr(obj, dunder)
|
||||
self.assertIs(via_meth(other), NotImplemented)
|
||||
via_op = getattr(operator, dunder)
|
||||
self.assertRaises(TypeError, via_op, obj, other)
|
||||
|
||||
self.assertIs(obj.__lt__(LARGEST), NotImplemented)
|
||||
self.assertTrue(obj < LARGEST)
|
||||
self.assertIs(obj.__le__(LARGEST), NotImplemented)
|
||||
self.assertTrue(obj <= LARGEST)
|
||||
self.assertIs(obj.__ge__(LARGEST), NotImplemented)
|
||||
self.assertFalse(obj >= LARGEST)
|
||||
self.assertIs(obj.__gt__(LARGEST), NotImplemented)
|
||||
self.assertFalse(obj > LARGEST)
|
||||
|
||||
self.assertIs(obj.__lt__(SMALLEST), NotImplemented)
|
||||
self.assertFalse(obj < SMALLEST)
|
||||
self.assertIs(obj.__le__(SMALLEST), NotImplemented)
|
||||
self.assertFalse(obj <= SMALLEST)
|
||||
self.assertIs(obj.__ge__(SMALLEST), NotImplemented)
|
||||
self.assertTrue(obj >= SMALLEST)
|
||||
self.assertIs(obj.__gt__(SMALLEST), NotImplemented)
|
||||
self.assertTrue(obj > SMALLEST)
|
||||
|
||||
def test_mixed_type_key(self):
|
||||
# with get_mixed_type_key, you can sort addresses and network.
|
||||
|
|
@ -1079,6 +1119,33 @@ class ComparisonTests(unittest.TestCase):
|
|||
self.assertRaises(TypeError, v6net_scoped.__lt__, v4net)
|
||||
self.assertRaises(TypeError, v6net_scoped.__gt__, v4net)
|
||||
|
||||
def test_object_compare_with_always_equal(self):
|
||||
# Check that __eq__/__lt__ for IP objects work for non-IP
|
||||
# objects that share the same attributes as IP objects.
|
||||
class AlwaysEqual:
|
||||
version = None
|
||||
def __eq__(self, other):
|
||||
return True
|
||||
|
||||
same_object = AlwaysEqual()
|
||||
for obj in self.objects_with_scoped:
|
||||
with self.subTest(obj=obj):
|
||||
self.assertEqual(obj, same_object)
|
||||
|
||||
def test_object_compare_with_always_smallest(self):
|
||||
@total_ordering
|
||||
class Smallest:
|
||||
version = None
|
||||
def __lt__(self, other):
|
||||
return True
|
||||
|
||||
smallest = Smallest()
|
||||
for obj in self.objects_with_scoped:
|
||||
with self.subTest(obj=obj):
|
||||
# ensure that we dispatch to Smallest.__lt__ instead.
|
||||
self.assertIs(obj.__lt__(smallest), NotImplemented)
|
||||
self.assertLess(smallest, obj)
|
||||
|
||||
|
||||
class IpaddrUnitTest(unittest.TestCase):
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
:mod:`ipaddress`: fix comparison operators :class:`~ipaddress.IPv4Address`,
|
||||
:class:`~ipaddress.IPv6Address`, :class:`~ipaddress.IPv4Network`,
|
||||
:class:`~ipaddress.IPv6Network`, :class:`~ipaddress.IPv4Interface`, and
|
||||
:class:`~ipaddress.IPv6Interface` to avoid comparing instances of incorrect
|
||||
types. Patch by Bénédikt Tran and yihong0618.
|
||||
Loading…
Add table
Add a link
Reference in a new issue