mirror of
https://github.com/python/cpython.git
synced 2025-08-30 05:35:08 +00:00
update to fix leak in SSL code
This commit is contained in:
parent
517b9ddda2
commit
54cc54c1fe
4 changed files with 225 additions and 68 deletions
|
@ -13,6 +13,7 @@ import pprint
|
|||
import urllib, urlparse
|
||||
import shutil
|
||||
import traceback
|
||||
import asyncore
|
||||
|
||||
from BaseHTTPServer import HTTPServer
|
||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||
|
@ -79,27 +80,6 @@ class BasicTests(unittest.TestCase):
|
|||
|
||||
class NetworkedTests(unittest.TestCase):
|
||||
|
||||
def testFetchServerCert(self):
|
||||
|
||||
pem = ssl.get_server_certificate(("svn.python.org", 443))
|
||||
if not pem:
|
||||
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
||||
|
||||
try:
|
||||
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
|
||||
except ssl.SSLError as x:
|
||||
#should fail
|
||||
if test_support.verbose:
|
||||
sys.stdout.write("%s\n" % x)
|
||||
else:
|
||||
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
|
||||
|
||||
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
|
||||
if not pem:
|
||||
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
||||
if test_support.verbose:
|
||||
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
|
||||
|
||||
def testConnect(self):
|
||||
|
||||
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
|
||||
|
@ -155,6 +135,29 @@ class NetworkedTests(unittest.TestCase):
|
|||
if test_support.verbose:
|
||||
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
|
||||
|
||||
def testFetchServerCert(self):
|
||||
|
||||
pem = ssl.get_server_certificate(("svn.python.org", 443))
|
||||
if not pem:
|
||||
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
|
||||
except ssl.SSLError as x:
|
||||
#should fail
|
||||
if test_support.verbose:
|
||||
sys.stdout.write("%s\n" % x)
|
||||
else:
|
||||
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
|
||||
|
||||
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
|
||||
if not pem:
|
||||
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
||||
if test_support.verbose:
|
||||
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
|
||||
|
||||
|
||||
try:
|
||||
import threading
|
||||
|
@ -333,7 +336,9 @@ else:
|
|||
def stop (self):
|
||||
self.active = False
|
||||
|
||||
class AsyncoreHTTPSServer(threading.Thread):
|
||||
class OurHTTPSServer(threading.Thread):
|
||||
|
||||
# This one's based on HTTPServer, which is based on SocketServer
|
||||
|
||||
class HTTPSServer(HTTPServer):
|
||||
|
||||
|
@ -463,6 +468,92 @@ else:
|
|||
self.server.server_close()
|
||||
|
||||
|
||||
class AsyncoreEchoServer(threading.Thread):
|
||||
|
||||
# this one's based on asyncore.dispatcher
|
||||
|
||||
class EchoServer (asyncore.dispatcher):
|
||||
|
||||
class ConnectionHandler (asyncore.dispatcher_with_send):
|
||||
|
||||
def __init__(self, conn, certfile):
|
||||
self.socket = ssl.wrap_socket(conn, server_side=True,
|
||||
certfile=certfile,
|
||||
do_handshake_on_connect=False)
|
||||
asyncore.dispatcher_with_send.__init__(self, self.socket)
|
||||
# now we have to do the handshake
|
||||
# we'll just do it the easy way, and block the connection
|
||||
# till it's finished. If we were doing it right, we'd
|
||||
# do this in multiple calls to handle_read...
|
||||
self.do_handshake(block=True)
|
||||
|
||||
def readable(self):
|
||||
if isinstance(self.socket, ssl.SSLSocket):
|
||||
while self.socket.pending() > 0:
|
||||
self.handle_read_event()
|
||||
return True
|
||||
|
||||
def handle_read(self):
|
||||
data = self.recv(1024)
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" server: read %s from client\n" % repr(data))
|
||||
if not data:
|
||||
self.close()
|
||||
else:
|
||||
self.send(str(data, 'ASCII', 'strict').lower().encode('ASCII', 'strict'))
|
||||
|
||||
def handle_close(self):
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" server: closed connection %s\n" % self.socket)
|
||||
|
||||
def handle_error(self):
|
||||
raise
|
||||
|
||||
def __init__(self, port, certfile):
|
||||
self.port = port
|
||||
self.certfile = certfile
|
||||
asyncore.dispatcher.__init__(self)
|
||||
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
self.bind(('', port))
|
||||
self.listen(5)
|
||||
|
||||
def handle_accept(self):
|
||||
sock_obj, addr = self.accept()
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" server: new connection from %s:%s\n" %addr)
|
||||
self.ConnectionHandler(sock_obj, self.certfile)
|
||||
|
||||
def handle_error(self):
|
||||
raise
|
||||
|
||||
def __init__(self, port, certfile):
|
||||
self.flag = None
|
||||
self.active = False
|
||||
self.server = self.EchoServer(port, certfile)
|
||||
threading.Thread.__init__(self)
|
||||
self.setDaemon(True)
|
||||
|
||||
def __str__(self):
|
||||
return "<%s %s>" % (self.__class__.__name__, self.server)
|
||||
|
||||
def start (self, flag=None):
|
||||
self.flag = flag
|
||||
threading.Thread.start(self)
|
||||
|
||||
def run (self):
|
||||
self.active = True
|
||||
if self.flag:
|
||||
self.flag.set()
|
||||
while self.active:
|
||||
try:
|
||||
asyncore.loop(1)
|
||||
except:
|
||||
pass
|
||||
|
||||
def stop (self):
|
||||
self.active = False
|
||||
self.server.close()
|
||||
|
||||
def badCertTest (certfile):
|
||||
server = ThreadedEchoServer(TESTPORT, CERTFILE,
|
||||
certreqs=ssl.CERT_REQUIRED,
|
||||
|
@ -509,6 +600,7 @@ else:
|
|||
client_protocol = protocol
|
||||
try:
|
||||
s = ssl.wrap_socket(socket.socket(),
|
||||
server_side=False,
|
||||
certfile=client_certfile,
|
||||
ca_certs=cacertsfile,
|
||||
cert_reqs=certreqs,
|
||||
|
@ -811,11 +903,9 @@ else:
|
|||
server.stop()
|
||||
server.join()
|
||||
|
||||
class AsyncoreTests(unittest.TestCase):
|
||||
def testSocketServer(self):
|
||||
|
||||
def testAsyncore(self):
|
||||
|
||||
server = AsyncoreHTTPSServer(TESTPORT, CERTFILE)
|
||||
server = OurHTTPSServer(TESTPORT, CERTFILE)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
|
@ -853,6 +943,47 @@ else:
|
|||
server.stop()
|
||||
server.join()
|
||||
|
||||
def testAsyncoreServer(self):
|
||||
|
||||
if test_support.verbose:
|
||||
sys.stdout.write("\n")
|
||||
|
||||
indata="FOO\n"
|
||||
server = AsyncoreEchoServer(TESTPORT, CERTFILE)
|
||||
flag = threading.Event()
|
||||
server.start(flag)
|
||||
# wait for it to start
|
||||
flag.wait()
|
||||
# try to connect
|
||||
try:
|
||||
s = ssl.wrap_socket(socket.socket())
|
||||
s.connect(('127.0.0.1', TESTPORT))
|
||||
except ssl.SSLError as x:
|
||||
raise test_support.TestFailed("Unexpected SSL error: " + str(x))
|
||||
except Exception as x:
|
||||
raise test_support.TestFailed("Unexpected exception: " + str(x))
|
||||
else:
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(
|
||||
" client: sending %s...\n" % (repr(indata)))
|
||||
s.sendall(indata.encode('ASCII', 'strict'))
|
||||
outdata = s.recv()
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" client: read %s\n" % repr(outdata))
|
||||
outdata = str(outdata, 'ASCII', 'strict')
|
||||
if outdata != indata.lower():
|
||||
raise test_support.TestFailed(
|
||||
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
|
||||
% (repr(outdata[:min(len(outdata),20)]), len(outdata),
|
||||
repr(indata[:min(len(indata),20)].lower()), len(indata)))
|
||||
s.write("over\n".encode("ASCII", "strict"))
|
||||
if test_support.verbose:
|
||||
sys.stdout.write(" client: closing connection.\n")
|
||||
s.close()
|
||||
finally:
|
||||
server.stop()
|
||||
server.join()
|
||||
|
||||
|
||||
def findtestsocket(start, end):
|
||||
def testbind(i):
|
||||
|
@ -900,7 +1031,6 @@ def test_main(verbose=False):
|
|||
thread_info = test_support.threading_setup()
|
||||
if thread_info and test_support.is_resource_enabled('network'):
|
||||
tests.append(ThreadedTests)
|
||||
tests.append(AsyncoreTests)
|
||||
|
||||
test_support.run_unittest(*tests)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue