mirror of
https://github.com/python/cpython.git
synced 2025-11-25 04:34:37 +00:00
most recent changes to SSL module to support non-blocking sockets properly
This commit is contained in:
parent
a37d4c693a
commit
48dc27c040
2 changed files with 65 additions and 12 deletions
35
Lib/ssl.py
35
Lib/ssl.py
|
|
@ -126,12 +126,20 @@ class SSLSocket(socket):
|
|||
keyfile, certfile,
|
||||
cert_reqs, ssl_version, ca_certs)
|
||||
if do_handshake_on_connect:
|
||||
timeout = self.gettimeout()
|
||||
if timeout == 0.0:
|
||||
# non-blocking
|
||||
raise ValueError("do_handshake_on_connect should not be specified for non-blocking sockets")
|
||||
self.do_handshake()
|
||||
|
||||
except socket_error as x:
|
||||
self.close()
|
||||
raise x
|
||||
|
||||
self._base = sock
|
||||
if sock and (self.fileno() != sock.fileno()):
|
||||
self._base = sock
|
||||
else:
|
||||
self._base = None
|
||||
self.keyfile = keyfile
|
||||
self.certfile = certfile
|
||||
self.cert_reqs = cert_reqs
|
||||
|
|
@ -148,7 +156,7 @@ class SSLSocket(socket):
|
|||
# raise an exception here if you wish to check for spurious closes
|
||||
pass
|
||||
|
||||
def read(self, len=1024, buffer=None):
|
||||
def read(self, len=None, buffer=None):
|
||||
"""Read up to LEN bytes and return them.
|
||||
Return zero-length string on EOF."""
|
||||
|
||||
|
|
@ -157,7 +165,7 @@ class SSLSocket(socket):
|
|||
if buffer:
|
||||
return self._sslobj.read(buffer, len)
|
||||
else:
|
||||
return self._sslobj.read(len)
|
||||
return self._sslobj.read(len or 1024)
|
||||
except SSLError as x:
|
||||
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||||
return b''
|
||||
|
|
@ -296,16 +304,18 @@ class SSLSocket(socket):
|
|||
# self._closed = True
|
||||
if self._base:
|
||||
self._base.close()
|
||||
socket._real_close(self)
|
||||
socket.close(self)
|
||||
|
||||
def do_handshake(self):
|
||||
def do_handshake(self, block=False):
|
||||
"""Perform a TLS/SSL handshake."""
|
||||
|
||||
timeout = self.gettimeout()
|
||||
try:
|
||||
if timeout == 0.0 and block:
|
||||
self.settimeout(None)
|
||||
self._sslobj.do_handshake()
|
||||
except:
|
||||
self._sslobj = None
|
||||
raise
|
||||
finally:
|
||||
self.settimeout(timeout)
|
||||
|
||||
def connect(self, addr):
|
||||
"""Connects to remote ADDR, and then wraps the connection in
|
||||
|
|
@ -339,15 +349,20 @@ class SSLSocket(socket):
|
|||
addr)
|
||||
|
||||
|
||||
def __del__(self):
|
||||
self._real_close()
|
||||
|
||||
def wrap_socket(sock, keyfile=None, certfile=None,
|
||||
server_side=False, cert_reqs=CERT_NONE,
|
||||
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
|
||||
do_handshake_on_connect=True):
|
||||
do_handshake_on_connect=True,
|
||||
suppress_ragged_eofs=True):
|
||||
|
||||
return SSLSocket(sock=sock, keyfile=keyfile, certfile=certfile,
|
||||
server_side=server_side, cert_reqs=cert_reqs,
|
||||
ssl_version=ssl_version, ca_certs=ca_certs,
|
||||
do_handshake_on_connect=do_handshake_on_connect)
|
||||
do_handshake_on_connect=do_handshake_on_connect,
|
||||
suppress_ragged_eofs=suppress_ragged_eofs)
|
||||
|
||||
# some utility functions
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue