mirror of
https://github.com/python/cpython.git
synced 2025-10-17 12:18:23 +00:00
Use context managers in test_ssl to simplify test writing.
This commit is contained in:
commit
6b15c90fd8
1 changed files with 43 additions and 83 deletions
|
@ -986,6 +986,14 @@ else:
|
||||||
threading.Thread.__init__(self)
|
threading.Thread.__init__(self)
|
||||||
self.daemon = True
|
self.daemon = True
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start(threading.Event())
|
||||||
|
self.flag.wait()
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
self.stop()
|
||||||
|
self.join()
|
||||||
|
|
||||||
def start(self, flag=None):
|
def start(self, flag=None):
|
||||||
self.flag = flag
|
self.flag = flag
|
||||||
threading.Thread.start(self)
|
threading.Thread.start(self)
|
||||||
|
@ -1097,6 +1105,20 @@ else:
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "<%s %s>" % (self.__class__.__name__, self.server)
|
return "<%s %s>" % (self.__class__.__name__, self.server)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.start(threading.Event())
|
||||||
|
self.flag.wait()
|
||||||
|
|
||||||
|
def __exit__(self, *args):
|
||||||
|
if support.verbose:
|
||||||
|
sys.stdout.write(" cleanup: stopping server.\n")
|
||||||
|
self.stop()
|
||||||
|
if support.verbose:
|
||||||
|
sys.stdout.write(" cleanup: joining server thread.\n")
|
||||||
|
self.join()
|
||||||
|
if support.verbose:
|
||||||
|
sys.stdout.write(" cleanup: successfully joined.\n")
|
||||||
|
|
||||||
def start (self, flag=None):
|
def start (self, flag=None):
|
||||||
self.flag = flag
|
self.flag = flag
|
||||||
threading.Thread.start(self)
|
threading.Thread.start(self)
|
||||||
|
@ -1124,12 +1146,7 @@ else:
|
||||||
certreqs=ssl.CERT_REQUIRED,
|
certreqs=ssl.CERT_REQUIRED,
|
||||||
cacerts=CERTFILE, chatty=False,
|
cacerts=CERTFILE, chatty=False,
|
||||||
connectionchatty=False)
|
connectionchatty=False)
|
||||||
flag = threading.Event()
|
with server:
|
||||||
server.start(flag)
|
|
||||||
# wait for it to start
|
|
||||||
flag.wait()
|
|
||||||
# try to connect
|
|
||||||
try:
|
|
||||||
try:
|
try:
|
||||||
with socket.socket() as sock:
|
with socket.socket() as sock:
|
||||||
s = ssl.wrap_socket(sock,
|
s = ssl.wrap_socket(sock,
|
||||||
|
@ -1149,9 +1166,6 @@ else:
|
||||||
sys.stdout.write("\IOError is %s\n" % str(x))
|
sys.stdout.write("\IOError is %s\n" % str(x))
|
||||||
else:
|
else:
|
||||||
raise AssertionError("Use of invalid cert should have failed!")
|
raise AssertionError("Use of invalid cert should have failed!")
|
||||||
finally:
|
|
||||||
server.stop()
|
|
||||||
server.join()
|
|
||||||
|
|
||||||
def server_params_test(client_context, server_context, indata=b"FOO\n",
|
def server_params_test(client_context, server_context, indata=b"FOO\n",
|
||||||
chatty=True, connectionchatty=False):
|
chatty=True, connectionchatty=False):
|
||||||
|
@ -1162,12 +1176,7 @@ else:
|
||||||
server = ThreadedEchoServer(context=server_context,
|
server = ThreadedEchoServer(context=server_context,
|
||||||
chatty=chatty,
|
chatty=chatty,
|
||||||
connectionchatty=False)
|
connectionchatty=False)
|
||||||
flag = threading.Event()
|
with server:
|
||||||
server.start(flag)
|
|
||||||
# wait for it to start
|
|
||||||
flag.wait()
|
|
||||||
# try to connect
|
|
||||||
try:
|
|
||||||
s = client_context.wrap_socket(socket.socket())
|
s = client_context.wrap_socket(socket.socket())
|
||||||
s.connect((HOST, server.port))
|
s.connect((HOST, server.port))
|
||||||
for arg in [indata, bytearray(indata), memoryview(indata)]:
|
for arg in [indata, bytearray(indata), memoryview(indata)]:
|
||||||
|
@ -1195,9 +1204,6 @@ else:
|
||||||
}
|
}
|
||||||
s.close()
|
s.close()
|
||||||
return stats
|
return stats
|
||||||
finally:
|
|
||||||
server.stop()
|
|
||||||
server.join()
|
|
||||||
|
|
||||||
def try_protocol_combo(server_protocol, client_protocol, expect_success,
|
def try_protocol_combo(server_protocol, client_protocol, expect_success,
|
||||||
certsreqs=None, server_options=0, client_options=0):
|
certsreqs=None, server_options=0, client_options=0):
|
||||||
|
@ -1266,12 +1272,7 @@ else:
|
||||||
context.load_verify_locations(CERTFILE)
|
context.load_verify_locations(CERTFILE)
|
||||||
context.load_cert_chain(CERTFILE)
|
context.load_cert_chain(CERTFILE)
|
||||||
server = ThreadedEchoServer(context=context, chatty=False)
|
server = ThreadedEchoServer(context=context, chatty=False)
|
||||||
flag = threading.Event()
|
with server:
|
||||||
server.start(flag)
|
|
||||||
# wait for it to start
|
|
||||||
flag.wait()
|
|
||||||
# try to connect
|
|
||||||
try:
|
|
||||||
s = context.wrap_socket(socket.socket())
|
s = context.wrap_socket(socket.socket())
|
||||||
s.connect((HOST, server.port))
|
s.connect((HOST, server.port))
|
||||||
cert = s.getpeercert()
|
cert = s.getpeercert()
|
||||||
|
@ -1294,9 +1295,6 @@ else:
|
||||||
after = ssl.cert_time_to_seconds(cert['notAfter'])
|
after = ssl.cert_time_to_seconds(cert['notAfter'])
|
||||||
self.assertLess(before, after)
|
self.assertLess(before, after)
|
||||||
s.close()
|
s.close()
|
||||||
finally:
|
|
||||||
server.stop()
|
|
||||||
server.join()
|
|
||||||
|
|
||||||
def test_empty_cert(self):
|
def test_empty_cert(self):
|
||||||
"""Connecting with an empty cert file"""
|
"""Connecting with an empty cert file"""
|
||||||
|
@ -1456,13 +1454,8 @@ else:
|
||||||
starttls_server=True,
|
starttls_server=True,
|
||||||
chatty=True,
|
chatty=True,
|
||||||
connectionchatty=True)
|
connectionchatty=True)
|
||||||
flag = threading.Event()
|
|
||||||
server.start(flag)
|
|
||||||
# wait for it to start
|
|
||||||
flag.wait()
|
|
||||||
# try to connect
|
|
||||||
wrapped = False
|
wrapped = False
|
||||||
try:
|
with server:
|
||||||
s = socket.socket()
|
s = socket.socket()
|
||||||
s.setblocking(1)
|
s.setblocking(1)
|
||||||
s.connect((HOST, server.port))
|
s.connect((HOST, server.port))
|
||||||
|
@ -1509,9 +1502,6 @@ else:
|
||||||
conn.close()
|
conn.close()
|
||||||
else:
|
else:
|
||||||
s.close()
|
s.close()
|
||||||
finally:
|
|
||||||
server.stop()
|
|
||||||
server.join()
|
|
||||||
|
|
||||||
def test_socketserver(self):
|
def test_socketserver(self):
|
||||||
"""Using a SocketServer to create and manage SSL connections."""
|
"""Using a SocketServer to create and manage SSL connections."""
|
||||||
|
@ -1547,12 +1537,7 @@ else:
|
||||||
|
|
||||||
indata = b"FOO\n"
|
indata = b"FOO\n"
|
||||||
server = AsyncoreEchoServer(CERTFILE)
|
server = AsyncoreEchoServer(CERTFILE)
|
||||||
flag = threading.Event()
|
with server:
|
||||||
server.start(flag)
|
|
||||||
# wait for it to start
|
|
||||||
flag.wait()
|
|
||||||
# try to connect
|
|
||||||
try:
|
|
||||||
s = ssl.wrap_socket(socket.socket())
|
s = ssl.wrap_socket(socket.socket())
|
||||||
s.connect(('127.0.0.1', server.port))
|
s.connect(('127.0.0.1', server.port))
|
||||||
if support.verbose:
|
if support.verbose:
|
||||||
|
@ -1573,15 +1558,6 @@ else:
|
||||||
s.close()
|
s.close()
|
||||||
if support.verbose:
|
if support.verbose:
|
||||||
sys.stdout.write(" client: connection closed.\n")
|
sys.stdout.write(" client: connection closed.\n")
|
||||||
finally:
|
|
||||||
if support.verbose:
|
|
||||||
sys.stdout.write(" cleanup: stopping server.\n")
|
|
||||||
server.stop()
|
|
||||||
if support.verbose:
|
|
||||||
sys.stdout.write(" cleanup: joining server thread.\n")
|
|
||||||
server.join()
|
|
||||||
if support.verbose:
|
|
||||||
sys.stdout.write(" cleanup: successfully joined.\n")
|
|
||||||
|
|
||||||
def test_recv_send(self):
|
def test_recv_send(self):
|
||||||
"""Test recv(), send() and friends."""
|
"""Test recv(), send() and friends."""
|
||||||
|
@ -1594,19 +1570,14 @@ else:
|
||||||
cacerts=CERTFILE,
|
cacerts=CERTFILE,
|
||||||
chatty=True,
|
chatty=True,
|
||||||
connectionchatty=False)
|
connectionchatty=False)
|
||||||
flag = threading.Event()
|
with server:
|
||||||
server.start(flag)
|
s = ssl.wrap_socket(socket.socket(),
|
||||||
# wait for it to start
|
server_side=False,
|
||||||
flag.wait()
|
certfile=CERTFILE,
|
||||||
# try to connect
|
ca_certs=CERTFILE,
|
||||||
s = ssl.wrap_socket(socket.socket(),
|
cert_reqs=ssl.CERT_NONE,
|
||||||
server_side=False,
|
ssl_version=ssl.PROTOCOL_TLSv1)
|
||||||
certfile=CERTFILE,
|
s.connect((HOST, server.port))
|
||||||
ca_certs=CERTFILE,
|
|
||||||
cert_reqs=ssl.CERT_NONE,
|
|
||||||
ssl_version=ssl.PROTOCOL_TLSv1)
|
|
||||||
s.connect((HOST, server.port))
|
|
||||||
try:
|
|
||||||
# helper methods for standardising recv* method signatures
|
# helper methods for standardising recv* method signatures
|
||||||
def _recv_into():
|
def _recv_into():
|
||||||
b = bytearray(b"\0"*100)
|
b = bytearray(b"\0"*100)
|
||||||
|
@ -1702,9 +1673,6 @@ else:
|
||||||
|
|
||||||
s.write(b"over\n")
|
s.write(b"over\n")
|
||||||
s.close()
|
s.close()
|
||||||
finally:
|
|
||||||
server.stop()
|
|
||||||
server.join()
|
|
||||||
|
|
||||||
def test_handshake_timeout(self):
|
def test_handshake_timeout(self):
|
||||||
# Issue #5103: SSL handshake must respect the socket timeout
|
# Issue #5103: SSL handshake must respect the socket timeout
|
||||||
|
@ -1768,19 +1736,14 @@ else:
|
||||||
cacerts=CERTFILE,
|
cacerts=CERTFILE,
|
||||||
chatty=True,
|
chatty=True,
|
||||||
connectionchatty=False)
|
connectionchatty=False)
|
||||||
flag = threading.Event()
|
with server:
|
||||||
server.start(flag)
|
s = ssl.wrap_socket(socket.socket(),
|
||||||
# wait for it to start
|
server_side=False,
|
||||||
flag.wait()
|
certfile=CERTFILE,
|
||||||
# try to connect
|
ca_certs=CERTFILE,
|
||||||
s = ssl.wrap_socket(socket.socket(),
|
cert_reqs=ssl.CERT_NONE,
|
||||||
server_side=False,
|
ssl_version=ssl.PROTOCOL_TLSv1)
|
||||||
certfile=CERTFILE,
|
s.connect((HOST, server.port))
|
||||||
ca_certs=CERTFILE,
|
|
||||||
cert_reqs=ssl.CERT_NONE,
|
|
||||||
ssl_version=ssl.PROTOCOL_TLSv1)
|
|
||||||
s.connect((HOST, server.port))
|
|
||||||
try:
|
|
||||||
# get the data
|
# get the data
|
||||||
cb_data = s.get_channel_binding("tls-unique")
|
cb_data = s.get_channel_binding("tls-unique")
|
||||||
if support.verbose:
|
if support.verbose:
|
||||||
|
@ -1819,9 +1782,6 @@ else:
|
||||||
self.assertEqual(peer_data_repr,
|
self.assertEqual(peer_data_repr,
|
||||||
repr(new_cb_data).encode("us-ascii"))
|
repr(new_cb_data).encode("us-ascii"))
|
||||||
s.close()
|
s.close()
|
||||||
finally:
|
|
||||||
server.stop()
|
|
||||||
server.join()
|
|
||||||
|
|
||||||
def test_compression(self):
|
def test_compression(self):
|
||||||
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue