11import pytest
2- from fastapi import FastAPI , Request
2+ from fastapi import FastAPI , Request , Response
33
44from src .core .bootstrap .exception import register_exception
55from src .core .bootstrap .middleware import register_middleware
66from src .core .config .setting import Settings
77from src .core .dependency import rate_limit as rate_limit_module
8- from src .core .dependency .rate_limit import apply_global_rate_limit
8+ from src .core .dependency .rate_limit import apply_global_rate_limit , custom_identifier
99from src .core .exceptions .handler import (
1010 DOMAIN_EXCEPTION_MAP ,
1111 domain_exception_handler ,
@@ -23,7 +23,7 @@ def __init__(self, times: int, seconds: int):
2323 self .seconds = seconds
2424 created_limiters .append (self )
2525
26- async def __call__ (self , request ):
26+ async def __call__ (self , request , response ):
2727 return None
2828
2929 request = Request (
@@ -44,12 +44,79 @@ async def __call__(self, request):
4444
4545 import asyncio
4646
47- asyncio .run (apply_global_rate_limit (request ))
47+ asyncio .run (apply_global_rate_limit (request , Response () ))
4848
4949 assert created_limiters [0 ].times == 42
5050 assert created_limiters [0 ].seconds == 60
5151
5252
53+ def test_rate_limit_passes_response_to_limiter (monkeypatch ):
54+ limiter_calls = []
55+
56+ class FakeLimiter :
57+ def __init__ (self , times : int , seconds : int ):
58+ self .times = times
59+ self .seconds = seconds
60+
61+ async def __call__ (self , request , response ):
62+ limiter_calls .append ((request , response ))
63+
64+ request = Request (
65+ {
66+ "type" : "http" ,
67+ "method" : "GET" ,
68+ "path" : "/api/v1/todos/" ,
69+ "headers" : [],
70+ "query_string" : b"" ,
71+ "server" : ("testserver" , 80 ),
72+ "scheme" : "http" ,
73+ "client" : ("testclient" , 50000 ),
74+ }
75+ )
76+ response = Response ()
77+
78+ monkeypatch .setattr (rate_limit_module , "RateLimiter" , FakeLimiter )
79+
80+ import asyncio
81+
82+ asyncio .run (apply_global_rate_limit (request , response ))
83+
84+ assert limiter_calls == [(request , response )]
85+
86+
87+ def test_rate_limit_handles_included_router_entries ():
88+ class FakeRedis :
89+ async def script_load (self , script ):
90+ return "sha"
91+
92+ async def evalsha (self , sha , keys , key , times , milliseconds ):
93+ return 0
94+
95+ app = create_app (Settings (APP_ENV = "development" ))
96+ request = Request (
97+ {
98+ "type" : "http" ,
99+ "app" : app ,
100+ "method" : "GET" ,
101+ "path" : "/api/v1/todos/" ,
102+ "headers" : [],
103+ "query_string" : b"" ,
104+ "server" : ("testserver" , 80 ),
105+ "scheme" : "http" ,
106+ "client" : ("testclient" , 50000 ),
107+ }
108+ )
109+
110+ import asyncio
111+ from fastapi_limiter import FastAPILimiter
112+
113+ async def run_rate_limit ():
114+ await FastAPILimiter .init (FakeRedis (), identifier = custom_identifier )
115+ await apply_global_rate_limit (request , Response ())
116+
117+ asyncio .run (run_rate_limit ())
118+
119+
53120def test_cors_middleware_uses_environment_driven_settings (monkeypatch ):
54121 monkeypatch .setattr (
55122 "src.core.bootstrap.middleware.settings" ,
0 commit comments