fix issue #17552: add socket.sendfile() method allowing to send a file over a socket by using high-performance os.sendfile() on UNIX. Patch by Giampaolo Rodola'·

This commit is contained in:
Giampaolo Rodola' 2014-06-11 03:54:30 +02:00
parent b398d33c65
commit 915d14190e
9 changed files with 483 additions and 2 deletions

View file

@ -19,6 +19,8 @@ import signal
import math
import pickle
import struct
import random
import string
try:
import multiprocessing
except ImportError:
@ -5077,6 +5079,275 @@ class TestSocketSharing(SocketTCPTest):
source.close()
@unittest.skipUnless(thread, 'Threading required for this test.')
class SendfileUsingSendTest(ThreadedTCPSocketTest):
"""
Test the send() implementation of socket.sendfile().
"""
FILESIZE = (10 * 1024 * 1024) # 10MB
BUFSIZE = 8192
FILEDATA = b""
TIMEOUT = 2
@classmethod
def setUpClass(cls):
def chunks(total, step):
assert total >= step
while total > step:
yield step
total -= step
if total:
yield total
chunk = b"".join([random.choice(string.ascii_letters).encode()
for i in range(cls.BUFSIZE)])
with open(support.TESTFN, 'wb') as f:
for csize in chunks(cls.FILESIZE, cls.BUFSIZE):
f.write(chunk)
with open(support.TESTFN, 'rb') as f:
cls.FILEDATA = f.read()
assert len(cls.FILEDATA) == cls.FILESIZE
@classmethod
def tearDownClass(cls):
support.unlink(support.TESTFN)
def accept_conn(self):
self.serv.settimeout(self.TIMEOUT)
conn, addr = self.serv.accept()
conn.settimeout(self.TIMEOUT)
self.addCleanup(conn.close)
return conn
def recv_data(self, conn):
received = []
while True:
chunk = conn.recv(self.BUFSIZE)
if not chunk:
break
received.append(chunk)
return b''.join(received)
def meth_from_sock(self, sock):
# Depending on the mixin class being run return either send()
# or sendfile() method implementation.
return getattr(sock, "_sendfile_use_send")
# regular file
def _testRegularFile(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file)
self.assertEqual(sent, self.FILESIZE)
self.assertEqual(file.tell(), self.FILESIZE)
def testRegularFile(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE)
self.assertEqual(data, self.FILEDATA)
# non regular file
def _testNonRegularFile(self):
address = self.serv.getsockname()
file = io.BytesIO(self.FILEDATA)
with socket.create_connection(address) as sock, file as file:
sent = sock.sendfile(file)
self.assertEqual(sent, self.FILESIZE)
self.assertEqual(file.tell(), self.FILESIZE)
self.assertRaises(socket._GiveupOnSendfile,
sock._sendfile_use_sendfile, file)
def testNonRegularFile(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE)
self.assertEqual(data, self.FILEDATA)
# empty file
def _testEmptyFileSend(self):
address = self.serv.getsockname()
filename = support.TESTFN + "2"
with open(filename, 'wb'):
self.addCleanup(support.unlink, filename)
file = open(filename, 'rb')
with socket.create_connection(address) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file)
self.assertEqual(sent, 0)
self.assertEqual(file.tell(), 0)
def testEmptyFileSend(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(data, b"")
# offset
def _testOffset(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file, offset=5000)
self.assertEqual(sent, self.FILESIZE - 5000)
self.assertEqual(file.tell(), self.FILESIZE)
def testOffset(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE - 5000)
self.assertEqual(data, self.FILEDATA[5000:])
# count
def _testCount(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
count = 5000007
meth = self.meth_from_sock(sock)
sent = meth(file, count=count)
self.assertEqual(sent, count)
self.assertEqual(file.tell(), count)
def testCount(self):
count = 5000007
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), count)
self.assertEqual(data, self.FILEDATA[:count])
# count small
def _testCountSmall(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
count = 1
meth = self.meth_from_sock(sock)
sent = meth(file, count=count)
self.assertEqual(sent, count)
self.assertEqual(file.tell(), count)
def testCountSmall(self):
count = 1
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), count)
self.assertEqual(data, self.FILEDATA[:count])
# count + offset
def _testCountWithOffset(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
count = 100007
meth = self.meth_from_sock(sock)
sent = meth(file, offset=2007, count=count)
self.assertEqual(sent, count)
self.assertEqual(file.tell(), count + 2007)
def testCountWithOffset(self):
count = 100007
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), count)
self.assertEqual(data, self.FILEDATA[2007:count+2007])
# non blocking sockets are not supposed to work
def _testNonBlocking(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address) as sock, file as file:
sock.setblocking(False)
meth = self.meth_from_sock(sock)
self.assertRaises(ValueError, meth, file)
self.assertRaises(ValueError, sock.sendfile, file)
def testNonBlocking(self):
conn = self.accept_conn()
if conn.recv(8192):
self.fail('was not supposed to receive any data')
# timeout (non-triggered)
def _testWithTimeout(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=2) as sock, file as file:
meth = self.meth_from_sock(sock)
sent = meth(file)
self.assertEqual(sent, self.FILESIZE)
def testWithTimeout(self):
conn = self.accept_conn()
data = self.recv_data(conn)
self.assertEqual(len(data), self.FILESIZE)
self.assertEqual(data, self.FILEDATA)
# timeout (triggered)
def _testWithTimeoutTriggeredSend(self):
address = self.serv.getsockname()
file = open(support.TESTFN, 'rb')
with socket.create_connection(address, timeout=0.01) as sock, \
file as file:
meth = self.meth_from_sock(sock)
self.assertRaises(socket.timeout, meth, file)
def testWithTimeoutTriggeredSend(self):
conn = self.accept_conn()
conn.recv(88192)
# errors
def _test_errors(self):
pass
def test_errors(self):
with open(support.TESTFN, 'rb') as file:
with socket.socket(type=socket.SOCK_DGRAM) as s:
meth = self.meth_from_sock(s)
self.assertRaisesRegex(
ValueError, "SOCK_STREAM", meth, file)
with open(support.TESTFN, 'rt') as file:
with socket.socket() as s:
meth = self.meth_from_sock(s)
self.assertRaisesRegex(
ValueError, "binary mode", meth, file)
with open(support.TESTFN, 'rb') as file:
with socket.socket() as s:
meth = self.meth_from_sock(s)
self.assertRaisesRegex(TypeError, "positive integer",
meth, file, count='2')
self.assertRaisesRegex(TypeError, "positive integer",
meth, file, count=0.1)
self.assertRaisesRegex(ValueError, "positive integer",
meth, file, count=0)
self.assertRaisesRegex(ValueError, "positive integer",
meth, file, count=-1)
@unittest.skipUnless(thread, 'Threading required for this test.')
@unittest.skipUnless(hasattr(os, "sendfile"),
'os.sendfile() required for this test.')
class SendfileUsingSendfileTest(SendfileUsingSendTest):
"""
Test the sendfile() implementation of socket.sendfile().
"""
def meth_from_sock(self, sock):
return getattr(sock, "_sendfile_use_sendfile")
def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest, UDPTimeoutTest ]
@ -5129,6 +5400,8 @@ def test_main():
InterruptedRecvTimeoutTest,
InterruptedSendTimeoutTest,
TestSocketSharing,
SendfileUsingSendTest,
SendfileUsingSendfileTest,
])
thread_info = support.threading_setup()