mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
Issue #14814: In the spirit of TOOWTDI, ditch the redundant version parameter to the factory functions by using the appropriate direct class references instead
This commit is contained in:
parent
072b1e1485
commit
51c3067551
2 changed files with 48 additions and 94 deletions
118
Lib/ipaddress.py
118
Lib/ipaddress.py
|
@ -36,34 +36,22 @@ class NetmaskValueError(ValueError):
|
||||||
"""A Value Error related to the netmask."""
|
"""A Value Error related to the netmask."""
|
||||||
|
|
||||||
|
|
||||||
def ip_address(address, version=None):
|
def ip_address(address):
|
||||||
"""Take an IP string/int and return an object of the correct type.
|
"""Take an IP string/int and return an object of the correct type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
address: A string or integer, the IP address. Either IPv4 or
|
address: A string or integer, the IP address. Either IPv4 or
|
||||||
IPv6 addresses may be supplied; integers less than 2**32 will
|
IPv6 addresses may be supplied; integers less than 2**32 will
|
||||||
be considered to be IPv4 by default.
|
be considered to be IPv4 by default.
|
||||||
version: An integer, 4 or 6. If set, don't try to automatically
|
|
||||||
determine what the IP address type is. Important for things
|
|
||||||
like ip_address(1), which could be IPv4, '192.0.2.1', or IPv6,
|
|
||||||
'2001:db8::1'.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An IPv4Address or IPv6Address object.
|
An IPv4Address or IPv6Address object.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the *address* passed isn't either a v4 or a v6
|
ValueError: if the *address* passed isn't either a v4 or a v6
|
||||||
address, or if the version is not None, 4, or 6.
|
address
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if version is not None:
|
|
||||||
if version == 4:
|
|
||||||
return IPv4Address(address)
|
|
||||||
elif version == 6:
|
|
||||||
return IPv6Address(address)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return IPv4Address(address)
|
return IPv4Address(address)
|
||||||
except (AddressValueError, NetmaskValueError):
|
except (AddressValueError, NetmaskValueError):
|
||||||
|
@ -78,35 +66,22 @@ def ip_address(address, version=None):
|
||||||
address)
|
address)
|
||||||
|
|
||||||
|
|
||||||
def ip_network(address, version=None, strict=True):
|
def ip_network(address, strict=True):
|
||||||
"""Take an IP string/int and return an object of the correct type.
|
"""Take an IP string/int and return an object of the correct type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
address: A string or integer, the IP network. Either IPv4 or
|
address: A string or integer, the IP network. Either IPv4 or
|
||||||
IPv6 networks may be supplied; integers less than 2**32 will
|
IPv6 networks may be supplied; integers less than 2**32 will
|
||||||
be considered to be IPv4 by default.
|
be considered to be IPv4 by default.
|
||||||
version: An integer, 4 or 6. If set, don't try to automatically
|
|
||||||
determine what the IP address type is. Important for things
|
|
||||||
like ip_network(1), which could be IPv4, '192.0.2.1/32', or IPv6,
|
|
||||||
'2001:db8::1/128'.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An IPv4Network or IPv6Network object.
|
An IPv4Network or IPv6Network object.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the string passed isn't either a v4 or a v6
|
ValueError: if the string passed isn't either a v4 or a v6
|
||||||
address. Or if the network has host bits set. Or if the version
|
address. Or if the network has host bits set.
|
||||||
is not None, 4, or 6.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if version is not None:
|
|
||||||
if version == 4:
|
|
||||||
return IPv4Network(address, strict)
|
|
||||||
elif version == 6:
|
|
||||||
return IPv6Network(address, strict)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return IPv4Network(address, strict)
|
return IPv4Network(address, strict)
|
||||||
except (AddressValueError, NetmaskValueError):
|
except (AddressValueError, NetmaskValueError):
|
||||||
|
@ -121,24 +96,20 @@ def ip_network(address, version=None, strict=True):
|
||||||
address)
|
address)
|
||||||
|
|
||||||
|
|
||||||
def ip_interface(address, version=None):
|
def ip_interface(address):
|
||||||
"""Take an IP string/int and return an object of the correct type.
|
"""Take an IP string/int and return an object of the correct type.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
address: A string or integer, the IP address. Either IPv4 or
|
address: A string or integer, the IP address. Either IPv4 or
|
||||||
IPv6 addresses may be supplied; integers less than 2**32 will
|
IPv6 addresses may be supplied; integers less than 2**32 will
|
||||||
be considered to be IPv4 by default.
|
be considered to be IPv4 by default.
|
||||||
version: An integer, 4 or 6. If set, don't try to automatically
|
|
||||||
determine what the IP address type is. Important for things
|
|
||||||
like ip_interface(1), which could be IPv4, '192.0.2.1/32', or IPv6,
|
|
||||||
'2001:db8::1/128'.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
An IPv4Interface or IPv6Interface object.
|
An IPv4Interface or IPv6Interface object.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if the string passed isn't either a v4 or a v6
|
ValueError: if the string passed isn't either a v4 or a v6
|
||||||
address. Or if the version is not None, 4, or 6.
|
address.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
The IPv?Interface classes describe an Address on a particular
|
The IPv?Interface classes describe an Address on a particular
|
||||||
|
@ -146,14 +117,6 @@ def ip_interface(address, version=None):
|
||||||
and Network classes.
|
and Network classes.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if version is not None:
|
|
||||||
if version == 4:
|
|
||||||
return IPv4Interface(address)
|
|
||||||
elif version == 6:
|
|
||||||
return IPv6Interface(address)
|
|
||||||
else:
|
|
||||||
raise ValueError()
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return IPv4Interface(address)
|
return IPv4Interface(address)
|
||||||
except (AddressValueError, NetmaskValueError):
|
except (AddressValueError, NetmaskValueError):
|
||||||
|
@ -281,7 +244,7 @@ def summarize_address_range(first, last):
|
||||||
If the first and last objects are not the same version.
|
If the first and last objects are not the same version.
|
||||||
ValueError:
|
ValueError:
|
||||||
If the last object is not greater than the first.
|
If the last object is not greater than the first.
|
||||||
If the version is not 4 or 6.
|
If the version of the first address is not 4 or 6.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if (not (isinstance(first, _BaseAddress) and
|
if (not (isinstance(first, _BaseAddress) and
|
||||||
|
@ -318,7 +281,7 @@ def summarize_address_range(first, last):
|
||||||
if current == ip._ALL_ONES:
|
if current == ip._ALL_ONES:
|
||||||
break
|
break
|
||||||
first_int = current + 1
|
first_int = current + 1
|
||||||
first = ip_address(first_int, version=first._version)
|
first = first.__class__(first_int)
|
||||||
|
|
||||||
|
|
||||||
def _collapse_addresses_recursive(addresses):
|
def _collapse_addresses_recursive(addresses):
|
||||||
|
@ -586,12 +549,12 @@ class _BaseAddress(_IPAddressBase):
|
||||||
def __add__(self, other):
|
def __add__(self, other):
|
||||||
if not isinstance(other, int):
|
if not isinstance(other, int):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return ip_address(int(self) + other, version=self._version)
|
return self.__class__(int(self) + other)
|
||||||
|
|
||||||
def __sub__(self, other):
|
def __sub__(self, other):
|
||||||
if not isinstance(other, int):
|
if not isinstance(other, int):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
return ip_address(int(self) - other, version=self._version)
|
return self.__class__(int(self) - other)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return '%s(%r)' % (self.__class__.__name__, str(self))
|
return '%s(%r)' % (self.__class__.__name__, str(self))
|
||||||
|
@ -612,13 +575,12 @@ class _BaseAddress(_IPAddressBase):
|
||||||
|
|
||||||
class _BaseNetwork(_IPAddressBase):
|
class _BaseNetwork(_IPAddressBase):
|
||||||
|
|
||||||
"""A generic IP object.
|
"""A generic IP network object.
|
||||||
|
|
||||||
This IP class contains the version independent methods which are
|
This IP class contains the version independent methods which are
|
||||||
used by networks.
|
used by networks.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
self._cache = {}
|
self._cache = {}
|
||||||
|
|
||||||
|
@ -642,14 +604,14 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
bcast = int(self.broadcast_address) - 1
|
bcast = int(self.broadcast_address) - 1
|
||||||
while cur <= bcast:
|
while cur <= bcast:
|
||||||
cur += 1
|
cur += 1
|
||||||
yield ip_address(cur - 1, version=self._version)
|
yield self._address_class(cur - 1)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
cur = int(self.network_address)
|
cur = int(self.network_address)
|
||||||
bcast = int(self.broadcast_address)
|
bcast = int(self.broadcast_address)
|
||||||
while cur <= bcast:
|
while cur <= bcast:
|
||||||
cur += 1
|
cur += 1
|
||||||
yield ip_address(cur - 1, version=self._version)
|
yield self._address_class(cur - 1)
|
||||||
|
|
||||||
def __getitem__(self, n):
|
def __getitem__(self, n):
|
||||||
network = int(self.network_address)
|
network = int(self.network_address)
|
||||||
|
@ -657,12 +619,12 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
if n >= 0:
|
if n >= 0:
|
||||||
if network + n > broadcast:
|
if network + n > broadcast:
|
||||||
raise IndexError
|
raise IndexError
|
||||||
return ip_address(network + n, version=self._version)
|
return self._address_class(network + n)
|
||||||
else:
|
else:
|
||||||
n += 1
|
n += 1
|
||||||
if broadcast + n < network:
|
if broadcast + n < network:
|
||||||
raise IndexError
|
raise IndexError
|
||||||
return ip_address(broadcast + n, version=self._version)
|
return self._address_class(broadcast + n)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
if self._version != other._version:
|
if self._version != other._version:
|
||||||
|
@ -746,8 +708,8 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
def broadcast_address(self):
|
def broadcast_address(self):
|
||||||
x = self._cache.get('broadcast_address')
|
x = self._cache.get('broadcast_address')
|
||||||
if x is None:
|
if x is None:
|
||||||
x = ip_address(int(self.network_address) | int(self.hostmask),
|
x = self._address_class(int(self.network_address) |
|
||||||
version=self._version)
|
int(self.hostmask))
|
||||||
self._cache['broadcast_address'] = x
|
self._cache['broadcast_address'] = x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
@ -755,14 +717,14 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
def hostmask(self):
|
def hostmask(self):
|
||||||
x = self._cache.get('hostmask')
|
x = self._cache.get('hostmask')
|
||||||
if x is None:
|
if x is None:
|
||||||
x = ip_address(int(self.netmask) ^ self._ALL_ONES,
|
x = self._address_class(int(self.netmask) ^ self._ALL_ONES)
|
||||||
version=self._version)
|
|
||||||
self._cache['hostmask'] = x
|
self._cache['hostmask'] = x
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def network(self):
|
def network(self):
|
||||||
return ip_network('%s/%d' % (str(self.network_address),
|
# XXX (ncoghlan): This is redundant now and will likely be removed
|
||||||
|
return self.__class__('%s/%d' % (str(self.network_address),
|
||||||
self.prefixlen))
|
self.prefixlen))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -786,6 +748,10 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
def version(self):
|
def version(self):
|
||||||
raise NotImplementedError('BaseNet has no version')
|
raise NotImplementedError('BaseNet has no version')
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _address_class(self):
|
||||||
|
raise NotImplementedError('BaseNet has no associated address class')
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prefixlen(self):
|
def prefixlen(self):
|
||||||
return self._prefixlen
|
return self._prefixlen
|
||||||
|
@ -840,9 +806,8 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
raise StopIteration
|
raise StopIteration
|
||||||
|
|
||||||
# Make sure we're comparing the network of other.
|
# Make sure we're comparing the network of other.
|
||||||
other = ip_network('%s/%s' % (str(other.network_address),
|
other = other.__class__('%s/%s' % (str(other.network_address),
|
||||||
str(other.prefixlen)),
|
str(other.prefixlen)))
|
||||||
version=other._version)
|
|
||||||
|
|
||||||
s1, s2 = self.subnets()
|
s1, s2 = self.subnets()
|
||||||
while s1 != other and s2 != other:
|
while s1 != other and s2 != other:
|
||||||
|
@ -973,9 +938,9 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
'prefix length diff %d is invalid for netblock %s' % (
|
'prefix length diff %d is invalid for netblock %s' % (
|
||||||
new_prefixlen, str(self)))
|
new_prefixlen, str(self)))
|
||||||
|
|
||||||
first = ip_network('%s/%s' % (str(self.network_address),
|
first = self.__class__('%s/%s' %
|
||||||
str(self._prefixlen + prefixlen_diff)),
|
(str(self.network_address),
|
||||||
version=self._version)
|
str(self._prefixlen + prefixlen_diff)))
|
||||||
|
|
||||||
yield first
|
yield first
|
||||||
current = first
|
current = first
|
||||||
|
@ -983,16 +948,17 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
broadcast = current.broadcast_address
|
broadcast = current.broadcast_address
|
||||||
if broadcast == self.broadcast_address:
|
if broadcast == self.broadcast_address:
|
||||||
return
|
return
|
||||||
new_addr = ip_address(int(broadcast) + 1, version=self._version)
|
new_addr = self._address_class(int(broadcast) + 1)
|
||||||
current = ip_network('%s/%s' % (str(new_addr), str(new_prefixlen)),
|
current = self.__class__('%s/%s' % (str(new_addr),
|
||||||
version=self._version)
|
str(new_prefixlen)))
|
||||||
|
|
||||||
yield current
|
yield current
|
||||||
|
|
||||||
def masked(self):
|
def masked(self):
|
||||||
"""Return the network object with the host bits masked out."""
|
"""Return the network object with the host bits masked out."""
|
||||||
return ip_network('%s/%d' % (self.network_address, self._prefixlen),
|
# XXX (ncoghlan): This is redundant now and will likely be removed
|
||||||
version=self._version)
|
return self.__class__('%s/%d' % (self.network_address,
|
||||||
|
self._prefixlen))
|
||||||
|
|
||||||
def supernet(self, prefixlen_diff=1, new_prefix=None):
|
def supernet(self, prefixlen_diff=1, new_prefix=None):
|
||||||
"""The supernet containing the current network.
|
"""The supernet containing the current network.
|
||||||
|
@ -1030,11 +996,10 @@ class _BaseNetwork(_IPAddressBase):
|
||||||
'current prefixlen is %d, cannot have a prefixlen_diff of %d' %
|
'current prefixlen is %d, cannot have a prefixlen_diff of %d' %
|
||||||
(self.prefixlen, prefixlen_diff))
|
(self.prefixlen, prefixlen_diff))
|
||||||
# TODO (pmoody): optimize this.
|
# TODO (pmoody): optimize this.
|
||||||
t = ip_network('%s/%d' % (str(self.network_address),
|
t = self.__class__('%s/%d' % (str(self.network_address),
|
||||||
self.prefixlen - prefixlen_diff),
|
self.prefixlen - prefixlen_diff),
|
||||||
version=self._version, strict=False)
|
strict=False)
|
||||||
return ip_network('%s/%d' % (str(t.network_address), t.prefixlen),
|
return t.__class__('%s/%d' % (str(t.network_address), t.prefixlen))
|
||||||
version=t._version)
|
|
||||||
|
|
||||||
|
|
||||||
class _BaseV4(object):
|
class _BaseV4(object):
|
||||||
|
@ -1391,6 +1356,9 @@ class IPv4Network(_BaseV4, _BaseNetwork):
|
||||||
.prefixlen: 27
|
.prefixlen: 27
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
# Class to use when creating address objects
|
||||||
|
# TODO (ncoghlan): Investigate using IPv4Interface instead
|
||||||
|
_address_class = IPv4Address
|
||||||
|
|
||||||
# the valid octets for host and netmasks. only useful for IPv4.
|
# the valid octets for host and netmasks. only useful for IPv4.
|
||||||
_valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
|
_valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0))
|
||||||
|
@ -2071,6 +2039,10 @@ class IPv6Network(_BaseV6, _BaseNetwork):
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Class to use when creating address objects
|
||||||
|
# TODO (ncoghlan): Investigate using IPv6Interface instead
|
||||||
|
_address_class = IPv6Address
|
||||||
|
|
||||||
def __init__(self, address, strict=True):
|
def __init__(self, address, strict=True):
|
||||||
"""Instantiate a new IPv6 Network object.
|
"""Instantiate a new IPv6 Network object.
|
||||||
|
|
||||||
|
|
|
@ -780,12 +780,6 @@ class IpaddrUnitTest(unittest.TestCase):
|
||||||
self.assertEqual(self.ipv4_address.version, 4)
|
self.assertEqual(self.ipv4_address.version, 4)
|
||||||
self.assertEqual(self.ipv6_address.version, 6)
|
self.assertEqual(self.ipv6_address.version, 6)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
ipaddress.ip_address('1', version=[])
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
ipaddress.ip_address('1', version=5)
|
|
||||||
|
|
||||||
def testMaxPrefixLength(self):
|
def testMaxPrefixLength(self):
|
||||||
self.assertEqual(self.ipv4_interface.max_prefixlen, 32)
|
self.assertEqual(self.ipv4_interface.max_prefixlen, 32)
|
||||||
self.assertEqual(self.ipv6_interface.max_prefixlen, 128)
|
self.assertEqual(self.ipv6_interface.max_prefixlen, 128)
|
||||||
|
@ -1052,12 +1046,7 @@ class IpaddrUnitTest(unittest.TestCase):
|
||||||
|
|
||||||
def testForceVersion(self):
|
def testForceVersion(self):
|
||||||
self.assertEqual(ipaddress.ip_network(1).version, 4)
|
self.assertEqual(ipaddress.ip_network(1).version, 4)
|
||||||
self.assertEqual(ipaddress.ip_network(1, version=6).version, 6)
|
self.assertEqual(ipaddress.IPv6Network(1).version, 6)
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
ipaddress.ip_network(1, version='l')
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
ipaddress.ip_network(1, version=3)
|
|
||||||
|
|
||||||
def testWithStar(self):
|
def testWithStar(self):
|
||||||
self.assertEqual(str(self.ipv4_interface.with_prefixlen), "1.2.3.4/24")
|
self.assertEqual(str(self.ipv4_interface.with_prefixlen), "1.2.3.4/24")
|
||||||
|
@ -1148,13 +1137,6 @@ class IpaddrUnitTest(unittest.TestCase):
|
||||||
sixtofouraddr.sixtofour)
|
sixtofouraddr.sixtofour)
|
||||||
self.assertFalse(bad_addr.sixtofour)
|
self.assertFalse(bad_addr.sixtofour)
|
||||||
|
|
||||||
def testIpInterfaceVersion(self):
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
ipaddress.ip_interface(1, version=123)
|
|
||||||
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
ipaddress.ip_interface(1, version='')
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue