mirror of
https://github.com/python/cpython.git
synced 2025-08-31 05:58:33 +00:00
bpo-34670: Add TLS 1.3 post handshake auth (GH-9460)
Add SSLContext.post_handshake_auth and SSLSocket.verify_client_post_handshake for TLS 1.3 post-handshake authentication. Signed-off-by: Christian Heimes <christian@python.org>q https://bugs.python.org/issue34670
This commit is contained in:
parent
4b860fd777
commit
9fb051f032
9 changed files with 370 additions and 16 deletions
|
@ -218,7 +218,7 @@ def testing_context(server_cert=SIGNED_CERTFILE):
|
|||
|
||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
server_context.load_cert_chain(server_cert)
|
||||
client_context.load_verify_locations(SIGNING_CA)
|
||||
server_context.load_verify_locations(SIGNING_CA)
|
||||
|
||||
return client_context, server_context, hostname
|
||||
|
||||
|
@ -2262,6 +2262,23 @@ class ThreadedEchoServer(threading.Thread):
|
|||
sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
|
||||
data = self.sslconn.get_channel_binding("tls-unique")
|
||||
self.write(repr(data).encode("us-ascii") + b"\n")
|
||||
elif stripped == b'PHA':
|
||||
if support.verbose and self.server.connectionchatty:
|
||||
sys.stdout.write(" server: initiating post handshake auth\n")
|
||||
try:
|
||||
self.sslconn.verify_client_post_handshake()
|
||||
except ssl.SSLError as e:
|
||||
self.write(repr(e).encode("us-ascii") + b"\n")
|
||||
else:
|
||||
self.write(b"OK\n")
|
||||
elif stripped == b'HASCERT':
|
||||
if self.sslconn.getpeercert() is not None:
|
||||
self.write(b'TRUE\n')
|
||||
else:
|
||||
self.write(b'FALSE\n')
|
||||
elif stripped == b'GETCERT':
|
||||
cert = self.sslconn.getpeercert()
|
||||
self.write(repr(cert).encode("us-ascii") + b"\n")
|
||||
else:
|
||||
if (support.verbose and
|
||||
self.server.connectionchatty):
|
||||
|
@ -4148,6 +4165,179 @@ class ThreadedTests(unittest.TestCase):
|
|||
'Session refers to a different SSLContext.')
|
||||
|
||||
|
||||
@unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
|
||||
class TestPostHandshakeAuth(unittest.TestCase):
|
||||
def test_pha_setter(self):
|
||||
protocols = [
|
||||
ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
|
||||
]
|
||||
for protocol in protocols:
|
||||
ctx = ssl.SSLContext(protocol)
|
||||
self.assertEqual(ctx.post_handshake_auth, False)
|
||||
|
||||
ctx.post_handshake_auth = True
|
||||
self.assertEqual(ctx.post_handshake_auth, True)
|
||||
|
||||
ctx.verify_mode = ssl.CERT_REQUIRED
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
self.assertEqual(ctx.post_handshake_auth, True)
|
||||
|
||||
ctx.post_handshake_auth = False
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
|
||||
self.assertEqual(ctx.post_handshake_auth, False)
|
||||
|
||||
ctx.verify_mode = ssl.CERT_OPTIONAL
|
||||
ctx.post_handshake_auth = True
|
||||
self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
|
||||
self.assertEqual(ctx.post_handshake_auth, True)
|
||||
|
||||
def test_pha_required(self):
|
||||
client_context, server_context, hostname = testing_context()
|
||||
server_context.post_handshake_auth = True
|
||||
server_context.verify_mode = ssl.CERT_REQUIRED
|
||||
client_context.post_handshake_auth = True
|
||||
client_context.load_cert_chain(SIGNED_CERTFILE)
|
||||
|
||||
server = ThreadedEchoServer(context=server_context, chatty=False)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket(),
|
||||
server_hostname=hostname) as s:
|
||||
s.connect((HOST, server.port))
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'FALSE\n')
|
||||
s.write(b'PHA')
|
||||
self.assertEqual(s.recv(1024), b'OK\n')
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'TRUE\n')
|
||||
# PHA method just returns true when cert is already available
|
||||
s.write(b'PHA')
|
||||
self.assertEqual(s.recv(1024), b'OK\n')
|
||||
s.write(b'GETCERT')
|
||||
cert_text = s.recv(4096).decode('us-ascii')
|
||||
self.assertIn('Python Software Foundation CA', cert_text)
|
||||
|
||||
def test_pha_required_nocert(self):
|
||||
client_context, server_context, hostname = testing_context()
|
||||
server_context.post_handshake_auth = True
|
||||
server_context.verify_mode = ssl.CERT_REQUIRED
|
||||
client_context.post_handshake_auth = True
|
||||
|
||||
server = ThreadedEchoServer(context=server_context, chatty=False)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket(),
|
||||
server_hostname=hostname) as s:
|
||||
s.connect((HOST, server.port))
|
||||
s.write(b'PHA')
|
||||
# receive CertificateRequest
|
||||
self.assertEqual(s.recv(1024), b'OK\n')
|
||||
# send empty Certificate + Finish
|
||||
s.write(b'HASCERT')
|
||||
# receive alert
|
||||
with self.assertRaisesRegex(
|
||||
ssl.SSLError,
|
||||
'tlsv13 alert certificate required'):
|
||||
s.recv(1024)
|
||||
|
||||
def test_pha_optional(self):
|
||||
if support.verbose:
|
||||
sys.stdout.write("\n")
|
||||
|
||||
client_context, server_context, hostname = testing_context()
|
||||
server_context.post_handshake_auth = True
|
||||
server_context.verify_mode = ssl.CERT_REQUIRED
|
||||
client_context.post_handshake_auth = True
|
||||
client_context.load_cert_chain(SIGNED_CERTFILE)
|
||||
|
||||
# check CERT_OPTIONAL
|
||||
server_context.verify_mode = ssl.CERT_OPTIONAL
|
||||
server = ThreadedEchoServer(context=server_context, chatty=False)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket(),
|
||||
server_hostname=hostname) as s:
|
||||
s.connect((HOST, server.port))
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'FALSE\n')
|
||||
s.write(b'PHA')
|
||||
self.assertEqual(s.recv(1024), b'OK\n')
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'TRUE\n')
|
||||
|
||||
def test_pha_optional_nocert(self):
|
||||
if support.verbose:
|
||||
sys.stdout.write("\n")
|
||||
|
||||
client_context, server_context, hostname = testing_context()
|
||||
server_context.post_handshake_auth = True
|
||||
server_context.verify_mode = ssl.CERT_OPTIONAL
|
||||
client_context.post_handshake_auth = True
|
||||
|
||||
server = ThreadedEchoServer(context=server_context, chatty=False)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket(),
|
||||
server_hostname=hostname) as s:
|
||||
s.connect((HOST, server.port))
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'FALSE\n')
|
||||
s.write(b'PHA')
|
||||
self.assertEqual(s.recv(1024), b'OK\n')
|
||||
# optional doens't fail when client does not have a cert
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'FALSE\n')
|
||||
|
||||
def test_pha_no_pha_client(self):
|
||||
client_context, server_context, hostname = testing_context()
|
||||
server_context.post_handshake_auth = True
|
||||
server_context.verify_mode = ssl.CERT_REQUIRED
|
||||
client_context.load_cert_chain(SIGNED_CERTFILE)
|
||||
|
||||
server = ThreadedEchoServer(context=server_context, chatty=False)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket(),
|
||||
server_hostname=hostname) as s:
|
||||
s.connect((HOST, server.port))
|
||||
with self.assertRaisesRegex(ssl.SSLError, 'not server'):
|
||||
s.verify_client_post_handshake()
|
||||
s.write(b'PHA')
|
||||
self.assertIn(b'extension not received', s.recv(1024))
|
||||
|
||||
def test_pha_no_pha_server(self):
|
||||
# server doesn't have PHA enabled, cert is requested in handshake
|
||||
client_context, server_context, hostname = testing_context()
|
||||
server_context.verify_mode = ssl.CERT_REQUIRED
|
||||
client_context.post_handshake_auth = True
|
||||
client_context.load_cert_chain(SIGNED_CERTFILE)
|
||||
|
||||
server = ThreadedEchoServer(context=server_context, chatty=False)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket(),
|
||||
server_hostname=hostname) as s:
|
||||
s.connect((HOST, server.port))
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'TRUE\n')
|
||||
# PHA doesn't fail if there is already a cert
|
||||
s.write(b'PHA')
|
||||
self.assertEqual(s.recv(1024), b'OK\n')
|
||||
s.write(b'HASCERT')
|
||||
self.assertEqual(s.recv(1024), b'TRUE\n')
|
||||
|
||||
def test_pha_not_tls13(self):
|
||||
# TLS 1.2
|
||||
client_context, server_context, hostname = testing_context()
|
||||
server_context.verify_mode = ssl.CERT_REQUIRED
|
||||
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
client_context.post_handshake_auth = True
|
||||
client_context.load_cert_chain(SIGNED_CERTFILE)
|
||||
|
||||
server = ThreadedEchoServer(context=server_context, chatty=False)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket(),
|
||||
server_hostname=hostname) as s:
|
||||
s.connect((HOST, server.port))
|
||||
# PHA fails for TLS != 1.3
|
||||
s.write(b'PHA')
|
||||
self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
|
||||
|
||||
|
||||
def test_main(verbose=False):
|
||||
if support.verbose:
|
||||
import warnings
|
||||
|
@ -4183,6 +4373,7 @@ def test_main(verbose=False):
|
|||
tests = [
|
||||
ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
|
||||
SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
|
||||
TestPostHandshakeAuth
|
||||
]
|
||||
|
||||
if support.is_resource_enabled('network'):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue