diff --git a/.github/workflows/qa-tests.yml b/.github/workflows/qa-tests.yml index 075a75b54..aa06c7593 100644 --- a/.github/workflows/qa-tests.yml +++ b/.github/workflows/qa-tests.yml @@ -52,4 +52,4 @@ jobs: app_port: 8080 sleep_before_test: 30 config_update_delay: 100 - skip_tests: test_bypassed_ip_for_geo_blocking,test_demo_apps_generic_tests,test_path_traversal,test_outbound_domain_blocking,test_bypassed_ip,test_wave_attack,test_block_traffic_by_countries,test_user_rate_limiting_1_minute + skip_tests: test_demo_apps_generic_tests,test_path_traversal,test_outbound_domain_blocking,test_wave_attack,test_block_traffic_by_countries,test_user_rate_limiting_1_minute diff --git a/aikido_zen/thread/thread_cache.py b/aikido_zen/thread/thread_cache.py index 0072bac3c..6eb60bddf 100644 --- a/aikido_zen/thread/thread_cache.py +++ b/aikido_zen/thread/thread_cache.py @@ -44,6 +44,10 @@ def reset(self): last_updated_at=-1, received_any_stats=False, ) + self._clear_synced_deltas() + + def _clear_synced_deltas(self): + """Clears delta counters synced to the background process.""" self.middleware_installed = False self.hostnames.clear() self.users.clear() @@ -51,28 +55,49 @@ def reset(self): self.ai_stats.clear() PackagesStore.clear() + def _restore_synced_deltas(self, payload): + """Merges a previously-cleared payload back, used when an IPC sync fails.""" + self.middleware_installed = ( + self.middleware_installed or payload["middleware_installed"] + ) + for entry in payload["hostnames"]: + self.hostnames.add(entry["hostname"], entry["port"], entry["hits"]) + for entry in payload["users"]: + self.users.add_user_from_entry(entry) + self.stats.import_from_record(payload["stats"]) + self.ai_stats.import_list(payload["ai_stats"]) + for pkg in payload["packages"]: + existing = PackagesStore.get_package(pkg["name"]) + if existing: + existing["cleared"] = False + def renew(self): if not comms.get_comms(): return - # send stored data and receive new config and routes + # Clear deltas before the IPC, not after. Clearing post-response would + # wipe any increments that arrived in the window where the IPC released + # the GIL. + payload = { + "current_routes": self.routes.get_routes_with_hits(), + "middleware_installed": self.middleware_installed, + "hostnames": self.hostnames.as_array(), + "users": self.users.as_array(), + "stats": self.stats.get_record(), + "ai_stats": self.ai_stats.get_stats(), + "packages": PackagesStore.export(), + } + self._clear_synced_deltas() + res = comms.get_comms().send_data_to_bg_process( action="SYNC_DATA", - obj={ - "current_routes": self.routes.get_routes_with_hits(), - "middleware_installed": self.middleware_installed, - "hostnames": self.hostnames.as_array(), - "users": self.users.as_array(), - "stats": self.stats.get_record(), - "ai_stats": self.ai_stats.get_stats(), - "packages": PackagesStore.export(), - }, + obj=payload, receive=True, ) if not res["success"] or not res["data"]: + self._restore_synced_deltas(payload) return - self.reset() # update config if isinstance(res["data"].get("config"), ServiceConfig): self.config = res["data"]["config"] diff --git a/aikido_zen/thread/thread_cache_test.py b/aikido_zen/thread/thread_cache_test.py index f9d55cc60..21e294025 100644 --- a/aikido_zen/thread/thread_cache_test.py +++ b/aikido_zen/thread/thread_cache_test.py @@ -429,6 +429,61 @@ def test_renew_called_with_empty_routes(mock_get_comms, thread_cache: ThreadCach ) +@patch("aikido_zen.background_process.comms.get_comms") +def test_renew_preserves_increments_during_ipc( + mock_get_comms, thread_cache: ThreadCache +): + """Increments arriving during the IPC call survive on the response path - + the snapshot is sent, but the live counter keeps the concurrent increment.""" + mock_comms = MagicMock() + mock_get_comms.return_value = mock_comms + + thread_cache.stats.increment_total_hits() + thread_cache.stats.increment_total_hits() + + def simulate_concurrent_increment(*args, **kwargs): + thread_cache.stats.increment_total_hits() + return {"success": True, "data": {"routes": {}}} + + mock_comms.send_data_to_bg_process.side_effect = simulate_concurrent_increment + + thread_cache.renew() + + sent_total = mock_comms.send_data_to_bg_process.call_args.kwargs["obj"]["stats"][ + "requests" + ]["total"] + assert sent_total == 2 + assert thread_cache.stats.get_record()["requests"]["total"] == 1 + + +@patch("aikido_zen.background_process.comms.get_comms") +def test_renew_restores_deltas_on_ipc_failure( + mock_get_comms, thread_cache: ThreadCache +): + """If the IPC fails, the cleared deltas must be merged back on top of any + concurrent increments - nothing lost, nothing double-counted.""" + mock_comms = MagicMock() + mock_get_comms.return_value = mock_comms + + thread_cache.stats.increment_total_hits() + thread_cache.stats.increment_total_hits() + thread_cache.ai_stats.on_ai_call("openai", "gpt-4o", 100, 50) + thread_cache.middleware_installed = True + + def fail_after_concurrent_increment(*args, **kwargs): + thread_cache.stats.increment_total_hits() + return {"success": False} + + mock_comms.send_data_to_bg_process.side_effect = fail_after_concurrent_increment + + thread_cache.renew() + + # 2 from the snapshot + 1 from the concurrent increment + assert thread_cache.stats.get_record()["requests"]["total"] == 3 + assert thread_cache.middleware_installed is True + assert thread_cache.ai_stats.get_stats()[0]["calls"] == 1 + + @patch("aikido_zen.background_process.comms.get_comms") def test_renew_called_with_no_requests(mock_get_comms, thread_cache: ThreadCache): """Test that renew calls send_data_to_bg_process with zero requests."""