gh-63284: Add support for TLS-PSK (pre-shared key) to the ssl module (#103181)

Add support for TLS-PSK (pre-shared key) to the ssl module.

---------

Co-authored-by: Oleg Iarygin <oleg@arhadthedev.net>
Co-authored-by: Gregory P. Smith <greg@krypto.org>
This commit is contained in:
Grant Ramsay 2023-11-27 17:01:44 +13:00 committed by GitHub
parent fb202af447
commit e954ac7205
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 561 additions and 1 deletions

View file

@ -4236,6 +4236,105 @@ class ThreadedTests(unittest.TestCase):
self.assertEqual(str(e.exception),
'Session refers to a different SSLContext.')
@requires_tls_version('TLSv1_2')
def test_psk(self):
psk = bytes.fromhex('deadbeef')
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
client_context.set_ciphers('PSK')
client_context.set_psk_client_callback(lambda hint: (None, psk))
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.maximum_version = ssl.TLSVersion.TLSv1_2
server_context.set_ciphers('PSK')
server_context.set_psk_server_callback(lambda identity: psk)
# correct PSK should connect
server = ThreadedEchoServer(context=server_context)
with server:
with client_context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
# incorrect PSK should fail
incorrect_psk = bytes.fromhex('cafebabe')
client_context.set_psk_client_callback(lambda hint: (None, incorrect_psk))
server = ThreadedEchoServer(context=server_context)
with server:
with client_context.wrap_socket(socket.socket()) as s:
with self.assertRaises(ssl.SSLError):
s.connect((HOST, server.port))
# identity_hint and client_identity should be sent to the other side
identity_hint = 'identity-hint'
client_identity = 'client-identity'
def client_callback(hint):
self.assertEqual(hint, identity_hint)
return client_identity, psk
def server_callback(identity):
self.assertEqual(identity, client_identity)
return psk
client_context.set_psk_client_callback(client_callback)
server_context.set_psk_server_callback(server_callback, identity_hint)
server = ThreadedEchoServer(context=server_context)
with server:
with client_context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
# adding client callback to server or vice versa raises an exception
with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK server callback'):
client_context.set_psk_server_callback(server_callback, identity_hint)
with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK client callback'):
server_context.set_psk_client_callback(client_callback)
# test with UTF-8 identities
identity_hint = '身份暗示' # Translation: "Identity hint"
client_identity = '客户身份' # Translation: "Customer identity"
client_context.set_psk_client_callback(client_callback)
server_context.set_psk_server_callback(server_callback, identity_hint)
server = ThreadedEchoServer(context=server_context)
with server:
with client_context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
@requires_tls_version('TLSv1_3')
def test_psk_tls1_3(self):
psk = bytes.fromhex('deadbeef')
identity_hint = 'identity-hint'
client_identity = 'client-identity'
def client_callback(hint):
# identity_hint is not sent to the client in TLS 1.3
self.assertIsNone(hint)
return client_identity, psk
def server_callback(identity):
self.assertEqual(identity, client_identity)
return psk
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.check_hostname = False
client_context.verify_mode = ssl.CERT_NONE
client_context.minimum_version = ssl.TLSVersion.TLSv1_3
client_context.set_ciphers('PSK')
client_context.set_psk_client_callback(client_callback)
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.minimum_version = ssl.TLSVersion.TLSv1_3
server_context.set_ciphers('PSK')
server_context.set_psk_server_callback(server_callback, identity_hint)
server = ThreadedEchoServer(context=server_context)
with server:
with client_context.wrap_socket(socket.socket()) as s:
s.connect((HOST, server.port))
@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
class TestPostHandshakeAuth(unittest.TestCase):