This commit is contained in:
GalaxySnail 2025-12-23 17:54:23 +09:00 committed by GitHub
commit 3a7ed76312
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 125 additions and 21 deletions

View file

@ -572,7 +572,8 @@ if hasattr(_socket.socket, "sendmsg"):
import array
return sock.sendmsg(buffers, [(_socket.SOL_SOCKET,
_socket.SCM_RIGHTS, array.array("i", fds))])
_socket.SCM_RIGHTS, array.array("i", fds))],
flags, address)
__all__.append("send_fds")
if hasattr(_socket.socket, "recvmsg"):
@ -587,14 +588,14 @@ if hasattr(_socket.socket, "recvmsg"):
# Array of ints
fds = array.array("i")
msg, ancdata, flags, addr = sock.recvmsg(bufsize,
_socket.CMSG_LEN(maxfds * fds.itemsize))
msg, ancdata, msg_flags, addr = sock.recvmsg(bufsize,
_socket.CMSG_LEN(maxfds * fds.itemsize), flags)
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if (cmsg_level == _socket.SOL_SOCKET and cmsg_type == _socket.SCM_RIGHTS):
fds.frombytes(cmsg_data[:
len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
return msg, list(fds), flags, addr
return msg, list(fds), msg_flags, addr
__all__.append("recv_fds")
if hasattr(_socket.socket, "share"):

View file

@ -3605,10 +3605,11 @@ class SendmsgStreamTests(SendmsgTests):
# XXX: would be nice to have more tests for sendmsg flags argument.
# Linux supports MSG_DONTWAIT when sending, but in general, it
# only works when receiving. Could add other platforms if they
# Linux and FreeBSD support MSG_DONTWAIT when sending, but in general,
# it only works when receiving. Could add other platforms if they
# support it too.
@skipWithClientIf(sys.platform not in {"linux", "android"},
@requireAttrs(socket, "MSG_DONTWAIT")
@skipWithClientIf(sys.platform in ("darwin",),
"MSG_DONTWAIT not known to work on this platform when "
"sending")
def testSendmsgDontWait(self):
@ -7433,19 +7434,34 @@ class CreateServerFunctionalTest(unittest.TestCase):
@requireAttrs(socket, "recv_fds")
@requireAttrs(socket, "AF_UNIX")
class SendRecvFdsTests(unittest.TestCase):
def testSendAndRecvFds(self):
def close_pipes(pipes):
for fd1, fd2 in pipes:
os.close(fd1)
os.close(fd2)
def _cleanup_fds(self, fds):
def close_fds(fds):
for fd in fds:
os.close(fd)
self.addCleanup(close_fds, fds)
def _test_pipe(self, rfd, wfd, msg):
# POSIX requires PIPE_BUF to be at least 512 bytes.
PIPE_BUF = 512
assert len(msg) < PIPE_BUF
os.write(wfd, msg)
data = os.read(rfd, PIPE_BUF)
self.assertEqual(data, msg)
@staticmethod
def _recv_one_fd(sock, bufsize, flags=0):
if sys.platform.startswith("freebsd"):
# FreeBSD requires at least CMSG_LEN(2*sizeof(int)),
# otherwise the access control message is truncated.
max_fds = 2
else:
max_fds = 1
return socket.recv_fds(sock, bufsize, max_fds, flags)
def test_send_and_recv_fds(self):
# send 10 file descriptors
pipes = [os.pipe() for _ in range(10)]
self.addCleanup(close_pipes, pipes)
self._cleanup_fds(fd for pair in pipes for fd in pair)
fds = [rfd for rfd, wfd in pipes]
# use a UNIX socket pair to exchange file descriptors locally
@ -7454,7 +7470,7 @@ class SendRecvFdsTests(unittest.TestCase):
socket.send_fds(sock1, [MSG], fds)
# request more data and file descriptors than expected
msg, fds2, flags, addr = socket.recv_fds(sock2, len(MSG) * 2, len(fds) * 2)
self.addCleanup(close_fds, fds2)
self._cleanup_fds(fds2)
self.assertEqual(msg, MSG)
self.assertEqual(len(fds2), len(fds))
@ -7462,13 +7478,98 @@ class SendRecvFdsTests(unittest.TestCase):
# don't test addr
# test that file descriptors are connected
for index, fds in enumerate(pipes):
rfd, wfd = fds
os.write(wfd, str(index).encode())
for index, ((_, wfd), rfd) in enumerate(zip(pipes, fds2, strict=True)):
self._test_pipe(rfd, wfd, str(index).encode())
for index, rfd in enumerate(fds2):
data = os.read(rfd, 100)
self.assertEqual(data, str(index).encode())
def test_send_recv_fds_with_addrs(self):
rfd, wfd = os.pipe()
self.addCleanup(os.close, rfd)
self.addCleanup(os.close, wfd)
with tempfile.TemporaryDirectory() as tmpdir, \
socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock1, \
socket.socket(socket.AF_UNIX, socket.SOCK_DGRAM) as sock2:
sock1_addr = os.path.join(tmpdir, "sock1")
sock2_addr = os.path.join(tmpdir, "sock2")
sock1.bind(sock1_addr)
sock2.bind(sock2_addr)
sock2.setblocking(False)
socket.send_fds(sock1, [MSG], [rfd], address=sock2_addr)
msg, fds, flags, addr = self._recv_one_fd(sock2, len(MSG))
self._cleanup_fds(fds)
self.assertEqual(msg, MSG)
if hasattr(socket, "MSG_CTRUNC"):
self.assertEqual(flags & socket.MSG_CTRUNC, 0)
self.assertEqual(len(fds), 1)
self.assertEqual(addr, sock1_addr)
self._test_pipe(fds[0], wfd, MSG)
@requireAttrs(socket, "MSG_PEEK")
@unittest.skipUnless(sys.platform in ("linux", "android"), "works on Linux")
def test_recv_fds_peek(self):
rfd, wfd = os.pipe()
self.addCleanup(os.close, rfd)
self.addCleanup(os.close, wfd)
sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_DGRAM)
with sock1, sock2:
socket.send_fds(sock1, [MSG], [rfd])
sock2.setblocking(False)
# peek message on sock2
peek_len = len(MSG) // 2
self.assertGreater(peek_len, 0)
msg, fds, flags, addr = socket.recv_fds(sock2, peek_len, 1,
flags=socket.MSG_PEEK)
self._cleanup_fds(fds)
self.assertEqual(len(msg), peek_len)
self.assertEqual(msg, MSG[:peek_len])
self.assertEqual(flags & socket.MSG_TRUNC, socket.MSG_TRUNC)
if hasattr(socket, "MSG_CTRUNC"):
self.assertEqual(flags & socket.MSG_CTRUNC, 0)
self.assertEqual(len(fds), 1)
self._test_pipe(fds[0], wfd, MSG)
# will raise BlockingIOError if MSG_PEEK didn't work
msg, fds, flags, addr = socket.recv_fds(sock2, len(MSG), 1)
self._cleanup_fds(fds)
self.assertEqual(msg, MSG)
if hasattr(socket, "MSG_CTRUNC"):
self.assertEqual(flags & socket.MSG_CTRUNC, 0)
self.assertEqual(len(fds), 1)
self._test_pipe(fds[0], wfd, MSG)
@requireAttrs(socket, "MSG_DONTWAIT")
@unittest.skipIf(sys.platform in ("darwin",),
"MSG_DONTWAIT not known to work on this platform when "
"sending")
def test_send_fds_dontwait(self):
rfd, wfd = os.pipe()
self.addCleanup(os.close, rfd)
self.addCleanup(os.close, wfd)
# use SOCK_STREAM instead of SOCK_DGRAM to support *BSD platforms
# ref: https://docs.python.org/3/library/asyncio-protocol.html#datagram-protocols
sock1, sock2 = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
with sock1, sock2:
sock1.setblocking(True)
with self.assertRaises(BlockingIOError):
for _ in range(64 * 1024):
socket.send_fds(sock1, [MSG], [rfd], socket.MSG_DONTWAIT)
msg, fds, flags, addr = self._recv_one_fd(sock2, len(MSG))
self._cleanup_fds(fds)
self.assertEqual(msg, MSG)
if hasattr(socket, "MSG_CTRUNC"):
self.assertEqual(flags & socket.MSG_CTRUNC, 0)
self.assertEqual(len(fds), 1)
self._test_pipe(fds[0], wfd, MSG)
class FreeThreadingTests(unittest.TestCase):

View file

@ -0,0 +1,2 @@
Fix ``flags`` and ``address`` parameters which were ignored in
:func:`socket.send_fds` and :func:`socket.recv_fds`.