From f13001781d00ba86a4f1c377b105b962b258b92d Mon Sep 17 00:00:00 2001 From: Benjamin Barrera-Altuna Date: Tue, 28 Apr 2026 22:16:50 -0400 Subject: [PATCH 1/3] Pass Reflex app instance to lifespan tasks # Conflicts: # tests/units/app_mixins/test_lifespan.py --- docs/utility_methods/lifespan_tasks.md | 7 ++- reflex/app_mixins/lifespan.py | 8 ++- tests/units/app_mixins/test_lifespan.py | 72 +++++++++++++++++++++++++ 3 files changed, 83 insertions(+), 4 deletions(-) diff --git a/docs/utility_methods/lifespan_tasks.md b/docs/utility_methods/lifespan_tasks.md index d053b54d6e2..9945e121a36 100644 --- a/docs/utility_methods/lifespan_tasks.md +++ b/docs/utility_methods/lifespan_tasks.md @@ -40,8 +40,11 @@ async def long_running_task(foo, bar): To register a lifespan task, use `app.register_lifespan_task(coro_func, **kwargs)`. Any keyword arguments specified during registration will be passed to the task. -If the task accepts the special argument, `app`, it will be passed the `Starlette` -application instance. +If the task accepts the special argument, `app`, it will be passed the Reflex app +instance (`rx.App`/`LifespanMixin`). + +If the task accepts the special argument, `starlette_app`, it will be passed the +underlying `Starlette` application instance. ```python app = rx.App() diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index bac00517bb0..faf21960371 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -87,7 +87,7 @@ def get_lifespan_tasks(self) -> tuple[asyncio.Task | Callable, ...]: return tuple(self._lifespan_tasks) @contextlib.asynccontextmanager - async def _run_lifespan_tasks(self, app: Starlette): + async def _run_lifespan_tasks(self, starlette_app: Starlette): self._lifespan_tasks_started = True running_tasks = [] try: @@ -100,7 +100,11 @@ async def _run_lifespan_tasks(self, app: Starlette): else: signature = inspect.signature(task) if "app" in signature.parameters: - task = functools.partial(task, app=app) + task = functools.partial(task, app=self) + if "starlette_app" in signature.parameters: + task = functools.partial( + task, starlette_app=starlette_app + ) t_ = task() if isinstance(t_, contextlib._AsyncGeneratorContextManager): await stack.enter_async_context(t_) diff --git a/tests/units/app_mixins/test_lifespan.py b/tests/units/app_mixins/test_lifespan.py index d1d2f38bdd4..f0fef56c2ab 100644 --- a/tests/units/app_mixins/test_lifespan.py +++ b/tests/units/app_mixins/test_lifespan.py @@ -7,6 +7,7 @@ import pytest from reflex_base.utils.exceptions import InvalidLifespanTaskTypeError +from starlette.applications import Starlette from reflex.app_mixins.lifespan import LifespanMixin @@ -38,6 +39,7 @@ def check_for_updates(timeout: int) -> int: assert registered_task() == 10 +@pytest.mark.asyncio async def test_register_lifespan_task_rejects_kwargs_for_asyncio_task(): """Registering kwargs against an asyncio.Task raises a clear error.""" mixin = LifespanMixin() @@ -53,3 +55,73 @@ async def test_register_lifespan_task_rejects_kwargs_for_asyncio_task(): task.cancel() with contextlib.suppress(asyncio.CancelledError): await task + + +@pytest.mark.asyncio +async def test_lifespan_task_app_param_receives_reflex_app_instance(): + """Lifespan tasks should receive the Reflex app instance, not Starlette.""" + + class DummyApp(LifespanMixin): + """Minimal test app based on the lifespan mixin.""" + + app = DummyApp() + received: dict[str, object] = {} + + def lifespan_task(app): + """Record the app argument injected by the lifespan runner.""" + received["app"] = app + + app.register_lifespan_task(lifespan_task) + + async with app._run_lifespan_tasks(Starlette()): + await asyncio.sleep(0) + + assert received["app"] is app + + +@pytest.mark.asyncio +async def test_lifespan_task_starlette_app_param_receives_starlette_instance(): + """Lifespan tasks should receive the Starlette app when requested.""" + + class DummyApp(LifespanMixin): + """Minimal test app based on the lifespan mixin.""" + + app = DummyApp() + received: dict[str, object] = {} + starlette_app = Starlette() + + def lifespan_task(starlette_app): + """Record the Starlette app argument injected by the lifespan runner.""" + received["starlette_app"] = starlette_app + + app.register_lifespan_task(lifespan_task) + + async with app._run_lifespan_tasks(starlette_app): + await asyncio.sleep(0) + + assert received["starlette_app"] is starlette_app + + +@pytest.mark.asyncio +async def test_lifespan_task_both_app_and_starlette_app_params_are_injected(): + """Lifespan tasks should receive both app and starlette_app when declared.""" + + class DummyApp(LifespanMixin): + """Minimal test app based on the lifespan mixin.""" + + app = DummyApp() + received: dict[str, object] = {} + starlette_app = Starlette() + + def lifespan_task(app, starlette_app): + """Record both injected app objects from the lifespan runner.""" + received["app"] = app + received["starlette_app"] = starlette_app + + app.register_lifespan_task(lifespan_task) + + async with app._run_lifespan_tasks(starlette_app): + await asyncio.sleep(0) + + assert received["app"] is app + assert received["starlette_app"] is starlette_app From 73913f57c286166963b25ddcfa9e15789ba512bb Mon Sep 17 00:00:00 2001 From: Benjamin Barrera-Altuna Date: Tue, 28 Apr 2026 22:17:19 -0400 Subject: [PATCH 2/3] test(lifespan): cover starlette_app injection paths # Conflicts: # tests/units/app_mixins/test_lifespan.py --- tests/units/app_mixins/test_lifespan.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/units/app_mixins/test_lifespan.py b/tests/units/app_mixins/test_lifespan.py index f0fef56c2ab..1e8dcef6b24 100644 --- a/tests/units/app_mixins/test_lifespan.py +++ b/tests/units/app_mixins/test_lifespan.py @@ -91,7 +91,11 @@ class DummyApp(LifespanMixin): starlette_app = Starlette() def lifespan_task(starlette_app): - """Record the Starlette app argument injected by the lifespan runner.""" + """Record the Starlette app argument injected by the lifespan runner. + + Args: + starlette_app: Starlette app object injected by the lifespan runner. + """ received["starlette_app"] = starlette_app app.register_lifespan_task(lifespan_task) @@ -114,7 +118,12 @@ class DummyApp(LifespanMixin): starlette_app = Starlette() def lifespan_task(app, starlette_app): - """Record both injected app objects from the lifespan runner.""" + """Record both injected app objects from the lifespan runner. + + Args: + app: Reflex app object injected by the lifespan runner. + starlette_app: Starlette app object injected by the lifespan runner. + """ received["app"] = app received["starlette_app"] = starlette_app From b7e06dd51fef28e91b9b217124bfe49ad8073f04 Mon Sep 17 00:00:00 2001 From: Benjamin Barrera-Altuna Date: Tue, 28 Apr 2026 22:40:25 -0400 Subject: [PATCH 3/3] style(lifespan): apply ruff formatting for pre-commit --- reflex/app_mixins/lifespan.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/reflex/app_mixins/lifespan.py b/reflex/app_mixins/lifespan.py index faf21960371..0f42c583580 100644 --- a/reflex/app_mixins/lifespan.py +++ b/reflex/app_mixins/lifespan.py @@ -102,9 +102,7 @@ async def _run_lifespan_tasks(self, starlette_app: Starlette): if "app" in signature.parameters: task = functools.partial(task, app=self) if "starlette_app" in signature.parameters: - task = functools.partial( - task, starlette_app=starlette_app - ) + task = functools.partial(task, starlette_app=starlette_app) t_ = task() if isinstance(t_, contextlib._AsyncGeneratorContextManager): await stack.enter_async_context(t_)