mirror of
				https://github.com/python/cpython.git
				synced 2025-11-03 19:34:08 +00:00 
			
		
		
		
	gh-128308: pass **kwargs to asyncio task_factory (#128768)
				
					
				
			Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
		
							parent
							
								
									6c914bf85c
								
							
						
					
					
						commit
						38a9956876
					
				
					 8 changed files with 48 additions and 29 deletions
				
			
		| 
						 | 
					@ -392,9 +392,9 @@ Creating Futures and Tasks
 | 
				
			||||||
 | 
					
 | 
				
			||||||
   If *factory* is ``None`` the default task factory will be set.
 | 
					   If *factory* is ``None`` the default task factory will be set.
 | 
				
			||||||
   Otherwise, *factory* must be a *callable* with the signature matching
 | 
					   Otherwise, *factory* must be a *callable* with the signature matching
 | 
				
			||||||
   ``(loop, coro, context=None)``, where *loop* is a reference to the active
 | 
					   ``(loop, coro, **kwargs)``, where *loop* is a reference to the active
 | 
				
			||||||
   event loop, and *coro* is a coroutine object.  The callable
 | 
					   event loop, and *coro* is a coroutine object.  The callable
 | 
				
			||||||
   must return a :class:`asyncio.Future`-compatible object.
 | 
					   must pass on all *kwargs*, and return a :class:`asyncio.Task`-compatible object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. method:: loop.get_task_factory()
 | 
					.. method:: loop.get_task_factory()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -458,25 +458,18 @@ class BaseEventLoop(events.AbstractEventLoop):
 | 
				
			||||||
        """Create a Future object attached to the loop."""
 | 
					        """Create a Future object attached to the loop."""
 | 
				
			||||||
        return futures.Future(loop=self)
 | 
					        return futures.Future(loop=self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_task(self, coro, *, name=None, context=None):
 | 
					    def create_task(self, coro, **kwargs):
 | 
				
			||||||
        """Schedule a coroutine object.
 | 
					        """Schedule a coroutine object.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Return a task object.
 | 
					        Return a task object.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        self._check_closed()
 | 
					        self._check_closed()
 | 
				
			||||||
        if self._task_factory is None:
 | 
					        if self._task_factory is not None:
 | 
				
			||||||
            task = tasks.Task(coro, loop=self, name=name, context=context)
 | 
					            return self._task_factory(self, coro, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        task = tasks.Task(coro, loop=self, **kwargs)
 | 
				
			||||||
        if task._source_traceback:
 | 
					        if task._source_traceback:
 | 
				
			||||||
            del task._source_traceback[-1]
 | 
					            del task._source_traceback[-1]
 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if context is None:
 | 
					 | 
				
			||||||
                # Use legacy API if context is not needed
 | 
					 | 
				
			||||||
                task = self._task_factory(self, coro)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                task = self._task_factory(self, coro, context=context)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            task.set_name(name)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            return task
 | 
					            return task
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
| 
						 | 
					@ -490,9 +483,10 @@ class BaseEventLoop(events.AbstractEventLoop):
 | 
				
			||||||
        If factory is None the default task factory will be set.
 | 
					        If factory is None the default task factory will be set.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        If factory is a callable, it should have a signature matching
 | 
					        If factory is a callable, it should have a signature matching
 | 
				
			||||||
        '(loop, coro)', where 'loop' will be a reference to the active
 | 
					        '(loop, coro, **kwargs)', where 'loop' will be a reference to the active
 | 
				
			||||||
        event loop, 'coro' will be a coroutine object.  The callable
 | 
					        event loop, 'coro' will be a coroutine object, and **kwargs will be
 | 
				
			||||||
        must return a Future.
 | 
					        arbitrary keyword arguments that should be passed on to Task.
 | 
				
			||||||
 | 
					        The callable must return a Task.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        if factory is not None and not callable(factory):
 | 
					        if factory is not None and not callable(factory):
 | 
				
			||||||
            raise TypeError('task factory must be a callable or None')
 | 
					            raise TypeError('task factory must be a callable or None')
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -329,7 +329,7 @@ class AbstractEventLoop:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Method scheduling a coroutine object: create a task.
 | 
					    # Method scheduling a coroutine object: create a task.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def create_task(self, coro, *, name=None, context=None):
 | 
					    def create_task(self, coro, **kwargs):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Methods for interacting with threads.
 | 
					    # Methods for interacting with threads.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -833,8 +833,8 @@ class BaseEventLoopTests(test_utils.TestCase):
 | 
				
			||||||
            loop.close()
 | 
					            loop.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_create_named_task_with_custom_factory(self):
 | 
					    def test_create_named_task_with_custom_factory(self):
 | 
				
			||||||
        def task_factory(loop, coro):
 | 
					        def task_factory(loop, coro, **kwargs):
 | 
				
			||||||
            return asyncio.Task(coro, loop=loop)
 | 
					            return asyncio.Task(coro, loop=loop, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async def test():
 | 
					        async def test():
 | 
				
			||||||
            pass
 | 
					            pass
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -302,6 +302,18 @@ class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
       self.run_coro(run())
 | 
					       self.run_coro(run())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_name(self):
 | 
				
			||||||
 | 
					        name = None
 | 
				
			||||||
 | 
					        async def coro():
 | 
				
			||||||
 | 
					            nonlocal name
 | 
				
			||||||
 | 
					            name = asyncio.current_task().get_name()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async def main():
 | 
				
			||||||
 | 
					            task = self.loop.create_task(coro(), name="test name")
 | 
				
			||||||
 | 
					            self.assertEqual(name, "test name")
 | 
				
			||||||
 | 
					            await task
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.run_coro(coro())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class AsyncTaskCounter:
 | 
					class AsyncTaskCounter:
 | 
				
			||||||
    def __init__(self, loop, *, task_class, eager):
 | 
					    def __init__(self, loop, *, task_class, eager):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -112,8 +112,8 @@ class TestPyFreeThreading(TestFreeThreading, TestCase):
 | 
				
			||||||
    all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
 | 
					    all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
 | 
				
			||||||
    current_task = staticmethod(asyncio.tasks._py_current_task)
 | 
					    current_task = staticmethod(asyncio.tasks._py_current_task)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def factory(self, loop, coro, context=None):
 | 
					    def factory(self, loop, coro, **kwargs):
 | 
				
			||||||
        return asyncio.tasks._PyTask(coro, loop=loop, context=context)
 | 
					        return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
 | 
					@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
 | 
				
			||||||
| 
						 | 
					@ -121,16 +121,16 @@ class TestCFreeThreading(TestFreeThreading, TestCase):
 | 
				
			||||||
    all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
 | 
					    all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
 | 
				
			||||||
    current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
 | 
					    current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def factory(self, loop, coro, context=None):
 | 
					    def factory(self, loop, coro, **kwargs):
 | 
				
			||||||
        return asyncio.tasks._CTask(coro, loop=loop, context=context)
 | 
					        return asyncio.tasks._CTask(coro, loop=loop, **kwargs)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestEagerPyFreeThreading(TestPyFreeThreading):
 | 
					class TestEagerPyFreeThreading(TestPyFreeThreading):
 | 
				
			||||||
    def factory(self, loop, coro, context=None):
 | 
					    def factory(self, loop, coro, eager_start=True, **kwargs):
 | 
				
			||||||
        return asyncio.tasks._PyTask(coro, loop=loop, context=context, eager_start=True)
 | 
					        return asyncio.tasks._PyTask(coro, loop=loop, **kwargs, eager_start=eager_start)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
 | 
					@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
 | 
				
			||||||
class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
 | 
					class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
 | 
				
			||||||
    def factory(self, loop, coro, context=None):
 | 
					    def factory(self, loop, coro, eager_start=True, **kwargs):
 | 
				
			||||||
        return asyncio.tasks._CTask(coro, loop=loop, context=context, eager_start=True)
 | 
					        return asyncio.tasks._CTask(coro, loop=loop, **kwargs, eager_start=eager_start)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1040,6 +1040,18 @@ class BaseTestTaskGroup:
 | 
				
			||||||
        self.assertIsNotNone(exc)
 | 
					        self.assertIsNotNone(exc)
 | 
				
			||||||
        self.assertListEqual(gc.get_referrers(exc), no_other_refs())
 | 
					        self.assertListEqual(gc.get_referrers(exc), no_other_refs())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def test_name(self):
 | 
				
			||||||
 | 
					        name = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async def asyncfn():
 | 
				
			||||||
 | 
					            nonlocal name
 | 
				
			||||||
 | 
					            name = asyncio.current_task().get_name()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        async with asyncio.TaskGroup() as tg:
 | 
				
			||||||
 | 
					            tg.create_task(asyncfn(), name="example name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(name, "example name")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
 | 
					class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
 | 
				
			||||||
    loop_factory = asyncio.EventLoop
 | 
					    loop_factory = asyncio.EventLoop
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -0,0 +1 @@
 | 
				
			||||||
 | 
					Support the *name* keyword argument for eager tasks in :func:`asyncio.loop.create_task`,  :func:`asyncio.create_task` and  :func:`asyncio.TaskGroup.create_task`, by passing on all *kwargs* to the task factory set by :func:`asyncio.loop.set_task_factory`.
 | 
				
			||||||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue