Skip to content
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ dev = [
"pytest-mock",
"pytest-asyncio",
"httpx",
"asgi-lifespan",
"hypothesis",
"deepdiff",
"pytest-xdist",
Expand Down
41 changes: 39 additions & 2 deletions src/core/logging.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Utility functions for logging."""

import sys
import time
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path
Expand All @@ -20,7 +21,11 @@ def setup_log_sinks(configuration_file: Path | None = None) -> None:
sink = sink_configuration.pop("sink")
if sink == "sys.stderr":
sink = sys.stderr
logger.add(sink, serialize=True, **sink_configuration)
# Logs the additionally provided data as JSON.
sink_configuration.setdefault("serialize", True)
# Decouples log calls from I/O and makes it multiprocessing safe.
sink_configuration.setdefault("enqueue", True)
logger.add(sink, **sink_configuration)


async def add_request_context_to_log(
Expand All @@ -29,10 +34,42 @@ async def add_request_context_to_log(
) -> Response:
"""Add a unique request id to each log call."""
identifier = uuid.uuid4().hex
with logger.contextualize(request_id=identifier):
with logger.contextualize(
request_id=identifier,
method=request.method,
path=request.url.path,
):
return await call_next(request)


async def log_request_duration(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Log the process and wallclock time for each call.

Reported times cannot be attributed solely to processing the request.
As multiple requests can be handled concurrently in the same process,
process time may be spent on other requests as well. The same goes for
wallclock time, which is additionally influenced by e.g., context switches.
"""
start_mono_ns = time.monotonic_ns()
start_process_ns = time.process_time_ns()
response: Response = await call_next(request)

duration_mono_ns = time.monotonic_ns() - start_mono_ns
duration_process_ns = time.process_time_ns() - start_process_ns
logger.info(
"Request took {mono_ms} ms wallclock time (process time {process_ms} ms)",
mono_ms=int(duration_mono_ns / 1_000_000),
process_ms=int(duration_process_ns / 1_000_000),
wallclock_time_ns=duration_mono_ns,
process_time_ns=duration_process_ns,
status=response.status_code,
)
return response


async def request_response_logger(
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
Expand Down
16 changes: 13 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import asyncio
import sys
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
Expand All @@ -10,7 +11,12 @@

from config import load_configuration
from core.errors import ProblemDetailError, problem_detail_exception_handler
from core.logging import add_request_context_to_log, request_response_logger, setup_log_sinks
from core.logging import (
add_request_context_to_log,
log_request_duration,
request_response_logger,
setup_log_sinks,
)
from database.setup import close_databases
from routers.mldcat_ap.dataset import router as mldcat_ap_router
from routers.openml.datasets import router as datasets_router
Expand All @@ -26,10 +32,13 @@


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
async def lifespan(app: FastAPI | None) -> AsyncGenerator[None, None]: # noqa: ARG001
"""Manage application lifespan - startup and shutdown events."""
yield
await close_databases()
asyncio.gather(
logger.complete(),
close_databases(),
)


def _parse_args() -> argparse.Namespace:
Expand Down Expand Up @@ -72,6 +81,7 @@ def create_api(configuration_file: Path | None = None) -> FastAPI:
# Order matters! Each added middleware wraps the previous, creating a stack.
# See also: https://fastapi.tiangolo.com/tutorial/middleware/#multiple-middleware-execution-order
app.middleware("http")(request_response_logger)
app.middleware("http")(log_request_duration)
app.middleware("http")(add_request_context_to_log)

app.add_exception_handler(ProblemDetailError, problem_detail_exception_handler) # type: ignore[arg-type]
Expand Down
24 changes: 18 additions & 6 deletions src/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Annotated

from fastapi import Depends
from loguru import logger
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncConnection

Expand All @@ -25,21 +26,32 @@ async def userdb_connection() -> AsyncGenerator[AsyncConnection, None]:
async def fetch_user(
api_key: APIKey | None = None,
user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None,
) -> User | None:
) -> AsyncGenerator[User | None, None]:
if not (api_key and user_data):
return None
yield None
return

user = await User.fetch(api_key, user_data)
if user:
return user
msg = "Invalid API key provided."
raise AuthenticationFailedError(msg)
masked_key = api_key[-4:]
if not user:
logger.info("Authentication failed.", api_key=masked_key)
msg = "Invalid API key provided."
raise AuthenticationFailedError(msg)

logger.info(
"User {identifier} authenticated with api key ending in '{api_key}'.",
identifier=user.user_id,
api_key=masked_key,
)
with logger.contextualize(user_id=user.user_id):
yield user


def fetch_user_or_raise(
user: Annotated[User | None, Depends(fetch_user)] = None,
) -> User:
if user is None:
logger.info("Unauthenticated user tried to access endpoint that requires authentication.")
msg = "No API key provided."
raise AuthenticationRequiredError(msg)
return user
Expand Down
8 changes: 8 additions & 0 deletions src/routers/openml/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Annotated, Any, Literal, NamedTuple

from fastapi import APIRouter, Body, Depends
from loguru import logger
from sqlalchemy import bindparam, text
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection
Expand Down Expand Up @@ -61,6 +62,7 @@ async def tag_dataset(
raise TagAlreadyExistsError(msg)

await database.datasets.tag(data_id, tag, user_id=user.user_id, connection=expdb_db)
logger.info("Dataset {dataset_id} tagged '{tag}'.", dataset_id=data_id, tag=tag)
return {
"data_tag": {"id": str(data_id), "tag": [*tags, tag]},
}
Expand Down Expand Up @@ -375,6 +377,12 @@ async def update_dataset_status(
msg = f"Unknown status transition: {current_status} -> {status}"
raise InternalError(msg)

logger.info(
"Dataset {dataset_id} changed from {previous} to {current}",
dataset_id=dataset_id,
previous=current_status.status if current_status else DatasetStatus.IN_PREPARATION,
current=status,
)
return {"dataset_id": dataset_id, "status": status}


Expand Down
8 changes: 8 additions & 0 deletions src/routers/openml/setups.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Annotated

from fastapi import APIRouter, Body, Depends, Path
from loguru import logger
from sqlalchemy.ext.asyncio import AsyncConnection

import database.setups
Expand Down Expand Up @@ -65,6 +66,7 @@ async def tag_setup(
raise TagAlreadyExistsError(msg)

await database.setups.tag(setup_id, tag, user.user_id, expdb_db)
logger.info("Setup {setup_id} tagged '{tag}'.", setup_id=setup_id, tag=tag)
all_tags = [t.tag for t in setup_tags] + [tag]
return {"setup_tag": {"id": str(setup_id), "tag": all_tags}}

Expand Down Expand Up @@ -94,9 +96,15 @@ async def untag_setup(
msg = (
f"You may not remove tag {tag!r} of setup {setup_id} because it was not created by you."
)
logger.warning(
"User attempted to remove tag '{tag}' from setup {setup_id}.",
setup_id=setup_id,
tag=tag,
)
raise TagNotOwnedError(msg)

await database.setups.untag(setup_id, matched_tag_row.tag, expdb_db)
logger.info("Setup {setup_id} had tag '{tag}' removed.", setup_id=setup_id, tag=tag)
remaining_tags = [
t.tag for t in setup_tags if t.tag.casefold() != matched_tag_row.tag.casefold()
]
Expand Down
18 changes: 18 additions & 0 deletions src/routers/openml/study.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated, Literal

from fastapi import APIRouter, Body, Depends
from loguru import logger
from pydantic import BaseModel
from sqlalchemy.engine import Row
from sqlalchemy.ext.asyncio import AsyncConnection
Expand Down Expand Up @@ -73,6 +74,12 @@ async def attach_to_study(
# PHP lets *anyone* edit *any* study. We're not going to do that.
if study.creator != user.user_id and not await user.is_admin():
msg = f"Study {study_id} can only be edited by its creator."
logger.warning(
"User {user_id} attempted to attach entities to study they do not own.",
study_id=study_id,
entity_ids=entity_ids,
user_id=user.user_id,
)
raise StudyNotEditableError(msg)
if study.status != StudyStatus.IN_PREPARATION:
msg = f"Study {study_id} can only be edited while in preparation."
Expand All @@ -93,6 +100,12 @@ async def attach_to_study(
except ValueError as e:
msg = str(e)
raise StudyConflictError(msg) from e
logger.info(
"User {user_id} attached entities to study {study_id}.",
study_id=study_id,
entity_ids=entity_ids,
user_id=user.user_id,
)
return AttachDetachResponse(study_id=study_id, main_entity_type=study.type_)


Expand Down Expand Up @@ -124,6 +137,11 @@ async def create_study(
user=user,
expdb=expdb,
)
logger.info(
"User {user_id} created study {study_id}.",
study_id=study_id,
user_id=user.user_id,
)
# Make sure that invalid fields raise an error (e.g., "task_ids")
return {"study_id": study_id}

Expand Down
12 changes: 10 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib
import json
from collections.abc import AsyncIterator, Iterable, Iterator
from collections.abc import AsyncGenerator, AsyncIterator, Iterable, Iterator
from pathlib import Path
from typing import Any, NamedTuple

Expand All @@ -13,7 +13,7 @@
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine

from database.setup import expdb_database, user_database
from main import create_api
from main import create_api, lifespan
from routers.dependencies import expdb_connection, userdb_connection

PHP_API_URL = "http://php-api:80/api/v1/json"
Expand Down Expand Up @@ -51,6 +51,12 @@ async def temporary_records(
await connection.commit()


@pytest.fixture(autouse=True, scope="session")
async def one_lifespan() -> AsyncGenerator[None, None]:
async with lifespan(app=None):
yield


@pytest.fixture
async def expdb_test() -> AsyncIterator[AsyncConnection]:
async with automatic_rollback(expdb_database()) as connection:
Expand Down Expand Up @@ -85,6 +91,8 @@ async def override_userdb() -> AsyncIterator[AsyncConnection]:

app.dependency_overrides[expdb_connection] = override_expdb
app.dependency_overrides[userdb_connection] = override_userdb
# We do not use the Lifespan manager for now because our auto-use fixture
# `one_lifespan` will do setup and teardown at a session scope level instead.
async with httpx.AsyncClient(
transport=httpx.ASGITransport(app=app),
base_url="http://test",
Expand Down
11 changes: 8 additions & 3 deletions tests/dependencies/fetch_user_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from contextlib import aclosing

import pytest
from sqlalchemy.ext.asyncio import AsyncConnection

Expand All @@ -16,19 +18,22 @@
],
)
async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) -> None:
db_user = await fetch_user(api_key, user_data=user_test)
async with aclosing(fetch_user(api_key, user_data=user_test)) as agen:
db_user = await anext(agen)
assert isinstance(db_user, User)
assert user.user_id == db_user.user_id
assert set(await user.get_groups()) == set(await db_user.get_groups())


async def test_fetch_user_no_key_no_user() -> None:
assert await fetch_user(api_key=None) is None
async with aclosing(fetch_user(api_key=None)) as agen:
assert await anext(agen) is None


async def test_fetch_user_invalid_key_raises(user_test: AsyncConnection) -> None:
with pytest.raises(AuthenticationFailedError):
await fetch_user(api_key=ApiKey.INVALID, user_data=user_test)
async with aclosing(fetch_user(api_key=ApiKey.INVALID, user_data=user_test)) as agen:
await anext(agen)


async def test_fetch_user_or_raise_raises_if_no_user() -> None:
Expand Down
Loading