Issue 14814: Make the ipaddress code easier to follow by using newer language features (patch by Serhiy Storchaka)

This commit is contained in:
Nick Coghlan 2012-07-07 21:43:30 +10:00
parent 79d79a0f29
commit 7319f69f49

View file

@ -214,8 +214,10 @@ def _count_righthand_zero_bits(number, bits):
if number == 0: if number == 0:
return bits return bits
for i in range(bits): for i in range(bits):
if (number >> i) % 2: if (number >> i) & 1:
return i return i
# All bits of interest were zero, even if there are more in the number
return bits
def summarize_address_range(first, last): def summarize_address_range(first, last):
@ -263,20 +265,13 @@ def summarize_address_range(first, last):
first_int = first._ip first_int = first._ip
last_int = last._ip last_int = last._ip
while first_int <= last_int: while first_int <= last_int:
nbits = _count_righthand_zero_bits(first_int, ip_bits) nbits = min(_count_righthand_zero_bits(first_int, ip_bits),
current = None (last_int - first_int + 1).bit_length() - 1)
while nbits >= 0: net = ip('%s/%d' % (first, ip_bits - nbits))
addend = 2**nbits - 1
current = first_int + addend
nbits -= 1
if current <= last_int:
break
prefix = _get_prefix_length(first_int, current, ip_bits)
net = ip('%s/%d' % (first, prefix))
yield net yield net
if current == ip._ALL_ONES: first_int += 1 << nbits
if first_int - 1 == ip._ALL_ONES:
break break
first_int = current + 1
first = first.__class__(first_int) first = first.__class__(first_int)
@ -304,26 +299,28 @@ def _collapse_addresses_recursive(addresses):
passed. passed.
""" """
ret_array = [] while True:
optimized = False last_addr = None
ret_array = []
optimized = False
for cur_addr in addresses: for cur_addr in addresses:
if not ret_array: if not ret_array:
ret_array.append(cur_addr) last_addr = cur_addr
continue ret_array.append(cur_addr)
if (cur_addr.network_address >= ret_array[-1].network_address and elif (cur_addr.network_address >= last_addr.network_address and
cur_addr.broadcast_address <= ret_array[-1].broadcast_address): cur_addr.broadcast_address <= last_addr.broadcast_address):
optimized = True optimized = True
elif cur_addr == list(ret_array[-1].supernet().subnets())[1]: elif cur_addr == list(last_addr.supernet().subnets())[1]:
ret_array.append(ret_array.pop().supernet()) ret_array[-1] = last_addr = last_addr.supernet()
optimized = True optimized = True
else: else:
ret_array.append(cur_addr) last_addr = cur_addr
ret_array.append(cur_addr)
if optimized: addresses = ret_array
return _collapse_addresses_recursive(ret_array) if not optimized:
return addresses
return ret_array
def collapse_addresses(addresses): def collapse_addresses(addresses):
@ -452,13 +449,7 @@ class _IPAddressBase:
An integer, the prefix length. An integer, the prefix length.
""" """
while mask: return mask - _count_righthand_zero_bits(ip_int, mask)
if ip_int & 1 == 1:
break
ip_int >>= 1
mask -= 1
return mask
def _ip_string_from_prefix(self, prefixlen=None): def _ip_string_from_prefix(self, prefixlen=None):
"""Turn a prefix length into a dotted decimal string. """Turn a prefix length into a dotted decimal string.
@ -597,18 +588,16 @@ class _BaseNetwork(_IPAddressBase):
or broadcast addresses. or broadcast addresses.
""" """
cur = int(self.network_address) + 1 network = int(self.network_address)
bcast = int(self.broadcast_address) - 1 broadcast = int(self.broadcast_address)
while cur <= bcast: for x in range(network + 1, broadcast):
cur += 1 yield self._address_class(x)
yield self._address_class(cur - 1)
def __iter__(self): def __iter__(self):
cur = int(self.network_address) network = int(self.network_address)
bcast = int(self.broadcast_address) broadcast = int(self.broadcast_address)
while cur <= bcast: for x in range(network, broadcast + 1):
cur += 1 yield self._address_class(x)
yield self._address_class(cur - 1)
def __getitem__(self, n): def __getitem__(self, n):
network = int(self.network_address) network = int(self.network_address)
@ -998,7 +987,7 @@ class _BaseV4:
_DECIMAL_DIGITS = frozenset('0123456789') _DECIMAL_DIGITS = frozenset('0123456789')
# the valid octets for host and netmasks. only useful for IPv4. # the valid octets for host and netmasks. only useful for IPv4.
_valid_mask_octets = set((255, 254, 252, 248, 240, 224, 192, 128, 0)) _valid_mask_octets = frozenset((255, 254, 252, 248, 240, 224, 192, 128, 0))
def __init__(self, address): def __init__(self, address):
self._version = 4 self._version = 4
@ -1027,13 +1016,10 @@ class _BaseV4:
if len(octets) != 4: if len(octets) != 4:
raise AddressValueError("Expected 4 octets in %r" % ip_str) raise AddressValueError("Expected 4 octets in %r" % ip_str)
packed_ip = 0 try:
for oc in octets: return int.from_bytes(map(self._parse_octet, octets), 'big')
try: except ValueError as exc:
packed_ip = (packed_ip << 8) | self._parse_octet(oc) raise AddressValueError("%s in %r" % (exc, ip_str)) from None
except ValueError as exc:
raise AddressValueError("%s in %r" % (exc, ip_str)) from None
return packed_ip
def _parse_octet(self, octet_str): def _parse_octet(self, octet_str):
"""Convert a decimal octet into an integer. """Convert a decimal octet into an integer.
@ -1075,11 +1061,7 @@ class _BaseV4:
The IP address as a string in dotted decimal notation. The IP address as a string in dotted decimal notation.
""" """
octets = [] return '.'.join(map(str, ip_int.to_bytes(4, 'big')))
for _ in range(4):
octets.insert(0, str(ip_int & 0xFF))
ip_int >>= 8
return '.'.join(octets)
def _is_valid_netmask(self, netmask): def _is_valid_netmask(self, netmask):
"""Verify that the netmask is valid. """Verify that the netmask is valid.
@ -1095,17 +1077,16 @@ class _BaseV4:
""" """
mask = netmask.split('.') mask = netmask.split('.')
if len(mask) == 4: if len(mask) == 4:
for x in mask: try:
try: for x in mask:
if int(x) in self._valid_mask_octets: if int(x) not in self._valid_mask_octets:
continue return False
except ValueError: except ValueError:
pass
# Found something that isn't an integer or isn't valid # Found something that isn't an integer or isn't valid
return False return False
if [y for idx, y in enumerate(mask) if idx > 0 and for idx, y in enumerate(mask):
y > mask[idx - 1]]: if idx > 0 and y > mask[idx - 1]:
return False return False
return True return True
try: try:
netmask = int(netmask) netmask = int(netmask)
@ -1125,7 +1106,7 @@ class _BaseV4:
""" """
bits = ip_str.split('.') bits = ip_str.split('.')
try: try:
parts = [int(x) for x in bits if int(x) in self._valid_mask_octets] parts = [x for x in map(int, bits) if x in self._valid_mask_octets]
except ValueError: except ValueError:
return False return False
if len(parts) != len(bits): if len(parts) != len(bits):
@ -1526,14 +1507,14 @@ class _BaseV6:
# Disregarding the endpoints, find '::' with nothing in between. # Disregarding the endpoints, find '::' with nothing in between.
# This indicates that a run of zeroes has been skipped. # This indicates that a run of zeroes has been skipped.
try: skip_index = None
skip_index, = ( for i in range(1, len(parts) - 1):
[i for i in range(1, len(parts) - 1) if not parts[i]] or if not parts[i]:
[None]) if skip_index is not None:
except ValueError: # Can't have more than one '::'
# Can't have more than one '::' msg = "At most one '::' permitted in %r" % ip_str
msg = "At most one '::' permitted in %r" % ip_str raise AddressValueError(msg)
raise AddressValueError(msg) from None skip_index = i
# parts_hi is the number of parts to copy from above/before the '::' # parts_hi is the number of parts to copy from above/before the '::'
# parts_lo is the number of parts to copy from below/after the '::' # parts_lo is the number of parts to copy from below/after the '::'
@ -1680,9 +1661,7 @@ class _BaseV6:
raise ValueError('IPv6 address is too large') raise ValueError('IPv6 address is too large')
hex_str = '%032x' % ip_int hex_str = '%032x' % ip_int
hextets = [] hextets = ['%x' % int(hex_str[x:x+4], 16) for x in range(0, 32, 4)]
for x in range(0, 32, 4):
hextets.append('%x' % int(hex_str[x:x+4], 16))
hextets = self._compress_hextets(hextets) hextets = self._compress_hextets(hextets)
return ':'.join(hextets) return ':'.join(hextets)
@ -1705,11 +1684,8 @@ class _BaseV6:
ip_str = str(self) ip_str = str(self)
ip_int = self._ip_int_from_string(ip_str) ip_int = self._ip_int_from_string(ip_str)
parts = [] hex_str = '%032x' % ip_int
for i in range(self._HEXTET_COUNT): parts = [hex_str[x:x+4] for x in range(0, 32, 4)]
parts.append('%04x' % (ip_int & 0xFFFF))
ip_int >>= 16
parts.reverse()
if isinstance(self, (_BaseNetwork, IPv6Interface)): if isinstance(self, (_BaseNetwork, IPv6Interface)):
return '%s/%d' % (':'.join(parts), self.prefixlen) return '%s/%d' % (':'.join(parts), self.prefixlen)
return ':'.join(parts) return ':'.join(parts)
@ -1756,9 +1732,9 @@ class _BaseV6:
IPv6Network('FE00::/9')] IPv6Network('FE00::/9')]
if isinstance(self, _BaseAddress): if isinstance(self, _BaseAddress):
return len([x for x in reserved_networks if self in x]) > 0 return any(self in x for x in reserved_networks)
return len([x for x in reserved_networks if self.network_address in x return any(self.network_address in x and self.broadcast_address in x
and self.broadcast_address in x]) > 0 for x in reserved_networks)
@property @property
def is_link_local(self): def is_link_local(self):