Issue #7232: Add support for the context manager protocol

to the TarFile class.
This commit is contained in:
Lars Gustäbel 2010-03-03 11:55:48 +00:00
parent 8af970ab58
commit 6458104188
4 changed files with 92 additions and 0 deletions

View file

@ -1292,6 +1292,65 @@ class LimitsTest(unittest.TestCase):
tarinfo.tobuf(tarfile.PAX_FORMAT)
class ContextManagerTest(unittest.TestCase):
def test_basic(self):
with tarfile.open(tarname) as tar:
self.assertFalse(tar.closed, "closed inside runtime context")
self.assertTrue(tar.closed, "context manager failed")
def test_closed(self):
# The __enter__() method is supposed to raise IOError
# if the TarFile object is already closed.
tar = tarfile.open(tarname)
tar.close()
with self.assertRaises(IOError):
with tar:
pass
def test_exception(self):
# Test if the IOError exception is passed through properly.
with self.assertRaises(Exception) as exc:
with tarfile.open(tarname) as tar:
raise IOError
self.assertIsInstance(exc.exception, IOError,
"wrong exception raised in context manager")
self.assertTrue(tar.closed, "context manager failed")
def test_no_eof(self):
# __exit__() must not write end-of-archive blocks if an
# exception was raised.
try:
with tarfile.open(tmpname, "w") as tar:
raise Exception
except:
pass
self.assertEqual(os.path.getsize(tmpname), 0,
"context manager wrote an end-of-archive block")
self.assertTrue(tar.closed, "context manager failed")
def test_eof(self):
# __exit__() must write end-of-archive blocks, i.e. call
# TarFile.close() if there was no error.
with tarfile.open(tmpname, "w"):
pass
self.assertNotEqual(os.path.getsize(tmpname), 0,
"context manager wrote no end-of-archive block")
def test_fileobj(self):
# Test that __exit__() did not close the external file
# object.
fobj = open(tmpname, "wb")
try:
with tarfile.open(fileobj=fobj, mode="w") as tar:
raise Exception
except:
pass
self.assertFalse(fobj.closed, "external file object was closed")
self.assertTrue(tar.closed, "context manager failed")
fobj.close()
class GzipMiscReadTest(MiscReadTest):
tarname = gzipname
mode = "r:gz"
@ -1371,6 +1430,7 @@ def test_main():
PaxUnicodeTest,
AppendTest,
LimitsTest,
ContextManagerTest,
]
if hasattr(os, "link"):