From ba7ff9b7837cf6ed1b8820433eff035476ba1a22 Mon Sep 17 00:00:00 2001 From: MohammadPatelNHS <247976665+MohammadPatelNHS@users.noreply.github.com> Date: Thu, 16 Apr 2026 09:35:43 +0000 Subject: [PATCH] [CDAPI-72]: Implemented pdm client --- .github/workflows/preview-env.yaml | 4 +- .../images/pathology-api/resources/.gitignore | 1 + mocks/src/apim_mock/auth_check.py | 10 +- mocks/src/apim_mock/handler.py | 11 +- mocks/src/common/test_utils.py | 9 + mocks/src/common/utils.py | 8 + mocks/src/pdm_mock/handler.py | 32 ++- mocks/src/pdm_mock/test_handler.py | 54 ++++- pathology-api/lambda_handler.py | 44 +++- pathology-api/src/pathology_api/handler.py | 22 +- pathology-api/src/pathology_api/http.py | 2 +- pathology-api/src/pathology_api/logging.py | 2 +- pathology-api/src/pathology_api/pdm.py | 62 ++++++ .../src/pathology_api/request_context.py | 23 +- pathology-api/src/pathology_api/test_apim.py | 12 +- .../src/pathology_api/test_handler.py | 54 ++++- pathology-api/src/pathology_api/test_http.py | 11 +- .../src/pathology_api/test_logging.py | 67 ++++-- pathology-api/src/pathology_api/test_mns.py | 13 +- pathology-api/src/pathology_api/test_pdm.py | 200 ++++++++++++++++++ .../src/pathology_api/test_request_context.py | 51 +++-- pathology-api/test_lambda_handler.py | 53 ++++- .../tests/integration/test_endpoints.py | 4 +- 23 files changed, 636 insertions(+), 113 deletions(-) create mode 100644 mocks/src/common/test_utils.py create mode 100644 mocks/src/common/utils.py create mode 100644 pathology-api/src/pathology_api/pdm.py create mode 100644 pathology-api/src/pathology_api/test_pdm.py diff --git a/.github/workflows/preview-env.yaml b/.github/workflows/preview-env.yaml index a11d6d02..d29ab8f0 100644 --- a/.github/workflows/preview-env.yaml +++ b/.github/workflows/preview-env.yaml @@ -182,7 +182,7 @@ jobs: APIM_MTLS_KEY_NAME=$MTLS_KEY, \ APIM_KEY_ID=$KEY_ID, \ APIM_TOKEN_URL=$MOCK_URL/apim/oauth2/token, \ - PDM_BUNDLE_URL=$MOCK_URL/apim/check_auth, \ + PDM_BUNDLE_URL=$MOCK_URL/pdm/FHIR/R4/Bundle, \ MNS_EVENT_URL=$MOCK_URL/mns/events, \ CLIENT_TIMEOUT=$CLIENT_TIMEOUT, \ JWKS_SECRET_NAME=$JWKS_SECRET}" || true @@ -205,7 +205,7 @@ jobs: APIM_MTLS_CERT_NAME=$MTLS_CERT, \ APIM_MTLS_KEY_NAME=$MTLS_KEY, \ APIM_TOKEN_URL=$MOCK_URL/apim/oauth2/token, \ - PDM_BUNDLE_URL=$MOCK_URL/apim/check_auth, \ + PDM_BUNDLE_URL=$MOCK_URL/pdm/FHIR/R4/Bundle, \ MNS_EVENT_URL=$MOCK_URL/mns/events, \ CLIENT_TIMEOUT=$CLIENT_TIMEOUT, \ JWKS_SECRET_NAME=$JWKS_SECRET}" \ diff --git a/infrastructure/images/pathology-api/resources/.gitignore b/infrastructure/images/pathology-api/resources/.gitignore index 796b96d1..1e60e223 100644 --- a/infrastructure/images/pathology-api/resources/.gitignore +++ b/infrastructure/images/pathology-api/resources/.gitignore @@ -1 +1,2 @@ /build +/.aws diff --git a/mocks/src/apim_mock/auth_check.py b/mocks/src/apim_mock/auth_check.py index bae56a39..1ae60109 100644 --- a/mocks/src/apim_mock/auth_check.py +++ b/mocks/src/apim_mock/auth_check.py @@ -2,19 +2,14 @@ from typing import Any from boto3.dynamodb.conditions import Attr +from common.logging import get_logger from common.storage_helper import StorageHelper -JWT_ALGORITHMS = ["RS512"] -REQUESTS_TIMEOUT = 5 -DEFAULT_TOKEN_LIFETIME = 599 - -AUTH_URL = os.environ["AUTH_URL"] -PUBLIC_KEY_URL = os.environ["PUBLIC_KEY_URL"] -API_KEY = os.environ["API_KEY"] TOKEN_TABLE_NAME = os.environ["TOKEN_TABLE_NAME"] BRANCH_NAME = os.environ["DDB_INDEX_TAG"] storage_helper = StorageHelper(TOKEN_TABLE_NAME, BRANCH_NAME) +_logger = get_logger(__name__) class AuthenticationError(Exception): @@ -24,6 +19,7 @@ class AuthenticationError(Exception): def check_authenticated(request_headers: dict[str, Any]) -> None: auth_token = request_headers.get("Authorization", "").replace("Bearer ", "") + _logger.debug("Querying DynamoDB table for access token") filter_expression = Attr("access_token").eq(auth_token) query_result = storage_helper.find_items(filter_expression) diff --git a/mocks/src/apim_mock/handler.py b/mocks/src/apim_mock/handler.py index be774385..bdd6844f 100644 --- a/mocks/src/apim_mock/handler.py +++ b/mocks/src/apim_mock/handler.py @@ -1,6 +1,5 @@ import json import os -import re import secrets import string from datetime import datetime, timedelta, timezone @@ -15,6 +14,7 @@ from aws_lambda_powertools.event_handler.router import APIGatewayHttpRouter from common.logging import get_logger from common.storage_helper import BaseMockItem, StorageHelper +from common.utils import check_valid_uuid4 from requests import HTTPError JWT_ALGORITHMS = ["RS512"] @@ -156,7 +156,7 @@ def _validate_assertions(assertions: dict[str, Any]) -> None: if not jti: raise ValueError("Missing 'jti' claim in client_assertion JWT") - if not _check_valid_uuid4(jti): + if not check_valid_uuid4(jti): raise ValueError("Invalid UUID4 value for jti") if not assertions.get("exp"): @@ -171,13 +171,6 @@ def _validate_assertions(assertions: dict[str, Any]) -> None: ) -def _check_valid_uuid4(string: str) -> bool: - uuid_regex = ( - r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$" - ) - return re.match(uuid_regex, string) is not None - - def _generate_random_token() -> str: return "".join( secrets.choice( diff --git a/mocks/src/common/test_utils.py b/mocks/src/common/test_utils.py new file mode 100644 index 00000000..07b1ec27 --- /dev/null +++ b/mocks/src/common/test_utils.py @@ -0,0 +1,9 @@ +from common.utils import check_valid_uuid4 + + +class TestUtils: + def test_check_valid_uuid_with_valid_uuid(self) -> None: + assert check_valid_uuid4("8c64be5f-3d7a-4b7b-8260-b716d122bdaf") + + def test_check_valid_uuid_with_invalid_uuid(self) -> None: + assert not check_valid_uuid4("invalid-uuid") diff --git a/mocks/src/common/utils.py b/mocks/src/common/utils.py new file mode 100644 index 00000000..b6e054c2 --- /dev/null +++ b/mocks/src/common/utils.py @@ -0,0 +1,8 @@ +import re + + +def check_valid_uuid4(string: str) -> bool: + uuid_regex = ( + r"^[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$" + ) + return re.match(uuid_regex, string) is not None diff --git a/mocks/src/pdm_mock/handler.py b/mocks/src/pdm_mock/handler.py index 20634200..c9355fe2 100644 --- a/mocks/src/pdm_mock/handler.py +++ b/mocks/src/pdm_mock/handler.py @@ -11,6 +11,7 @@ from aws_lambda_powertools.event_handler.router import APIGatewayHttpRouter from common.logging import get_logger from common.storage_helper import BaseMockItem, StorageHelper +from common.utils import check_valid_uuid4 PDM_TABLE_NAME = os.environ["PDM_TABLE_NAME"] BRANCH_NAME = os.environ["DDB_INDEX_TAG"] @@ -93,8 +94,11 @@ def _fetch_patient_from_payload(payload: dict[str, Any]) -> str | None: def handle_post_request(payload: dict[str, Any]) -> PDMResponse: if (patient := _fetch_patient_from_payload(payload)) in REQUEST_HANDLERS: + _logger.debug("Using magic patient id bypass, %s", patient) return REQUEST_HANDLERS[patient]() + _logger.debug("Not using magic patient id bypass") + document_id = str(uuid4()) created_document = { **payload, @@ -115,7 +119,7 @@ def handle_post_request(payload: dict[str, Any]) -> PDMResponse: _write_document_to_table(item) - return {"status_code": 200, "response": created_document} + return {"status_code": 201, "response": created_document} def handle_get_request(document_id: str) -> PDMResponse: @@ -127,10 +131,12 @@ def handle_get_request(document_id: str) -> PDMResponse: def _write_document_to_table(item: DocumentItem) -> None: + _logger.debug("Writing document to dynamodb table") storage_helper.put_item(item) def _get_document_from_table(document_id: str) -> DocumentItem: + _logger.debug("Retrieving document from dynamodb table") item = storage_helper.get_item_by_session_id(document_id) return cast("DocumentItem", item) @@ -151,6 +157,20 @@ def create_document() -> Response[str]: check_authenticated(request_headers) + _logger.debug("Passed Auth Check") + + x_request_id = request_headers.get("X-Request-ID") + if not x_request_id: + _logger.error("Missing X-Request-ID header.") + return _with_default_headers( + _create_operation_outcome(400, "Missing X-Request-ID header", "required") + ) + if not check_valid_uuid4(x_request_id): + _logger.error("Invalid X-Request-ID header. Value provided: %s", x_request_id) + return _with_default_headers( + _create_operation_outcome(400, "Invalid X-Request-ID header", "invalid") + ) + try: payload = pdm_routes.current_event.json_body except json.JSONDecodeError as err: @@ -174,7 +194,15 @@ def create_document() -> Response[str]: _logger.exception("Error handling PDM request") return Response(status_code=500, body=json.dumps({"error": str(err)})) - return _with_default_headers(response) + return Response( + body=json.dumps(response["response"]), + status_code=response["status_code"], + headers={ + "Content-Type": "application/fhir+json", + "x-request-id": x_request_id, + "etag": 'W/"1"', + }, + ) @pdm_routes.get("/pdm/mock/Bundle/") diff --git a/mocks/src/pdm_mock/test_handler.py b/mocks/src/pdm_mock/test_handler.py index c25d3b0f..eb119f0b 100644 --- a/mocks/src/pdm_mock/test_handler.py +++ b/mocks/src/pdm_mock/test_handler.py @@ -158,7 +158,7 @@ def test_handle_post_request( response = handler.handle_post_request(basic_document_payload) assert response == { - "status_code": 200, + "status_code": 201, "response": { "resourceType": "Bundle", "id": "uuid4", @@ -308,6 +308,7 @@ def test_create_document( event = self._create_test_event( path_params="pdm/FHIR/R4/Bundle", request_method="POST", + headers={"X-Request-ID": "8c64be5f-3d7a-4b7b-8260-b716d122bdaf"}, body=json.dumps({"test": "data"}), ) context = LambdaContext() @@ -315,7 +316,50 @@ def test_create_document( with patch("boto3.resource"): response = lambda_app.resolve(event, context) - assert response["statusCode"] == 200 + assert response["statusCode"] == 201 + + @pytest.mark.parametrize( + ("headers", "expected_issue_code", "expected_error_message"), + [ + pytest.param({}, "required", "Missing X-Request-ID header"), + pytest.param( + {"X-Request-ID": "invalid"}, "invalid", "Invalid X-Request-ID header" + ), + ], + ) + @patch("pdm_mock.handler.check_authenticated") + def test_pdm_invalid_or_missing_x_request_id( + self, + check_authenticated_mock: MagicMock, + headers: dict[str, str], + expected_issue_code: str, + expected_error_message: str, + lambda_app: APIGatewayHttpResolver, + ) -> None: + check_authenticated_mock.return_value = True + + event = self._create_test_event( + path_params="pdm/FHIR/R4/Bundle", + request_method="POST", + headers=headers, + body=json.dumps({"test": "data"}), + ) + context = LambdaContext() + response = lambda_app.resolve(event, context) + + assert response["statusCode"] == 400 + assert json.loads(response["body"]) == { + "resourceType": "OperationOutcome", + "issue": [ + { + "severity": "error", + "code": expected_issue_code, + "details": { + "text": expected_error_message, + }, + } + ], + } @patch("pdm_mock.handler.check_authenticated") def test_pdm_mock_failed_authentication( @@ -328,10 +372,11 @@ def test_pdm_mock_failed_authentication( event = self._create_test_event( path_params="pdm/FHIR/R4/Bundle", request_method="POST", + headers={"X-Request-ID": "8c64be5f-3d7a-4b7b-8260-b716d122bdaf"}, body=json.dumps({"test": "data"}), ) context = LambdaContext() - with pytest.raises(AuthenticationError, match=""): + with pytest.raises(AuthenticationError, match=r"^$"): lambda_app.resolve(event, context) @patch("pdm_mock.handler.check_authenticated") @@ -345,6 +390,7 @@ def test_create_document_invalid_body( event = self._create_test_event( path_params="pdm/FHIR/R4/Bundle", request_method="POST", + headers={"X-Request-ID": "8c64be5f-3d7a-4b7b-8260-b716d122bdaf"}, body="Invalid Body", ) context = LambdaContext() @@ -376,6 +422,7 @@ def test_create_document_invalid_payload( event = self._create_test_event( path_params="pdm/FHIR/R4/Bundle", request_method="POST", + headers={"X-Request-ID": "8c64be5f-3d7a-4b7b-8260-b716d122bdaf"}, body="", ) context = LambdaContext() @@ -412,6 +459,7 @@ def test_pdm_mock_create_document_internal_server_error( event = self._create_test_event( path_params="pdm/FHIR/R4/Bundle", request_method="POST", + headers={"X-Request-ID": "8c64be5f-3d7a-4b7b-8260-b716d122bdaf"}, body=json.dumps({"test": "data"}), ) diff --git a/pathology-api/lambda_handler.py b/pathology-api/lambda_handler.py index 3cbb93c5..785e7d00 100644 --- a/pathology-api/lambda_handler.py +++ b/pathology-api/lambda_handler.py @@ -2,6 +2,7 @@ from functools import reduce from json import JSONDecodeError from typing import Any +from uuid import uuid4 import pydantic from aws_lambda_powertools.event_handler import ( @@ -14,7 +15,11 @@ from pathology_api.handler import handle_request from pathology_api.logging import get_logger from pathology_api.mns import MnsException -from pathology_api.request_context import reset_correlation_id, set_correlation_id +from pathology_api.pdm import PdmException +from pathology_api.request_context import ( + reset_correlation_id, + set_correlation_id, +) _logger = get_logger(__name__) _CORRELATION_ID_HEADER = "nhsd-correlation-id" @@ -112,6 +117,15 @@ def handle_exception(exception: Exception) -> Response[str]: ) +@_exception_handler(PdmException) +def handle_pdm_excepton(exception: PdmException) -> Response[str]: + _logger.exception("PDMClientError encountered: %s", exception) + return _with_default_headers( + status_code=500, + body=OperationOutcome.create_validation_error(exception.message), + ) + + @_exception_handler(MnsException) def handle_mns_exception(exception: MnsException) -> Response[str]: _logger.exception("Failed to publish MNS event: %s", exception) @@ -123,6 +137,13 @@ def handle_mns_exception(exception: MnsException) -> Response[str]: @app.get("/_status") def status() -> Response[str]: + pathology_api_correlation_id = str(uuid4()) + + set_correlation_id( + full_id=pathology_api_correlation_id, + short_id=pathology_api_correlation_id, + ) + _logger.debug("Status check endpoint called") return Response( status_code=200, @@ -133,12 +154,20 @@ def status() -> Response[str]: @app.post("/FHIR/R4/Bundle") def post_result() -> Response[str]: - correlation_id = app.current_event.headers.get(_CORRELATION_ID_HEADER) + correlation_id_header = app.current_event.headers.get(_CORRELATION_ID_HEADER) - if not correlation_id: + pathology_api_correlation_id = str(uuid4()) + if not correlation_id_header: + set_correlation_id( + full_id=pathology_api_correlation_id, + short_id=pathology_api_correlation_id, + ) raise ValueError(f"Missing required header: {_CORRELATION_ID_HEADER}") - set_correlation_id(correlation_id) + set_correlation_id( + full_id=f"{correlation_id_header}.{pathology_api_correlation_id}", + short_id=pathology_api_correlation_id, + ) _logger.debug("Post result endpoint called.") try: @@ -155,11 +184,12 @@ def post_result() -> Response[str]: bundle = Bundle.model_validate(payload, by_alias=True) - response = handle_request(bundle) + pdm_response = handle_request(bundle) - return _with_default_headers( + return Response( status_code=200, - body=response, + headers={"Content-Type": "application/fhir+json", "etag": pdm_response.etag}, + body=pdm_response.bundle.model_dump_json(by_alias=True, exclude_none=True), ) diff --git a/pathology-api/src/pathology_api/handler.py b/pathology-api/src/pathology_api/handler.py index b73c782e..1ed50626 100644 --- a/pathology-api/src/pathology_api/handler.py +++ b/pathology-api/src/pathology_api/handler.py @@ -1,9 +1,6 @@ -import uuid - from pathology_api.exception import ValidationError from pathology_api.fhir.r4.elements import ( LiteralReference, - Meta, OrganizationIdentifier, ReferenceExtension, ) @@ -16,6 +13,7 @@ ) from pathology_api.logging import get_logger from pathology_api.mns import create_event +from pathology_api.pdm import PdmResponse, post_document _logger = get_logger(__name__) @@ -125,7 +123,7 @@ def _fetch_requesting_organisation( return organisation_identifiers[0] -def handle_request(bundle: Bundle) -> Bundle: +def handle_request(bundle: Bundle) -> PdmResponse: _logger.debug("Bundle entries: %s", bundle.entries) _validate_bundle(bundle) @@ -146,22 +144,16 @@ def handle_request(bundle: Bundle) -> Bundle: if subject is None: raise ValidationError("Composition does not define a valid subject identifier") - return_bundle = Bundle.create( - id=str(uuid.uuid4()), - meta=Meta.with_last_updated(), - identifier=bundle.identifier, - type=bundle.bundle_type, - entry=bundle.entries, - ) - _logger.debug("Return bundle: %s", return_bundle) + pdm_response = post_document(bundle) + _logger.debug("Return bundle: %s", pdm_response.bundle) - if return_bundle.id is None: + if pdm_response.bundle.id is None: raise ValueError("Bundle returned from PDM does not include an ID.") create_event( requesting_org=requesting_organisation.value, nhs_number=subject.identifier.value, - bundle_id=return_bundle.id, + bundle_id=pdm_response.bundle.id, ) - return return_bundle + return pdm_response diff --git a/pathology-api/src/pathology_api/http.py b/pathology-api/src/pathology_api/http.py index 5c8ec272..f7f0022e 100644 --- a/pathology-api/src/pathology_api/http.py +++ b/pathology-api/src/pathology_api/http.py @@ -43,7 +43,7 @@ def send( ) -> requests.Response: kwargs["timeout"] = self._timeout if "X-Correlation-ID" not in request.headers: - request.headers["X-Correlation-ID"] = get_correlation_id() + request.headers["X-Correlation-ID"] = get_correlation_id().short_id _logger.info( "Sending HTTP request. method=%s url=%s headers=%s", diff --git a/pathology-api/src/pathology_api/logging.py b/pathology-api/src/pathology_api/logging.py index fc59087e..1d1068c4 100644 --- a/pathology-api/src/pathology_api/logging.py +++ b/pathology-api/src/pathology_api/logging.py @@ -10,7 +10,7 @@ class _CorrelationIdFilter(logging.Filter): """Injects the current correlation ID into every log record.""" def filter(self, record: logging.LogRecord) -> bool: - record.correlation_id = get_correlation_id() + record.correlation_id = get_correlation_id().full_id return True diff --git a/pathology-api/src/pathology_api/pdm.py b/pathology-api/src/pathology_api/pdm.py new file mode 100644 index 00000000..433b1804 --- /dev/null +++ b/pathology-api/src/pathology_api/pdm.py @@ -0,0 +1,62 @@ +from typing import Any, NamedTuple +from uuid import uuid4 + +import requests + +from pathology_api import environment +from pathology_api.fhir.r4.resources import Bundle +from pathology_api.logging import get_logger + +_logger = get_logger(__name__) + + +class PdmException(Exception): + """ + Custom exception for validation errors in the PDM Client. + Note that any message here will be provided in the error response returned to users. + """ + + def __init__(self, message: str): + self.message = message + super().__init__(message) + + +class PdmResponse(NamedTuple): + bundle: Bundle + etag: str + + +@environment.apim_authenticator().auth +def _make_post_request(session: requests.Session, document: Bundle) -> Any: + response = session.post( + url=environment.values()["pdm_url"], + data=document.model_dump_json(by_alias=True, exclude_none=True), + headers={"Content-Type": "application/fhir+json", "X-Request-ID": str(uuid4())}, + ) + + return response + + +def post_document(document: Bundle) -> PdmResponse: + + response = _make_post_request(document) + + _logger.debug( + "Result of post request. status_code=%s data=%s", + response.status_code, + response.text, + ) + + if response.status_code == 201: + returned_document = response.json() + etag = response.headers.get("etag") + pdm_response = PdmResponse( + Bundle.model_validate(returned_document, by_alias=True), etag + ) + return pdm_response + elif response.status_code == 401: + raise PdmException("An unexpected internal server error has occured") + # all other responses including 5xx and 4xx return same format for now + else: + pdm_error = response.text + raise PdmException(f"Failed to send document: {pdm_error}") diff --git a/pathology-api/src/pathology_api/request_context.py b/pathology-api/src/pathology_api/request_context.py index 365d481a..8f2ef586 100644 --- a/pathology-api/src/pathology_api/request_context.py +++ b/pathology-api/src/pathology_api/request_context.py @@ -1,18 +1,29 @@ from contextvars import ContextVar +from typing import NamedTuple -_correlation_id: ContextVar[str] = ContextVar("correlation_id", default="") +class CorrelationID(NamedTuple): + full_id: str + short_id: str -def set_correlation_id(value: str) -> None: + +_correlation_id: ContextVar[CorrelationID | None] = ContextVar( + "correlation_id", default=None +) + + +def set_correlation_id(full_id: str, short_id: str) -> None: """Set the correlation ID for the current request context.""" - _correlation_id.set(value) + _correlation_id.set(CorrelationID(full_id=full_id, short_id=short_id)) def reset_correlation_id() -> None: """Reset the correlation ID to the default empty string.""" - _correlation_id.set("") + _correlation_id.set(None) -def get_correlation_id() -> str: +def get_correlation_id() -> CorrelationID: """Get the correlation ID for the current request context.""" - return _correlation_id.get() + if (correlation_id := _correlation_id.get()) is None: + raise ValueError("Correlation ID is not set in the current context.") + return correlation_id diff --git a/pathology-api/src/pathology_api/test_apim.py b/pathology-api/src/pathology_api/test_apim.py index bd3c663b..99bb8d0d 100644 --- a/pathology-api/src/pathology_api/test_apim.py +++ b/pathology-api/src/pathology_api/test_apim.py @@ -1,4 +1,4 @@ -from collections.abc import Callable +from collections.abc import Callable, Generator from datetime import datetime, timedelta, timezone from typing import Any from unittest.mock import MagicMock, Mock, patch @@ -8,9 +8,19 @@ from jwt import InvalidKeyError from pathology_api.apim import ApimAuthenticationException, ApimAuthenticator +from pathology_api.request_context import reset_correlation_id, set_correlation_id class TestApimAuthenticator: + @pytest.fixture(autouse=True) + def set_correlation_id_for_logger(self) -> Generator[None, None, None]: + set_correlation_id( + full_id="test_id_long", + short_id="test_id", + ) + yield + reset_correlation_id() + def setup_method(self) -> None: self.mock_session = Mock() diff --git a/pathology-api/src/pathology_api/test_handler.py b/pathology-api/src/pathology_api/test_handler.py index 53337026..039c77a4 100644 --- a/pathology-api/src/pathology_api/test_handler.py +++ b/pathology-api/src/pathology_api/test_handler.py @@ -1,5 +1,5 @@ import datetime -from collections.abc import Callable +from collections.abc import Callable, Generator from typing import Any from unittest.mock import patch @@ -11,6 +11,7 @@ Extension, LiteralReference, LogicalReference, + Meta, OrganizationIdentifier, PatientIdentifier, ReferenceExtension, @@ -22,14 +23,17 @@ PractitionerRole, ServiceRequest, ) +from pathology_api.request_context import reset_correlation_id, set_correlation_id from pathology_api.test_utils import BundleBuilder with ( patch("pathology_api.environment.apim_authenticator"), patch("pathology_api.mns.create_event") as create_event_mock, + patch("pathology_api.pdm.post_document") as post_document_mock, ): from pathology_api.handler import handle_request from pathology_api.mns import MnsException + from pathology_api.pdm import PdmException, PdmResponse def _missing_resource_scenarios() -> list[Any]: @@ -212,6 +216,15 @@ def _invalid_organization_scenarios() -> list[Any]: class TestHandleRequest: + @pytest.fixture(autouse=True) + def set_correlation_id_for_logger(self) -> Generator[None, None, None]: + set_correlation_id( + full_id="test_id_long", + short_id="test_id", + ) + yield + reset_correlation_id() + def _build_valid_test_result(self) -> Bundle: organisation_entry = Bundle.Entry( fullUrl="organisation", @@ -263,10 +276,21 @@ def test_handle_request( ) -> None: # Arrange bundle = build_valid_test_result("nhs_number_1", "ods_code") + expected_bundle = Bundle.create( + id="generated_id", + type="document", + meta=Meta( + last_updated=datetime.datetime.now(tz=datetime.timezone.utc), + version_id="1", + ), + entry=bundle.entries, + ) + expected_etag = "generated_etag" + + post_document_mock.return_value = PdmResponse(expected_bundle, expected_etag) - before_call = datetime.datetime.now(tz=datetime.timezone.utc) - result_bundle = handle_request(bundle) - after_call = datetime.datetime.now(tz=datetime.timezone.utc) + pdm_response = handle_request(bundle) + result_bundle = pdm_response.bundle assert result_bundle is not None @@ -280,10 +304,11 @@ def test_handle_request( created_meta = result_bundle.meta assert created_meta.last_updated is not None - assert before_call <= created_meta.last_updated - assert created_meta.last_updated <= after_call + assert created_meta.version_id == "1" - assert created_meta.version_id is None + assert pdm_response.etag == "generated_etag" + + post_document_mock.assert_called_with(bundle) create_event_mock.assert_called_once_with( requesting_org="ods_code", @@ -291,7 +316,7 @@ def test_handle_request( bundle_id=result_bundle.id, ) - def test_handle_request_raises_error_when_create_bundle_fails( + def test_handle_request_raises_error_when_create_event_fails( self, build_valid_test_result: Callable[[str, str], Bundle] ) -> None: # Arrange @@ -303,6 +328,19 @@ def test_handle_request_raises_error_when_create_bundle_fails( with pytest.raises(MnsException, match=expected_error_message): handle_request(bundle) + def test_handle_request_raises_error_when_post_request_fails( + self, + build_valid_test_result: Callable[[str, str], Bundle], + ) -> None: + # Arrange + bundle = build_valid_test_result("nhs_number_1", "ods_code") + + expected_error_message = "An unexpected internal server error has occured" + post_document_mock.side_effect = PdmException(expected_error_message) + + with pytest.raises(PdmException, match=expected_error_message): + handle_request(bundle) + @pytest.mark.parametrize( ("bundle", "expected_error_message"), _missing_resource_scenarios() ) diff --git a/pathology-api/src/pathology_api/test_http.py b/pathology-api/src/pathology_api/test_http.py index 319a7883..f34b5ae2 100644 --- a/pathology-api/src/pathology_api/test_http.py +++ b/pathology-api/src/pathology_api/test_http.py @@ -148,8 +148,9 @@ def test_adapter_applies_defaults(self) -> None: with patch.object( requests.adapters.HTTPAdapter, "send", autospec=True ) as mock_send: - expected_correlation_id = "correaltion-id" - set_correlation_id(expected_correlation_id) + set_correlation_id( + full_id="test-correlation-id-long", short_id="test-correlation-id" + ) expected_timeout = timedelta(seconds=30) adapter = SessionManager._Adapter(timeout=expected_timeout.total_seconds()) # noqa: SLF001 - Private access for testing @@ -170,7 +171,7 @@ def test_adapter_applies_defaults(self) -> None: timeout=expected_timeout.total_seconds(), ) - assert mock_request.headers["X-Correlation-ID"] == expected_correlation_id + assert mock_request.headers["X-Correlation-ID"] == "test-correlation-id" def test_adapter_overrides_defaults(self) -> None: with patch.object( @@ -209,7 +210,9 @@ def test_adapter_request_error(self) -> None: mock_request = Mock() mock_request.headers = {} - set_correlation_id("test-correlation-id") + set_correlation_id( + full_id="test-correlation-id-long", short_id="test-correlation-id" + ) with pytest.raises(requests.RequestException, match="request failed"): adapter.send(mock_request) diff --git a/pathology-api/src/pathology_api/test_logging.py b/pathology-api/src/pathology_api/test_logging.py index b2c577ee..09b1199e 100644 --- a/pathology-api/src/pathology_api/test_logging.py +++ b/pathology-api/src/pathology_api/test_logging.py @@ -5,7 +5,11 @@ from aws_lambda_powertools import Logger from pathology_api.logging import LogProvider, _CorrelationIdFilter, get_logger -from pathology_api.request_context import reset_correlation_id, set_correlation_id +from pathology_api.request_context import ( + CorrelationID, + reset_correlation_id, + set_correlation_id, +) def _make_log_record() -> logging.LogRecord: @@ -28,65 +32,79 @@ def setup_method(self) -> None: def test_filter_is_a_logging_filter_subclass(self) -> None: assert issubclass(_CorrelationIdFilter, logging.Filter) - def test_filter_always_returns_true(self) -> None: + def test_filter_always_returns_true( + self, + ) -> None: + set_correlation_id(full_id="abc-123-long", short_id="abc-123") + f = _CorrelationIdFilter() record = _make_log_record() assert f.filter(record) is True - def test_filter_injects_empty_string_when_no_correlation_id_set(self) -> None: + def test_filter_raises_exception_when_no_correlation_id_set(self) -> None: f = _CorrelationIdFilter() record = _make_log_record() - f.filter(record) - assert record.correlation_id == "" # type: ignore[attr-defined] + with pytest.raises( + ValueError, match="Correlation ID is not set in the current context." + ): + f.filter(record) def test_filter_injects_active_correlation_id(self) -> None: f = _CorrelationIdFilter() record = _make_log_record() - set_correlation_id("abc-123") + set_correlation_id(full_id="abc-123-long", short_id="abc-123") f.filter(record) - assert record.correlation_id == "abc-123" # type: ignore[attr-defined] + assert record.correlation_id == "abc-123-long" # type: ignore[attr-defined] def test_filter_injects_empty_string_after_correlation_id_reset( self, ) -> None: f = _CorrelationIdFilter() - set_correlation_id("to-be-cleared") + set_correlation_id(full_id="to-be-cleared-long", short_id="to-be-cleared") record_during = _make_log_record() f.filter(record_during) - assert record_during.correlation_id == "to-be-cleared" # type: ignore[attr-defined] + assert record_during.correlation_id == "to-be-cleared-long" # type: ignore[attr-defined] reset_correlation_id() record_after = _make_log_record() - f.filter(record_after) - assert record_after.correlation_id == "" # type: ignore[attr-defined] + with pytest.raises( + ValueError, match="Correlation ID is not set in the current context." + ): + f.filter(record_after) - def test_filter_uses_get_correlation_id(self) -> None: + def test_filter_uses_full_correlation_id(self) -> None: f = _CorrelationIdFilter() record = _make_log_record() + correlation_id = CorrelationID(full_id="mocked-id-long", short_id="mocked-id") with patch( - "pathology_api.logging.get_correlation_id", return_value="mocked-id" + "pathology_api.logging.get_correlation_id", + return_value=correlation_id, ) as mock_fn: f.filter(record) mock_fn.assert_called_once() - assert record.correlation_id == "mocked-id" # type: ignore[attr-defined] + assert record.correlation_id == "mocked-id-long" # type: ignore[attr-defined] def test_filter_overwrites_existing_correlation_id_attribute(self) -> None: f = _CorrelationIdFilter() record = _make_log_record() record.correlation_id = "old-id" - set_correlation_id("new-id") + set_correlation_id(full_id="new-id-long", short_id="new-id") f.filter(record) reset_correlation_id() - assert record.correlation_id == "new-id" # type: ignore[attr-defined] + assert record.correlation_id == "new-id-long" # type: ignore[attr-defined] def test_filter_handles_different_correlation_id_values(self) -> None: f = _CorrelationIdFilter() - values = ["uuid-1234-5678", "X-Corr-99", "a" * 100] + values: list[dict[str, str]] = [ + {"full_id": "uuid-1234-5678-long", "short_id": "uuid-1234"}, + {"full_id": "X-Corr-99-long", "short_id": "X-Corr-99"}, + {"full_id": "a" * 110, "short_id": "a" * 100}, + ] for value in values: record = _make_log_record() - set_correlation_id(value) + set_correlation_id(**value) f.filter(record) reset_correlation_id() - assert record.correlation_id == value # type: ignore[attr-defined] + assert record.correlation_id == value["full_id"] # type: ignore[attr-defined] class TestLogProvider: @@ -135,14 +153,17 @@ def test_correlation_id_filter_is_applied_to_log_records(self) -> None: get_logger("service-filter-applied") stdlib_logger = logging.getLogger("service-filter-applied") record = _make_log_record() - set_correlation_id("applied-id") + set_correlation_id(full_id="applied-id-long", short_id="applied-id") stdlib_logger.filter(record) reset_correlation_id() - assert record.correlation_id == "applied-id" # type: ignore[attr-defined] + assert record.correlation_id == "applied-id-long" # type: ignore[attr-defined] def test_correlation_id_filter_injects_empty_string_by_default(self) -> None: get_logger("service-filter-empty") stdlib_logger = logging.getLogger("service-filter-empty") record = _make_log_record() - stdlib_logger.filter(record) - assert record.correlation_id == "" # type: ignore[attr-defined] + + with pytest.raises( + ValueError, match="Correlation ID is not set in the current context." + ): + stdlib_logger.filter(record) diff --git a/pathology-api/src/pathology_api/test_mns.py b/pathology-api/src/pathology_api/test_mns.py index 230b8a49..91edcb7c 100644 --- a/pathology-api/src/pathology_api/test_mns.py +++ b/pathology-api/src/pathology_api/test_mns.py @@ -1,6 +1,6 @@ import importlib import uuid -from collections.abc import Callable +from collections.abc import Callable, Generator from datetime import datetime, timezone from json import JSONDecodeError from typing import Any @@ -9,6 +9,8 @@ import pytest import requests +from pathology_api.request_context import reset_correlation_id, set_correlation_id + mock_session = Mock() @@ -37,6 +39,15 @@ class TestMns: def setup_method(self) -> None: mock_session.reset_mock(return_value=True, side_effect=True) + @pytest.fixture(autouse=True) + def set_correlation_id_for_logger(self) -> Generator[None, None, None]: + set_correlation_id( + full_id="test_id_long", + short_id="test_id", + ) + yield + reset_correlation_id() + @patch("pathology_api.environment.values") @patch("pathology_api.mns.datetime") @patch("pathology_api.mns.uuid") diff --git a/pathology-api/src/pathology_api/test_pdm.py b/pathology-api/src/pathology_api/test_pdm.py new file mode 100644 index 00000000..9ea8036f --- /dev/null +++ b/pathology-api/src/pathology_api/test_pdm.py @@ -0,0 +1,200 @@ +import datetime +import importlib +from collections.abc import Callable, Generator +from typing import Any +from unittest.mock import Mock, patch + +import pytest + +from pathology_api.fhir.r4.elements import LogicalReference, PatientIdentifier +from pathology_api.fhir.r4.resources import Bundle, Composition +from pathology_api.request_context import reset_correlation_id, set_correlation_id + +mock_session = Mock() + + +def _mock_auth() -> Callable[..., Any]: + def _auth_decorator(func: Callable[..., Any]) -> Callable[..., Any]: + def wrapper(*args: Any, **kwargs: Any) -> Any: + return func(mock_session, *args, **kwargs) + + return wrapper + + return _auth_decorator + + +with patch("pathology_api.environment.apim_authenticator") as apim_authenticator_mock: + import pathology_api.pdm + + apim_authenticator_mock.return_value.auth = _mock_auth() + + # Reload the module to ensure the patched authenticator is used in case it has + # already been imported + importlib.reload(pathology_api.pdm) + from pathology_api.pdm import PdmException, post_document + + +@pytest.fixture +def default_returned_bundle() -> dict[str, Any]: + return { + "type": "document", + "resourceType": "Bundle", + "id": "8aa429d6-6281-481c-ae31-df043636245e", + "meta": { + "last_updated": datetime.datetime.now(tz=datetime.timezone.utc), + "version_id": "1", + }, + "entry": [ + { + "fullUrl": "patient", + "resource": { + "id": None, + "meta": None, + "resourceType": "Composition", + "subject": { + "identifier": { + "system": "https://fhir.nhs.uk/Id/nhs-number", + "value": "nhs_number", + } + }, + }, + } + ], + } + + +class TestPDMClient: + def setup_method(self) -> None: + mock_session.reset_mock(return_value=True, side_effect=True) + + @pytest.fixture(autouse=True) + def set_correlation_id_for_logger(self) -> Generator[None, None, None]: + set_correlation_id( + full_id="test_id_long", + short_id="test_id", + ) + yield + reset_correlation_id() + + @patch("pathology_api.pdm.uuid4") + def test_post_document_success( + self, mock_uuid: Mock, default_returned_bundle: dict[str, Any] + ) -> None: + + ## Arrange + mock_session.post.return_value.status_code = 201 + mock_session.post.return_value.json.return_value = default_returned_bundle + mock_session.post.return_value.headers = {"etag": 'W/"1"'} + mock_uuid.return_value = "x_request_id" + + bundle = Bundle.create( + type="document", + entry=[ + Bundle.Entry( + fullUrl="patient", + resource=Composition.create( + subject=LogicalReference( + PatientIdentifier.from_nhs_number("nhs_number") + ) + ), + ) + ], + ) + + response = post_document(bundle) + + result_bundle = response.bundle + assert result_bundle is not None + assert type(result_bundle) is Bundle + assert result_bundle.id is not None + + assert result_bundle.bundle_type == bundle.bundle_type + assert result_bundle.entries == bundle.entries + assert result_bundle.meta is not None + + result_meta = result_bundle.meta + + assert result_meta.last_updated is not None + assert result_meta.version_id == "1" + + assert response.etag == 'W/"1"' + + mock_session.post.assert_called_once_with( + url="pdm_url", + data=bundle.model_dump_json(by_alias=True, exclude_none=True), + headers={ + "Content-Type": "application/fhir+json", + "X-Request-ID": "x_request_id", + }, + ) + + def test_post_document_401(self) -> None: + + mock_session.post.return_value.status_code = 401 + mock_session.post.return_value.text = "error message" + + bundle = Bundle.create( + type="document", + entry=[ + Bundle.Entry( + fullUrl="patient", + resource=Composition.create( + subject=LogicalReference( + PatientIdentifier.from_nhs_number("nhs_number") + ) + ), + ) + ], + ) + + with pytest.raises( + PdmException, match="An unexpected internal server error has occured" + ): + post_document(bundle) + + def test_post_document_4xx(self) -> None: + + mock_session.post.return_value.status_code = 400 + mock_session.post.return_value.text = "error message" + + bundle = Bundle.create( + type="document", + entry=[ + Bundle.Entry( + fullUrl="patient", + resource=Composition.create( + subject=LogicalReference( + PatientIdentifier.from_nhs_number("nhs_number") + ) + ), + ) + ], + ) + + with pytest.raises( + PdmException, match="Failed to send document: error message" + ): + post_document(bundle) + + def test_post_document_5xx(self) -> None: + mock_session.post.return_value.status_code = 500 + mock_session.post.return_value.text = "error message" + + bundle = Bundle.create( + type="document", + entry=[ + Bundle.Entry( + fullUrl="patient", + resource=Composition.create( + subject=LogicalReference( + PatientIdentifier.from_nhs_number("nhs_number") + ) + ), + ) + ], + ) + + with pytest.raises( + PdmException, match="Failed to send document: error message" + ): + post_document(bundle) diff --git a/pathology-api/src/pathology_api/test_request_context.py b/pathology-api/src/pathology_api/test_request_context.py index 48f5256a..ebd98631 100644 --- a/pathology-api/src/pathology_api/test_request_context.py +++ b/pathology-api/src/pathology_api/test_request_context.py @@ -1,6 +1,9 @@ import threading +import pytest + from pathology_api.request_context import ( + CorrelationID, get_correlation_id, reset_correlation_id, set_correlation_id, @@ -9,36 +12,55 @@ class TestSetAndGetCorrelationId: def test_correlation_id_is_set_and_retrieved(self) -> None: - set_correlation_id("round-trip-test-123") - assert get_correlation_id() == "round-trip-test-123" + set_correlation_id( + full_id="round-trip-test-123-long", + short_id="round-trip-test-123", + ) + correlation_id = get_correlation_id() + assert correlation_id.full_id == "round-trip-test-123-long" + assert correlation_id.short_id == "round-trip-test-123" reset_correlation_id() def test_correlation_id_is_cleared_after_reset(self) -> None: - set_correlation_id("round-trip-test-123") + set_correlation_id( + full_id="round-trip-test-123-long", + short_id="round-trip-test-123", + ) reset_correlation_id() - assert get_correlation_id() == "" + with pytest.raises( + ValueError, match="Correlation ID is not set in the current context." + ): + get_correlation_id() - def test_default_correlation_id_is_empty_string(self) -> None: - assert get_correlation_id() == "" + def test_default_correlation_throws_error(self) -> None: + with pytest.raises( + ValueError, match="Correlation ID is not set in the current context." + ): + get_correlation_id() def test_correlation_id_is_cleared_when_reset_called_after_exception( self, ) -> None: try: - set_correlation_id("will-be-cleared") + set_correlation_id( + full_id="will-be-cleared-long", short_id="will-be-cleared" + ) raise RuntimeError("simulated mid-handler failure") except RuntimeError: pass finally: reset_correlation_id() - assert get_correlation_id() == "" + with pytest.raises( + ValueError, match="Correlation ID is not set in the current context." + ): + get_correlation_id() def test_correlation_id_does_not_bleed_between_threads(self) -> None: - results: dict[str, str] = {} + results: dict[str, CorrelationID] = {} def thread_a() -> None: - set_correlation_id("thread-a-id") + set_correlation_id(full_id="thread-a-id-long", short_id="thread-a-id") import time time.sleep(0.05) @@ -46,8 +68,9 @@ def thread_a() -> None: reset_correlation_id() def thread_b() -> None: - set_correlation_id("thread-b-id") + set_correlation_id(full_id="thread-b-id-long", short_id="thread-b-id") results["b"] = get_correlation_id() + reset_correlation_id() t_a = threading.Thread(target=thread_a) @@ -57,5 +80,7 @@ def thread_b() -> None: t_a.join() t_b.join() - assert results["a"] == "thread-a-id" - assert results["b"] == "thread-b-id" + assert results["a"].full_id == "thread-a-id-long" + assert results["a"].short_id == "thread-a-id" + assert results["b"].full_id == "thread-b-id-long" + assert results["b"].short_id == "thread-b-id" diff --git a/pathology-api/test_lambda_handler.py b/pathology-api/test_lambda_handler.py index fe8608be..4e5bca5c 100644 --- a/pathology-api/test_lambda_handler.py +++ b/pathology-api/test_lambda_handler.py @@ -12,6 +12,7 @@ ): from lambda_handler import handler from pathology_api.mns import MnsException + from pathology_api.pdm import PdmException, PdmResponse from pathology_api.exception import ValidationError from pathology_api.fhir.r4.elements import Meta @@ -77,13 +78,14 @@ def test_create_test_result_success( post_event: dict[str, Any], context: LambdaContext, ) -> None: - expected_response = Bundle.create( + expected_bundle = Bundle.create( id="test-id", type="document", meta=Meta.with_last_updated(), entry=bundle.entries, ) - handle_request_mock.return_value = expected_response + expected_etag = 'W/"1"' + handle_request_mock.return_value = PdmResponse(expected_bundle, expected_etag) response = handler(post_event, context) @@ -94,11 +96,18 @@ def test_create_test_result_success( assert isinstance(response_body, str) response_bundle = Bundle.model_validate_json(response_body, by_alias=True) - assert response_bundle == expected_response + assert response_bundle == expected_bundle + @patch("lambda_handler.uuid4") def test_correlation_id_is_set_on_all_log_records_during_request( - self, caplog: pytest.LogCaptureFixture, bundle: Bundle, context: LambdaContext + self, + uuid_mock: MagicMock, + caplog: pytest.LogCaptureFixture, + bundle: Bundle, + context: LambdaContext, ) -> None: + uuid_mock.return_value = "test_uuid" + event = self._create_test_event( body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", @@ -116,13 +125,19 @@ def test_correlation_id_is_set_on_all_log_records_during_request( for record in caplog.records: assert ( getattr(record, "correlation_id", None) - == "b876145d-1ebf-4e22-8ff8-275b570c1123" + == "b876145d-1ebf-4e22-8ff8-275b570c1123.test_uuid" ) + @patch("lambda_handler.uuid4") def test_correlation_id_is_cleared_after_request( - self, caplog: pytest.LogCaptureFixture, bundle: Bundle, context: LambdaContext + self, + uuid_mock: MagicMock, + caplog: pytest.LogCaptureFixture, + bundle: Bundle, + context: LambdaContext, ) -> None: # First request sets a correlation ID + uuid_mock.return_value = "test_uuid" event = self._create_test_event( body=bundle.model_dump_json(by_alias=True), path_params="FHIR/R4/Bundle", @@ -134,7 +149,15 @@ def test_correlation_id_is_cleared_after_request( caplog.at_level(logging.DEBUG), ): handler(event, context) + + for record in caplog.records: + assert ( + getattr(record, "correlation_id", None) + == "c876145d-1ebf-4e22-8ff8-275b570c1ec4.test_uuid" + ) + caplog.clear() + uuid_mock.return_value = "different_uuid" # Second request with a different correlation ID — no bleed-through event2 = self._create_test_event( @@ -152,7 +175,7 @@ def test_correlation_id_is_cleared_after_request( for record in caplog.records: assert ( getattr(record, "correlation_id", None) - == "d876145d-1ebf-4e22-8ff8-275b570c1ec4" + == "d876145d-1ebf-4e22-8ff8-275b570c1ec4.different_uuid" ) def test_missing_correlation_id_header_returns_500( @@ -209,8 +232,10 @@ def test_correlation_id_is_cleared_after_exception_mid_handler( ) handler(event, context) - - assert get_correlation_id() == "" + with pytest.raises( + ValueError, match="Correlation ID is not set in the current context." + ): + get_correlation_id() def test_create_test_result_no_payload(self, context: LambdaContext) -> None: event = self._create_test_event( @@ -307,6 +332,16 @@ def test_create_test_result_invalid_json(self, context: LambdaContext) -> None: 500, id="MnsException", ), + pytest.param( + PdmException("Test PDM error"), + { + "severity": "error", + "code": "invalid", + "diagnostics": "Test PDM error", + }, + 500, + id="PdmException", + ), ], ) @patch("lambda_handler.handle_request") diff --git a/pathology-api/tests/integration/test_endpoints.py b/pathology-api/tests/integration/test_endpoints.py index c11e4a2e..08b76a76 100644 --- a/pathology-api/tests/integration/test_endpoints.py +++ b/pathology-api/tests/integration/test_endpoints.py @@ -46,7 +46,9 @@ def test_bundle_returns_200( assert response_bundle.meta is not None response_meta = response_bundle.meta assert response_meta.last_updated is not None - assert response_meta.version_id is None + assert response_meta.version_id == "1" + + assert response.headers["etag"] == 'W/"1"' def test_no_payload_returns_error(self, client: Client) -> None: response = client.send_without_payload(