Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions aikido_zen/middleware/init_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@ def test_with_context_with_cache():
assert thread_cache.stats.rate_limited_hits == 0


def test_bypassed_ip_skips_user_blocking():
test_utils.generate_and_set_context(user={"id": "123"}, ip="1.2.3.4")
thread_cache = get_cache()
thread_cache.config.blocked_uids = ["123"]
thread_cache.config.set_bypassed_ips(["1.2.3.4"])

assert should_block_request() == {"block": False}


def test_cache_comms_with_endpoints():
test_utils.generate_and_set_context(user={"id": "456"}, route="/posts/:id")
set_rate_limit_group("my_group")
Expand Down
4 changes: 4 additions & 0 deletions aikido_zen/middleware/should_block_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def should_block_request():
context.executed_middleware = True
context.set_as_current_context()

# Bypassed IPs skip user blocking and rate limiting
if cache.is_bypassed_ip(context.remote_address):
return {"block": False}

# User blocking allows customers to easily take action when attacks are coming from specific accounts
if context.user and cache.is_user_blocked(context.user["id"]):
return {"block": True, "type": "blocked", "trigger": "user"}
Expand Down
10 changes: 9 additions & 1 deletion aikido_zen/sources/functions/request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def request_handler(stage, status_code=0):
try:
if stage == "init":
cache = get_cache()
if ctx.get_current_context() and cache:
context = ctx.get_current_context()
if context and cache and not cache.is_bypassed_ip(context.remote_address):
cache.stats.increment_total_hits()
if stage == "pre_response":
return pre_response()
Expand All @@ -44,6 +45,10 @@ def pre_response():
logger.debug("Request was not complete, not running any pre_response code")
return

# Bypassed IPs skip all allowlist and blocklist checks
if cache.is_bypassed_ip(context.remote_address):
return None

# Per endpoint IP Allowlist
matched_endpoints = cache.config.get_endpoints(context.get_route_metadata())
if not ip_allowed_to_access_route(
Expand Down Expand Up @@ -98,6 +103,9 @@ def post_response(status_code):
if not cache:
return

if cache.is_bypassed_ip(context.remote_address):
return

attack_wave = attack_wave_detector_store.is_attack_wave(context)
if attack_wave:
cache.stats.on_detected_attack_wave(blocked=False)
Expand Down
80 changes: 80 additions & 0 deletions aikido_zen/sources/functions/request_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,43 @@ def test_post_response_no_context(mock_get_comms):
comms.send_data_to_bg_process.assert_not_called()


def test_bypassed_ip_no_stats_in_init():
cache = get_cache()
cache.config.set_bypassed_ips(["1.2.3.4"])
cache.stats.clear()

context = MagicMock()
context.remote_address = "1.2.3.4"
with patch("aikido_zen.context.get_current_context", return_value=context):
request_handler("init")

assert cache.stats.get_record()["requests"]["total"] == 0


def test_non_bypassed_ip_increments_stats_in_init():
cache = get_cache()
cache.config.set_bypassed_ips([])
cache.stats.clear()

context = MagicMock()
context.remote_address = "1.2.3.4"
with patch("aikido_zen.context.get_current_context", return_value=context):
request_handler("init")

assert cache.stats.get_record()["requests"]["total"] == 1


def test_bypassed_ip_no_route_tracking_in_post_response(mock_context):
cache = get_cache()
cache.config.set_bypassed_ips(["5.6.7.8"])
mock_context.remote_address = "5.6.7.8"

with patch("aikido_zen.context.get_current_context", return_value=mock_context):
request_handler("post_response", status_code=200)

assert cache.routes.routes == {}


# Test firewall lists
def set_context(remote_address, user_agent="", route="/posts/:number"):
headers = Headers()
Expand Down Expand Up @@ -189,6 +226,49 @@ def wrapper(*args, **kwargs):
return wrapper


@patch_firewall_lists
def test_bypassed_ip_skips_all_checks(firewall_lists):
set_context("192.168.1.1")
config = ServiceConfig(
endpoints=[
{
"method": "POST",
"route": "/posts/:number",
"graphql": False,
"allowedIPAddresses": ["1.1.1.1"], # 192.168.1.1 not in this list
}
],
last_updated_at=None,
blocked_uids=set(),
bypassed_ips=["192.168.1.1"],
received_any_stats=False,
)
get_cache().config = config
firewall_lists.set_blocked_ips(
[
{
"source": "test",
"description": "Blocked for testing",
"ips": ["192.168.1.1"],
}
]
)
firewall_lists.set_allowed_ips(
[
{
"source": "test",
"description": "Allowed ranges",
"ips": ["4.4.4.0/24"],
}
]
)

# Act
result = request_handler("pre_response")

assert result is None


@patch_firewall_lists
def test_blocked_ip(firewall_lists):
# Arrange
Expand Down
Loading