mirror of
https://github.com/python/cpython.git
synced 2025-08-25 03:04:55 +00:00
Sockets facelift. APIs that could return binary data (e.g. aton() and
recv()) now return bytes, not str or str8. The socket.py code is redone; it now subclasses _socket.socket and instead of having its own _fileobject for makefile(), it uses io.SocketIO. Some stuff in io.py was moved around to make this work. (I really need to rethink my policy regarding readline() and read(-1) on raw files; and readline() on buffered files ought to use peeking(). Later.)
This commit is contained in:
parent
88effc1251
commit
7d0a8264ff
5 changed files with 200 additions and 516 deletions
395
Lib/socket.py
395
Lib/socket.py
|
@ -54,7 +54,7 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
import os, sys
|
||||
import os, sys, io
|
||||
|
||||
try:
|
||||
from errno import EBADF
|
||||
|
@ -66,14 +66,6 @@ __all__.extend(os._get_exports_list(_socket))
|
|||
if _have_ssl:
|
||||
__all__.extend(os._get_exports_list(_ssl))
|
||||
|
||||
_realsocket = socket
|
||||
if _have_ssl:
|
||||
_realssl = ssl
|
||||
def ssl(sock, keyfile=None, certfile=None):
|
||||
if hasattr(sock, "_sock"):
|
||||
sock = sock._sock
|
||||
return _realssl(sock, keyfile, certfile)
|
||||
|
||||
# WSA error codes
|
||||
if sys.platform.lower().startswith("win"):
|
||||
errorTab = {}
|
||||
|
@ -95,6 +87,99 @@ if sys.platform.lower().startswith("win"):
|
|||
__all__.append("errorTab")
|
||||
|
||||
|
||||
_os_has_dup = hasattr(os, "dup")
|
||||
if _os_has_dup:
|
||||
def fromfd(fd, family=AF_INET, type=SOCK_STREAM, proto=0):
|
||||
nfd = os.dup(fd)
|
||||
return socket(family, type, proto, fileno=nfd)
|
||||
|
||||
|
||||
class socket(_socket.socket):
|
||||
|
||||
"""A subclass of _socket.socket adding the makefile() method."""
|
||||
|
||||
__slots__ = ["__weakref__"]
|
||||
if not _os_has_dup:
|
||||
__slots__.append("_base")
|
||||
|
||||
def __repr__(self):
|
||||
"""Wrap __repr__() to reveal the real class name."""
|
||||
s = _socket.socket.__repr__(self)
|
||||
if s.startswith("<socket object"):
|
||||
s = "<%s.%s%s" % (self.__class__.__module__,
|
||||
self.__class__.__name__,
|
||||
s[7:])
|
||||
return s
|
||||
|
||||
def accept(self):
|
||||
"""Wrap accept() to give the connection the right type."""
|
||||
conn, addr = _socket.socket.accept(self)
|
||||
fd = conn.fileno()
|
||||
nfd = fd
|
||||
if _os_has_dup:
|
||||
nfd = os.dup(fd)
|
||||
wrapper = socket(self.family, self.type, self.proto, fileno=nfd)
|
||||
if fd == nfd:
|
||||
wrapper._base = conn # Keep the base alive
|
||||
else:
|
||||
conn.close()
|
||||
return wrapper, addr
|
||||
|
||||
if not _os_has_dup:
|
||||
def close(self):
|
||||
"""Wrap close() to close the _base as well."""
|
||||
_socket.socket.close(self)
|
||||
base = getattr(self, "_base", None)
|
||||
if base is not None:
|
||||
base.close()
|
||||
|
||||
def makefile(self, mode="r", buffering=None, *,
|
||||
encoding=None, newline=None):
|
||||
"""Return an I/O stream connected to the socket.
|
||||
|
||||
The arguments are as for io.open() after the filename,
|
||||
except the only mode characters supported are 'r', 'w' and 'b'.
|
||||
The semantics are similar too. (XXX refactor to share code?)
|
||||
"""
|
||||
for c in mode:
|
||||
if c not in {"r", "w", "b"}:
|
||||
raise ValueError("invalid mode %r (only r, w, b allowed)")
|
||||
writing = "w" in mode
|
||||
reading = "r" in mode or not writing
|
||||
assert reading or writing
|
||||
binary = "b" in mode
|
||||
rawmode = ""
|
||||
if reading:
|
||||
rawmode += "r"
|
||||
if writing:
|
||||
rawmode += "w"
|
||||
raw = io.SocketIO(self, rawmode)
|
||||
if buffering is None:
|
||||
buffering = -1
|
||||
if buffering < 0:
|
||||
buffering = io.DEFAULT_BUFFER_SIZE
|
||||
if buffering == 0:
|
||||
if not binary:
|
||||
raise ValueError("unbuffered streams must be binary")
|
||||
raw.name = self.fileno()
|
||||
raw.mode = mode
|
||||
return raw
|
||||
if reading and writing:
|
||||
buffer = io.BufferedRWPair(raw, raw, buffering)
|
||||
elif reading:
|
||||
buffer = io.BufferedReader(raw, buffering)
|
||||
else:
|
||||
assert writing
|
||||
buffer = io.BufferedWriter(raw, buffering)
|
||||
if binary:
|
||||
buffer.name = self.fileno()
|
||||
buffer.mode = mode
|
||||
return buffer
|
||||
text = io.TextIOWrapper(buffer, encoding, newline)
|
||||
text.name = self.fileno()
|
||||
self.mode = mode
|
||||
return text
|
||||
|
||||
|
||||
def getfqdn(name=''):
|
||||
"""Get fully qualified domain name from name.
|
||||
|
@ -122,298 +207,6 @@ def getfqdn(name=''):
|
|||
return name
|
||||
|
||||
|
||||
_socketmethods = (
|
||||
'bind', 'connect', 'connect_ex', 'fileno', 'listen',
|
||||
'getpeername', 'getsockname', 'getsockopt', 'setsockopt',
|
||||
'sendall', 'setblocking',
|
||||
'settimeout', 'gettimeout', 'shutdown')
|
||||
|
||||
if sys.platform == "riscos":
|
||||
_socketmethods = _socketmethods + ('sleeptaskw',)
|
||||
|
||||
# All the method names that must be delegated to either the real socket
|
||||
# object or the _closedsocket object.
|
||||
_delegate_methods = ("recv", "recvfrom", "recv_into", "recvfrom_into",
|
||||
"send", "sendto")
|
||||
|
||||
class _closedsocket(object):
|
||||
__slots__ = []
|
||||
def _dummy(*args):
|
||||
raise error(EBADF, 'Bad file descriptor')
|
||||
# All _delegate_methods must also be initialized here.
|
||||
send = recv = recv_into = sendto = recvfrom = recvfrom_into = _dummy
|
||||
__getattr__ = _dummy
|
||||
|
||||
class _socketobject(object):
|
||||
|
||||
__doc__ = _realsocket.__doc__
|
||||
|
||||
__slots__ = ["_sock", "__weakref__"] + list(_delegate_methods)
|
||||
|
||||
def __init__(self, family=AF_INET, type=SOCK_STREAM, proto=0, _sock=None):
|
||||
if _sock is None:
|
||||
_sock = _realsocket(family, type, proto)
|
||||
self._sock = _sock
|
||||
for method in _delegate_methods:
|
||||
setattr(self, method, getattr(_sock, method))
|
||||
|
||||
def close(self):
|
||||
self._sock = _closedsocket()
|
||||
dummy = self._sock._dummy
|
||||
for method in _delegate_methods:
|
||||
setattr(self, method, dummy)
|
||||
close.__doc__ = _realsocket.close.__doc__
|
||||
|
||||
def accept(self):
|
||||
sock, addr = self._sock.accept()
|
||||
return _socketobject(_sock=sock), addr
|
||||
accept.__doc__ = _realsocket.accept.__doc__
|
||||
|
||||
def dup(self):
|
||||
"""dup() -> socket object
|
||||
|
||||
Return a new socket object connected to the same system resource."""
|
||||
return _socketobject(_sock=self._sock)
|
||||
|
||||
def makefile(self, mode='r', bufsize=-1):
|
||||
"""makefile([mode[, bufsize]]) -> file object
|
||||
|
||||
Return a regular file object corresponding to the socket. The mode
|
||||
and bufsize arguments are as for the built-in open() function."""
|
||||
return _fileobject(self._sock, mode, bufsize)
|
||||
|
||||
family = property(lambda self: self._sock.family, doc="the socket family")
|
||||
type = property(lambda self: self._sock.type, doc="the socket type")
|
||||
proto = property(lambda self: self._sock.proto, doc="the socket protocol")
|
||||
|
||||
_s = ("def %s(self, *args): return self._sock.%s(*args)\n\n"
|
||||
"%s.__doc__ = _realsocket.%s.__doc__\n")
|
||||
for _m in _socketmethods:
|
||||
exec(_s % (_m, _m, _m, _m))
|
||||
del _m, _s
|
||||
|
||||
socket = SocketType = _socketobject
|
||||
|
||||
class _fileobject(object):
|
||||
"""Faux file object attached to a socket object."""
|
||||
|
||||
default_bufsize = 8192
|
||||
name = "<socket>"
|
||||
|
||||
__slots__ = ["mode", "bufsize",
|
||||
# "closed" is a property, see below
|
||||
"_sock", "_rbufsize", "_wbufsize", "_rbuf", "_wbuf",
|
||||
"_close"]
|
||||
|
||||
def __init__(self, sock, mode='rb', bufsize=-1, close=False):
|
||||
self._sock = sock
|
||||
self.mode = mode # Not actually used in this version
|
||||
if bufsize < 0:
|
||||
bufsize = self.default_bufsize
|
||||
self.bufsize = bufsize
|
||||
if bufsize == 0:
|
||||
self._rbufsize = 1
|
||||
elif bufsize == 1:
|
||||
self._rbufsize = self.default_bufsize
|
||||
else:
|
||||
self._rbufsize = bufsize
|
||||
self._wbufsize = bufsize
|
||||
self._rbuf = "" # A string
|
||||
self._wbuf = [] # A list of strings
|
||||
self._close = close
|
||||
|
||||
def _getclosed(self):
|
||||
return self._sock is None
|
||||
closed = property(_getclosed, doc="True if the file is closed")
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
if self._sock:
|
||||
self.flush()
|
||||
finally:
|
||||
if self._close:
|
||||
self._sock.close()
|
||||
self._sock = None
|
||||
|
||||
def __del__(self):
|
||||
try:
|
||||
self.close()
|
||||
except:
|
||||
# close() may fail if __init__ didn't complete
|
||||
pass
|
||||
|
||||
def flush(self):
|
||||
if self._wbuf:
|
||||
buffer = "".join(self._wbuf)
|
||||
self._wbuf = []
|
||||
self._sock.sendall(buffer)
|
||||
|
||||
def fileno(self):
|
||||
return self._sock.fileno()
|
||||
|
||||
def write(self, data):
|
||||
data = str(data) # XXX Should really reject non-string non-buffers
|
||||
if not data:
|
||||
return
|
||||
self._wbuf.append(data)
|
||||
if (self._wbufsize == 0 or
|
||||
self._wbufsize == 1 and '\n' in data or
|
||||
self._get_wbuf_len() >= self._wbufsize):
|
||||
self.flush()
|
||||
|
||||
def writelines(self, list):
|
||||
# XXX We could do better here for very long lists
|
||||
# XXX Should really reject non-string non-buffers
|
||||
self._wbuf.extend(filter(None, map(str, list)))
|
||||
if (self._wbufsize <= 1 or
|
||||
self._get_wbuf_len() >= self._wbufsize):
|
||||
self.flush()
|
||||
|
||||
def _get_wbuf_len(self):
|
||||
buf_len = 0
|
||||
for x in self._wbuf:
|
||||
buf_len += len(x)
|
||||
return buf_len
|
||||
|
||||
def read(self, size=-1):
|
||||
data = self._rbuf
|
||||
if size < 0:
|
||||
# Read until EOF
|
||||
buffers = []
|
||||
if data:
|
||||
buffers.append(data)
|
||||
self._rbuf = ""
|
||||
if self._rbufsize <= 1:
|
||||
recv_size = self.default_bufsize
|
||||
else:
|
||||
recv_size = self._rbufsize
|
||||
while True:
|
||||
data = self._sock.recv(recv_size)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
return "".join(buffers)
|
||||
else:
|
||||
# Read until size bytes or EOF seen, whichever comes first
|
||||
buf_len = len(data)
|
||||
if buf_len >= size:
|
||||
self._rbuf = data[size:]
|
||||
return data[:size]
|
||||
buffers = []
|
||||
if data:
|
||||
buffers.append(data)
|
||||
self._rbuf = ""
|
||||
while True:
|
||||
left = size - buf_len
|
||||
recv_size = max(self._rbufsize, left)
|
||||
data = self._sock.recv(recv_size)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
n = len(data)
|
||||
if n >= left:
|
||||
self._rbuf = data[left:]
|
||||
buffers[-1] = data[:left]
|
||||
break
|
||||
buf_len += n
|
||||
return "".join(buffers)
|
||||
|
||||
def readline(self, size=-1):
|
||||
data = self._rbuf
|
||||
if size < 0:
|
||||
# Read until \n or EOF, whichever comes first
|
||||
if self._rbufsize <= 1:
|
||||
# Speed up unbuffered case
|
||||
assert data == ""
|
||||
buffers = []
|
||||
recv = self._sock.recv
|
||||
while data != "\n":
|
||||
data = recv(1)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
return "".join(buffers)
|
||||
nl = data.find('\n')
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
return data[:nl]
|
||||
buffers = []
|
||||
if data:
|
||||
buffers.append(data)
|
||||
self._rbuf = ""
|
||||
while True:
|
||||
data = self._sock.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
nl = data.find('\n')
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
buffers[-1] = data[:nl]
|
||||
break
|
||||
return "".join(buffers)
|
||||
else:
|
||||
# Read until size bytes or \n or EOF seen, whichever comes first
|
||||
nl = data.find('\n', 0, size)
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
return data[:nl]
|
||||
buf_len = len(data)
|
||||
if buf_len >= size:
|
||||
self._rbuf = data[size:]
|
||||
return data[:size]
|
||||
buffers = []
|
||||
if data:
|
||||
buffers.append(data)
|
||||
self._rbuf = ""
|
||||
while True:
|
||||
data = self._sock.recv(self._rbufsize)
|
||||
if not data:
|
||||
break
|
||||
buffers.append(data)
|
||||
left = size - buf_len
|
||||
nl = data.find('\n', 0, left)
|
||||
if nl >= 0:
|
||||
nl += 1
|
||||
self._rbuf = data[nl:]
|
||||
buffers[-1] = data[:nl]
|
||||
break
|
||||
n = len(data)
|
||||
if n >= left:
|
||||
self._rbuf = data[left:]
|
||||
buffers[-1] = data[:left]
|
||||
break
|
||||
buf_len += n
|
||||
return "".join(buffers)
|
||||
|
||||
def readlines(self, sizehint=0):
|
||||
total = 0
|
||||
list = []
|
||||
while True:
|
||||
line = self.readline()
|
||||
if not line:
|
||||
break
|
||||
list.append(line)
|
||||
total += len(line)
|
||||
if sizehint and total >= sizehint:
|
||||
break
|
||||
return list
|
||||
|
||||
# Iterator protocols
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
line = self.readline()
|
||||
if not line:
|
||||
raise StopIteration
|
||||
return line
|
||||
|
||||
|
||||
def create_connection(address, timeout=None):
|
||||
"""Connect to address (host, port) with an optional timeout.
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue