gh-78502: Add a trackfd parameter to mmap.mmap() (GH-25425)

If *trackfd* is False, the file descriptor specified by *fileno*
will not be duplicated.

Co-authored-by: Erlend E. Aasland <erlend@python.org>
Co-authored-by: Petr Viktorin <encukou@gmail.com>
Co-authored-by: Serhiy Storchaka <storchaka@gmail.com>
This commit is contained in:
Zackery Spytz 2024-01-15 23:51:46 -08:00 committed by GitHub
parent 42b90cf0d6
commit 8fd287b18f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 101 additions and 11 deletions

View file

@ -4,6 +4,7 @@ from test.support import (
from test.support.import_helper import import_module
from test.support.os_helper import TESTFN, unlink
import unittest
import errno
import os
import re
import itertools
@ -266,6 +267,62 @@ class MmapTests(unittest.TestCase):
self.assertRaises(TypeError, m.write_byte, 0)
m.close()
@unittest.skipIf(os.name == 'nt', 'trackfd not present on Windows')
def test_trackfd_parameter(self):
size = 64
with open(TESTFN, "wb") as f:
f.write(b"a"*size)
for close_original_fd in True, False:
with self.subTest(close_original_fd=close_original_fd):
with open(TESTFN, "r+b") as f:
with mmap.mmap(f.fileno(), size, trackfd=False) as m:
if close_original_fd:
f.close()
self.assertEqual(len(m), size)
with self.assertRaises(OSError) as err_cm:
m.size()
self.assertEqual(err_cm.exception.errno, errno.EBADF)
with self.assertRaises(ValueError):
m.resize(size * 2)
with self.assertRaises(ValueError):
m.resize(size // 2)
self.assertEqual(m.closed, False)
# Smoke-test other API
m.write_byte(ord('X'))
m[2] = ord('Y')
m.flush()
with open(TESTFN, "rb") as f:
self.assertEqual(f.read(4), b'XaYa')
self.assertEqual(m.tell(), 1)
m.seek(0)
self.assertEqual(m.tell(), 0)
self.assertEqual(m.read_byte(), ord('X'))
self.assertEqual(m.closed, True)
self.assertEqual(os.stat(TESTFN).st_size, size)
@unittest.skipIf(os.name == 'nt', 'trackfd not present on Windows')
def test_trackfd_neg1(self):
size = 64
with mmap.mmap(-1, size, trackfd=False) as m:
with self.assertRaises(OSError):
m.size()
with self.assertRaises(ValueError):
m.resize(size // 2)
self.assertEqual(len(m), size)
m[0] = ord('a')
assert m[0] == ord('a')
@unittest.skipIf(os.name != 'nt', 'trackfd only fails on Windows')
def test_no_trackfd_parameter_on_windows(self):
# 'trackffd' is an invalid keyword argument for this function
size = 64
with self.assertRaises(TypeError):
mmap.mmap(-1, size, trackfd=True)
with self.assertRaises(TypeError):
mmap.mmap(-1, size, trackfd=False)
def test_bad_file_desc(self):
# Try opening a bad file descriptor...
self.assertRaises(OSError, mmap.mmap, -2, 4096)