mirror of
https://github.com/python/cpython.git
synced 2025-08-04 08:59:19 +00:00
gh-63284: Add support for TLS-PSK (pre-shared key) to the ssl module (#103181)
Add support for TLS-PSK (pre-shared key) to the ssl module. --------- Co-authored-by: Oleg Iarygin <oleg@arhadthedev.net> Co-authored-by: Gregory P. Smith <greg@krypto.org>
This commit is contained in:
parent
fb202af447
commit
e954ac7205
10 changed files with 561 additions and 1 deletions
|
@ -4236,6 +4236,105 @@ class ThreadedTests(unittest.TestCase):
|
|||
self.assertEqual(str(e.exception),
|
||||
'Session refers to a different SSLContext.')
|
||||
|
||||
@requires_tls_version('TLSv1_2')
|
||||
def test_psk(self):
|
||||
psk = bytes.fromhex('deadbeef')
|
||||
|
||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
client_context.check_hostname = False
|
||||
client_context.verify_mode = ssl.CERT_NONE
|
||||
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
client_context.set_ciphers('PSK')
|
||||
client_context.set_psk_client_callback(lambda hint: (None, psk))
|
||||
|
||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
server_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
||||
server_context.set_ciphers('PSK')
|
||||
server_context.set_psk_server_callback(lambda identity: psk)
|
||||
|
||||
# correct PSK should connect
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
# incorrect PSK should fail
|
||||
incorrect_psk = bytes.fromhex('cafebabe')
|
||||
client_context.set_psk_client_callback(lambda hint: (None, incorrect_psk))
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
with self.assertRaises(ssl.SSLError):
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
# identity_hint and client_identity should be sent to the other side
|
||||
identity_hint = 'identity-hint'
|
||||
client_identity = 'client-identity'
|
||||
|
||||
def client_callback(hint):
|
||||
self.assertEqual(hint, identity_hint)
|
||||
return client_identity, psk
|
||||
|
||||
def server_callback(identity):
|
||||
self.assertEqual(identity, client_identity)
|
||||
return psk
|
||||
|
||||
client_context.set_psk_client_callback(client_callback)
|
||||
server_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
# adding client callback to server or vice versa raises an exception
|
||||
with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK server callback'):
|
||||
client_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
with self.assertRaisesRegex(ssl.SSLError, 'Cannot add PSK client callback'):
|
||||
server_context.set_psk_client_callback(client_callback)
|
||||
|
||||
# test with UTF-8 identities
|
||||
identity_hint = '身份暗示' # Translation: "Identity hint"
|
||||
client_identity = '客户身份' # Translation: "Customer identity"
|
||||
|
||||
client_context.set_psk_client_callback(client_callback)
|
||||
server_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
@requires_tls_version('TLSv1_3')
|
||||
def test_psk_tls1_3(self):
|
||||
psk = bytes.fromhex('deadbeef')
|
||||
identity_hint = 'identity-hint'
|
||||
client_identity = 'client-identity'
|
||||
|
||||
def client_callback(hint):
|
||||
# identity_hint is not sent to the client in TLS 1.3
|
||||
self.assertIsNone(hint)
|
||||
return client_identity, psk
|
||||
|
||||
def server_callback(identity):
|
||||
self.assertEqual(identity, client_identity)
|
||||
return psk
|
||||
|
||||
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
||||
client_context.check_hostname = False
|
||||
client_context.verify_mode = ssl.CERT_NONE
|
||||
client_context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
client_context.set_ciphers('PSK')
|
||||
client_context.set_psk_client_callback(client_callback)
|
||||
|
||||
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||||
server_context.minimum_version = ssl.TLSVersion.TLSv1_3
|
||||
server_context.set_ciphers('PSK')
|
||||
server_context.set_psk_server_callback(server_callback, identity_hint)
|
||||
|
||||
server = ThreadedEchoServer(context=server_context)
|
||||
with server:
|
||||
with client_context.wrap_socket(socket.socket()) as s:
|
||||
s.connect((HOST, server.port))
|
||||
|
||||
|
||||
@unittest.skipUnless(has_tls_version('TLSv1_3'), "Test needs TLS 1.3")
|
||||
class TestPostHandshakeAuth(unittest.TestCase):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue