mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 19:34:08 +00:00 
			
		
		
		
	after a failed import. This is the last checkin in the "change import failure semantics" series.
		
			
				
	
	
		
			82 lines
		
	
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			82 lines
		
	
	
	
		
			2.8 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import os, sys, string, random, tempfile, unittest
 | 
						|
 | 
						|
from test.test_support import run_unittest
 | 
						|
 | 
						|
class TestImport(unittest.TestCase):
 | 
						|
 | 
						|
    def __init__(self, *args, **kw):
 | 
						|
        self.package_name = 'PACKAGE_'
 | 
						|
        while sys.modules.has_key(self.package_name):
 | 
						|
            self.package_name += random.choose(string.letters)
 | 
						|
        self.module_name = self.package_name + '.foo'
 | 
						|
        unittest.TestCase.__init__(self, *args, **kw)
 | 
						|
 | 
						|
    def remove_modules(self):
 | 
						|
        for module_name in (self.package_name, self.module_name):
 | 
						|
            if sys.modules.has_key(module_name):
 | 
						|
                del sys.modules[module_name]
 | 
						|
 | 
						|
    def setUp(self):
 | 
						|
        self.test_dir = tempfile.mkdtemp()
 | 
						|
        sys.path.append(self.test_dir)
 | 
						|
        self.package_dir = os.path.join(self.test_dir,
 | 
						|
                                        self.package_name)
 | 
						|
        os.mkdir(self.package_dir)
 | 
						|
        open(os.path.join(self.package_dir, '__init__'+os.extsep+'py'), 'w')
 | 
						|
        self.module_path = os.path.join(self.package_dir, 'foo'+os.extsep+'py')
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        for file in os.listdir(self.package_dir):
 | 
						|
            os.remove(os.path.join(self.package_dir, file))
 | 
						|
        os.rmdir(self.package_dir)
 | 
						|
        os.rmdir(self.test_dir)
 | 
						|
        self.assertNotEqual(sys.path.count(self.test_dir), 0)
 | 
						|
        sys.path.remove(self.test_dir)
 | 
						|
        self.remove_modules()
 | 
						|
 | 
						|
    def rewrite_file(self, contents):
 | 
						|
        for extension in "co":
 | 
						|
            compiled_path = self.module_path + extension
 | 
						|
            if os.path.exists(compiled_path):
 | 
						|
                os.remove(compiled_path)
 | 
						|
        f = open(self.module_path, 'w')
 | 
						|
        f.write(contents)
 | 
						|
        f.close()
 | 
						|
 | 
						|
    def test_package_import__semantics(self):
 | 
						|
 | 
						|
        # Generate a couple of broken modules to try importing.
 | 
						|
 | 
						|
        # ...try loading the module when there's a SyntaxError
 | 
						|
        self.rewrite_file('for')
 | 
						|
        try: __import__(self.module_name)
 | 
						|
        except SyntaxError: pass
 | 
						|
        else: raise RuntimeError, 'Failed to induce SyntaxError'
 | 
						|
        self.assert_(not sys.modules.has_key(self.module_name) and
 | 
						|
                     not hasattr(sys.modules[self.package_name], 'foo'))
 | 
						|
 | 
						|
        # ...make up a variable name that isn't bound in __builtins__
 | 
						|
        var = 'a'
 | 
						|
        while var in dir(__builtins__):
 | 
						|
            var += random.choose(string.letters)
 | 
						|
 | 
						|
        # ...make a module that just contains that
 | 
						|
        self.rewrite_file(var)
 | 
						|
 | 
						|
        try: __import__(self.module_name)
 | 
						|
        except NameError: pass
 | 
						|
        else: raise RuntimeError, 'Failed to induce NameError.'
 | 
						|
 | 
						|
        # ...now  change  the module  so  that  the NameError  doesn't
 | 
						|
        # happen
 | 
						|
        self.rewrite_file('%s = 1' % var)
 | 
						|
        module = __import__(self.module_name).foo
 | 
						|
        self.assertEqual(getattr(module, var), 1)
 | 
						|
 | 
						|
 | 
						|
def test_main():
 | 
						|
    run_unittest(TestImport)
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    test_main()
 |