diff --git a/src/main.py b/src/main.py index 93ad1ac..3be2c5c 100644 --- a/src/main.py +++ b/src/main.py @@ -1,7 +1,7 @@ import argparse import asyncio import sys -from collections.abc import AsyncGenerator +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from pathlib import Path @@ -32,10 +32,12 @@ @asynccontextmanager -async def lifespan(app: FastAPI | None) -> AsyncGenerator[None, None]: # noqa: ARG001 +async def lifespan( + app: FastAPI | None, # noqa: ARG001 # parameter required by FastAPI/Starlette +) -> AsyncIterator[None]: """Manage application lifespan - startup and shutdown events.""" yield - asyncio.gather( + await asyncio.gather( logger.complete(), close_databases(), ) diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index c5623fb..a770a45 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, AsyncIterator from typing import Annotated from fastapi import Depends @@ -11,13 +11,13 @@ from database.users import APIKey, User -async def expdb_connection() -> AsyncGenerator[AsyncConnection, None]: +async def expdb_connection() -> AsyncIterator[AsyncConnection]: engine = expdb_database() async with engine.connect() as connection, connection.begin(): yield connection -async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]: +async def userdb_connection() -> AsyncIterator[AsyncConnection]: engine = user_database() async with engine.connect() as connection, connection.begin(): yield connection @@ -26,7 +26,7 @@ async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]: async def fetch_user( api_key: APIKey | None = None, user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, -) -> AsyncGenerator[User | None, None]: +) -> AsyncGenerator[User | None]: if not (api_key and user_data): yield None return diff --git a/tests/conftest.py b/tests/conftest.py index 483206e..ad86ce4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ import contextlib import json -from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator +from collections.abc import AsyncIterator, Iterable, Iterator from pathlib import Path from typing import Any, NamedTuple @@ -9,11 +9,13 @@ import pytest from _pytest.config import Config from _pytest.nodes import Item +from asgi_lifespan import LifespanManager +from fastapi import FastAPI from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine from database.setup import expdb_database, user_database -from main import create_api, lifespan +from main import create_api from routers.dependencies import expdb_connection, userdb_connection PHP_API_URL = "http://php-api:80/api/v1/json" @@ -51,12 +53,6 @@ async def temporary_records( await connection.commit() -@pytest.fixture(autouse=True, scope="session") -async def one_lifespan() -> AsyncGenerator[None, None]: - async with lifespan(app=None): - yield - - @pytest.fixture async def expdb_test() -> AsyncIterator[AsyncConnection]: async with automatic_rollback(expdb_database()) as connection: @@ -69,20 +65,34 @@ async def user_test() -> AsyncIterator[AsyncConnection]: yield connection -@pytest.fixture +# The PHP API fixture can be session scoped since they do not need access to +# function-scoped database transactions. +@pytest.fixture(scope="session") async def php_api() -> AsyncIterator[httpx.AsyncClient]: async with httpx.AsyncClient(base_url=PHP_API_URL) as client: yield client +@pytest.fixture(scope="session") +async def app() -> AsyncIterator[FastAPI]: + _app = create_api(Path(__file__).parent / "config.test.toml") + async with LifespanManager(_app): + yield _app + + @pytest.fixture async def py_api( - expdb_test: AsyncConnection, user_test: AsyncConnection + expdb_test: AsyncConnection, user_test: AsyncConnection, app: FastAPI ) -> AsyncIterator[httpx.AsyncClient]: - app = create_api(Path(__file__).parent / "config.test.toml") + """Create test client which automatically rolls back database updates on teardown.""" + # Using the function-scoped database fixtures automatically benefits the + # automatic rollbacks, but also lets a test author write to a database + # transaction that is shared with the app. That is, it enables: + # + # def my_test(expdb_test, py_api): + # expdb_test.execute(...) # write some data # noqa: ERA001 + # py_api.get(...) # read that data # noqa: ERA001 - # We use async generator functions because fixtures may not be called directly. - # The async generator returns the test connections for FastAPI to handle properly async def override_expdb() -> AsyncIterator[AsyncConnection]: yield expdb_test @@ -91,8 +101,7 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]: app.dependency_overrides[expdb_connection] = override_expdb app.dependency_overrides[userdb_connection] = override_userdb - # We do not use the Lifespan manager for now because our auto-use fixture - # `one_lifespan` will do setup and teardown at a session scope level instead. + async with httpx.AsyncClient( transport=httpx.ASGITransport(app=app), base_url="http://test", @@ -100,6 +109,9 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]: ) as client: yield client + app.dependency_overrides[expdb_connection] = expdb_connection + app.dependency_overrides[userdb_connection] = userdb_connection + @pytest.fixture def dataset_130() -> Iterator[dict[str, Any]]: diff --git a/tests/routers/openml/migration/setups_migration_test.py b/tests/routers/openml/migration/setups_migration_test.py index a042661..37c8bc1 100644 --- a/tests/routers/openml/migration/setups_migration_test.py +++ b/tests/routers/openml/migration/setups_migration_test.py @@ -1,7 +1,7 @@ import asyncio import contextlib import re -from collections.abc import AsyncGenerator, Callable, Iterable +from collections.abc import AsyncIterator, Callable, Iterable from contextlib import AbstractAsyncContextManager from http import HTTPStatus @@ -22,7 +22,7 @@ def temporary_tags( @contextlib.asynccontextmanager async def _temporary_tags( tags: Iterable[str], setup_id: int, *, persist: bool = False - ) -> AsyncGenerator[None]: + ) -> AsyncIterator[None]: insert_queries = [ ( "INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);",