Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import asyncio
import sys
from collections.abc import AsyncGenerator
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from pathlib import Path

Expand Down Expand Up @@ -32,10 +32,12 @@


@asynccontextmanager
async def lifespan(app: FastAPI | None) -> AsyncGenerator[None, None]: # noqa: ARG001
async def lifespan(
app: FastAPI | None, # noqa: ARG001 # parameter required by FastAPI/Starlette
) -> AsyncIterator[None]:
"""Manage application lifespan - startup and shutdown events."""
yield
asyncio.gather(
await asyncio.gather(
logger.complete(),
close_databases(),
)
Expand Down
8 changes: 4 additions & 4 deletions src/routers/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Annotated

from fastapi import Depends
Expand All @@ -11,13 +11,13 @@
from database.users import APIKey, User


async def expdb_connection() -> AsyncGenerator[AsyncConnection, None]:
async def expdb_connection() -> AsyncIterator[AsyncConnection]:
engine = expdb_database()
async with engine.connect() as connection, connection.begin():
yield connection


async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]:
async def userdb_connection() -> AsyncIterator[AsyncConnection]:
engine = user_database()
async with engine.connect() as connection, connection.begin():
yield connection
Expand All @@ -26,7 +26,7 @@ async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]:
async def fetch_user(
api_key: APIKey | None = None,
user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None,
) -> AsyncGenerator[User | None, None]:
) -> AsyncGenerator[User | None]:
if not (api_key and user_data):
yield None
return
Expand Down
42 changes: 27 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import json
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator
from collections.abc import AsyncIterator, Iterable, Iterator
from pathlib import Path
from typing import Any, NamedTuple

Expand All @@ -9,11 +9,13 @@
import pytest
from _pytest.config import Config
from _pytest.nodes import Item
from asgi_lifespan import LifespanManager
from fastapi import FastAPI
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine

from database.setup import expdb_database, user_database
from main import create_api, lifespan
from main import create_api
from routers.dependencies import expdb_connection, userdb_connection

PHP_API_URL = "http://php-api:80/api/v1/json"
Expand Down Expand Up @@ -51,12 +53,6 @@ async def temporary_records(
await connection.commit()


@pytest.fixture(autouse=True, scope="session")
async def one_lifespan() -> AsyncGenerator[None, None]:
async with lifespan(app=None):
yield


@pytest.fixture
async def expdb_test() -> AsyncIterator[AsyncConnection]:
async with automatic_rollback(expdb_database()) as connection:
Expand All @@ -69,20 +65,34 @@ async def user_test() -> AsyncIterator[AsyncConnection]:
yield connection


@pytest.fixture
# The PHP API fixture can be session scoped since they do not need access to
# function-scoped database transactions.
@pytest.fixture(scope="session")
async def php_api() -> AsyncIterator[httpx.AsyncClient]:
async with httpx.AsyncClient(base_url=PHP_API_URL) as client:
yield client


@pytest.fixture(scope="session")
async def app() -> AsyncIterator[FastAPI]:
_app = create_api(Path(__file__).parent / "config.test.toml")
async with LifespanManager(_app):
yield _app


@pytest.fixture
async def py_api(
expdb_test: AsyncConnection, user_test: AsyncConnection
expdb_test: AsyncConnection, user_test: AsyncConnection, app: FastAPI
) -> AsyncIterator[httpx.AsyncClient]:
app = create_api(Path(__file__).parent / "config.test.toml")
"""Create test client which automatically rolls back database updates on teardown."""
# Using the function-scoped database fixtures automatically benefits the
# automatic rollbacks, but also lets a test author write to a database
# transaction that is shared with the app. That is, it enables:
#
# def my_test(expdb_test, py_api):
# expdb_test.execute(...) # write some data # noqa: ERA001
# py_api.get(...) # read that data # noqa: ERA001

# We use async generator functions because fixtures may not be called directly.
# The async generator returns the test connections for FastAPI to handle properly
async def override_expdb() -> AsyncIterator[AsyncConnection]:
yield expdb_test

Expand All @@ -91,15 +101,17 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]:

app.dependency_overrides[expdb_connection] = override_expdb
app.dependency_overrides[userdb_connection] = override_userdb
# We do not use the Lifespan manager for now because our auto-use fixture
# `one_lifespan` will do setup and teardown at a session scope level instead.

async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://test",
follow_redirects=True,
) as client:
yield client

app.dependency_overrides[expdb_connection] = expdb_connection
app.dependency_overrides[userdb_connection] = userdb_connection


@pytest.fixture
def dataset_130() -> Iterator[dict[str, Any]]:
Expand Down
4 changes: 2 additions & 2 deletions tests/routers/openml/migration/setups_migration_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import contextlib
import re
from collections.abc import AsyncGenerator, Callable, Iterable
from collections.abc import AsyncIterator, Callable, Iterable
from contextlib import AbstractAsyncContextManager
from http import HTTPStatus

Expand All @@ -22,7 +22,7 @@ def temporary_tags(
@contextlib.asynccontextmanager
async def _temporary_tags(
tags: Iterable[str], setup_id: int, *, persist: bool = False
) -> AsyncGenerator[None]:
) -> AsyncIterator[None]:
insert_queries = [
(
"INSERT INTO setup_tag(`id`,`tag`,`uploader`) VALUES (:setup_id, :tag, :user_id);",
Expand Down
Loading