Issue #21965: Add support for in-memory SSL to the ssl module.

Patch by Geert Jansen.
This commit is contained in:
Antoine Pitrou 2014-10-05 20:41:53 +02:00
parent 414e15a88d
commit b1fdf47ff5
5 changed files with 926 additions and 102 deletions

View file

@ -518,9 +518,14 @@ class BasicSocketTests(unittest.TestCase):
def test_unknown_channel_binding(self):
# should raise ValueError for unknown type
s = socket.socket(socket.AF_INET)
with ssl.wrap_socket(s) as ss:
s.bind(('127.0.0.1', 0))
s.listen()
c = socket.socket(socket.AF_INET)
c.connect(s.getsockname())
with ssl.wrap_socket(c, do_handshake_on_connect=False) as ss:
with self.assertRaises(ValueError):
ss.get_channel_binding("unknown-type")
s.close()
@unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES,
"'tls-unique' channel binding not available")
@ -1247,6 +1252,69 @@ class SSLErrorTests(unittest.TestCase):
self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
class MemoryBIOTests(unittest.TestCase):
def test_read_write(self):
bio = ssl.MemoryBIO()
bio.write(b'foo')
self.assertEqual(bio.read(), b'foo')
self.assertEqual(bio.read(), b'')
bio.write(b'foo')
bio.write(b'bar')
self.assertEqual(bio.read(), b'foobar')
self.assertEqual(bio.read(), b'')
bio.write(b'baz')
self.assertEqual(bio.read(2), b'ba')
self.assertEqual(bio.read(1), b'z')
self.assertEqual(bio.read(1), b'')
def test_eof(self):
bio = ssl.MemoryBIO()
self.assertFalse(bio.eof)
self.assertEqual(bio.read(), b'')
self.assertFalse(bio.eof)
bio.write(b'foo')
self.assertFalse(bio.eof)
bio.write_eof()
self.assertFalse(bio.eof)
self.assertEqual(bio.read(2), b'fo')
self.assertFalse(bio.eof)
self.assertEqual(bio.read(1), b'o')
self.assertTrue(bio.eof)
self.assertEqual(bio.read(), b'')
self.assertTrue(bio.eof)
def test_pending(self):
bio = ssl.MemoryBIO()
self.assertEqual(bio.pending, 0)
bio.write(b'foo')
self.assertEqual(bio.pending, 3)
for i in range(3):
bio.read(1)
self.assertEqual(bio.pending, 3-i-1)
for i in range(3):
bio.write(b'x')
self.assertEqual(bio.pending, i+1)
bio.read()
self.assertEqual(bio.pending, 0)
def test_buffer_types(self):
bio = ssl.MemoryBIO()
bio.write(b'foo')
self.assertEqual(bio.read(), b'foo')
bio.write(bytearray(b'bar'))
self.assertEqual(bio.read(), b'bar')
bio.write(memoryview(b'baz'))
self.assertEqual(bio.read(), b'baz')
def test_error_types(self):
bio = ssl.MemoryBIO()
self.assertRaises(TypeError, bio.write, 'foo')
self.assertRaises(TypeError, bio.write, None)
self.assertRaises(TypeError, bio.write, True)
self.assertRaises(TypeError, bio.write, 1)
class NetworkedTests(unittest.TestCase):
def test_connect(self):
@ -1577,6 +1645,95 @@ class NetworkedTests(unittest.TestCase):
self.assertIs(ss.context, ctx2)
self.assertIs(ss._sslobj.context, ctx2)
class NetworkedBIOTests(unittest.TestCase):
def ssl_io_loop(self, sock, incoming, outgoing, func, *args, **kwargs):
# A simple IO loop. Call func(*args) depending on the error we get
# (WANT_READ or WANT_WRITE) move data between the socket and the BIOs.
timeout = kwargs.get('timeout', 10)
count = 0
while True:
errno = None
count += 1
try:
ret = func(*args)
except ssl.SSLError as e:
# Note that we get a spurious -1/SSL_ERROR_SYSCALL for
# non-blocking IO. The SSL_shutdown manpage hints at this.
# It *should* be safe to just ignore SYS_ERROR_SYSCALL because
# with a Memory BIO there's no syscalls (for IO at least).
if e.errno not in (ssl.SSL_ERROR_WANT_READ,
ssl.SSL_ERROR_WANT_WRITE,
ssl.SSL_ERROR_SYSCALL):
raise
errno = e.errno
# Get any data from the outgoing BIO irrespective of any error, and
# send it to the socket.
buf = outgoing.read()
sock.sendall(buf)
# If there's no error, we're done. For WANT_READ, we need to get
# data from the socket and put it in the incoming BIO.
if errno is None:
break
elif errno == ssl.SSL_ERROR_WANT_READ:
buf = sock.recv(32768)
if buf:
incoming.write(buf)
else:
incoming.write_eof()
if support.verbose:
sys.stdout.write("Needed %d calls to complete %s().\n"
% (count, func.__name__))
return ret
def test_handshake(self):
with support.transient_internet("svn.python.org"):
sock = socket.socket(socket.AF_INET)
sock.connect(("svn.python.org", 443))
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
ctx.load_verify_locations(SVN_PYTHON_ORG_ROOT_CERT)
if ssl.HAS_SNI:
ctx.check_hostname = True
sslobj = ctx.wrap_bio(incoming, outgoing, False, 'svn.python.org')
else:
ctx.check_hostname = False
sslobj = ctx.wrap_bio(incoming, outgoing, False)
self.assertIs(sslobj._sslobj.owner, sslobj)
self.assertIsNone(sslobj.cipher())
self.assertRaises(ValueError, sslobj.getpeercert)
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertIsNone(sslobj.get_channel_binding('tls-unique'))
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
self.assertTrue(sslobj.cipher())
self.assertTrue(sslobj.getpeercert())
if 'tls-unique' in ssl.CHANNEL_BINDING_TYPES:
self.assertTrue(sslobj.get_channel_binding('tls-unique'))
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
self.assertRaises(ssl.SSLError, sslobj.write, b'foo')
sock.close()
def test_read_write_data(self):
with support.transient_internet("svn.python.org"):
sock = socket.socket(socket.AF_INET)
sock.connect(("svn.python.org", 443))
incoming = ssl.MemoryBIO()
outgoing = ssl.MemoryBIO()
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_NONE
sslobj = ctx.wrap_bio(incoming, outgoing, False)
self.ssl_io_loop(sock, incoming, outgoing, sslobj.do_handshake)
req = b'GET / HTTP/1.0\r\n\r\n'
self.ssl_io_loop(sock, incoming, outgoing, sslobj.write, req)
buf = self.ssl_io_loop(sock, incoming, outgoing, sslobj.read, 1024)
self.assertEqual(buf[:5], b'HTTP/')
self.ssl_io_loop(sock, incoming, outgoing, sslobj.unwrap)
sock.close()
try:
import threading
except ImportError:
@ -3061,10 +3218,11 @@ def test_main(verbose=False):
if not os.path.exists(filename):
raise support.TestFailed("Can't read certificate file %r" % filename)
tests = [ContextTests, BasicSocketTests, SSLErrorTests]
tests = [ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests]
if support.is_resource_enabled('network'):
tests.append(NetworkedTests)
tests.append(NetworkedBIOTests)
if _have_threads:
thread_info = support.threading_setup()