mirror of
https://github.com/python/cpython.git
synced 2025-08-27 12:16:04 +00:00
gh-128308: pass **kwargs
to asyncio task_factory (#128768)
Co-authored-by: Kumar Aditya <kumaraditya@python.org>
This commit is contained in:
parent
6c914bf85c
commit
38a9956876
8 changed files with 48 additions and 29 deletions
|
@ -392,9 +392,9 @@ Creating Futures and Tasks
|
||||||
|
|
||||||
If *factory* is ``None`` the default task factory will be set.
|
If *factory* is ``None`` the default task factory will be set.
|
||||||
Otherwise, *factory* must be a *callable* with the signature matching
|
Otherwise, *factory* must be a *callable* with the signature matching
|
||||||
``(loop, coro, context=None)``, where *loop* is a reference to the active
|
``(loop, coro, **kwargs)``, where *loop* is a reference to the active
|
||||||
event loop, and *coro* is a coroutine object. The callable
|
event loop, and *coro* is a coroutine object. The callable
|
||||||
must return a :class:`asyncio.Future`-compatible object.
|
must pass on all *kwargs*, and return a :class:`asyncio.Task`-compatible object.
|
||||||
|
|
||||||
.. method:: loop.get_task_factory()
|
.. method:: loop.get_task_factory()
|
||||||
|
|
||||||
|
|
|
@ -458,25 +458,18 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
"""Create a Future object attached to the loop."""
|
"""Create a Future object attached to the loop."""
|
||||||
return futures.Future(loop=self)
|
return futures.Future(loop=self)
|
||||||
|
|
||||||
def create_task(self, coro, *, name=None, context=None):
|
def create_task(self, coro, **kwargs):
|
||||||
"""Schedule a coroutine object.
|
"""Schedule a coroutine object.
|
||||||
|
|
||||||
Return a task object.
|
Return a task object.
|
||||||
"""
|
"""
|
||||||
self._check_closed()
|
self._check_closed()
|
||||||
if self._task_factory is None:
|
if self._task_factory is not None:
|
||||||
task = tasks.Task(coro, loop=self, name=name, context=context)
|
return self._task_factory(self, coro, **kwargs)
|
||||||
if task._source_traceback:
|
|
||||||
del task._source_traceback[-1]
|
|
||||||
else:
|
|
||||||
if context is None:
|
|
||||||
# Use legacy API if context is not needed
|
|
||||||
task = self._task_factory(self, coro)
|
|
||||||
else:
|
|
||||||
task = self._task_factory(self, coro, context=context)
|
|
||||||
|
|
||||||
task.set_name(name)
|
|
||||||
|
|
||||||
|
task = tasks.Task(coro, loop=self, **kwargs)
|
||||||
|
if task._source_traceback:
|
||||||
|
del task._source_traceback[-1]
|
||||||
try:
|
try:
|
||||||
return task
|
return task
|
||||||
finally:
|
finally:
|
||||||
|
@ -490,9 +483,10 @@ class BaseEventLoop(events.AbstractEventLoop):
|
||||||
If factory is None the default task factory will be set.
|
If factory is None the default task factory will be set.
|
||||||
|
|
||||||
If factory is a callable, it should have a signature matching
|
If factory is a callable, it should have a signature matching
|
||||||
'(loop, coro)', where 'loop' will be a reference to the active
|
'(loop, coro, **kwargs)', where 'loop' will be a reference to the active
|
||||||
event loop, 'coro' will be a coroutine object. The callable
|
event loop, 'coro' will be a coroutine object, and **kwargs will be
|
||||||
must return a Future.
|
arbitrary keyword arguments that should be passed on to Task.
|
||||||
|
The callable must return a Task.
|
||||||
"""
|
"""
|
||||||
if factory is not None and not callable(factory):
|
if factory is not None and not callable(factory):
|
||||||
raise TypeError('task factory must be a callable or None')
|
raise TypeError('task factory must be a callable or None')
|
||||||
|
|
|
@ -329,7 +329,7 @@ class AbstractEventLoop:
|
||||||
|
|
||||||
# Method scheduling a coroutine object: create a task.
|
# Method scheduling a coroutine object: create a task.
|
||||||
|
|
||||||
def create_task(self, coro, *, name=None, context=None):
|
def create_task(self, coro, **kwargs):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
# Methods for interacting with threads.
|
# Methods for interacting with threads.
|
||||||
|
|
|
@ -833,8 +833,8 @@ class BaseEventLoopTests(test_utils.TestCase):
|
||||||
loop.close()
|
loop.close()
|
||||||
|
|
||||||
def test_create_named_task_with_custom_factory(self):
|
def test_create_named_task_with_custom_factory(self):
|
||||||
def task_factory(loop, coro):
|
def task_factory(loop, coro, **kwargs):
|
||||||
return asyncio.Task(coro, loop=loop)
|
return asyncio.Task(coro, loop=loop, **kwargs)
|
||||||
|
|
||||||
async def test():
|
async def test():
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -302,6 +302,18 @@ class CEagerTaskFactoryLoopTests(EagerTaskFactoryLoopTests, test_utils.TestCase)
|
||||||
|
|
||||||
self.run_coro(run())
|
self.run_coro(run())
|
||||||
|
|
||||||
|
def test_name(self):
|
||||||
|
name = None
|
||||||
|
async def coro():
|
||||||
|
nonlocal name
|
||||||
|
name = asyncio.current_task().get_name()
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
task = self.loop.create_task(coro(), name="test name")
|
||||||
|
self.assertEqual(name, "test name")
|
||||||
|
await task
|
||||||
|
|
||||||
|
self.run_coro(coro())
|
||||||
|
|
||||||
class AsyncTaskCounter:
|
class AsyncTaskCounter:
|
||||||
def __init__(self, loop, *, task_class, eager):
|
def __init__(self, loop, *, task_class, eager):
|
||||||
|
|
|
@ -112,8 +112,8 @@ class TestPyFreeThreading(TestFreeThreading, TestCase):
|
||||||
all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
|
all_tasks = staticmethod(asyncio.tasks._py_all_tasks)
|
||||||
current_task = staticmethod(asyncio.tasks._py_current_task)
|
current_task = staticmethod(asyncio.tasks._py_current_task)
|
||||||
|
|
||||||
def factory(self, loop, coro, context=None):
|
def factory(self, loop, coro, **kwargs):
|
||||||
return asyncio.tasks._PyTask(coro, loop=loop, context=context)
|
return asyncio.tasks._PyTask(coro, loop=loop, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
|
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
|
||||||
|
@ -121,16 +121,16 @@ class TestCFreeThreading(TestFreeThreading, TestCase):
|
||||||
all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
|
all_tasks = staticmethod(getattr(asyncio.tasks, "_c_all_tasks", None))
|
||||||
current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
|
current_task = staticmethod(getattr(asyncio.tasks, "_c_current_task", None))
|
||||||
|
|
||||||
def factory(self, loop, coro, context=None):
|
def factory(self, loop, coro, **kwargs):
|
||||||
return asyncio.tasks._CTask(coro, loop=loop, context=context)
|
return asyncio.tasks._CTask(coro, loop=loop, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TestEagerPyFreeThreading(TestPyFreeThreading):
|
class TestEagerPyFreeThreading(TestPyFreeThreading):
|
||||||
def factory(self, loop, coro, context=None):
|
def factory(self, loop, coro, eager_start=True, **kwargs):
|
||||||
return asyncio.tasks._PyTask(coro, loop=loop, context=context, eager_start=True)
|
return asyncio.tasks._PyTask(coro, loop=loop, **kwargs, eager_start=eager_start)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
|
@unittest.skipUnless(hasattr(asyncio.tasks, "_c_all_tasks"), "requires _asyncio")
|
||||||
class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
|
class TestEagerCFreeThreading(TestCFreeThreading, TestCase):
|
||||||
def factory(self, loop, coro, context=None):
|
def factory(self, loop, coro, eager_start=True, **kwargs):
|
||||||
return asyncio.tasks._CTask(coro, loop=loop, context=context, eager_start=True)
|
return asyncio.tasks._CTask(coro, loop=loop, **kwargs, eager_start=eager_start)
|
||||||
|
|
|
@ -1040,6 +1040,18 @@ class BaseTestTaskGroup:
|
||||||
self.assertIsNotNone(exc)
|
self.assertIsNotNone(exc)
|
||||||
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
self.assertListEqual(gc.get_referrers(exc), no_other_refs())
|
||||||
|
|
||||||
|
async def test_name(self):
|
||||||
|
name = None
|
||||||
|
|
||||||
|
async def asyncfn():
|
||||||
|
nonlocal name
|
||||||
|
name = asyncio.current_task().get_name()
|
||||||
|
|
||||||
|
async with asyncio.TaskGroup() as tg:
|
||||||
|
tg.create_task(asyncfn(), name="example name")
|
||||||
|
|
||||||
|
self.assertEqual(name, "example name")
|
||||||
|
|
||||||
|
|
||||||
class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
|
class TestTaskGroup(BaseTestTaskGroup, unittest.IsolatedAsyncioTestCase):
|
||||||
loop_factory = asyncio.EventLoop
|
loop_factory = asyncio.EventLoop
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
Support the *name* keyword argument for eager tasks in :func:`asyncio.loop.create_task`, :func:`asyncio.create_task` and :func:`asyncio.TaskGroup.create_task`, by passing on all *kwargs* to the task factory set by :func:`asyncio.loop.set_task_factory`.
|
Loading…
Add table
Add a link
Reference in a new issue