Skip to content

Use TaskGroup from anyioutils #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 25 additions & 34 deletions src/fps/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

import anyio
import structlog
from anyio import Event, create_task_group, fail_after, move_on_after
from anyioutils import create_task, wait, FIRST_COMPLETED
from anyio import Event, fail_after, move_on_after
from anyioutils import FIRST_COMPLETED, TaskGroup, create_task, wait

from ._container import Container
from ._importer import import_from_string
Expand Down Expand Up @@ -145,19 +145,13 @@ async def get(
self, value_type: type[T_Value], timeout: float = float("inf")
) -> T_Value:
log.debug("Module getting value", path=self.path, value_type=value_type)
tasks = [create_task(self._container.get(value_type, self), self._task_group)]
tasks = [create_task(self._container.get(value_type, self))]
if self.parent is not None:
tasks.append(
create_task(
self.parent._container.get(value_type, self), self._task_group
)
)
tasks.append(create_task(self.parent._container.get(value_type, self)))
with fail_after(timeout):
done, pending = await wait(
tasks, self._task_group, return_when=FIRST_COMPLETED
)
done, pending = await wait(tasks, return_when=FIRST_COMPLETED)
for task in pending:
task.cancel(raise_exception=False)
task.cancel()
for task in done:
break
value = await task.wait()
Expand All @@ -173,11 +167,11 @@ async def __aenter__(self) -> Module:
log.debug("Running root module", name=self.path)
initialize(self)
async with AsyncExitStack() as exit_stack:
self._task_group = await exit_stack.enter_async_context(create_task_group())
self._task_group = await exit_stack.enter_async_context(TaskGroup())
self._exceptions = []
self._phase = "preparing"
with move_on_after(self._prepare_timeout) as scope:
self._task_group.start_soon(self._prepare, name=f"{self.path} _prepare")
create_task(self._prepare(), name=f"{self.path} _prepare")
await self._all_prepared()
if scope.cancelled_caught:
self._get_all_prepare_timeout()
Expand All @@ -186,7 +180,7 @@ async def __aenter__(self) -> Module:
else:
self._phase = "starting"
with move_on_after(self._start_timeout) as scope:
self._task_group.start_soon(self._start, name=f"{self.path} start")
create_task(self._start(), name=f"{self.path} start")
await self._all_started()
if scope.cancelled_caught:
self._get_all_start_timeout()
Expand All @@ -200,7 +194,7 @@ async def __aenter__(self) -> Module:
async def __aexit__(self, exc_type, exc_value, exc_tb):
self._phase = "stopping"
with move_on_after(self._stop_timeout) as scope:
self._task_group.start_soon(self._stop, name=f"{self.path} stop")
create_task(self._stop(), name=f"{self.path} stop")
await self._all_stopped()
self._exit.set()
if scope.cancelled_caught:
Expand Down Expand Up @@ -271,14 +265,13 @@ async def _all_stopped(self):
async def _prepare(self) -> None:
log.debug("Preparing module", path=self.path)
try:
async with create_task_group() as tg:
async with TaskGroup():
for module in self._modules.values():
module._task_group = tg
module._phase = self._phase
module._exceptions = self._exceptions
tg.start_soon(module._prepare, name=f"{module.path} _prepare")
tg.start_soon(
self._prepare_and_done, name=f"{self.path} _prepare_and_done"
create_task(module._prepare(), name=f"{module.path} _prepare")
create_task(
self._prepare_and_done(), name=f"{self.path} _prepare_and_done"
)
except ExceptionGroup as exc:
self._exceptions.append(*exc.exceptions)
Expand All @@ -300,16 +293,16 @@ def done(self) -> None:
self._started.set()
log.debug("Module started", path=self.path)
else:
self._task_group.start_soon(self._finish)
create_task(self._finish())

async def _finish(self):
tasks = (
create_task(self._drop_and_wait_values(), self._task_group),
create_task(self._exit.wait(), self._task_group),
create_task(self._drop_and_wait_values()),
create_task(self._exit.wait()),
)
done, pending = await wait(tasks, self._task_group, return_when=FIRST_COMPLETED)
done, pending = await wait(tasks, return_when=FIRST_COMPLETED)
for task in pending:
task.cancel(raise_exception=False)
task.cancel()

async def _drop_and_wait_values(self):
self.drop_all()
Expand All @@ -320,12 +313,11 @@ async def _drop_and_wait_values(self):
async def _start(self) -> None:
log.debug("Starting module", path=self.path)
try:
async with create_task_group() as tg:
async with TaskGroup():
for module in self._modules.values():
module._task_group = tg
module._phase = self._phase
tg.start_soon(module._start, name=f"{module.path} _start")
tg.start_soon(self._start_and_done, name=f"{self.path} _start_and_done")
create_task(module._start(), name=f"{module.path} _start")
create_task(self._start_and_done(), name=f"{self.path} _start_and_done")
except ExceptionGroup as exc:
self._exceptions.append(*exc.exceptions)
self._started.set()
Expand All @@ -340,16 +332,15 @@ async def start(self) -> None:
async def _stop(self) -> None:
log.debug("Stopping module", path=self.path)
try:
async with create_task_group() as tg:
async with TaskGroup():
for module in self._modules.values():
module._task_group = tg
module._phase = self._phase
tg.start_soon(module._stop, name=f"{module.path} _stop")
create_task(module._stop(), name=f"{module.path} _stop")
for context_manager_exit in self._context_manager_exits[::-1]:
res = context_manager_exit(None, None, None)
if isawaitable(res):
await res
tg.start_soon(self._stop_and_done, name=f"{self.path} _stop_and_done")
create_task(self._stop_and_done(), name=f"{self.path} _stop_and_done")
except ExceptionGroup as exc:
self._exceptions.append(*exc.exceptions)
self._stopped.set()
Expand Down
Loading