gh-93852: Add test.support.create_unix_domain_name() (#93914)

test_asyncio, test_logging, test_socket and test_socketserver now
create AF_UNIX domains in the current directory to no longer fail
with OSError("AF_UNIX path too long") if the temporary directory (the
TMPDIR environment variable) is too long.

Modify the following tests to use create_unix_domain_name():

* test_asyncio
* test_logging
* test_socket
* test_socketserver

test_asyncio.utils: remove unused time import.
This commit is contained in:
Victor Stinner 2022-06-17 13:16:51 +02:00 committed by GitHub
parent ffc228dd4e
commit c5b750dc0b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 82 additions and 86 deletions

View file

@ -1,8 +1,10 @@
import contextlib import contextlib
import errno import errno
import os.path
import socket import socket
import unittest
import sys import sys
import tempfile
import unittest
from .. import support from .. import support
from . import warnings_helper from . import warnings_helper
@ -270,3 +272,14 @@ def transient_internet(resource_name, *, timeout=_NOT_SET, errnos=()):
# __cause__ or __context__? # __cause__ or __context__?
finally: finally:
socket.setdefaulttimeout(old_timeout) socket.setdefaulttimeout(old_timeout)
def create_unix_domain_name():
"""
Create a UNIX domain name: socket.bind() argument of a AF_UNIX socket.
Return a path relative to the current directory to get a short path
(around 27 ASCII characters).
"""
return tempfile.mktemp(prefix="test_python_", suffix='.sock',
dir=os.path.curdir)

View file

@ -315,11 +315,15 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
self.loop.run_until_complete(coro) self.loop.run_until_complete(coro)
def test_create_unix_server_existing_path_nonsock(self): def test_create_unix_server_existing_path_nonsock(self):
with tempfile.NamedTemporaryFile() as file: path = test_utils.gen_unix_socket_path()
coro = self.loop.create_unix_server(lambda: None, file.name) self.addCleanup(os_helper.unlink, path)
with self.assertRaisesRegex(OSError, # create the file
'Address.*is already in use'): open(path, "wb").close()
self.loop.run_until_complete(coro)
coro = self.loop.create_unix_server(lambda: None, path)
with self.assertRaisesRegex(OSError,
'Address.*is already in use'):
self.loop.run_until_complete(coro)
def test_create_unix_server_ssl_bool(self): def test_create_unix_server_ssl_bool(self):
coro = self.loop.create_unix_server(lambda: None, path='spam', coro = self.loop.create_unix_server(lambda: None, path='spam',
@ -356,20 +360,18 @@ class SelectorEventLoopUnixSocketTests(test_utils.TestCase):
'no socket.SOCK_NONBLOCK (linux only)') 'no socket.SOCK_NONBLOCK (linux only)')
@socket_helper.skip_unless_bind_unix_socket @socket_helper.skip_unless_bind_unix_socket
def test_create_unix_server_path_stream_bittype(self): def test_create_unix_server_path_stream_bittype(self):
sock = socket.socket( fn = test_utils.gen_unix_socket_path()
socket.AF_UNIX, socket.SOCK_STREAM | socket.SOCK_NONBLOCK) self.addCleanup(os_helper.unlink, fn)
with tempfile.NamedTemporaryFile() as file:
fn = file.name sock = socket.socket(socket.AF_UNIX,
try: socket.SOCK_STREAM | socket.SOCK_NONBLOCK)
with sock: with sock:
sock.bind(fn) sock.bind(fn)
coro = self.loop.create_unix_server(lambda: None, path=None, coro = self.loop.create_unix_server(lambda: None, path=None,
sock=sock) sock=sock)
srv = self.loop.run_until_complete(coro) srv = self.loop.run_until_complete(coro)
srv.close() srv.close()
self.loop.run_until_complete(srv.wait_closed()) self.loop.run_until_complete(srv.wait_closed())
finally:
os.unlink(fn)
def test_create_unix_server_ssl_timeout_with_plain_sock(self): def test_create_unix_server_ssl_timeout_with_plain_sock(self):
coro = self.loop.create_unix_server(lambda: None, path='spam', coro = self.loop.create_unix_server(lambda: None, path='spam',

View file

@ -11,9 +11,7 @@ import selectors
import socket import socket
import socketserver import socketserver
import sys import sys
import tempfile
import threading import threading
import time
import unittest import unittest
import weakref import weakref
@ -34,6 +32,7 @@ from asyncio import futures
from asyncio import tasks from asyncio import tasks
from asyncio.log import logger from asyncio.log import logger
from test import support from test import support
from test.support import socket_helper
from test.support import threading_helper from test.support import threading_helper
@ -251,8 +250,7 @@ if hasattr(socket, 'AF_UNIX'):
def gen_unix_socket_path(): def gen_unix_socket_path():
with tempfile.NamedTemporaryFile() as file: return socket_helper.create_unix_domain_name()
return file.name
@contextlib.contextmanager @contextlib.contextmanager

View file

@ -1828,12 +1828,6 @@ class SocketHandlerTest(BaseTest):
time.sleep(self.sock_hdlr.retryTime - now + 0.001) time.sleep(self.sock_hdlr.retryTime - now + 0.001)
self.root_logger.error('Nor this') self.root_logger.error('Nor this')
def _get_temp_domain_socket():
fn = make_temp_file(prefix='test_logging_', suffix='.sock')
# just need a name - file can't be present, or we'll get an
# 'address already in use' error.
os.remove(fn)
return fn
@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required") @unittest.skipUnless(hasattr(socket, "AF_UNIX"), "Unix sockets required")
class UnixSocketHandlerTest(SocketHandlerTest): class UnixSocketHandlerTest(SocketHandlerTest):
@ -1845,13 +1839,10 @@ class UnixSocketHandlerTest(SocketHandlerTest):
def setUp(self): def setUp(self):
# override the definition in the base class # override the definition in the base class
self.address = _get_temp_domain_socket() self.address = socket_helper.create_unix_domain_name()
self.addCleanup(os_helper.unlink, self.address)
SocketHandlerTest.setUp(self) SocketHandlerTest.setUp(self)
def tearDown(self):
SocketHandlerTest.tearDown(self)
os_helper.unlink(self.address)
@support.requires_working_socket() @support.requires_working_socket()
@threading_helper.requires_working_threading() @threading_helper.requires_working_threading()
class DatagramHandlerTest(BaseTest): class DatagramHandlerTest(BaseTest):
@ -1928,13 +1919,10 @@ class UnixDatagramHandlerTest(DatagramHandlerTest):
def setUp(self): def setUp(self):
# override the definition in the base class # override the definition in the base class
self.address = _get_temp_domain_socket() self.address = socket_helper.create_unix_domain_name()
self.addCleanup(os_helper.unlink, self.address)
DatagramHandlerTest.setUp(self) DatagramHandlerTest.setUp(self)
def tearDown(self):
DatagramHandlerTest.tearDown(self)
os_helper.unlink(self.address)
@support.requires_working_socket() @support.requires_working_socket()
@threading_helper.requires_working_threading() @threading_helper.requires_working_threading()
class SysLogHandlerTest(BaseTest): class SysLogHandlerTest(BaseTest):
@ -2022,13 +2010,10 @@ class UnixSysLogHandlerTest(SysLogHandlerTest):
def setUp(self): def setUp(self):
# override the definition in the base class # override the definition in the base class
self.address = _get_temp_domain_socket() self.address = socket_helper.create_unix_domain_name()
self.addCleanup(os_helper.unlink, self.address)
SysLogHandlerTest.setUp(self) SysLogHandlerTest.setUp(self)
def tearDown(self):
SysLogHandlerTest.tearDown(self)
os_helper.unlink(self.address)
@unittest.skipUnless(socket_helper.IPV6_ENABLED, @unittest.skipUnless(socket_helper.IPV6_ENABLED,
'IPv6 support required for this test.') 'IPv6 support required for this test.')
class IPv6SysLogHandlerTest(SysLogHandlerTest): class IPv6SysLogHandlerTest(SysLogHandlerTest):

View file

@ -4,31 +4,30 @@ from test.support import os_helper
from test.support import socket_helper from test.support import socket_helper
from test.support import threading_helper from test.support import threading_helper
import _thread as thread
import array
import contextlib
import errno import errno
import io import io
import itertools import itertools
import socket import math
import os
import pickle
import platform
import queue
import random
import re
import select import select
import signal
import socket
import string
import struct
import sys
import tempfile import tempfile
import threading
import time import time
import traceback import traceback
import queue
import sys
import os
import platform
import array
import contextlib
from weakref import proxy from weakref import proxy
import signal
import math
import pickle
import re
import struct
import random
import shutil
import string
import _thread as thread
import threading
try: try:
import multiprocessing import multiprocessing
except ImportError: except ImportError:
@ -605,17 +604,18 @@ class SocketTestBase(unittest.TestCase):
def setUp(self): def setUp(self):
self.serv = self.newSocket() self.serv = self.newSocket()
self.addCleanup(self.close_server)
self.bindServer() self.bindServer()
def close_server(self):
self.serv.close()
self.serv = None
def bindServer(self): def bindServer(self):
"""Bind server socket and set self.serv_addr to its address.""" """Bind server socket and set self.serv_addr to its address."""
self.bindSock(self.serv) self.bindSock(self.serv)
self.serv_addr = self.serv.getsockname() self.serv_addr = self.serv.getsockname()
def tearDown(self):
self.serv.close()
self.serv = None
class SocketListeningTestMixin(SocketTestBase): class SocketListeningTestMixin(SocketTestBase):
"""Mixin to listen on the server socket.""" """Mixin to listen on the server socket."""
@ -700,15 +700,10 @@ class UnixSocketTestBase(SocketTestBase):
# can't send anything that might be problematic for a privileged # can't send anything that might be problematic for a privileged
# user running the tests. # user running the tests.
def setUp(self):
self.dir_path = tempfile.mkdtemp()
self.addCleanup(os.rmdir, self.dir_path)
super().setUp()
def bindSock(self, sock): def bindSock(self, sock):
path = tempfile.mktemp(dir=self.dir_path) path = socket_helper.create_unix_domain_name()
socket_helper.bind_unix_socket(sock, path)
self.addCleanup(os_helper.unlink, path) self.addCleanup(os_helper.unlink, path)
socket_helper.bind_unix_socket(sock, path)
class UnixStreamBase(UnixSocketTestBase): class UnixStreamBase(UnixSocketTestBase):
"""Base class for Unix-domain SOCK_STREAM tests.""" """Base class for Unix-domain SOCK_STREAM tests."""
@ -1905,17 +1900,18 @@ class GeneralModuleTests(unittest.TestCase):
self._test_socket_fileno(s, socket.AF_INET6, socket.SOCK_STREAM) self._test_socket_fileno(s, socket.AF_INET6, socket.SOCK_STREAM)
if hasattr(socket, "AF_UNIX"): if hasattr(socket, "AF_UNIX"):
tmpdir = tempfile.mkdtemp() unix_name = socket_helper.create_unix_domain_name()
self.addCleanup(shutil.rmtree, tmpdir) self.addCleanup(os_helper.unlink, unix_name)
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.addCleanup(s.close) with s:
try: try:
s.bind(os.path.join(tmpdir, 'socket')) s.bind(unix_name)
except PermissionError: except PermissionError:
pass pass
else: else:
self._test_socket_fileno(s, socket.AF_UNIX, self._test_socket_fileno(s, socket.AF_UNIX,
socket.SOCK_STREAM) socket.SOCK_STREAM)
def test_socket_fileno_rejects_float(self): def test_socket_fileno_rejects_float(self):
with self.assertRaises(TypeError): with self.assertRaises(TypeError):

View file

@ -8,7 +8,6 @@ import os
import select import select
import signal import signal
import socket import socket
import tempfile
import threading import threading
import unittest import unittest
import socketserver import socketserver
@ -98,8 +97,7 @@ class SocketServerTest(unittest.TestCase):
else: else:
# XXX: We need a way to tell AF_UNIX to pick its own name # XXX: We need a way to tell AF_UNIX to pick its own name
# like AF_INET provides port==0. # like AF_INET provides port==0.
dir = None fn = socket_helper.create_unix_domain_name()
fn = tempfile.mktemp(prefix='unix_socket.', dir=dir)
self.test_files.append(fn) self.test_files.append(fn)
return fn return fn

View file

@ -0,0 +1,4 @@
test_asyncio, test_logging, test_socket and test_socketserver now create
AF_UNIX domains in the current directory to no longer fail with
``OSError("AF_UNIX path too long")`` if the temporary directory (the
:envvar:`TMPDIR` environment variable) is too long. Patch by Victor Stinner.