mirror of
				https://github.com/python/cpython.git
				synced 2025-11-04 03:44:55 +00:00 
			
		
		
		
	Allows contextvars from the main thread to be accessed in the separate thread used in `asyncio.to_thread()`. See the [discussion](https://github.com/python/cpython/pull/20143#discussion_r427808225) in GH-20143 for context. Automerge-Triggered-By: @aeros
		
			
				
	
	
		
			93 lines
		
	
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			93 lines
		
	
	
	
		
			2.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
"""Tests for asyncio/threads.py"""
 | 
						|
 | 
						|
import asyncio
 | 
						|
import unittest
 | 
						|
 | 
						|
from contextvars import ContextVar
 | 
						|
from unittest import mock
 | 
						|
from test.test_asyncio import utils as test_utils
 | 
						|
 | 
						|
 | 
						|
def tearDownModule():
 | 
						|
    asyncio.set_event_loop_policy(None)
 | 
						|
 | 
						|
 | 
						|
class ToThreadTests(test_utils.TestCase):
 | 
						|
    def setUp(self):
 | 
						|
        super().setUp()
 | 
						|
        self.loop = asyncio.new_event_loop()
 | 
						|
        asyncio.set_event_loop(self.loop)
 | 
						|
 | 
						|
    def tearDown(self):
 | 
						|
        self.loop.run_until_complete(
 | 
						|
            self.loop.shutdown_default_executor())
 | 
						|
        self.loop.close()
 | 
						|
        asyncio.set_event_loop(None)
 | 
						|
        self.loop = None
 | 
						|
        super().tearDown()
 | 
						|
 | 
						|
    def test_to_thread(self):
 | 
						|
        async def main():
 | 
						|
            return await asyncio.to_thread(sum, [40, 2])
 | 
						|
 | 
						|
        result = self.loop.run_until_complete(main())
 | 
						|
        self.assertEqual(result, 42)
 | 
						|
 | 
						|
    def test_to_thread_exception(self):
 | 
						|
        def raise_runtime():
 | 
						|
            raise RuntimeError("test")
 | 
						|
 | 
						|
        async def main():
 | 
						|
            await asyncio.to_thread(raise_runtime)
 | 
						|
 | 
						|
        with self.assertRaisesRegex(RuntimeError, "test"):
 | 
						|
            self.loop.run_until_complete(main())
 | 
						|
 | 
						|
    def test_to_thread_once(self):
 | 
						|
        func = mock.Mock()
 | 
						|
 | 
						|
        async def main():
 | 
						|
            await asyncio.to_thread(func)
 | 
						|
 | 
						|
        self.loop.run_until_complete(main())
 | 
						|
        func.assert_called_once()
 | 
						|
 | 
						|
    def test_to_thread_concurrent(self):
 | 
						|
        func = mock.Mock()
 | 
						|
 | 
						|
        async def main():
 | 
						|
            futs = []
 | 
						|
            for _ in range(10):
 | 
						|
                fut = asyncio.to_thread(func)
 | 
						|
                futs.append(fut)
 | 
						|
            await asyncio.gather(*futs)
 | 
						|
 | 
						|
        self.loop.run_until_complete(main())
 | 
						|
        self.assertEqual(func.call_count, 10)
 | 
						|
 | 
						|
    def test_to_thread_args_kwargs(self):
 | 
						|
        # Unlike run_in_executor(), to_thread() should directly accept kwargs.
 | 
						|
        func = mock.Mock()
 | 
						|
 | 
						|
        async def main():
 | 
						|
            await asyncio.to_thread(func, 'test', something=True)
 | 
						|
 | 
						|
        self.loop.run_until_complete(main())
 | 
						|
        func.assert_called_once_with('test', something=True)
 | 
						|
 | 
						|
    def test_to_thread_contextvars(self):
 | 
						|
        test_ctx = ContextVar('test_ctx')
 | 
						|
 | 
						|
        def get_ctx():
 | 
						|
            return test_ctx.get()
 | 
						|
 | 
						|
        async def main():
 | 
						|
            test_ctx.set('parrot')
 | 
						|
            return await asyncio.to_thread(get_ctx)
 | 
						|
 | 
						|
        result = self.loop.run_until_complete(main())
 | 
						|
        self.assertEqual(result, 'parrot')
 | 
						|
 | 
						|
 | 
						|
if __name__ == "__main__":
 | 
						|
    unittest.main()
 |