Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 19 additions & 37 deletions tests/test_instrumentation_request_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from starlette.requests import Request
from starlette.responses import Response
from starlette.applications import Starlette
from starlette.routing import Route
from starlette.testclient import TestClient

from cloud_pipelines_backend.instrumentation import contextual_logging
Expand Down Expand Up @@ -78,19 +79,17 @@ def teardown_method(self):

def test_middleware_generates_request_id(self):
"""Test that middleware generates a request_id for each request."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

request_ids_seen = []

@app.route("/test")
def test_route(request):
# Capture the request_id during request processing
request_ids_seen.append(
contextual_logging.get_context_metadata("request_id")
)
return Response("ok")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)
response = client.get("/test")

Expand All @@ -101,13 +100,11 @@ def test_route(request):

def test_middleware_adds_request_id_to_response_headers(self):
"""Test that middleware adds request_id to response headers."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

@app.route("/test")
def test_route(request):
return Response("ok")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)
response = client.get("/test")

Expand All @@ -118,14 +115,12 @@ def test_route(request):

def test_middleware_clears_request_id_after_request(self):
"""Test that middleware clears request_id after request completes."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

@app.route("/test")
def test_route(request):
assert contextual_logging.get_context_metadata("request_id") is not None
return Response("ok")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)

# Before request
Expand All @@ -140,13 +135,11 @@ def test_route(request):

def test_middleware_generates_unique_request_ids(self):
"""Test that middleware generates unique request_ids for each request."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

@app.route("/test")
def test_route(request):
return Response("ok")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)

# Make multiple requests
Expand All @@ -160,17 +153,15 @@ def test_route(request):

def test_middleware_request_id_available_in_route(self):
"""Test that request_id set by middleware is available in route handler."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

captured_request_id = None

@app.route("/test")
def test_route(request):
nonlocal captured_request_id
captured_request_id = contextual_logging.get_context_metadata("request_id")
return Response(f"request_id: {captured_request_id}")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)
response = client.get("/test")

Expand All @@ -180,17 +171,15 @@ def test_route(request):

def test_middleware_handles_exception_in_route(self):
"""Test that middleware clears request_id even when route raises exception."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

@app.route("/test")
def test_route(request):
request_id_during_exception = contextual_logging.get_context_metadata(
"request_id"
)
assert request_id_during_exception is not None
raise ValueError("Test exception")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app, raise_server_exceptions=False)
response = client.get("/test")

Expand All @@ -200,25 +189,22 @@ def test_route(request):

def test_middleware_with_multiple_routes(self):
"""Test middleware works correctly with multiple routes."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

request_ids_by_route = {}

@app.route("/route1")
def route1(request):
request_ids_by_route["route1"] = contextual_logging.get_context_metadata(
"request_id"
)
return Response("route1")

@app.route("/route2")
def route2(request):
request_ids_by_route["route2"] = contextual_logging.get_context_metadata(
"request_id"
)
return Response("route2")

app = Starlette(routes=[Route("/route1", route1), Route("/route2", route2)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)

response1 = client.get("/route1")
Expand Down Expand Up @@ -255,9 +241,6 @@ def test_middleware_enables_request_id_in_logs(self):
"""Test that middleware enables request_id to be used in logging."""
import logging

app = Starlette()
app.add_middleware(RequestContextMiddleware)

logged_request_ids = []

# Create a custom handler to capture log records
Expand All @@ -275,11 +258,12 @@ def emit(self, record):
logger.addHandler(handler)
logger.setLevel(logging.INFO)

@app.route("/test")
def test_route(request):
logger.info("Processing request")
return Response("ok")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)
response = client.get("/test")

Expand All @@ -292,9 +276,6 @@ def test_route(request):

def test_middleware_request_id_persists_across_function_calls(self):
"""Test that request_id persists across function calls within a request."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

request_ids_collected = []

def helper_function():
Expand All @@ -303,7 +284,6 @@ def helper_function():
contextual_logging.get_context_metadata("request_id")
)

@app.route("/test")
def test_route(request):
request_ids_collected.append(
contextual_logging.get_context_metadata("request_id")
Expand All @@ -314,6 +294,8 @@ def test_route(request):
)
return Response("ok")

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)
response = client.get("/test")

Expand Down
20 changes: 7 additions & 13 deletions tests/test_request_id_concurrency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from starlette.applications import Starlette
from starlette.responses import JSONResponse
from starlette.routing import Route
from starlette.testclient import TestClient

from cloud_pipelines_backend.instrumentation import contextual_logging
Expand All @@ -12,16 +13,12 @@

def test_request_id_isolation_with_concurrent_requests():
"""Test that each concurrent request gets its own isolated request_id."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

# Store request_ids seen by each endpoint
request_ids_seen = {
"endpoint1": [],
"endpoint2": [],
}

@app.route("/endpoint1")
async def endpoint1(request):
request_id = contextual_logging.get_context_metadata("request_id")
request_ids_seen["endpoint1"].append(request_id)
Expand All @@ -31,7 +28,6 @@ async def endpoint1(request):
assert contextual_logging.get_context_metadata("request_id") == request_id
return JSONResponse({"request_id": request_id})

@app.route("/endpoint2")
async def endpoint2(request):
request_id = contextual_logging.get_context_metadata("request_id")
request_ids_seen["endpoint2"].append(request_id)
Expand All @@ -41,6 +37,8 @@ async def endpoint2(request):
assert contextual_logging.get_context_metadata("request_id") == request_id
return JSONResponse({"request_id": request_id})

app = Starlette(routes=[Route("/endpoint1", endpoint1), Route("/endpoint2", endpoint2)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)

# Make concurrent requests
Expand Down Expand Up @@ -71,9 +69,6 @@ async def endpoint2(request):

def test_request_id_isolation_with_nested_async_calls():
"""Test that request_id persists correctly through nested async function calls."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

request_ids_collected = []

async def helper_function_1():
Expand All @@ -97,7 +92,6 @@ async def helper_function_2():
("helper2_after", contextual_logging.get_context_metadata("request_id"))
)

@app.route("/test")
async def test_route(request):
request_ids_collected.append(
("start", contextual_logging.get_context_metadata("request_id"))
Expand All @@ -108,6 +102,8 @@ async def test_route(request):
)
return JSONResponse({"ok": True})

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)
response = client.get("/test")

Expand All @@ -124,12 +120,8 @@ async def test_route(request):

def test_request_id_does_not_leak_between_requests():
"""Test that request_id from one request doesn't leak into another."""
app = Starlette()
app.add_middleware(RequestContextMiddleware)

request_ids_per_request = []

@app.route("/test")
async def test_route(request):
# Capture request_id at start
start_request_id = contextual_logging.get_context_metadata("request_id")
Expand All @@ -144,6 +136,8 @@ async def test_route(request):

return JSONResponse({"request_id": end_request_id})

app = Starlette(routes=[Route("/test", test_route)])
app.add_middleware(RequestContextMiddleware)
client = TestClient(app)

# Make multiple sequential requests
Expand Down
Loading