From bd4a06ad4b4fc4de830fd41ccd66c323d2998875 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Sat, 15 Feb 2025 17:34:42 +0100 Subject: [PATCH] Use TaskGroup from anyioutils --- src/fps/_module.py | 59 ++++++++++++++++++++-------------------------- 1 file changed, 25 insertions(+), 34 deletions(-) diff --git a/src/fps/_module.py b/src/fps/_module.py index ed194df..72c99d0 100644 --- a/src/fps/_module.py +++ b/src/fps/_module.py @@ -9,8 +9,8 @@ import anyio import structlog -from anyio import Event, create_task_group, fail_after, move_on_after -from anyioutils import create_task, wait, FIRST_COMPLETED +from anyio import Event, fail_after, move_on_after +from anyioutils import FIRST_COMPLETED, TaskGroup, create_task, wait from ._container import Container from ._importer import import_from_string @@ -145,19 +145,13 @@ async def get( self, value_type: type[T_Value], timeout: float = float("inf") ) -> T_Value: log.debug("Module getting value", path=self.path, value_type=value_type) - tasks = [create_task(self._container.get(value_type, self), self._task_group)] + tasks = [create_task(self._container.get(value_type, self))] if self.parent is not None: - tasks.append( - create_task( - self.parent._container.get(value_type, self), self._task_group - ) - ) + tasks.append(create_task(self.parent._container.get(value_type, self))) with fail_after(timeout): - done, pending = await wait( - tasks, self._task_group, return_when=FIRST_COMPLETED - ) + done, pending = await wait(tasks, return_when=FIRST_COMPLETED) for task in pending: - task.cancel(raise_exception=False) + task.cancel() for task in done: break value = await task.wait() @@ -173,11 +167,11 @@ async def __aenter__(self) -> Module: log.debug("Running root module", name=self.path) initialize(self) async with AsyncExitStack() as exit_stack: - self._task_group = await exit_stack.enter_async_context(create_task_group()) + self._task_group = await exit_stack.enter_async_context(TaskGroup()) self._exceptions = [] self._phase = "preparing" with move_on_after(self._prepare_timeout) as scope: - self._task_group.start_soon(self._prepare, name=f"{self.path} _prepare") + create_task(self._prepare(), name=f"{self.path} _prepare") await self._all_prepared() if scope.cancelled_caught: self._get_all_prepare_timeout() @@ -186,7 +180,7 @@ async def __aenter__(self) -> Module: else: self._phase = "starting" with move_on_after(self._start_timeout) as scope: - self._task_group.start_soon(self._start, name=f"{self.path} start") + create_task(self._start(), name=f"{self.path} start") await self._all_started() if scope.cancelled_caught: self._get_all_start_timeout() @@ -200,7 +194,7 @@ async def __aenter__(self) -> Module: async def __aexit__(self, exc_type, exc_value, exc_tb): self._phase = "stopping" with move_on_after(self._stop_timeout) as scope: - self._task_group.start_soon(self._stop, name=f"{self.path} stop") + create_task(self._stop(), name=f"{self.path} stop") await self._all_stopped() self._exit.set() if scope.cancelled_caught: @@ -271,14 +265,13 @@ async def _all_stopped(self): async def _prepare(self) -> None: log.debug("Preparing module", path=self.path) try: - async with create_task_group() as tg: + async with TaskGroup(): for module in self._modules.values(): - module._task_group = tg module._phase = self._phase module._exceptions = self._exceptions - tg.start_soon(module._prepare, name=f"{module.path} _prepare") - tg.start_soon( - self._prepare_and_done, name=f"{self.path} _prepare_and_done" + create_task(module._prepare(), name=f"{module.path} _prepare") + create_task( + self._prepare_and_done(), name=f"{self.path} _prepare_and_done" ) except ExceptionGroup as exc: self._exceptions.append(*exc.exceptions) @@ -300,16 +293,16 @@ def done(self) -> None: self._started.set() log.debug("Module started", path=self.path) else: - self._task_group.start_soon(self._finish) + create_task(self._finish()) async def _finish(self): tasks = ( - create_task(self._drop_and_wait_values(), self._task_group), - create_task(self._exit.wait(), self._task_group), + create_task(self._drop_and_wait_values()), + create_task(self._exit.wait()), ) - done, pending = await wait(tasks, self._task_group, return_when=FIRST_COMPLETED) + done, pending = await wait(tasks, return_when=FIRST_COMPLETED) for task in pending: - task.cancel(raise_exception=False) + task.cancel() async def _drop_and_wait_values(self): self.drop_all() @@ -320,12 +313,11 @@ async def _drop_and_wait_values(self): async def _start(self) -> None: log.debug("Starting module", path=self.path) try: - async with create_task_group() as tg: + async with TaskGroup(): for module in self._modules.values(): - module._task_group = tg module._phase = self._phase - tg.start_soon(module._start, name=f"{module.path} _start") - tg.start_soon(self._start_and_done, name=f"{self.path} _start_and_done") + create_task(module._start(), name=f"{module.path} _start") + create_task(self._start_and_done(), name=f"{self.path} _start_and_done") except ExceptionGroup as exc: self._exceptions.append(*exc.exceptions) self._started.set() @@ -340,16 +332,15 @@ async def start(self) -> None: async def _stop(self) -> None: log.debug("Stopping module", path=self.path) try: - async with create_task_group() as tg: + async with TaskGroup(): for module in self._modules.values(): - module._task_group = tg module._phase = self._phase - tg.start_soon(module._stop, name=f"{module.path} _stop") + create_task(module._stop(), name=f"{module.path} _stop") for context_manager_exit in self._context_manager_exits[::-1]: res = context_manager_exit(None, None, None) if isawaitable(res): await res - tg.start_soon(self._stop_and_done, name=f"{self.path} _stop_and_done") + create_task(self._stop_and_done(), name=f"{self.path} _stop_and_done") except ExceptionGroup as exc: self._exceptions.append(*exc.exceptions) self._stopped.set()