mirror of
https://github.com/python/cpython.git
synced 2025-09-26 18:29:57 +00:00
update to fix leak in SSL code
This commit is contained in:
parent
517b9ddda2
commit
54cc54c1fe
4 changed files with 225 additions and 68 deletions
|
@ -174,11 +174,13 @@ class socket(_socket.socket):
|
||||||
if self._closed:
|
if self._closed:
|
||||||
self.close()
|
self.close()
|
||||||
|
|
||||||
|
def _real_close(self):
|
||||||
|
_socket.socket.close(self)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self._closed = True
|
self._closed = True
|
||||||
if self._io_refs <= 0:
|
if self._io_refs <= 0:
|
||||||
_socket.socket.close(self)
|
self._real_close()
|
||||||
|
|
||||||
|
|
||||||
def fromfd(fd, family, type, proto=0):
|
def fromfd(fd, family, type, proto=0):
|
||||||
""" fromfd(fd, family, type[, proto]) -> socket object
|
""" fromfd(fd, family, type[, proto]) -> socket object
|
||||||
|
|
40
Lib/ssl.py
40
Lib/ssl.py
|
@ -80,6 +80,7 @@ from socket import getnameinfo as _getnameinfo
|
||||||
from socket import error as socket_error
|
from socket import error as socket_error
|
||||||
from socket import dup as _dup
|
from socket import dup as _dup
|
||||||
import base64 # for DER-to-PEM translation
|
import base64 # for DER-to-PEM translation
|
||||||
|
import traceback
|
||||||
|
|
||||||
class SSLSocket(socket):
|
class SSLSocket(socket):
|
||||||
|
|
||||||
|
@ -94,16 +95,13 @@ class SSLSocket(socket):
|
||||||
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
|
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
|
||||||
suppress_ragged_eofs=True):
|
suppress_ragged_eofs=True):
|
||||||
|
|
||||||
self._base = None
|
|
||||||
|
|
||||||
if sock is not None:
|
if sock is not None:
|
||||||
# copied this code from socket.accept()
|
socket.__init__(self,
|
||||||
fd = sock.fileno()
|
family=sock.family,
|
||||||
nfd = _dup(fd)
|
type=sock.type,
|
||||||
socket.__init__(self, family=sock.family, type=sock.type,
|
proto=sock.proto,
|
||||||
proto=sock.proto, fileno=nfd)
|
fileno=_dup(sock.fileno()))
|
||||||
sock.close()
|
sock.close()
|
||||||
sock = None
|
|
||||||
elif fileno is not None:
|
elif fileno is not None:
|
||||||
socket.__init__(self, fileno=fileno)
|
socket.__init__(self, fileno=fileno)
|
||||||
else:
|
else:
|
||||||
|
@ -136,10 +134,6 @@ class SSLSocket(socket):
|
||||||
self.close()
|
self.close()
|
||||||
raise x
|
raise x
|
||||||
|
|
||||||
if sock and (self.fileno() != sock.fileno()):
|
|
||||||
self._base = sock
|
|
||||||
else:
|
|
||||||
self._base = None
|
|
||||||
self.keyfile = keyfile
|
self.keyfile = keyfile
|
||||||
self.certfile = certfile
|
self.certfile = certfile
|
||||||
self.cert_reqs = cert_reqs
|
self.cert_reqs = cert_reqs
|
||||||
|
@ -156,18 +150,22 @@ class SSLSocket(socket):
|
||||||
# raise an exception here if you wish to check for spurious closes
|
# raise an exception here if you wish to check for spurious closes
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def read(self, len=None, buffer=None):
|
def read(self, len=0, buffer=None):
|
||||||
"""Read up to LEN bytes and return them.
|
"""Read up to LEN bytes and return them.
|
||||||
Return zero-length string on EOF."""
|
Return zero-length string on EOF."""
|
||||||
|
|
||||||
self._checkClosed()
|
self._checkClosed()
|
||||||
try:
|
try:
|
||||||
if buffer:
|
if buffer:
|
||||||
return self._sslobj.read(buffer, len)
|
v = self._sslobj.read(buffer, len)
|
||||||
else:
|
else:
|
||||||
return self._sslobj.read(len or 1024)
|
v = self._sslobj.read(len or 1024)
|
||||||
|
return v
|
||||||
except SSLError as x:
|
except SSLError as x:
|
||||||
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
|
||||||
|
if buffer:
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
return b''
|
return b''
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
@ -269,7 +267,6 @@ class SSLSocket(socket):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
v = self.read(nbytes, buffer)
|
v = self.read(nbytes, buffer)
|
||||||
sys.stdout.flush()
|
|
||||||
return v
|
return v
|
||||||
except SSLError as x:
|
except SSLError as x:
|
||||||
if x.args[0] == SSL_ERROR_WANT_READ:
|
if x.args[0] == SSL_ERROR_WANT_READ:
|
||||||
|
@ -302,9 +299,7 @@ class SSLSocket(socket):
|
||||||
def _real_close(self):
|
def _real_close(self):
|
||||||
self._sslobj = None
|
self._sslobj = None
|
||||||
# self._closed = True
|
# self._closed = True
|
||||||
if self._base:
|
socket._real_close(self)
|
||||||
self._base.close()
|
|
||||||
socket.close(self)
|
|
||||||
|
|
||||||
def do_handshake(self, block=False):
|
def do_handshake(self, block=False):
|
||||||
"""Perform a TLS/SSL handshake."""
|
"""Perform a TLS/SSL handshake."""
|
||||||
|
@ -329,8 +324,12 @@ class SSLSocket(socket):
|
||||||
self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
|
self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
|
||||||
self.cert_reqs, self.ssl_version,
|
self.cert_reqs, self.ssl_version,
|
||||||
self.ca_certs)
|
self.ca_certs)
|
||||||
|
try:
|
||||||
if self.do_handshake_on_connect:
|
if self.do_handshake_on_connect:
|
||||||
self.do_handshake()
|
self.do_handshake()
|
||||||
|
except:
|
||||||
|
self._sslobj = None
|
||||||
|
raise
|
||||||
|
|
||||||
def accept(self):
|
def accept(self):
|
||||||
"""Accepts a new connection from a remote client, and returns
|
"""Accepts a new connection from a remote client, and returns
|
||||||
|
@ -348,10 +347,11 @@ class SSLSocket(socket):
|
||||||
self.do_handshake_on_connect),
|
self.do_handshake_on_connect),
|
||||||
addr)
|
addr)
|
||||||
|
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self):
|
||||||
|
# sys.stderr.write("__del__ on %s\n" % repr(self))
|
||||||
self._real_close()
|
self._real_close()
|
||||||
|
|
||||||
|
|
||||||
def wrap_socket(sock, keyfile=None, certfile=None,
|
def wrap_socket(sock, keyfile=None, certfile=None,
|
||||||
server_side=False, cert_reqs=CERT_NONE,
|
server_side=False, cert_reqs=CERT_NONE,
|
||||||
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
|
ssl_version=PROTOCOL_SSLv23, ca_certs=None,
|
||||||
|
|
|
@ -13,6 +13,7 @@ import pprint
|
||||||
import urllib, urlparse
|
import urllib, urlparse
|
||||||
import shutil
|
import shutil
|
||||||
import traceback
|
import traceback
|
||||||
|
import asyncore
|
||||||
|
|
||||||
from BaseHTTPServer import HTTPServer
|
from BaseHTTPServer import HTTPServer
|
||||||
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
from SimpleHTTPServer import SimpleHTTPRequestHandler
|
||||||
|
@ -79,27 +80,6 @@ class BasicTests(unittest.TestCase):
|
||||||
|
|
||||||
class NetworkedTests(unittest.TestCase):
|
class NetworkedTests(unittest.TestCase):
|
||||||
|
|
||||||
def testFetchServerCert(self):
|
|
||||||
|
|
||||||
pem = ssl.get_server_certificate(("svn.python.org", 443))
|
|
||||||
if not pem:
|
|
||||||
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
|
||||||
|
|
||||||
try:
|
|
||||||
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
|
|
||||||
except ssl.SSLError as x:
|
|
||||||
#should fail
|
|
||||||
if test_support.verbose:
|
|
||||||
sys.stdout.write("%s\n" % x)
|
|
||||||
else:
|
|
||||||
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
|
|
||||||
|
|
||||||
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
|
|
||||||
if not pem:
|
|
||||||
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
|
||||||
if test_support.verbose:
|
|
||||||
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
|
|
||||||
|
|
||||||
def testConnect(self):
|
def testConnect(self):
|
||||||
|
|
||||||
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
|
s = ssl.wrap_socket(socket.socket(socket.AF_INET),
|
||||||
|
@ -155,6 +135,29 @@ class NetworkedTests(unittest.TestCase):
|
||||||
if test_support.verbose:
|
if test_support.verbose:
|
||||||
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
|
sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)
|
||||||
|
|
||||||
|
def testFetchServerCert(self):
|
||||||
|
|
||||||
|
pem = ssl.get_server_certificate(("svn.python.org", 443))
|
||||||
|
if not pem:
|
||||||
|
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE)
|
||||||
|
except ssl.SSLError as x:
|
||||||
|
#should fail
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write("%s\n" % x)
|
||||||
|
else:
|
||||||
|
raise test_support.TestFailed("Got server certificate %s for svn.python.org!" % pem)
|
||||||
|
|
||||||
|
pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT)
|
||||||
|
if not pem:
|
||||||
|
raise test_support.TestFailed("No server certificate on svn.python.org:443!")
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem)
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import threading
|
import threading
|
||||||
|
@ -333,7 +336,9 @@ else:
|
||||||
def stop (self):
|
def stop (self):
|
||||||
self.active = False
|
self.active = False
|
||||||
|
|
||||||
class AsyncoreHTTPSServer(threading.Thread):
|
class OurHTTPSServer(threading.Thread):
|
||||||
|
|
||||||
|
# This one's based on HTTPServer, which is based on SocketServer
|
||||||
|
|
||||||
class HTTPSServer(HTTPServer):
|
class HTTPSServer(HTTPServer):
|
||||||
|
|
||||||
|
@ -463,6 +468,92 @@ else:
|
||||||
self.server.server_close()
|
self.server.server_close()
|
||||||
|
|
||||||
|
|
||||||
|
class AsyncoreEchoServer(threading.Thread):
|
||||||
|
|
||||||
|
# this one's based on asyncore.dispatcher
|
||||||
|
|
||||||
|
class EchoServer (asyncore.dispatcher):
|
||||||
|
|
||||||
|
class ConnectionHandler (asyncore.dispatcher_with_send):
|
||||||
|
|
||||||
|
def __init__(self, conn, certfile):
|
||||||
|
self.socket = ssl.wrap_socket(conn, server_side=True,
|
||||||
|
certfile=certfile,
|
||||||
|
do_handshake_on_connect=False)
|
||||||
|
asyncore.dispatcher_with_send.__init__(self, self.socket)
|
||||||
|
# now we have to do the handshake
|
||||||
|
# we'll just do it the easy way, and block the connection
|
||||||
|
# till it's finished. If we were doing it right, we'd
|
||||||
|
# do this in multiple calls to handle_read...
|
||||||
|
self.do_handshake(block=True)
|
||||||
|
|
||||||
|
def readable(self):
|
||||||
|
if isinstance(self.socket, ssl.SSLSocket):
|
||||||
|
while self.socket.pending() > 0:
|
||||||
|
self.handle_read_event()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def handle_read(self):
|
||||||
|
data = self.recv(1024)
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write(" server: read %s from client\n" % repr(data))
|
||||||
|
if not data:
|
||||||
|
self.close()
|
||||||
|
else:
|
||||||
|
self.send(str(data, 'ASCII', 'strict').lower().encode('ASCII', 'strict'))
|
||||||
|
|
||||||
|
def handle_close(self):
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write(" server: closed connection %s\n" % self.socket)
|
||||||
|
|
||||||
|
def handle_error(self):
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __init__(self, port, certfile):
|
||||||
|
self.port = port
|
||||||
|
self.certfile = certfile
|
||||||
|
asyncore.dispatcher.__init__(self)
|
||||||
|
self.create_socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
self.bind(('', port))
|
||||||
|
self.listen(5)
|
||||||
|
|
||||||
|
def handle_accept(self):
|
||||||
|
sock_obj, addr = self.accept()
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write(" server: new connection from %s:%s\n" %addr)
|
||||||
|
self.ConnectionHandler(sock_obj, self.certfile)
|
||||||
|
|
||||||
|
def handle_error(self):
|
||||||
|
raise
|
||||||
|
|
||||||
|
def __init__(self, port, certfile):
|
||||||
|
self.flag = None
|
||||||
|
self.active = False
|
||||||
|
self.server = self.EchoServer(port, certfile)
|
||||||
|
threading.Thread.__init__(self)
|
||||||
|
self.setDaemon(True)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return "<%s %s>" % (self.__class__.__name__, self.server)
|
||||||
|
|
||||||
|
def start (self, flag=None):
|
||||||
|
self.flag = flag
|
||||||
|
threading.Thread.start(self)
|
||||||
|
|
||||||
|
def run (self):
|
||||||
|
self.active = True
|
||||||
|
if self.flag:
|
||||||
|
self.flag.set()
|
||||||
|
while self.active:
|
||||||
|
try:
|
||||||
|
asyncore.loop(1)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def stop (self):
|
||||||
|
self.active = False
|
||||||
|
self.server.close()
|
||||||
|
|
||||||
def badCertTest (certfile):
|
def badCertTest (certfile):
|
||||||
server = ThreadedEchoServer(TESTPORT, CERTFILE,
|
server = ThreadedEchoServer(TESTPORT, CERTFILE,
|
||||||
certreqs=ssl.CERT_REQUIRED,
|
certreqs=ssl.CERT_REQUIRED,
|
||||||
|
@ -509,6 +600,7 @@ else:
|
||||||
client_protocol = protocol
|
client_protocol = protocol
|
||||||
try:
|
try:
|
||||||
s = ssl.wrap_socket(socket.socket(),
|
s = ssl.wrap_socket(socket.socket(),
|
||||||
|
server_side=False,
|
||||||
certfile=client_certfile,
|
certfile=client_certfile,
|
||||||
ca_certs=cacertsfile,
|
ca_certs=cacertsfile,
|
||||||
cert_reqs=certreqs,
|
cert_reqs=certreqs,
|
||||||
|
@ -811,11 +903,9 @@ else:
|
||||||
server.stop()
|
server.stop()
|
||||||
server.join()
|
server.join()
|
||||||
|
|
||||||
class AsyncoreTests(unittest.TestCase):
|
def testSocketServer(self):
|
||||||
|
|
||||||
def testAsyncore(self):
|
server = OurHTTPSServer(TESTPORT, CERTFILE)
|
||||||
|
|
||||||
server = AsyncoreHTTPSServer(TESTPORT, CERTFILE)
|
|
||||||
flag = threading.Event()
|
flag = threading.Event()
|
||||||
server.start(flag)
|
server.start(flag)
|
||||||
# wait for it to start
|
# wait for it to start
|
||||||
|
@ -853,6 +943,47 @@ else:
|
||||||
server.stop()
|
server.stop()
|
||||||
server.join()
|
server.join()
|
||||||
|
|
||||||
|
def testAsyncoreServer(self):
|
||||||
|
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
|
||||||
|
indata="FOO\n"
|
||||||
|
server = AsyncoreEchoServer(TESTPORT, CERTFILE)
|
||||||
|
flag = threading.Event()
|
||||||
|
server.start(flag)
|
||||||
|
# wait for it to start
|
||||||
|
flag.wait()
|
||||||
|
# try to connect
|
||||||
|
try:
|
||||||
|
s = ssl.wrap_socket(socket.socket())
|
||||||
|
s.connect(('127.0.0.1', TESTPORT))
|
||||||
|
except ssl.SSLError as x:
|
||||||
|
raise test_support.TestFailed("Unexpected SSL error: " + str(x))
|
||||||
|
except Exception as x:
|
||||||
|
raise test_support.TestFailed("Unexpected exception: " + str(x))
|
||||||
|
else:
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write(
|
||||||
|
" client: sending %s...\n" % (repr(indata)))
|
||||||
|
s.sendall(indata.encode('ASCII', 'strict'))
|
||||||
|
outdata = s.recv()
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write(" client: read %s\n" % repr(outdata))
|
||||||
|
outdata = str(outdata, 'ASCII', 'strict')
|
||||||
|
if outdata != indata.lower():
|
||||||
|
raise test_support.TestFailed(
|
||||||
|
"bad data <<%s>> (%d) received; expected <<%s>> (%d)\n"
|
||||||
|
% (repr(outdata[:min(len(outdata),20)]), len(outdata),
|
||||||
|
repr(indata[:min(len(indata),20)].lower()), len(indata)))
|
||||||
|
s.write("over\n".encode("ASCII", "strict"))
|
||||||
|
if test_support.verbose:
|
||||||
|
sys.stdout.write(" client: closing connection.\n")
|
||||||
|
s.close()
|
||||||
|
finally:
|
||||||
|
server.stop()
|
||||||
|
server.join()
|
||||||
|
|
||||||
|
|
||||||
def findtestsocket(start, end):
|
def findtestsocket(start, end):
|
||||||
def testbind(i):
|
def testbind(i):
|
||||||
|
@ -900,7 +1031,6 @@ def test_main(verbose=False):
|
||||||
thread_info = test_support.threading_setup()
|
thread_info = test_support.threading_setup()
|
||||||
if thread_info and test_support.is_resource_enabled('network'):
|
if thread_info and test_support.is_resource_enabled('network'):
|
||||||
tests.append(ThreadedTests)
|
tests.append(ThreadedTests)
|
||||||
tests.append(AsyncoreTests)
|
|
||||||
|
|
||||||
test_support.run_unittest(*tests)
|
test_support.run_unittest(*tests)
|
||||||
|
|
||||||
|
|
|
@ -46,6 +46,7 @@ enum py_ssl_error {
|
||||||
PY_SSL_ERROR_WANT_CONNECT,
|
PY_SSL_ERROR_WANT_CONNECT,
|
||||||
/* start of non ssl.h errorcodes */
|
/* start of non ssl.h errorcodes */
|
||||||
PY_SSL_ERROR_EOF, /* special case of SSL_ERROR_SYSCALL */
|
PY_SSL_ERROR_EOF, /* special case of SSL_ERROR_SYSCALL */
|
||||||
|
PY_SSL_ERROR_NO_SOCKET, /* socket has been GC'd */
|
||||||
PY_SSL_ERROR_INVALID_ERROR_CODE
|
PY_SSL_ERROR_INVALID_ERROR_CODE
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -111,7 +112,7 @@ static unsigned int _ssl_locks_count = 0;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
PyObject_HEAD
|
PyObject_HEAD
|
||||||
PySocketSockObject *Socket; /* Socket on which we're layered */
|
PyObject *Socket; /* weakref to socket on which we're layered */
|
||||||
SSL_CTX* ctx;
|
SSL_CTX* ctx;
|
||||||
SSL* ssl;
|
SSL* ssl;
|
||||||
X509* peer_cert;
|
X509* peer_cert;
|
||||||
|
@ -188,13 +189,15 @@ PySSL_SetError(PySSLObject *obj, int ret, char *filename, int lineno)
|
||||||
{
|
{
|
||||||
unsigned long e = ERR_get_error();
|
unsigned long e = ERR_get_error();
|
||||||
if (e == 0) {
|
if (e == 0) {
|
||||||
if (ret == 0 || !obj->Socket) {
|
PySocketSockObject *s
|
||||||
|
= (PySocketSockObject *) PyWeakref_GetObject(obj->Socket);
|
||||||
|
if (ret == 0 || (((PyObject *)s) == Py_None)) {
|
||||||
p = PY_SSL_ERROR_EOF;
|
p = PY_SSL_ERROR_EOF;
|
||||||
errstr =
|
errstr =
|
||||||
"EOF occurred in violation of protocol";
|
"EOF occurred in violation of protocol";
|
||||||
} else if (ret == -1) {
|
} else if (ret == -1) {
|
||||||
/* underlying BIO reported an I/O error */
|
/* underlying BIO reported an I/O error */
|
||||||
return obj->Socket->errorhandler();
|
return s->errorhandler();
|
||||||
} else { /* possible? */
|
} else { /* possible? */
|
||||||
p = PY_SSL_ERROR_SYSCALL;
|
p = PY_SSL_ERROR_SYSCALL;
|
||||||
errstr = "Some I/O error occurred";
|
errstr = "Some I/O error occurred";
|
||||||
|
@ -383,8 +386,7 @@ newPySSLObject(PySocketSockObject *Sock, char *key_file, char *cert_file,
|
||||||
SSL_set_accept_state(self->ssl);
|
SSL_set_accept_state(self->ssl);
|
||||||
PySSL_END_ALLOW_THREADS
|
PySSL_END_ALLOW_THREADS
|
||||||
|
|
||||||
self->Socket = Sock;
|
self->Socket = PyWeakref_NewRef((PyObject *) Sock, Py_None);
|
||||||
Py_INCREF(self->Socket);
|
|
||||||
return self;
|
return self;
|
||||||
fail:
|
fail:
|
||||||
if (errstr)
|
if (errstr)
|
||||||
|
@ -442,6 +444,14 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
|
||||||
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
|
/* XXX If SSL_do_handshake() returns 0, it's also a failure. */
|
||||||
sockstate = 0;
|
sockstate = 0;
|
||||||
do {
|
do {
|
||||||
|
PySocketSockObject *sock
|
||||||
|
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
|
||||||
|
if (((PyObject*)sock) == Py_None) {
|
||||||
|
_setSSLError("Underlying socket connection gone",
|
||||||
|
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
PySSL_BEGIN_ALLOW_THREADS
|
PySSL_BEGIN_ALLOW_THREADS
|
||||||
ret = SSL_do_handshake(self->ssl);
|
ret = SSL_do_handshake(self->ssl);
|
||||||
err = SSL_get_error(self->ssl, ret);
|
err = SSL_get_error(self->ssl, ret);
|
||||||
|
@ -450,9 +460,9 @@ static PyObject *PySSL_SSLdo_handshake(PySSLObject *self)
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
if (err == SSL_ERROR_WANT_READ) {
|
if (err == SSL_ERROR_WANT_READ) {
|
||||||
sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
|
sockstate = check_socket_and_wait_for_timeout(sock, 0);
|
||||||
} else if (err == SSL_ERROR_WANT_WRITE) {
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||||
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
|
sockstate = check_socket_and_wait_for_timeout(sock, 1);
|
||||||
} else {
|
} else {
|
||||||
sockstate = SOCKET_OPERATION_OK;
|
sockstate = SOCKET_OPERATION_OK;
|
||||||
}
|
}
|
||||||
|
@ -1140,16 +1150,24 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
|
||||||
int sockstate;
|
int sockstate;
|
||||||
int err;
|
int err;
|
||||||
int nonblocking;
|
int nonblocking;
|
||||||
|
PySocketSockObject *sock
|
||||||
|
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
|
||||||
|
|
||||||
|
if (((PyObject*)sock) == Py_None) {
|
||||||
|
_setSSLError("Underlying socket connection gone",
|
||||||
|
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(args, "y#:write", &data, &count))
|
if (!PyArg_ParseTuple(args, "y#:write", &data, &count))
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
/* just in case the blocking state of the socket has been changed */
|
/* just in case the blocking state of the socket has been changed */
|
||||||
nonblocking = (self->Socket->sock_timeout >= 0.0);
|
nonblocking = (sock->sock_timeout >= 0.0);
|
||||||
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
|
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
|
||||||
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
|
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
|
||||||
|
|
||||||
sockstate = check_socket_and_wait_for_timeout(self->Socket, 1);
|
sockstate = check_socket_and_wait_for_timeout(sock, 1);
|
||||||
if (sockstate == SOCKET_HAS_TIMED_OUT) {
|
if (sockstate == SOCKET_HAS_TIMED_OUT) {
|
||||||
PyErr_SetString(PySSLErrorObject,
|
PyErr_SetString(PySSLErrorObject,
|
||||||
"The write operation timed out");
|
"The write operation timed out");
|
||||||
|
@ -1174,10 +1192,10 @@ static PyObject *PySSL_SSLwrite(PySSLObject *self, PyObject *args)
|
||||||
}
|
}
|
||||||
if (err == SSL_ERROR_WANT_READ) {
|
if (err == SSL_ERROR_WANT_READ) {
|
||||||
sockstate =
|
sockstate =
|
||||||
check_socket_and_wait_for_timeout(self->Socket, 0);
|
check_socket_and_wait_for_timeout(sock, 0);
|
||||||
} else if (err == SSL_ERROR_WANT_WRITE) {
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||||
sockstate =
|
sockstate =
|
||||||
check_socket_and_wait_for_timeout(self->Socket, 1);
|
check_socket_and_wait_for_timeout(sock, 1);
|
||||||
} else {
|
} else {
|
||||||
sockstate = SOCKET_OPERATION_OK;
|
sockstate = SOCKET_OPERATION_OK;
|
||||||
}
|
}
|
||||||
|
@ -1233,10 +1251,17 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
|
||||||
int sockstate;
|
int sockstate;
|
||||||
int err;
|
int err;
|
||||||
int nonblocking;
|
int nonblocking;
|
||||||
|
PySocketSockObject *sock
|
||||||
|
= (PySocketSockObject *) PyWeakref_GetObject(self->Socket);
|
||||||
|
|
||||||
|
if (((PyObject*)sock) == Py_None) {
|
||||||
|
_setSSLError("Underlying socket connection gone",
|
||||||
|
PY_SSL_ERROR_NO_SOCKET, __FILE__, __LINE__);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
|
||||||
if (!PyArg_ParseTuple(args, "|Oi:read", &buf, &count))
|
if (!PyArg_ParseTuple(args, "|Oi:read", &buf, &count))
|
||||||
return NULL;
|
return NULL;
|
||||||
|
|
||||||
if ((buf == NULL) || (buf == Py_None)) {
|
if ((buf == NULL) || (buf == Py_None)) {
|
||||||
if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
|
if (!(buf = PyBytes_FromStringAndSize((char *) 0, len)))
|
||||||
return NULL;
|
return NULL;
|
||||||
|
@ -1254,7 +1279,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
|
||||||
}
|
}
|
||||||
|
|
||||||
/* just in case the blocking state of the socket has been changed */
|
/* just in case the blocking state of the socket has been changed */
|
||||||
nonblocking = (self->Socket->sock_timeout >= 0.0);
|
nonblocking = (sock->sock_timeout >= 0.0);
|
||||||
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
|
BIO_set_nbio(SSL_get_rbio(self->ssl), nonblocking);
|
||||||
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
|
BIO_set_nbio(SSL_get_wbio(self->ssl), nonblocking);
|
||||||
|
|
||||||
|
@ -1264,7 +1289,7 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
|
||||||
PySSL_END_ALLOW_THREADS
|
PySSL_END_ALLOW_THREADS
|
||||||
|
|
||||||
if (!count) {
|
if (!count) {
|
||||||
sockstate = check_socket_and_wait_for_timeout(self->Socket, 0);
|
sockstate = check_socket_and_wait_for_timeout(sock, 0);
|
||||||
if (sockstate == SOCKET_HAS_TIMED_OUT) {
|
if (sockstate == SOCKET_HAS_TIMED_OUT) {
|
||||||
PyErr_SetString(PySSLErrorObject,
|
PyErr_SetString(PySSLErrorObject,
|
||||||
"The read operation timed out");
|
"The read operation timed out");
|
||||||
|
@ -1299,10 +1324,10 @@ static PyObject *PySSL_SSLread(PySSLObject *self, PyObject *args)
|
||||||
}
|
}
|
||||||
if (err == SSL_ERROR_WANT_READ) {
|
if (err == SSL_ERROR_WANT_READ) {
|
||||||
sockstate =
|
sockstate =
|
||||||
check_socket_and_wait_for_timeout(self->Socket, 0);
|
check_socket_and_wait_for_timeout(sock, 0);
|
||||||
} else if (err == SSL_ERROR_WANT_WRITE) {
|
} else if (err == SSL_ERROR_WANT_WRITE) {
|
||||||
sockstate =
|
sockstate =
|
||||||
check_socket_and_wait_for_timeout(self->Socket, 1);
|
check_socket_and_wait_for_timeout(sock, 1);
|
||||||
} else if ((err == SSL_ERROR_ZERO_RETURN) &&
|
} else if ((err == SSL_ERROR_ZERO_RETURN) &&
|
||||||
(SSL_get_shutdown(self->ssl) ==
|
(SSL_get_shutdown(self->ssl) ==
|
||||||
SSL_RECEIVED_SHUTDOWN))
|
SSL_RECEIVED_SHUTDOWN))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue