Issue 14814: Eliminate bytes warnings from ipaddress by correctly throwing an exception early when given bytes data of the wrong length. Also removes 2.x backwards compatibility code from associated tests.

This commit is contained in:
Nick Coghlan 2012-07-07 01:43:31 +10:00
parent 3c2570caf2
commit 5cf896fea8
2 changed files with 46 additions and 31 deletions

View file

@ -1250,7 +1250,9 @@ class IPv4Address(_BaseV4, _BaseAddress):
return return
# Constructing from a packed address # Constructing from a packed address
if isinstance(address, bytes) and len(address) == 4: if isinstance(address, bytes):
if len(address) != 4:
raise AddressValueError(address)
self._ip = struct.unpack('!I', address)[0] self._ip = struct.unpack('!I', address)[0]
return return
@ -1379,7 +1381,9 @@ class IPv4Network(_BaseV4, _BaseNetwork):
_BaseNetwork.__init__(self, address) _BaseNetwork.__init__(self, address)
# Constructing from a packed address # Constructing from a packed address
if isinstance(address, bytes) and len(address) == 4: if isinstance(address, bytes):
if len(address) != 4:
raise AddressValueError(address)
self.network_address = IPv4Address( self.network_address = IPv4Address(
struct.unpack('!I', address)[0]) struct.unpack('!I', address)[0])
self._prefixlen = self._max_prefixlen self._prefixlen = self._max_prefixlen
@ -1864,7 +1868,9 @@ class IPv6Address(_BaseV6, _BaseAddress):
return return
# Constructing from a packed address # Constructing from a packed address
if isinstance(address, bytes) and len(address) == 16: if isinstance(address, bytes):
if len(address) != 16:
raise AddressValueError(address)
tmp = struct.unpack('!QQ', address) tmp = struct.unpack('!QQ', address)
self._ip = (tmp[0] << 64) | tmp[1] self._ip = (tmp[0] << 64) | tmp[1]
return return
@ -1996,7 +2002,9 @@ class IPv6Network(_BaseV6, _BaseNetwork):
return return
# Constructing from a packed address # Constructing from a packed address
if isinstance(address, bytes) and len(address) == 16: if isinstance(address, bytes):
if len(address) != 16:
raise AddressValueError(address)
tmp = struct.unpack('!QQ', address) tmp = struct.unpack('!QQ', address)
self.network_address = IPv6Address((tmp[0] << 64) | tmp[1]) self.network_address = IPv6Address((tmp[0] << 64) | tmp[1])
self._prefixlen = self._max_prefixlen self._prefixlen = self._max_prefixlen

View file

@ -8,10 +8,6 @@ import unittest
import ipaddress import ipaddress
# Compatibility function to cast str to bytes objects
_cb = lambda bytestr: bytes(bytestr, 'charmap')
class IpaddrUnitTest(unittest.TestCase): class IpaddrUnitTest(unittest.TestCase):
def setUp(self): def setUp(self):
@ -267,25 +263,36 @@ class IpaddrUnitTest(unittest.TestCase):
6) 6)
def testIpFromPacked(self): def testIpFromPacked(self):
ip = ipaddress.ip_network address = ipaddress.ip_address
self.assertEqual(self.ipv4_interface._ip, self.assertEqual(self.ipv4_interface._ip,
ipaddress.ip_interface(_cb('\x01\x02\x03\x04'))._ip) ipaddress.ip_interface(b'\x01\x02\x03\x04')._ip)
self.assertEqual(ip('255.254.253.252'), self.assertEqual(address('255.254.253.252'),
ip(_cb('\xff\xfe\xfd\xfc'))) address(b'\xff\xfe\xfd\xfc'))
self.assertRaises(ValueError, ipaddress.ip_network, _cb('\x00' * 3))
self.assertRaises(ValueError, ipaddress.ip_network, _cb('\x00' * 5))
self.assertEqual(self.ipv6_interface.ip, self.assertEqual(self.ipv6_interface.ip,
ipaddress.ip_interface( ipaddress.ip_interface(
_cb('\x20\x01\x06\x58\x02\x2a\xca\xfe' b'\x20\x01\x06\x58\x02\x2a\xca\xfe'
'\x02\x00\x00\x00\x00\x00\x00\x01')).ip) b'\x02\x00\x00\x00\x00\x00\x00\x01').ip)
self.assertEqual(ip('ffff:2:3:4:ffff::'), self.assertEqual(address('ffff:2:3:4:ffff::'),
ip(_cb('\xff\xff\x00\x02\x00\x03\x00\x04' + address(b'\xff\xff\x00\x02\x00\x03\x00\x04' +
'\xff\xff' + '\x00' * 6))) b'\xff\xff' + b'\x00' * 6))
self.assertEqual(ip('::'), self.assertEqual(address('::'),
ip(_cb('\x00' * 16))) address(b'\x00' * 16))
self.assertRaises(ValueError, ip, _cb('\x00' * 15))
self.assertRaises(ValueError, ip, _cb('\x00' * 17)) def testIpFromPackedErrors(self):
def assertInvalidPackedAddress(f, length):
self.assertRaises(ValueError, f, b'\x00' * length)
assertInvalidPackedAddress(ipaddress.ip_address, 3)
assertInvalidPackedAddress(ipaddress.ip_address, 5)
assertInvalidPackedAddress(ipaddress.ip_address, 15)
assertInvalidPackedAddress(ipaddress.ip_address, 17)
assertInvalidPackedAddress(ipaddress.ip_interface, 3)
assertInvalidPackedAddress(ipaddress.ip_interface, 5)
assertInvalidPackedAddress(ipaddress.ip_interface, 15)
assertInvalidPackedAddress(ipaddress.ip_interface, 17)
assertInvalidPackedAddress(ipaddress.ip_network, 3)
assertInvalidPackedAddress(ipaddress.ip_network, 5)
assertInvalidPackedAddress(ipaddress.ip_network, 15)
assertInvalidPackedAddress(ipaddress.ip_network, 17)
def testGetIp(self): def testGetIp(self):
self.assertEqual(int(self.ipv4_interface.ip), 16909060) self.assertEqual(int(self.ipv4_interface.ip), 16909060)
@ -893,17 +900,17 @@ class IpaddrUnitTest(unittest.TestCase):
def testPacked(self): def testPacked(self):
self.assertEqual(self.ipv4_address.packed, self.assertEqual(self.ipv4_address.packed,
_cb('\x01\x02\x03\x04')) b'\x01\x02\x03\x04')
self.assertEqual(ipaddress.IPv4Interface('255.254.253.252').packed, self.assertEqual(ipaddress.IPv4Interface('255.254.253.252').packed,
_cb('\xff\xfe\xfd\xfc')) b'\xff\xfe\xfd\xfc')
self.assertEqual(self.ipv6_address.packed, self.assertEqual(self.ipv6_address.packed,
_cb('\x20\x01\x06\x58\x02\x2a\xca\xfe' b'\x20\x01\x06\x58\x02\x2a\xca\xfe'
'\x02\x00\x00\x00\x00\x00\x00\x01')) b'\x02\x00\x00\x00\x00\x00\x00\x01')
self.assertEqual(ipaddress.IPv6Interface('ffff:2:3:4:ffff::').packed, self.assertEqual(ipaddress.IPv6Interface('ffff:2:3:4:ffff::').packed,
_cb('\xff\xff\x00\x02\x00\x03\x00\x04\xff\xff' b'\xff\xff\x00\x02\x00\x03\x00\x04\xff\xff'
+ '\x00' * 6)) + b'\x00' * 6)
self.assertEqual(ipaddress.IPv6Interface('::1:0:0:0:0').packed, self.assertEqual(ipaddress.IPv6Interface('::1:0:0:0:0').packed,
_cb('\x00' * 6 + '\x00\x01' + '\x00' * 8)) b'\x00' * 6 + b'\x00\x01' + b'\x00' * 8)
def testIpStrFromPrefixlen(self): def testIpStrFromPrefixlen(self):
ipv4 = ipaddress.IPv4Interface('1.2.3.4/24') ipv4 = ipaddress.IPv4Interface('1.2.3.4/24')