mirror of
https://github.com/python/cpython.git
synced 2025-08-01 23:53:15 +00:00

svn+ssh://pythondev@svn.python.org/python/trunk ........ r79539 | florent.xicluna | 2010-04-01 01:01:03 +0300 (Thu, 01 Apr 2010) | 2 lines Replace catch_warnings with check_warnings when it makes sense. Use assertRaises context manager to simplify some tests. ........
326 lines
8.9 KiB
Python
326 lines
8.9 KiB
Python
"""Unit tests for contextlib.py, and other context managers."""
|
|
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from contextlib import * # Tests __all__
|
|
from test import test_support
|
|
try:
|
|
import threading
|
|
except ImportError:
|
|
threading = None
|
|
|
|
|
|
class ContextManagerTestCase(unittest.TestCase):
|
|
|
|
def test_contextmanager_plain(self):
|
|
state = []
|
|
@contextmanager
|
|
def woohoo():
|
|
state.append(1)
|
|
yield 42
|
|
state.append(999)
|
|
with woohoo() as x:
|
|
self.assertEqual(state, [1])
|
|
self.assertEqual(x, 42)
|
|
state.append(x)
|
|
self.assertEqual(state, [1, 42, 999])
|
|
|
|
def test_contextmanager_finally(self):
|
|
state = []
|
|
@contextmanager
|
|
def woohoo():
|
|
state.append(1)
|
|
try:
|
|
yield 42
|
|
finally:
|
|
state.append(999)
|
|
with self.assertRaises(ZeroDivisionError):
|
|
with woohoo() as x:
|
|
self.assertEqual(state, [1])
|
|
self.assertEqual(x, 42)
|
|
state.append(x)
|
|
raise ZeroDivisionError()
|
|
self.assertEqual(state, [1, 42, 999])
|
|
|
|
def test_contextmanager_no_reraise(self):
|
|
@contextmanager
|
|
def whee():
|
|
yield
|
|
ctx = whee()
|
|
ctx.__enter__()
|
|
# Calling __exit__ should not result in an exception
|
|
self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
|
|
|
|
def test_contextmanager_trap_yield_after_throw(self):
|
|
@contextmanager
|
|
def whoo():
|
|
try:
|
|
yield
|
|
except:
|
|
yield
|
|
ctx = whoo()
|
|
ctx.__enter__()
|
|
self.assertRaises(
|
|
RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
|
|
)
|
|
|
|
def test_contextmanager_except(self):
|
|
state = []
|
|
@contextmanager
|
|
def woohoo():
|
|
state.append(1)
|
|
try:
|
|
yield 42
|
|
except ZeroDivisionError, e:
|
|
state.append(e.args[0])
|
|
self.assertEqual(state, [1, 42, 999])
|
|
with woohoo() as x:
|
|
self.assertEqual(state, [1])
|
|
self.assertEqual(x, 42)
|
|
state.append(x)
|
|
raise ZeroDivisionError(999)
|
|
self.assertEqual(state, [1, 42, 999])
|
|
|
|
def _create_contextmanager_attribs(self):
|
|
def attribs(**kw):
|
|
def decorate(func):
|
|
for k,v in kw.items():
|
|
setattr(func,k,v)
|
|
return func
|
|
return decorate
|
|
@contextmanager
|
|
@attribs(foo='bar')
|
|
def baz(spam):
|
|
"""Whee!"""
|
|
return baz
|
|
|
|
def test_contextmanager_attribs(self):
|
|
baz = self._create_contextmanager_attribs()
|
|
self.assertEqual(baz.__name__,'baz')
|
|
self.assertEqual(baz.foo, 'bar')
|
|
|
|
@unittest.skipIf(sys.flags.optimize >= 2,
|
|
"Docstrings are omitted with -O2 and above")
|
|
def test_contextmanager_doc_attrib(self):
|
|
baz = self._create_contextmanager_attribs()
|
|
self.assertEqual(baz.__doc__, "Whee!")
|
|
|
|
class NestedTestCase(unittest.TestCase):
|
|
|
|
# XXX This needs more work
|
|
|
|
def test_nested(self):
|
|
@contextmanager
|
|
def a():
|
|
yield 1
|
|
@contextmanager
|
|
def b():
|
|
yield 2
|
|
@contextmanager
|
|
def c():
|
|
yield 3
|
|
with nested(a(), b(), c()) as (x, y, z):
|
|
self.assertEqual(x, 1)
|
|
self.assertEqual(y, 2)
|
|
self.assertEqual(z, 3)
|
|
|
|
def test_nested_cleanup(self):
|
|
state = []
|
|
@contextmanager
|
|
def a():
|
|
state.append(1)
|
|
try:
|
|
yield 2
|
|
finally:
|
|
state.append(3)
|
|
@contextmanager
|
|
def b():
|
|
state.append(4)
|
|
try:
|
|
yield 5
|
|
finally:
|
|
state.append(6)
|
|
with self.assertRaises(ZeroDivisionError):
|
|
with nested(a(), b()) as (x, y):
|
|
state.append(x)
|
|
state.append(y)
|
|
1 // 0
|
|
self.assertEqual(state, [1, 4, 2, 5, 6, 3])
|
|
|
|
def test_nested_right_exception(self):
|
|
@contextmanager
|
|
def a():
|
|
yield 1
|
|
class b(object):
|
|
def __enter__(self):
|
|
return 2
|
|
def __exit__(self, *exc_info):
|
|
try:
|
|
raise Exception()
|
|
except:
|
|
pass
|
|
with self.assertRaises(ZeroDivisionError):
|
|
with nested(a(), b()) as (x, y):
|
|
1 // 0
|
|
self.assertEqual((x, y), (1, 2))
|
|
|
|
def test_nested_b_swallows(self):
|
|
@contextmanager
|
|
def a():
|
|
yield
|
|
@contextmanager
|
|
def b():
|
|
try:
|
|
yield
|
|
except:
|
|
# Swallow the exception
|
|
pass
|
|
try:
|
|
with nested(a(), b()):
|
|
1 // 0
|
|
except ZeroDivisionError:
|
|
self.fail("Didn't swallow ZeroDivisionError")
|
|
|
|
def test_nested_break(self):
|
|
@contextmanager
|
|
def a():
|
|
yield
|
|
state = 0
|
|
while True:
|
|
state += 1
|
|
with nested(a(), a()):
|
|
break
|
|
state += 10
|
|
self.assertEqual(state, 1)
|
|
|
|
def test_nested_continue(self):
|
|
@contextmanager
|
|
def a():
|
|
yield
|
|
state = 0
|
|
while state < 3:
|
|
state += 1
|
|
with nested(a(), a()):
|
|
continue
|
|
state += 10
|
|
self.assertEqual(state, 3)
|
|
|
|
def test_nested_return(self):
|
|
@contextmanager
|
|
def a():
|
|
try:
|
|
yield
|
|
except:
|
|
pass
|
|
def foo():
|
|
with nested(a(), a()):
|
|
return 1
|
|
return 10
|
|
self.assertEqual(foo(), 1)
|
|
|
|
class ClosingTestCase(unittest.TestCase):
|
|
|
|
# XXX This needs more work
|
|
|
|
def test_closing(self):
|
|
state = []
|
|
class C:
|
|
def close(self):
|
|
state.append(1)
|
|
x = C()
|
|
self.assertEqual(state, [])
|
|
with closing(x) as y:
|
|
self.assertEqual(x, y)
|
|
self.assertEqual(state, [1])
|
|
|
|
def test_closing_error(self):
|
|
state = []
|
|
class C:
|
|
def close(self):
|
|
state.append(1)
|
|
x = C()
|
|
self.assertEqual(state, [])
|
|
with self.assertRaises(ZeroDivisionError):
|
|
with closing(x) as y:
|
|
self.assertEqual(x, y)
|
|
1 // 0
|
|
self.assertEqual(state, [1])
|
|
|
|
class FileContextTestCase(unittest.TestCase):
|
|
|
|
def testWithOpen(self):
|
|
tfn = tempfile.mktemp()
|
|
try:
|
|
f = None
|
|
with open(tfn, "w") as f:
|
|
self.assertFalse(f.closed)
|
|
f.write("Booh\n")
|
|
self.assertTrue(f.closed)
|
|
f = None
|
|
with self.assertRaises(ZeroDivisionError):
|
|
with open(tfn, "r") as f:
|
|
self.assertFalse(f.closed)
|
|
self.assertEqual(f.read(), "Booh\n")
|
|
1 // 0
|
|
self.assertTrue(f.closed)
|
|
finally:
|
|
test_support.unlink(tfn)
|
|
|
|
@unittest.skipUnless(threading, 'Threading required for this test.')
|
|
class LockContextTestCase(unittest.TestCase):
|
|
|
|
def boilerPlate(self, lock, locked):
|
|
self.assertFalse(locked())
|
|
with lock:
|
|
self.assertTrue(locked())
|
|
self.assertFalse(locked())
|
|
with self.assertRaises(ZeroDivisionError):
|
|
with lock:
|
|
self.assertTrue(locked())
|
|
1 // 0
|
|
self.assertFalse(locked())
|
|
|
|
def testWithLock(self):
|
|
lock = threading.Lock()
|
|
self.boilerPlate(lock, lock.locked)
|
|
|
|
def testWithRLock(self):
|
|
lock = threading.RLock()
|
|
self.boilerPlate(lock, lock._is_owned)
|
|
|
|
def testWithCondition(self):
|
|
lock = threading.Condition()
|
|
def locked():
|
|
return lock._is_owned()
|
|
self.boilerPlate(lock, locked)
|
|
|
|
def testWithSemaphore(self):
|
|
lock = threading.Semaphore()
|
|
def locked():
|
|
if lock.acquire(False):
|
|
lock.release()
|
|
return False
|
|
else:
|
|
return True
|
|
self.boilerPlate(lock, locked)
|
|
|
|
def testWithBoundedSemaphore(self):
|
|
lock = threading.BoundedSemaphore()
|
|
def locked():
|
|
if lock.acquire(False):
|
|
lock.release()
|
|
return False
|
|
else:
|
|
return True
|
|
self.boilerPlate(lock, locked)
|
|
|
|
# This is needed to make the test actually run under regrtest.py!
|
|
def test_main():
|
|
with test_support.check_warnings(("With-statements now directly support "
|
|
"multiple context managers",
|
|
DeprecationWarning)):
|
|
test_support.run_unittest(__name__)
|
|
|
|
if __name__ == "__main__":
|
|
test_main()
|