Skip to content

Commit 5ff06de

Browse files
committed
chore: add ci/cd
1 parent 141f1a2 commit 5ff06de

2 files changed

Lines changed: 101 additions & 7 deletions

File tree

src/core/dependency/rate_limit.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from fastapi import Request
1+
from types import SimpleNamespace
2+
3+
from fastapi import Request, Response
24
from fastapi_limiter import FastAPILimiter
35
from fastapi_limiter.depends import RateLimiter
46

@@ -20,6 +22,31 @@
2022
settings = get_settings()
2123

2224

25+
def _rate_limiter_request(request: Request) -> Request:
26+
if "app" not in request.scope:
27+
return request
28+
29+
routes = []
30+
for route in request.app.routes:
31+
if hasattr(route, "path") and hasattr(route, "methods"):
32+
routes.append(route)
33+
continue
34+
35+
effective_route_contexts = getattr(route, "effective_route_contexts", None)
36+
if effective_route_contexts is None:
37+
continue
38+
39+
routes.extend(
40+
nested_route
41+
for nested_route in effective_route_contexts()
42+
if hasattr(nested_route, "path") and hasattr(nested_route, "methods")
43+
)
44+
45+
scope = dict(request.scope)
46+
scope["app"] = SimpleNamespace(routes=routes)
47+
return Request(scope, receive=request.receive)
48+
49+
2350
async def custom_identifier(request: Request) -> str:
2451
"""Smart identifier: User ID > Proxy IP > Direct IP"""
2552
user_id = getattr(request.state, "user_id", None)
@@ -46,7 +73,7 @@ async def close_rate_limiter():
4673
await redis_client.aclose()
4774

4875

49-
async def apply_global_rate_limit(request: Request):
76+
async def apply_global_rate_limit(request: Request, response: Response):
5077
if request.url.path in EXEMPT_PATHS:
5178
return
5279

@@ -56,4 +83,4 @@ async def apply_global_rate_limit(request: Request):
5683
seconds = 60 if "minute" in period else 1
5784

5885
limiter = RateLimiter(times=times, seconds=seconds)
59-
await limiter(request)
86+
await limiter(_rate_limiter_request(request), response)

tests/core/test_security_todo.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import pytest
2-
from fastapi import FastAPI, Request
2+
from fastapi import FastAPI, Request, Response
33

44
from src.core.bootstrap.exception import register_exception
55
from src.core.bootstrap.middleware import register_middleware
66
from src.core.config.setting import Settings
77
from src.core.dependency import rate_limit as rate_limit_module
8-
from src.core.dependency.rate_limit import apply_global_rate_limit
8+
from src.core.dependency.rate_limit import apply_global_rate_limit, custom_identifier
99
from src.core.exceptions.handler import (
1010
DOMAIN_EXCEPTION_MAP,
1111
domain_exception_handler,
@@ -23,7 +23,7 @@ def __init__(self, times: int, seconds: int):
2323
self.seconds = seconds
2424
created_limiters.append(self)
2525

26-
async def __call__(self, request):
26+
async def __call__(self, request, response):
2727
return None
2828

2929
request = Request(
@@ -44,12 +44,79 @@ async def __call__(self, request):
4444

4545
import asyncio
4646

47-
asyncio.run(apply_global_rate_limit(request))
47+
asyncio.run(apply_global_rate_limit(request, Response()))
4848

4949
assert created_limiters[0].times == 42
5050
assert created_limiters[0].seconds == 60
5151

5252

53+
def test_rate_limit_passes_response_to_limiter(monkeypatch):
54+
limiter_calls = []
55+
56+
class FakeLimiter:
57+
def __init__(self, times: int, seconds: int):
58+
self.times = times
59+
self.seconds = seconds
60+
61+
async def __call__(self, request, response):
62+
limiter_calls.append((request, response))
63+
64+
request = Request(
65+
{
66+
"type": "http",
67+
"method": "GET",
68+
"path": "/api/v1/todos/",
69+
"headers": [],
70+
"query_string": b"",
71+
"server": ("testserver", 80),
72+
"scheme": "http",
73+
"client": ("testclient", 50000),
74+
}
75+
)
76+
response = Response()
77+
78+
monkeypatch.setattr(rate_limit_module, "RateLimiter", FakeLimiter)
79+
80+
import asyncio
81+
82+
asyncio.run(apply_global_rate_limit(request, response))
83+
84+
assert limiter_calls == [(request, response)]
85+
86+
87+
def test_rate_limit_handles_included_router_entries():
88+
class FakeRedis:
89+
async def script_load(self, script):
90+
return "sha"
91+
92+
async def evalsha(self, sha, keys, key, times, milliseconds):
93+
return 0
94+
95+
app = create_app(Settings(APP_ENV="development"))
96+
request = Request(
97+
{
98+
"type": "http",
99+
"app": app,
100+
"method": "GET",
101+
"path": "/api/v1/todos/",
102+
"headers": [],
103+
"query_string": b"",
104+
"server": ("testserver", 80),
105+
"scheme": "http",
106+
"client": ("testclient", 50000),
107+
}
108+
)
109+
110+
import asyncio
111+
from fastapi_limiter import FastAPILimiter
112+
113+
async def run_rate_limit():
114+
await FastAPILimiter.init(FakeRedis(), identifier=custom_identifier)
115+
await apply_global_rate_limit(request, Response())
116+
117+
asyncio.run(run_rate_limit())
118+
119+
53120
def test_cors_middleware_uses_environment_driven_settings(monkeypatch):
54121
monkeypatch.setattr(
55122
"src.core.bootstrap.middleware.settings",

0 commit comments

Comments
 (0)