bpo-28724: Add methods send_fds and recv_fds to the socket module (GH-12889)

The socket module now has the socket.send_fds() and socket.recv.fds() functions.
Contributed by Joannah Nanjekye, Shinya Okano (original patch)
and Victor Stinner.

Co-Authored-By: Victor Stinner <vstinner@redhat.com>
This commit is contained in:
Joannah Nanjekye 2019-09-11 18:12:21 +01:00 committed by Victor Stinner
parent 58ab13479d
commit 8d120f75fb
5 changed files with 105 additions and 1 deletions

45
Lib/test/test_socket.py Normal file → Executable file
View file

@ -6436,11 +6436,53 @@ class CreateServerFunctionalTest(unittest.TestCase):
self.echo_server(sock)
self.echo_client(("::1", port), socket.AF_INET6)
@requireAttrs(socket, "send_fds")
@requireAttrs(socket, "recv_fds")
@requireAttrs(socket, "AF_UNIX")
class SendRecvFdsTests(unittest.TestCase):
def testSendAndRecvFds(self):
def close_pipes(pipes):
for fd1, fd2 in pipes:
os.close(fd1)
os.close(fd2)
def close_fds(fds):
for fd in fds:
os.close(fd)
# send 10 file descriptors
pipes = [os.pipe() for _ in range(10)]
self.addCleanup(close_pipes, pipes)
fds = [rfd for rfd, wfd in pipes]
# use a UNIX socket pair to exchange file descriptors locally
sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
with sock1, sock2:
socket.send_fds(sock1, [MSG], fds)
# request more data and file descriptors than expected
msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2)
self.addCleanup(close_fds, fds2)
self.assertEqual(msg, MSG)
self.assertEqual(len(fds2), len(fds))
self.assertEqual(flags, 0)
# don't test addr
# test that file descriptors are connected
for index, fds in enumerate(pipes):
rfd, wfd = fds
os.write(wfd, str(index).encode())
for index, rfd in enumerate(fds2):
data = os.read(rfd, 100)
self.assertEqual(data, str(index).encode())
def test_main():
tests = [GeneralModuleTests, BasicTCPTest, TCPCloserTest, TCPTimeoutTest,
TestExceptions, BufferIOTest, BasicTCPTest2, BasicUDPTest,
UDPTimeoutTest, CreateServerTest, CreateServerFunctionalTest]
UDPTimeoutTest, CreateServerTest, CreateServerFunctionalTest,
SendRecvFdsTests]
tests.extend([
NonBlockingTCPTests,
@ -6513,5 +6555,6 @@ def test_main():
support.run_unittest(*tests)
support.threading_cleanup(*thread_info)
if __name__ == "__main__":
test_main()