mirror of
https://github.com/python/cpython.git
synced 2025-08-04 17:08:35 +00:00
bpo-45500: Rewrite test_dbm (GH-29002)
* Generate test classes at import time. It allows to filter them when run with unittest. E.g: "./python -m unittest test.test_dbm.TestCase_gnu -v". * Create a database class in a new directory which will be removed after test. It guarantees that all created files and directories be removed and will not conflict with other dbm tests. * Restore dbm._defaultmod after tests. Previously it was set to the last dbm module (dbm.dumb) which affected other tests. * Enable the whichdb test for dbm.dumb. * Move test_keys to the correct test class. It does not test whichdb(). * Remove some outdated code and comments.
This commit is contained in:
parent
236e301b8a
commit
975b94b9de
1 changed files with 55 additions and 67 deletions
|
@ -1,24 +1,21 @@
|
|||
"""Test script for the dbm.open function based on testdumbdbm.py"""
|
||||
|
||||
import unittest
|
||||
import glob
|
||||
import dbm
|
||||
import os
|
||||
from test.support import import_helper
|
||||
from test.support import os_helper
|
||||
|
||||
# Skip tests if dbm module doesn't exist.
|
||||
dbm = import_helper.import_module('dbm')
|
||||
|
||||
try:
|
||||
from dbm import ndbm
|
||||
except ImportError:
|
||||
ndbm = None
|
||||
|
||||
_fname = os_helper.TESTFN
|
||||
dirname = os_helper.TESTFN
|
||||
_fname = os.path.join(dirname, os_helper.TESTFN)
|
||||
|
||||
#
|
||||
# Iterates over every database module supported by dbm currently available,
|
||||
# setting dbm to use each in turn, and yielding that module
|
||||
# Iterates over every database module supported by dbm currently available.
|
||||
#
|
||||
def dbm_iterator():
|
||||
for name in dbm._names:
|
||||
|
@ -32,11 +29,12 @@ def dbm_iterator():
|
|||
#
|
||||
# Clean up all scratch databases we might have created during testing
|
||||
#
|
||||
def delete_files():
|
||||
# we don't know the precise name the underlying database uses
|
||||
# so we use glob to locate all names
|
||||
for f in glob.glob(glob.escape(_fname) + "*"):
|
||||
os_helper.unlink(f)
|
||||
def cleaunup_test_dir():
|
||||
os_helper.rmtree(dirname)
|
||||
|
||||
def setup_test_dir():
|
||||
cleaunup_test_dir()
|
||||
os.mkdir(dirname)
|
||||
|
||||
|
||||
class AnyDBMTestCase:
|
||||
|
@ -144,86 +142,76 @@ class AnyDBMTestCase:
|
|||
for key in self._dict:
|
||||
self.assertEqual(self._dict[key], f[key.encode("ascii")])
|
||||
|
||||
def tearDown(self):
|
||||
delete_files()
|
||||
def test_keys(self):
|
||||
with dbm.open(_fname, 'c') as d:
|
||||
self.assertEqual(d.keys(), [])
|
||||
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
|
||||
for k, v in a:
|
||||
d[k] = v
|
||||
self.assertEqual(sorted(d.keys()), sorted(k for (k, v) in a))
|
||||
for k, v in a:
|
||||
self.assertIn(k, d)
|
||||
self.assertEqual(d[k], v)
|
||||
self.assertNotIn(b'xxx', d)
|
||||
self.assertRaises(KeyError, lambda: d[b'xxx'])
|
||||
|
||||
def setUp(self):
|
||||
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
|
||||
dbm._defaultmod = self.module
|
||||
delete_files()
|
||||
self.addCleanup(cleaunup_test_dir)
|
||||
setup_test_dir()
|
||||
|
||||
|
||||
class WhichDBTestCase(unittest.TestCase):
|
||||
def test_whichdb(self):
|
||||
self.addCleanup(setattr, dbm, '_defaultmod', dbm._defaultmod)
|
||||
_bytes_fname = os.fsencode(_fname)
|
||||
for path in [_fname, os_helper.FakePath(_fname),
|
||||
_bytes_fname, os_helper.FakePath(_bytes_fname)]:
|
||||
for module in dbm_iterator():
|
||||
# Check whether whichdb correctly guesses module name
|
||||
# for databases opened with "module" module.
|
||||
# Try with empty files first
|
||||
name = module.__name__
|
||||
if name == 'dbm.dumb':
|
||||
continue # whichdb can't support dbm.dumb
|
||||
delete_files()
|
||||
f = module.open(path, 'c')
|
||||
f.close()
|
||||
fnames = [_fname, os_helper.FakePath(_fname),
|
||||
_bytes_fname, os_helper.FakePath(_bytes_fname)]
|
||||
for module in dbm_iterator():
|
||||
# Check whether whichdb correctly guesses module name
|
||||
# for databases opened with "module" module.
|
||||
name = module.__name__
|
||||
setup_test_dir()
|
||||
dbm._defaultmod = module
|
||||
# Try with empty files first
|
||||
with module.open(_fname, 'c'): pass
|
||||
for path in fnames:
|
||||
self.assertEqual(name, self.dbm.whichdb(path))
|
||||
# Now add a key
|
||||
f = module.open(path, 'w')
|
||||
# Now add a key
|
||||
with module.open(_fname, 'w') as f:
|
||||
f[b"1"] = b"1"
|
||||
# and test that we can find it
|
||||
self.assertIn(b"1", f)
|
||||
# and read it
|
||||
self.assertEqual(f[b"1"], b"1")
|
||||
f.close()
|
||||
for path in fnames:
|
||||
self.assertEqual(name, self.dbm.whichdb(path))
|
||||
|
||||
@unittest.skipUnless(ndbm, reason='Test requires ndbm')
|
||||
def test_whichdb_ndbm(self):
|
||||
# Issue 17198: check that ndbm which is referenced in whichdb is defined
|
||||
db_file = '{}_ndbm.db'.format(_fname)
|
||||
with open(db_file, 'w'):
|
||||
self.addCleanup(os_helper.unlink, db_file)
|
||||
db_file_bytes = os.fsencode(db_file)
|
||||
self.assertIsNone(self.dbm.whichdb(db_file[:-3]))
|
||||
self.assertIsNone(self.dbm.whichdb(os_helper.FakePath(db_file[:-3])))
|
||||
self.assertIsNone(self.dbm.whichdb(db_file_bytes[:-3]))
|
||||
self.assertIsNone(self.dbm.whichdb(os_helper.FakePath(db_file_bytes[:-3])))
|
||||
|
||||
def tearDown(self):
|
||||
delete_files()
|
||||
with open(_fname + '.db', 'wb'): pass
|
||||
_bytes_fname = os.fsencode(_fname)
|
||||
fnames = [_fname, os_helper.FakePath(_fname),
|
||||
_bytes_fname, os_helper.FakePath(_bytes_fname)]
|
||||
for path in fnames:
|
||||
self.assertIsNone(self.dbm.whichdb(path))
|
||||
|
||||
def setUp(self):
|
||||
delete_files()
|
||||
self.filename = os_helper.TESTFN
|
||||
self.d = dbm.open(self.filename, 'c')
|
||||
self.d.close()
|
||||
self.addCleanup(cleaunup_test_dir)
|
||||
setup_test_dir()
|
||||
self.dbm = import_helper.import_fresh_module('dbm')
|
||||
|
||||
def test_keys(self):
|
||||
self.d = dbm.open(self.filename, 'c')
|
||||
self.assertEqual(self.d.keys(), [])
|
||||
a = [(b'a', b'b'), (b'12345678910', b'019237410982340912840198242')]
|
||||
for k, v in a:
|
||||
self.d[k] = v
|
||||
self.assertEqual(sorted(self.d.keys()), sorted(k for (k, v) in a))
|
||||
for k, v in a:
|
||||
self.assertIn(k, self.d)
|
||||
self.assertEqual(self.d[k], v)
|
||||
self.assertNotIn(b'xxx', self.d)
|
||||
self.assertRaises(KeyError, lambda: self.d[b'xxx'])
|
||||
self.d.close()
|
||||
|
||||
for mod in dbm_iterator():
|
||||
assert mod.__name__.startswith('dbm.')
|
||||
suffix = mod.__name__[4:]
|
||||
testname = f'TestCase_{suffix}'
|
||||
globals()[testname] = type(testname,
|
||||
(AnyDBMTestCase, unittest.TestCase),
|
||||
{'module': mod})
|
||||
|
||||
def load_tests(loader, tests, pattern):
|
||||
classes = []
|
||||
for mod in dbm_iterator():
|
||||
classes.append(type("TestCase-" + mod.__name__,
|
||||
(AnyDBMTestCase, unittest.TestCase),
|
||||
{'module': mod}))
|
||||
for c in classes:
|
||||
tests.addTest(loader.loadTestsFromTestCase(c))
|
||||
return tests
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue