From 9f0296f903752a7b9c803e8ac4802693d7ccc109 Mon Sep 17 00:00:00 2001 From: Morgan Wowk Date: Fri, 24 Apr 2026 12:07:23 -0700 Subject: [PATCH] fix: Tests - Update middleware tests to use Starlette 1.0.0 routing API --- ...test_instrumentation_request_middleware.py | 56 +++++++------------ tests/test_request_id_concurrency.py | 20 +++---- 2 files changed, 26 insertions(+), 50 deletions(-) diff --git a/tests/test_instrumentation_request_middleware.py b/tests/test_instrumentation_request_middleware.py index 70dfeb3..1b44db1 100644 --- a/tests/test_instrumentation_request_middleware.py +++ b/tests/test_instrumentation_request_middleware.py @@ -5,6 +5,7 @@ from starlette.requests import Request from starlette.responses import Response from starlette.applications import Starlette +from starlette.routing import Route from starlette.testclient import TestClient from cloud_pipelines_backend.instrumentation import contextual_logging @@ -78,12 +79,8 @@ def teardown_method(self): def test_middleware_generates_request_id(self): """Test that middleware generates a request_id for each request.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - request_ids_seen = [] - @app.route("/test") def test_route(request): # Capture the request_id during request processing request_ids_seen.append( @@ -91,6 +88,8 @@ def test_route(request): ) return Response("ok") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) response = client.get("/test") @@ -101,13 +100,11 @@ def test_route(request): def test_middleware_adds_request_id_to_response_headers(self): """Test that middleware adds request_id to response headers.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - - @app.route("/test") def test_route(request): return Response("ok") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) response = client.get("/test") @@ -118,14 +115,12 @@ def test_route(request): def test_middleware_clears_request_id_after_request(self): """Test that middleware clears request_id after request completes.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - - @app.route("/test") def test_route(request): assert contextual_logging.get_context_metadata("request_id") is not None return Response("ok") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) # Before request @@ -140,13 +135,11 @@ def test_route(request): def test_middleware_generates_unique_request_ids(self): """Test that middleware generates unique request_ids for each request.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - - @app.route("/test") def test_route(request): return Response("ok") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) # Make multiple requests @@ -160,17 +153,15 @@ def test_route(request): def test_middleware_request_id_available_in_route(self): """Test that request_id set by middleware is available in route handler.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - captured_request_id = None - @app.route("/test") def test_route(request): nonlocal captured_request_id captured_request_id = contextual_logging.get_context_metadata("request_id") return Response(f"request_id: {captured_request_id}") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) response = client.get("/test") @@ -180,10 +171,6 @@ def test_route(request): def test_middleware_handles_exception_in_route(self): """Test that middleware clears request_id even when route raises exception.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - - @app.route("/test") def test_route(request): request_id_during_exception = contextual_logging.get_context_metadata( "request_id" @@ -191,6 +178,8 @@ def test_route(request): assert request_id_during_exception is not None raise ValueError("Test exception") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app, raise_server_exceptions=False) response = client.get("/test") @@ -200,25 +189,22 @@ def test_route(request): def test_middleware_with_multiple_routes(self): """Test middleware works correctly with multiple routes.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - request_ids_by_route = {} - @app.route("/route1") def route1(request): request_ids_by_route["route1"] = contextual_logging.get_context_metadata( "request_id" ) return Response("route1") - @app.route("/route2") def route2(request): request_ids_by_route["route2"] = contextual_logging.get_context_metadata( "request_id" ) return Response("route2") + app = Starlette(routes=[Route("/route1", route1), Route("/route2", route2)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) response1 = client.get("/route1") @@ -255,9 +241,6 @@ def test_middleware_enables_request_id_in_logs(self): """Test that middleware enables request_id to be used in logging.""" import logging - app = Starlette() - app.add_middleware(RequestContextMiddleware) - logged_request_ids = [] # Create a custom handler to capture log records @@ -275,11 +258,12 @@ def emit(self, record): logger.addHandler(handler) logger.setLevel(logging.INFO) - @app.route("/test") def test_route(request): logger.info("Processing request") return Response("ok") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) response = client.get("/test") @@ -292,9 +276,6 @@ def test_route(request): def test_middleware_request_id_persists_across_function_calls(self): """Test that request_id persists across function calls within a request.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - request_ids_collected = [] def helper_function(): @@ -303,7 +284,6 @@ def helper_function(): contextual_logging.get_context_metadata("request_id") ) - @app.route("/test") def test_route(request): request_ids_collected.append( contextual_logging.get_context_metadata("request_id") @@ -314,6 +294,8 @@ def test_route(request): ) return Response("ok") + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) response = client.get("/test") diff --git a/tests/test_request_id_concurrency.py b/tests/test_request_id_concurrency.py index 97dce6c..c301816 100644 --- a/tests/test_request_id_concurrency.py +++ b/tests/test_request_id_concurrency.py @@ -4,6 +4,7 @@ import pytest from starlette.applications import Starlette from starlette.responses import JSONResponse +from starlette.routing import Route from starlette.testclient import TestClient from cloud_pipelines_backend.instrumentation import contextual_logging @@ -12,16 +13,12 @@ def test_request_id_isolation_with_concurrent_requests(): """Test that each concurrent request gets its own isolated request_id.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - # Store request_ids seen by each endpoint request_ids_seen = { "endpoint1": [], "endpoint2": [], } - @app.route("/endpoint1") async def endpoint1(request): request_id = contextual_logging.get_context_metadata("request_id") request_ids_seen["endpoint1"].append(request_id) @@ -31,7 +28,6 @@ async def endpoint1(request): assert contextual_logging.get_context_metadata("request_id") == request_id return JSONResponse({"request_id": request_id}) - @app.route("/endpoint2") async def endpoint2(request): request_id = contextual_logging.get_context_metadata("request_id") request_ids_seen["endpoint2"].append(request_id) @@ -41,6 +37,8 @@ async def endpoint2(request): assert contextual_logging.get_context_metadata("request_id") == request_id return JSONResponse({"request_id": request_id}) + app = Starlette(routes=[Route("/endpoint1", endpoint1), Route("/endpoint2", endpoint2)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) # Make concurrent requests @@ -71,9 +69,6 @@ async def endpoint2(request): def test_request_id_isolation_with_nested_async_calls(): """Test that request_id persists correctly through nested async function calls.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - request_ids_collected = [] async def helper_function_1(): @@ -97,7 +92,6 @@ async def helper_function_2(): ("helper2_after", contextual_logging.get_context_metadata("request_id")) ) - @app.route("/test") async def test_route(request): request_ids_collected.append( ("start", contextual_logging.get_context_metadata("request_id")) @@ -108,6 +102,8 @@ async def test_route(request): ) return JSONResponse({"ok": True}) + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) response = client.get("/test") @@ -124,12 +120,8 @@ async def test_route(request): def test_request_id_does_not_leak_between_requests(): """Test that request_id from one request doesn't leak into another.""" - app = Starlette() - app.add_middleware(RequestContextMiddleware) - request_ids_per_request = [] - @app.route("/test") async def test_route(request): # Capture request_id at start start_request_id = contextual_logging.get_context_metadata("request_id") @@ -144,6 +136,8 @@ async def test_route(request): return JSONResponse({"request_id": end_request_id}) + app = Starlette(routes=[Route("/test", test_route)]) + app.add_middleware(RequestContextMiddleware) client = TestClient(app) # Make multiple sequential requests