update to fix leak in SSL code

This commit is contained in:
Bill Janssen 2007-12-14 22:08:56 +00:00
parent 517b9ddda2
commit 54cc54c1fe
4 changed files with 225 additions and 68 deletions

View file

@ -80,6 +80,7 @@ from socket import getnameinfo as _getnameinfo
from socket import error as socket_error
from socket import dup as _dup
import base64 # for DER-to-PEM translation
import traceback
class SSLSocket(socket):
@ -94,16 +95,13 @@ class SSLSocket(socket):
family=AF_INET, type=SOCK_STREAM, proto=0, fileno=None,
suppress_ragged_eofs=True):
self._base = None
if sock is not None:
# copied this code from socket.accept()
fd = sock.fileno()
nfd = _dup(fd)
socket.__init__(self, family=sock.family, type=sock.type,
proto=sock.proto, fileno=nfd)
socket.__init__(self,
family=sock.family,
type=sock.type,
proto=sock.proto,
fileno=_dup(sock.fileno()))
sock.close()
sock = None
elif fileno is not None:
socket.__init__(self, fileno=fileno)
else:
@ -136,10 +134,6 @@ class SSLSocket(socket):
self.close()
raise x
if sock and (self.fileno() != sock.fileno()):
self._base = sock
else:
self._base = None
self.keyfile = keyfile
self.certfile = certfile
self.cert_reqs = cert_reqs
@ -156,19 +150,23 @@ class SSLSocket(socket):
# raise an exception here if you wish to check for spurious closes
pass
def read(self, len=None, buffer=None):
def read(self, len=0, buffer=None):
"""Read up to LEN bytes and return them.
Return zero-length string on EOF."""
self._checkClosed()
try:
if buffer:
return self._sslobj.read(buffer, len)
v = self._sslobj.read(buffer, len)
else:
return self._sslobj.read(len or 1024)
v = self._sslobj.read(len or 1024)
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
return b''
if buffer:
return 0
else:
return b''
else:
raise
@ -269,7 +267,6 @@ class SSLSocket(socket):
while True:
try:
v = self.read(nbytes, buffer)
sys.stdout.flush()
return v
except SSLError as x:
if x.args[0] == SSL_ERROR_WANT_READ:
@ -302,9 +299,7 @@ class SSLSocket(socket):
def _real_close(self):
self._sslobj = None
# self._closed = True
if self._base:
self._base.close()
socket.close(self)
socket._real_close(self)
def do_handshake(self, block=False):
"""Perform a TLS/SSL handshake."""
@ -329,8 +324,12 @@ class SSLSocket(socket):
self._sslobj = _ssl.sslwrap(self, False, self.keyfile, self.certfile,
self.cert_reqs, self.ssl_version,
self.ca_certs)
if self.do_handshake_on_connect:
self.do_handshake()
try:
if self.do_handshake_on_connect:
self.do_handshake()
except:
self._sslobj = None
raise
def accept(self):
"""Accepts a new connection from a remote client, and returns
@ -348,10 +347,11 @@ class SSLSocket(socket):
self.do_handshake_on_connect),
addr)
def __del__(self):
# sys.stderr.write("__del__ on %s\n" % repr(self))
self._real_close()
def wrap_socket(sock, keyfile=None, certfile=None,
server_side=False, cert_reqs=CERT_NONE,
ssl_version=PROTOCOL_SSLv23, ca_certs=None,