Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/qa-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,4 @@ jobs:
app_port: 8080
sleep_before_test: 30
config_update_delay: 100
skip_tests: test_bypassed_ip_for_geo_blocking,test_demo_apps_generic_tests,test_path_traversal,test_outbound_domain_blocking,test_bypassed_ip,test_wave_attack,test_block_traffic_by_countries,test_user_rate_limiting_1_minute
skip_tests: test_demo_apps_generic_tests,test_path_traversal,test_outbound_domain_blocking,test_wave_attack,test_block_traffic_by_countries,test_user_rate_limiting_1_minute
47 changes: 36 additions & 11 deletions aikido_zen/thread/thread_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,35 +44,60 @@ def reset(self):
last_updated_at=-1,
received_any_stats=False,
)
self._clear_synced_deltas()
Comment thread
tomaisthorpe marked this conversation as resolved.

def _clear_synced_deltas(self):
"""Clears delta counters synced to the background process."""
self.middleware_installed = False
self.hostnames.clear()
self.users.clear()
self.stats.clear()
self.ai_stats.clear()
PackagesStore.clear()

def _restore_synced_deltas(self, payload):
"""Merges a previously-cleared payload back, used when an IPC sync fails."""
self.middleware_installed = (
self.middleware_installed or payload["middleware_installed"]
)
for entry in payload["hostnames"]:
self.hostnames.add(entry["hostname"], entry["port"], entry["hits"])
for entry in payload["users"]:
self.users.add_user_from_entry(entry)
self.stats.import_from_record(payload["stats"])
self.ai_stats.import_list(payload["ai_stats"])
for pkg in payload["packages"]:
existing = PackagesStore.get_package(pkg["name"])
if existing:
existing["cleared"] = False

def renew(self):
if not comms.get_comms():
return

# send stored data and receive new config and routes
# Clear deltas before the IPC, not after. Clearing post-response would
# wipe any increments that arrived in the window where the IPC released
# the GIL.
payload = {
"current_routes": self.routes.get_routes_with_hits(),
"middleware_installed": self.middleware_installed,
"hostnames": self.hostnames.as_array(),
"users": self.users.as_array(),
"stats": self.stats.get_record(),
"ai_stats": self.ai_stats.get_stats(),
"packages": PackagesStore.export(),
}
self._clear_synced_deltas()

res = comms.get_comms().send_data_to_bg_process(
action="SYNC_DATA",
obj={
"current_routes": self.routes.get_routes_with_hits(),
"middleware_installed": self.middleware_installed,
"hostnames": self.hostnames.as_array(),
"users": self.users.as_array(),
"stats": self.stats.get_record(),
"ai_stats": self.ai_stats.get_stats(),
"packages": PackagesStore.export(),
},
obj=payload,
receive=True,
)
if not res["success"] or not res["data"]:
self._restore_synced_deltas(payload)
return

self.reset()
# update config
if isinstance(res["data"].get("config"), ServiceConfig):
self.config = res["data"]["config"]
Expand Down
55 changes: 55 additions & 0 deletions aikido_zen/thread/thread_cache_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,61 @@ def test_renew_called_with_empty_routes(mock_get_comms, thread_cache: ThreadCach
)


@patch("aikido_zen.background_process.comms.get_comms")
def test_renew_preserves_increments_during_ipc(
mock_get_comms, thread_cache: ThreadCache
):
"""Increments arriving during the IPC call survive on the response path -
the snapshot is sent, but the live counter keeps the concurrent increment."""
mock_comms = MagicMock()
mock_get_comms.return_value = mock_comms

thread_cache.stats.increment_total_hits()
thread_cache.stats.increment_total_hits()

def simulate_concurrent_increment(*args, **kwargs):
thread_cache.stats.increment_total_hits()
return {"success": True, "data": {"routes": {}}}

mock_comms.send_data_to_bg_process.side_effect = simulate_concurrent_increment

thread_cache.renew()

sent_total = mock_comms.send_data_to_bg_process.call_args.kwargs["obj"]["stats"][
"requests"
]["total"]
assert sent_total == 2
assert thread_cache.stats.get_record()["requests"]["total"] == 1


@patch("aikido_zen.background_process.comms.get_comms")
def test_renew_restores_deltas_on_ipc_failure(
mock_get_comms, thread_cache: ThreadCache
):
"""If the IPC fails, the cleared deltas must be merged back on top of any
concurrent increments - nothing lost, nothing double-counted."""
mock_comms = MagicMock()
mock_get_comms.return_value = mock_comms

thread_cache.stats.increment_total_hits()
thread_cache.stats.increment_total_hits()
thread_cache.ai_stats.on_ai_call("openai", "gpt-4o", 100, 50)
thread_cache.middleware_installed = True

def fail_after_concurrent_increment(*args, **kwargs):
thread_cache.stats.increment_total_hits()
return {"success": False}

mock_comms.send_data_to_bg_process.side_effect = fail_after_concurrent_increment

thread_cache.renew()

# 2 from the snapshot + 1 from the concurrent increment
assert thread_cache.stats.get_record()["requests"]["total"] == 3
assert thread_cache.middleware_installed is True
assert thread_cache.ai_stats.get_stats()[0]["calls"] == 1


@patch("aikido_zen.background_process.comms.get_comms")
def test_renew_called_with_no_requests(mock_get_comms, thread_cache: ThreadCache):
"""Test that renew calls send_data_to_bg_process with zero requests."""
Expand Down
Loading