diff --git a/Lib/test/test_unittest.py b/Lib/test/test_unittest.py index ab043822daf..04a7322b236 100644 --- a/Lib/test/test_unittest.py +++ b/Lib/test/test_unittest.py @@ -3059,8 +3059,13 @@ class Test_Assertions(TestCase): pass else: self.fail("assertRaises() didn't let exception pass through") - with self.assertRaises(KeyError): - raise KeyError + with self.assertRaises(KeyError) as cm: + try: + raise KeyError + except Exception, e: + raise + self.assertIs(cm.exception, e) + with self.assertRaises(KeyError): raise KeyError("key") try: diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index 63408e35ace..4acfa6539af 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -91,7 +91,7 @@ class _AssertRaisesContext(object): self.expected_regexp = expected_regexp def __enter__(self): - pass + return self def __exit__(self, exc_type, exc_value, tb): if exc_type is None: