diff --git a/.gitignore b/.gitignore index 69174f9d..2bec1033 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,6 @@ results/* *.so *.html *.csv -*.json *.txt *.parquet .DS_Store @@ -77,6 +76,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/output/ # PyBuilder .pybuilder/ @@ -324,7 +324,9 @@ cython_debug/ # Others docs/api/ +!/docs/api/ !/docs/api/index.rst +!/docs/api/*.rst # requirements.txt !*/requirements.*.txt diff --git a/config.yaml b/config.yaml index 6e816380..8a55e98f 100644 --- a/config.yaml +++ b/config.yaml @@ -13,7 +13,7 @@ logging: data_inspection.inspector: debug: false data_analysis.detector: - debug: false + debug: true pipeline: scaling: @@ -23,50 +23,58 @@ pipeline: modules: log_storage.logserver: executor: thread - max_workers: 1 + max_workers: 4 log_collection.collector: executor: thread max_workers: 1 instances: dga_collector: - max_workers: 1 + max_workers: 2 domainator_collector: - max_workers: 1 + max_workers: 2 log_filtering.prefilter: executor: thread - max_workers: 1 + max_workers: 2 instances: dga_filter: - max_workers: 1 + max_workers: 2 no_filter: - max_workers: 1 + max_workers: 2 data_inspection.inspector: executor: thread - max_workers: 1 + max_workers: 2 instances: dga_inspector: - max_workers: 1 + max_workers: 2 no_inspector: - max_workers: 1 + max_workers: 2 data_analysis.detector: executor: thread - max_workers: 1 + max_workers: 2 instances: RF-dga_detector: - max_workers: 1 + max_workers: 2 domainator: + max_workers: 3 + domainator_attributor: + max_workers: 1 + domainator_attributor_behaviour: + max_workers: 1 + domainator_attributor_identification_behaviour: + max_workers: 1 + domainator_attributor_identification: max_workers: 1 pipeline.alerter: executor: thread - max_workers: 1 + max_workers: 2 instances: generic: - max_workers: 1 + max_workers: 2 attributor: - max_workers: 1 + max_workers: 2 monitoring.agent: executor: thread - max_workers: 1 + max_workers: 2 log_storage: logserver: input_file: "/opt/file.txt" @@ -144,33 +152,65 @@ pipeline: detector_module_name: "dga_detector" detector_class_name: "DGADetector" model: rf + use_scaler: false checksum: 5db8bfb617e80361362c33b1d1afc6d762c28e9fa9275fb11514a3bdef76bb88 base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/ threshold: 0.5 consume_from: inspector inspector_name: dga_inspector - next_detectors: domainator + next_detectors: "" send_to_alerter: true produce_topics: "" - name: "domainator" detector_module_name: "domainator_detector" detector_class_name: "DomainatorDetector" model: domainator - checksum: 9d86d66b4976c9b325bed0934a9a9eb3a20960b08be9afe491454624cc0aaa6c + use_scaler: false + checksum: a4aac4c585f1e614c3cf0d737e80b960c5de6e87b253f7cdd07125d9ce486476 base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/ - threshold: 0.5 + threshold: 0.05 consume_from: inspector inspector_name: "domainator_inspector" - next_detectors: "domainator_attributor" + next_detectors: + - "domainator_attributor_behaviour" + - "domainator_attributor_identification_behaviour" + - "domainator_attributor_identification" send_to_alerter: true produce_topics: "" - - name: "domainator_attributor" + - name: "domainator_attributor_behaviour" detector_module_name: "domainator_attributor" detector_class_name: "DomainatorAttributor" - model: domainator - checksum: 9d86d66b4976c9b325bed0934a9a9eb3a20960b08be9afe491454624cc0aaa6c + model: domainator_attributor_behaviour + use_scaler: false + checksum: d8f302edc166ecc80985838a30b5dff16ccc83480ea3c2480652f49c8f6b5e9b base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/ - threshold: 0.5 + threshold: 0.05 + consume_from: detector + detector_name: "domainator" + next_detectors: "" + send_to_alerter: true + produce_topics: "" + - name: "domainator_attributor_identification_behaviour" + detector_module_name: "domainator_attributor" + detector_class_name: "DomainatorAttributor" + model: domainator_attributor_identification_behaviour + use_scaler: false + checksum: 9a0970b4160b22f4c3c5ac99760f0ace5500dd25c5a195ff13254ad3c11d5dcd + base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/ + threshold: 0.05 + consume_from: detector + detector_name: "domainator" + next_detectors: "" + send_to_alerter: true + produce_topics: "" + - name: "domainator_attributor_identification" + detector_module_name: "domainator_attributor" + detector_class_name: "DomainatorAttributor" + model: domainator_attributor_identification + use_scaler: false + checksum: 360bd26881beabce7e7581963240915de807c48b5e4a3501a657139f2ecb8a8b + base_url: https://heibox.uni-heidelberg.de/d/0d5cbcbe16cd46a58021/ + threshold: 0.05 consume_from: detector detector_name: "domainator" next_detectors: "" @@ -180,6 +220,9 @@ pipeline: alerting: log_to_file: true log_file_path: "/opt/logs/alerts.txt" + log_rotation: + enabled: true + retention_days: 7 log_to_kafka: true external_kafka_topic: "hamstring_alerts" plugins: [] @@ -212,6 +255,40 @@ environment: internal_port: 19094 external_port: 8099 node_ip: 127.0.0.1 + kafka_consumer: + # Allow long-running detector batches without Kafka forcing a group rebalance. + # Default librdkafka value is 300000 ms (5 minutes), which can be too short + # for model inference plus downstream alert/monitoring writes. + max_poll_interval_ms: 1800000 + kafka_topics: + replication_factor: 3 + auto_expand_partitions: true + stages: + logserver_in: + partitions: 12 + replication_factor: 3 + logserver_to_collector: + partitions: 12 + replication_factor: 3 + batch_sender_to_prefilter: + partitions: 12 + replication_factor: 3 + prefilter_to_inspector: + partitions: 12 + replication_factor: 3 + inspector_to_detector: + partitions: 12 + replication_factor: 3 + detector_to_alerter: + partitions: 12 + replication_factor: 3 + detector_to_detector: + partitions: 12 + replication_factor: 3 + topics: + hamstring_alerts: + partitions: 12 + replication_factor: 3 kafka_topics_prefix: pipeline: logserver_in: "hamstring_input" diff --git a/docker/create_tables/alerts.sql b/docker/create_tables/alerts.sql index af14b1a2..16b11a59 100644 --- a/docker/create_tables/alerts.sql +++ b/docker/create_tables/alerts.sql @@ -8,4 +8,5 @@ CREATE TABLE IF NOT EXISTS alerts ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(alert_timestamp) -ORDER BY (alert_timestamp, src_ip, suspicious_batch_id); +ORDER BY (alert_timestamp, src_ip, suspicious_batch_id) +TTL toDateTime(alert_timestamp) + INTERVAL 60 DAY; \ No newline at end of file diff --git a/docker/create_tables/batch_timestamps.sql b/docker/create_tables/batch_timestamps.sql index dba9bbc6..57b23e5e 100644 --- a/docker/create_tables/batch_timestamps.sql +++ b/docker/create_tables/batch_timestamps.sql @@ -9,4 +9,5 @@ CREATE TABLE IF NOT EXISTS batch_timestamps ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp) -ORDER BY (stage, status, timestamp, instance_name, batch_id); +ORDER BY (stage, status, timestamp, instance_name, batch_id) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; diff --git a/docker/create_tables/batch_tree.sql b/docker/create_tables/batch_tree.sql index d542dd65..ec8540e5 100644 --- a/docker/create_tables/batch_tree.sql +++ b/docker/create_tables/batch_tree.sql @@ -11,4 +11,5 @@ CREATE TABLE IF NOT EXISTS batch_tree ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp) -ORDER BY (stage, status, timestamp, instance_name, batch_row_id, parent_batch_row_id); +ORDER BY (stage, status, timestamp, instance_name, batch_row_id, parent_batch_row_id) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; diff --git a/docker/create_tables/failed_loglines.sql b/docker/create_tables/failed_loglines.sql index dc4fed10..93702050 100644 --- a/docker/create_tables/failed_loglines.sql +++ b/docker/create_tables/failed_loglines.sql @@ -6,4 +6,5 @@ CREATE TABLE IF NOT EXISTS failed_loglines ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp_failed) -ORDER BY (timestamp_failed, timestamp_in); +ORDER BY (timestamp_failed, timestamp_in) +TTL toDateTime(timestamp_failed) + INTERVAL 1 DAY; diff --git a/docker/create_tables/fill_levels.sql b/docker/create_tables/fill_levels.sql index 02affa07..fc87e1af 100644 --- a/docker/create_tables/fill_levels.sql +++ b/docker/create_tables/fill_levels.sql @@ -6,4 +6,5 @@ CREATE TABLE IF NOT EXISTS fill_levels ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp) -ORDER BY (stage, entry_type, timestamp); +ORDER BY (stage, entry_type, timestamp) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; diff --git a/docker/create_tables/logline_timestamps.sql b/docker/create_tables/logline_timestamps.sql index fa81af92..3ecfd9ed 100644 --- a/docker/create_tables/logline_timestamps.sql +++ b/docker/create_tables/logline_timestamps.sql @@ -7,4 +7,5 @@ CREATE TABLE IF NOT EXISTS logline_timestamps ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp) -ORDER BY (stage, status, timestamp, logline_id); +ORDER BY (stage, status, timestamp, logline_id) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; diff --git a/docker/create_tables/logline_to_batches.sql b/docker/create_tables/logline_to_batches.sql index 2c3c8254..8b306171 100644 --- a/docker/create_tables/logline_to_batches.sql +++ b/docker/create_tables/logline_to_batches.sql @@ -1,6 +1,11 @@ CREATE TABLE IF NOT EXISTS logline_to_batches ( + timestamp DateTime64(6) NOT NULL, logline_id UUID NOT NULL, batch_id UUID NOT NULL ) ENGINE = MergeTree -ORDER BY (batch_id, logline_id); +ORDER BY (timestamp, batch_id, logline_id) +PARTITION BY toYYYYMM(timestamp) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; + + diff --git a/docker/create_tables/loglines.sql b/docker/create_tables/loglines.sql index ff6cfbc5..0b01c175 100644 --- a/docker/create_tables/loglines.sql +++ b/docker/create_tables/loglines.sql @@ -7,4 +7,5 @@ CREATE TABLE IF NOT EXISTS loglines ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp) -ORDER BY (timestamp, src_ip, subnet_id, logline_id); +ORDER BY (timestamp, src_ip, subnet_id, logline_id) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; diff --git a/docker/create_tables/server_log_terminal_events.sql b/docker/create_tables/server_log_terminal_events.sql new file mode 100644 index 00000000..d4d9dae7 --- /dev/null +++ b/docker/create_tables/server_log_terminal_events.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS server_log_terminal_events ( + message_id UUID NOT NULL, + stage LowCardinality(String) NOT NULL, + status LowCardinality(String) NOT NULL, + timestamp DateTime64(6) NOT NULL +) +ENGINE = MergeTree +PARTITION BY toYYYYMM(timestamp) +ORDER BY (stage, status, timestamp, message_id) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; diff --git a/docker/create_tables/server_log_to_logline.sql b/docker/create_tables/server_log_to_logline.sql new file mode 100644 index 00000000..96fe645f --- /dev/null +++ b/docker/create_tables/server_log_to_logline.sql @@ -0,0 +1,10 @@ +CREATE TABLE IF NOT EXISTS server_log_to_logline ( + message_id UUID NOT NULL, + logline_id UUID NOT NULL, + timestamp DateTime64(6) NOT NULL, +) +ENGINE = MergeTree +ORDER BY (timestamp, message_id, logline_id) +PARTITION BY toYYYYMM(timestamp) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; + diff --git a/docker/create_tables/server_logs.sql b/docker/create_tables/server_logs.sql index 2494925c..d166d9a5 100644 --- a/docker/create_tables/server_logs.sql +++ b/docker/create_tables/server_logs.sql @@ -5,4 +5,5 @@ CREATE TABLE IF NOT EXISTS server_logs ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp_in) -ORDER BY (timestamp_in, message_id); +ORDER BY (timestamp_in, message_id) +TTL toDateTime(timestamp_in) + INTERVAL 1 DAY; \ No newline at end of file diff --git a/docker/create_tables/server_logs_timestamps.sql b/docker/create_tables/server_logs_timestamps.sql index 62dbd16d..81a6b972 100644 --- a/docker/create_tables/server_logs_timestamps.sql +++ b/docker/create_tables/server_logs_timestamps.sql @@ -5,4 +5,5 @@ CREATE TABLE IF NOT EXISTS server_logs_timestamps ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(event_timestamp) -ORDER BY (event, event_timestamp, message_id); +ORDER BY (event, event_timestamp, message_id) +TTL toDateTime(event_timestamp) + INTERVAL 1 DAY; \ No newline at end of file diff --git a/docker/create_tables/suspicious_batch_timestamps.sql b/docker/create_tables/suspicious_batch_timestamps.sql index 63993ce4..4bc857b9 100644 --- a/docker/create_tables/suspicious_batch_timestamps.sql +++ b/docker/create_tables/suspicious_batch_timestamps.sql @@ -10,4 +10,5 @@ CREATE TABLE IF NOT EXISTS suspicious_batch_timestamps ( ) ENGINE = MergeTree PARTITION BY toYYYYMM(timestamp) -ORDER BY (stage, status, timestamp, instance_name, suspicious_batch_id, src_ip); +ORDER BY (stage, status, timestamp, instance_name, suspicious_batch_id, src_ip) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; \ No newline at end of file diff --git a/docker/create_tables/suspicious_batches_to_batch.sql b/docker/create_tables/suspicious_batches_to_batch.sql index d95fdda6..4ddfeeab 100644 --- a/docker/create_tables/suspicious_batches_to_batch.sql +++ b/docker/create_tables/suspicious_batches_to_batch.sql @@ -1,6 +1,9 @@ CREATE TABLE IF NOT EXISTS suspicious_batches_to_batch ( + timestamp DateTime64(6) NOT NULL, suspicious_batch_id UUID NOT NULL, batch_id UUID NOT NULL ) ENGINE = MergeTree -ORDER BY (batch_id, suspicious_batch_id); +ORDER BY (timestamp, batch_id, suspicious_batch_id) +PARTITION BY toYYYYMM(timestamp) +TTL toDateTime(timestamp) + INTERVAL 1 DAY; diff --git a/docker/create_tables/zz_monitoring_rollups.sql b/docker/create_tables/zz_monitoring_rollups.sql index 0155cd9d..a6782a7a 100644 --- a/docker/create_tables/zz_monitoring_rollups.sql +++ b/docker/create_tables/zz_monitoring_rollups.sql @@ -5,7 +5,9 @@ CREATE TABLE IF NOT EXISTS alerts_1m ( ) ENGINE = AggregatingMergeTree PARTITION BY toYYYYMM(time_bucket) -ORDER BY (time_bucket, src_ip); +ORDER BY (time_bucket, src_ip) +TTL toDateTime(time_bucket) + INTERVAL 1 DAY; + CREATE MATERIALIZED VIEW IF NOT EXISTS alerts_1m_mv TO alerts_1m @@ -28,7 +30,8 @@ CREATE TABLE IF NOT EXISTS fill_levels_1m ( ) ENGINE = AggregatingMergeTree PARTITION BY toYYYYMM(time_bucket) -ORDER BY (stage, entry_type, time_bucket); +ORDER BY (stage, entry_type, time_bucket) +TTL toDateTime(time_bucket) + INTERVAL 1 DAY; CREATE MATERIALIZED VIEW IF NOT EXISTS fill_levels_1m_mv TO fill_levels_1m @@ -52,7 +55,8 @@ CREATE TABLE IF NOT EXISTS server_log_latencies ( ) ENGINE = AggregatingMergeTree PARTITION BY toYYYYMM(event_date) -ORDER BY (event_date, message_id); +ORDER BY (event_date, message_id) +TTL toDateTime(event_date) + INTERVAL 1 DAY; CREATE MATERIALIZED VIEW IF NOT EXISTS server_log_start_latency_mv TO server_log_latencies @@ -85,7 +89,8 @@ CREATE TABLE IF NOT EXISTS logline_stage_latencies ( ) ENGINE = AggregatingMergeTree PARTITION BY toYYYYMM(event_date) -ORDER BY (stage, event_date, logline_id); +ORDER BY (stage, event_date, logline_id) +TTL toDateTime(event_date) + INTERVAL 1 DAY; CREATE MATERIALIZED VIEW IF NOT EXISTS logline_stage_latencies_mv TO logline_stage_latencies @@ -109,7 +114,8 @@ CREATE TABLE IF NOT EXISTS batch_stage_latencies ( ) ENGINE = AggregatingMergeTree PARTITION BY toYYYYMM(event_date) -ORDER BY (stage, event_date, instance_name, batch_id); +ORDER BY (stage, event_date, instance_name, batch_id) +TTL toDateTime(event_date) + INTERVAL 1 DAY; CREATE MATERIALIZED VIEW IF NOT EXISTS batch_stage_latencies_mv TO batch_stage_latencies @@ -134,7 +140,8 @@ CREATE TABLE IF NOT EXISTS suspicious_batch_stage_latencies ( ) ENGINE = AggregatingMergeTree PARTITION BY toYYYYMM(event_date) -ORDER BY (stage, event_date, instance_name, suspicious_batch_id); +ORDER BY (stage, event_date, instance_name, suspicious_batch_id) +TTL toDateTime(event_date) + INTERVAL 1 DAY; CREATE MATERIALIZED VIEW IF NOT EXISTS suspicious_batch_stage_latencies_mv TO suspicious_batch_stage_latencies diff --git a/docker/docker-compose/dev/docker-compose.pipeline.yml b/docker/docker-compose/dev/docker-compose.pipeline.yml index 6fecad92..eb67a7d2 100644 --- a/docker/docker-compose/dev/docker-compose.pipeline.yml +++ b/docker/docker-compose/dev/docker-compose.pipeline.yml @@ -8,6 +8,9 @@ services: - ../../../config.yaml:/app/config.yaml environment: - GROUP_ID=log_storage + deploy: + mode: "replicated" + replicas: 1 networks: hamstring: @@ -16,6 +19,9 @@ services: context: ../../.. dockerfile: docker/dockerfiles/Dockerfile.logcollector restart: "unless-stopped" + deploy: + mode: "replicated" + replicas: 1 volumes: - ../../../config.yaml:/app/config.yaml networks: diff --git a/docker/grafana-provisioning/dashboards/alert_inspector.json b/docker/grafana-provisioning/dashboards/alert_inspector.json new file mode 100644 index 00000000..71b04796 --- /dev/null +++ b/docker/grafana-provisioning/dashboards/alert_inspector.json @@ -0,0 +1,321 @@ +{ + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "editable": true, + "fiscalYearStartMonth": 0, + "graphTooltip": 0, + "links": [ + { + "asDropdown": true, + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [ + "HAMSTRING" + ], + "targetBlank": false, + "title": "Dashboards", + "tooltip": "Open another HAMSTRING dashboard", + "type": "dashboards", + "url": "" + } + ], + "panels": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "fieldConfig": { + "defaults": { + "custom": { + "align": "auto", + "cellOptions": { + "type": "auto", + "wrapText": false + }, + "filterable": true, + "inspect": true + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [ + { + "matcher": { + "id": "byName", + "options": "Suspicious batch ID" + }, + "properties": [ + { + "id": "links", + "value": [ + { + "targetBlank": false, + "title": "Inspect alert", + "url": "/d/edxz3fduc2m0af/alert-inspector?from=${__from}&to=${__to}&var-IpFilter=&var-DomainFilter=&var-BatchFilter=${__value.raw}&var-DetectorFilter=" + } + ] + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Client IP address" + }, + "properties": [ + { + "id": "links", + "value": [ + { + "targetBlank": false, + "title": "Filter by IP", + "url": "/d/edxz3fduc2m0af/alert-inspector?from=${__from}&to=${__to}&var-IpFilter=${__value.raw}&var-DomainFilter=&var-BatchFilter=&var-DetectorFilter=" + } + ] + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Detectors" + }, + "properties": [ + { + "id": "links", + "value": [ + { + "targetBlank": false, + "title": "Filter by detector", + "url": "/d/edxz3fduc2m0af/alert-inspector?from=${__from}&to=${__to}&var-IpFilter=&var-DomainFilter=&var-BatchFilter=&var-DetectorFilter=${__value.raw}" + } + ] + } + ] + }, + { + "matcher": { + "id": "byName", + "options": "Domains" + }, + "properties": [ + { + "id": "links", + "value": [ + { + "targetBlank": false, + "title": "Filter by domain", + "url": "/d/edxz3fduc2m0af/alert-inspector?from=${__from}&to=${__to}&var-IpFilter=&var-DomainFilter=${__value.raw}&var-BatchFilter=&var-DetectorFilter=" + } + ] + } + ] + } + ] + }, + "gridPos": { + "h": 10, + "w": 24, + "x": 0, + "y": 0 + }, + "id": 1, + "options": { + "cellHeight": "sm", + "footer": { + "countRows": false, + "enablePagination": true, + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, + "showHeader": true + }, + "pluginVersion": "11.2.2+security-01", + "targets": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "editorType": "sql", + "format": 1, + "pluginVersion": "4.6.0", + "queryType": "table", + "rawSql": "WITH parsed_alerts AS (\n SELECT\n alert_timestamp,\n src_ip,\n suspicious_batch_id,\n overall_score,\n arrayStringConcat(arrayDistinct(arrayFilter(x -> x != '', arrayMap(entry -> multiIf(\n substring(entry, 1, 1) = '\"', replaceRegexpAll(entry, '^\"|\"$', ''),\n JSONHas(entry, 'domain_name'), JSONExtractString(entry, 'domain_name'),\n entry\n ), arrayFlatten(arrayMap(item -> if(substring(item, 1, 1) = '[', JSONExtractArrayRaw(item), [item]), JSONExtractArrayRaw(domain_names)))))), ', ') AS domains,\n arrayStringConcat(arrayDistinct(arrayFilter(x -> x != '', arrayMap(warning -> JSONExtractString(warning, 'name'), JSONExtractArrayRaw(result)))), ', ') AS detector_names,\n result\n FROM alerts\n WHERE alert_timestamp >= $__fromTime\n AND alert_timestamp <= $__toTime\n)\nSELECT\n alert_timestamp AS \"Time\",\n src_ip AS \"Client IP address\",\n suspicious_batch_id AS \"Suspicious batch ID\",\n round(overall_score, 4) AS \"Score\",\n detector_names AS \"Detectors\",\n domains AS \"Domains\"\nFROM parsed_alerts\nWHERE ('${IpFilter}' = '' OR positionCaseInsensitive(src_ip, '${IpFilter}') > 0)\n AND ('${DomainFilter}' = '' OR positionCaseInsensitive(domains, '${DomainFilter}') > 0)\n AND ('${BatchFilter}' = '' OR positionCaseInsensitive(toString(suspicious_batch_id), '${BatchFilter}') > 0)\n AND ('${DetectorFilter}' = '' OR positionCaseInsensitive(detector_names, '${DetectorFilter}') > 0)\nORDER BY alert_timestamp DESC\nLIMIT 500;", + "refId": "A" + } + ], + "title": "Alerts", + "type": "table" + }, + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "fieldConfig": { + "defaults": { + "custom": { + "align": "auto", + "cellOptions": { + "type": "auto", + "wrapText": true + }, + "filterable": true, + "inspect": true + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 10, + "w": 24, + "x": 0, + "y": 10 + }, + "id": 2, + "options": { + "cellHeight": "lg", + "footer": { + "countRows": false, + "enablePagination": true, + "fields": "", + "reducer": [ + "sum" + ], + "show": false + }, + "showHeader": true + }, + "pluginVersion": "11.2.2+security-01", + "targets": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "editorType": "sql", + "format": 1, + "pluginVersion": "4.6.0", + "queryType": "table", + "rawSql": "WITH parsed_alerts AS (\n SELECT\n alert_timestamp,\n src_ip,\n suspicious_batch_id,\n overall_score,\n arrayStringConcat(arrayDistinct(arrayFilter(x -> x != '', arrayMap(entry -> multiIf(\n substring(entry, 1, 1) = '\"', replaceRegexpAll(entry, '^\"|\"$', ''),\n JSONHas(entry, 'domain_name'), JSONExtractString(entry, 'domain_name'),\n entry\n ), arrayFlatten(arrayMap(item -> if(substring(item, 1, 1) = '[', JSONExtractArrayRaw(item), [item]), JSONExtractArrayRaw(domain_names)))))), ', ') AS domains,\n result\n FROM alerts\n WHERE alert_timestamp >= $__fromTime\n AND alert_timestamp <= $__toTime\n), detector_outputs AS (\n SELECT\n alert_timestamp,\n src_ip,\n suspicious_batch_id,\n domains,\n warning,\n multiIf(\n JSONExtractString(warning, 'detector_name') != '', JSONExtractString(warning, 'detector_name'),\n JSONExtractString(warning, 'name') != '', JSONExtractString(warning, 'name'),\n ''\n ) AS detector_name,\n multiIf(\n JSONHas(warning, 'score'), JSONExtractFloat(warning, 'score'),\n JSONHas(warning, 'probability'), JSONExtractFloat(warning, 'probability'),\n 0\n ) AS score,\n JSONExtractString(warning, 'predicted_class') AS predicted_class,\n multiIf(\n JSONHas(warning, 'attributes'), JSONExtractRaw(warning, 'attributes'),\n JSONHas(warning, 'class_probabilities'), JSONExtractRaw(warning, 'class_probabilities'),\n ''\n ) AS detector_class_output,\n multiIf(\n JSONHas(warning, 'domains'), JSONExtractRaw(warning, 'domains'),\n JSONHas(warning, 'domain_names'), JSONExtractRaw(warning, 'domain_names'),\n ''\n ) AS detector_domains,\n multiIf(JSONHas(warning, 'logline_ids'), JSONExtractRaw(warning, 'logline_ids'), '') AS logline_ids,\n multiIf(JSONHas(warning, 'server_message_ids'), JSONExtractRaw(warning, 'server_message_ids'), '') AS server_message_ids,\n JSONExtractRaw(warning, 'request') AS request_raw,\n multiIf(JSONHas(warning, 'raw_detector_output'), JSONExtractRaw(warning, 'raw_detector_output'), warning) AS raw_detector_output\n FROM parsed_alerts\n ARRAY JOIN JSONExtractArrayRaw(result) AS warning\n)\nSELECT\n alert_timestamp AS \"Time\",\n src_ip AS \"Client IP address\",\n suspicious_batch_id AS \"Suspicious batch ID\",\n detector_name AS \"Detector\",\n round(score, 4) AS \"Score\",\n predicted_class AS \"Predicted class\",\n detector_class_output AS \"Class output\",\n detector_domains AS \"Detector domains\",\n logline_ids AS \"Logline IDs\",\n server_message_ids AS \"Server message IDs\",\n request_raw AS \"Raw request/window\",\n raw_detector_output AS \"Raw detector output\"\nFROM detector_outputs\nWHERE ('${IpFilter}' = '' OR positionCaseInsensitive(src_ip, '${IpFilter}') > 0)\n AND ('${DomainFilter}' = '' OR positionCaseInsensitive(concat(domains, ' ', detector_domains, ' ', request_raw), '${DomainFilter}') > 0)\n AND ('${BatchFilter}' = '' OR positionCaseInsensitive(toString(suspicious_batch_id), '${BatchFilter}') > 0)\n AND ('${DetectorFilter}' = '' OR positionCaseInsensitive(detector_name, '${DetectorFilter}') > 0)\nORDER BY alert_timestamp DESC, score DESC\nLIMIT 500;", + "refId": "A" + } + ], + "title": "Detector outputs", + "type": "table" + } + ], + "refresh": "auto", + "schemaVersion": 39, + "tags": [ + "HAMSTRING" + ], + "templating": { + "list": [ + { + "current": { + "selected": false, + "text": "", + "value": "" + }, + "hide": 0, + "label": "IP", + "name": "IpFilter", + "options": [], + "query": "", + "skipUrlSync": false, + "type": "textbox" + }, + { + "current": { + "selected": false, + "text": "", + "value": "" + }, + "hide": 0, + "label": "DNS name", + "name": "DomainFilter", + "options": [], + "query": "", + "skipUrlSync": false, + "type": "textbox" + }, + { + "current": { + "selected": false, + "text": "", + "value": "" + }, + "hide": 0, + "label": "Selected batch ID", + "name": "BatchFilter", + "options": [], + "query": "", + "skipUrlSync": false, + "type": "textbox" + }, + { + "current": { + "selected": false, + "text": "", + "value": "" + }, + "hide": 0, + "label": "Detector", + "name": "DetectorFilter", + "options": [], + "query": "", + "skipUrlSync": false, + "type": "textbox" + } + ] + }, + "time": { + "from": "now-24h", + "to": "now" + }, + "timepicker": {}, + "timezone": "browser", + "title": "Alert Inspector", + "uid": "edxz3fduc2m0af", + "version": 1, + "weekStart": "" +} diff --git a/docker/grafana-provisioning/dashboards/alerts.json b/docker/grafana-provisioning/dashboards/alerts.json index 7cdcd06c..436382c9 100644 --- a/docker/grafana-provisioning/dashboards/alerts.json +++ b/docker/grafana-provisioning/dashboards/alerts.json @@ -18,7 +18,22 @@ "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, - "links": [], + "links": [ + { + "asDropdown": true, + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [ + "HAMSTRING" + ], + "targetBlank": false, + "title": "Dashboards", + "tooltip": "Open another HAMSTRING dashboard", + "type": "dashboards", + "url": "" + } + ], "panels": [ { "datasource": { @@ -629,7 +644,7 @@ }, "pluginVersion": "4.6.0", "queryType": "table", - "rawSql": "WITH\n JSONExtractArrayRaw(domain_names) AS top_items,\n arrayFlatten(arrayMap(item -> if(substring(item, 1, 1) = '[', JSONExtractArrayRaw(item), [item]), top_items)) AS entries,\n arrayFilter(x -> x != '', arrayMap(entry -> JSONExtractString(entry, 'domain_name'), entries)) AS domains\nSELECT\n src_ip AS \"Client IP address\",\n arrayStringConcat(arrayDistinct(domains), ', ') AS \"Domains used\"\nFROM alerts\nORDER BY alert_timestamp DESC\nLIMIT 20;", + "rawSql": "WITH\n JSONExtractArrayRaw(domain_names) AS top_items,\n arrayFlatten(arrayMap(item -> if(substring(item, 1, 1) = '[', JSONExtractArrayRaw(item), [item]), top_items)) AS entries,\n arrayFilter(x -> x != '', arrayMap(entry -> multiIf(\n substring(entry, 1, 1) = '\"', replaceRegexpAll(entry, '^\"|\"$', ''),\n JSONHas(entry, 'domain_name'), JSONExtractString(entry, 'domain_name'),\n entry\n ), entries)) AS domains\nSELECT\n src_ip AS \"Client IP address\",\n arrayStringConcat(arrayDistinct(domains), ', ') AS \"Domains used\"\nFROM alerts\nORDER BY alert_timestamp DESC\nLIMIT 20;", "refId": "A" } ], @@ -725,7 +740,9 @@ ], "refresh": "auto", "schemaVersion": 39, - "tags": [], + "tags": [ + "HAMSTRING" + ], "templating": { "list": [ { diff --git a/docker/grafana-provisioning/dashboards/kafka_exporter.json b/docker/grafana-provisioning/dashboards/kafka_exporter.json new file mode 100644 index 00000000..cd8da095 --- /dev/null +++ b/docker/grafana-provisioning/dashboards/kafka_exporter.json @@ -0,0 +1,615 @@ +{ + "__requires": [ + { + "type": "grafana", + "id": "grafana", + "name": "Grafana", + "version": "5.1.1" + }, + { + "type": "panel", + "id": "graph", + "name": "Graph", + "version": "5.0.0" + }, + { + "type": "datasource", + "id": "prometheus", + "name": "Prometheus", + "version": "5.0.0" + } + ], + "annotations": { + "list": [ + { + "builtIn": 1, + "datasource": { + "type": "grafana", + "uid": "-- Grafana --" + }, + "enable": true, + "hide": true, + "iconColor": "rgba(0, 211, 255, 1)", + "name": "Annotations & Alerts", + "type": "dashboard" + } + ] + }, + "description": "Kafka resource usage and throughput", + "editable": true, + "gnetId": 7589, + "graphTooltip": 0, + "id": null, + "iteration": 1534756791145, + "links": [ + { + "asDropdown": true, + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [ + "HAMSTRING" + ], + "targetBlank": false, + "title": "Dashboards", + "tooltip": "Open another HAMSTRING dashboard", + "type": "dashboards", + "url": "" + } + ], + "panels": [ + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "prometheus", + "fill": 0, + "gridPos": { + "h": 10, + "w": 10, + "x": 0, + "y": 0 + }, + "id": 14, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": true, + "min": false, + "rightSide": false, + "show": true, + "sideWidth": 480, + "sort": "max", + "sortDesc": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "connected", + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(rate(kafka_topic_partition_current_offset{job=~\"$job\",instance=~\"$instance\",topic=~\"$topic\"}[1m])) by (topic)", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "{{topic}}", + "refId": "B" + } + ], + "thresholds": [], + "timeFrom": null, + "timeShift": null, + "title": "Message in per second", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "prometheus", + "fill": 0, + "gridPos": { + "h": 10, + "w": 10, + "x": 10, + "y": 0 + }, + "id": 12, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": true, + "min": false, + "rightSide": false, + "show": true, + "sideWidth": 480, + "sortDesc": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "connected", + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(kafka_consumergroup_lag{job=~\"$job\",instance=~\"$instance\",topic=~\"$topic\"}) by (consumergroup, topic)", + "format": "time_series", + "instant": false, + "interval": "", + "intervalFactor": 1, + "legendFormat": "{{consumergroup}} (topic: {{topic}})", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeShift": null, + "title": "Lag by Consumer Group", + "tooltip": { + "shared": true, + "sort": 2, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": "", + "logBase": 1, + "max": null, + "min": "0", + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "prometheus", + "fill": 0, + "gridPos": { + "h": 10, + "w": 10, + "x": 0, + "y": 10 + }, + "id": 16, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": true, + "min": false, + "rightSide": false, + "show": true, + "sideWidth": 480, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "connected", + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(delta(kafka_topic_partition_current_offset{job=~\"$job\",instance=~\"$instance\",topic=~\"$topic\"}[5m])/5) by (topic)", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "{{topic}}", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeShift": null, + "title": "Message in per minute", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "transparent": false, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": false, + "dashLength": 10, + "dashes": false, + "datasource": "prometheus", + "fill": 0, + "gridPos": { + "h": 10, + "w": 10, + "x": 10, + "y": 10 + }, + "id": 18, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": true, + "min": false, + "rightSide": false, + "show": true, + "sideWidth": 480, + "sort": "current", + "sortDesc": true, + "total": false, + "values": true + }, + "lines": true, + "linewidth": 1, + "links": [], + "nullPointMode": "connected", + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum(delta(kafka_consumergroup_current_offset{job=~\"$job\",instance=~\"$instance\",topic=~\"$topic\"}[5m])/5) by (consumergroup, topic)", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "{{consumergroup}} (topic: {{topic}})", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeShift": null, + "title": "Message consume per minute", + "tooltip": { + "shared": true, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "time", + "name": null, + "show": true, + "values": [] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + }, + { + "aliasColors": {}, + "bars": true, + "dashLength": 10, + "dashes": false, + "datasource": "prometheus", + "fill": 1, + "gridPos": { + "h": 7, + "w": 20, + "x": 0, + "y": 20 + }, + "id": 8, + "legend": { + "alignAsTable": true, + "avg": false, + "current": true, + "max": false, + "min": false, + "rightSide": true, + "show": true, + "sideWidth": 420, + "total": false, + "values": true + }, + "lines": false, + "linewidth": 1, + "links": [], + "nullPointMode": "null", + "percentage": false, + "pointradius": 5, + "points": false, + "renderer": "flot", + "seriesOverrides": [], + "spaceLength": 10, + "stack": false, + "steppedLine": false, + "targets": [ + { + "expr": "sum by(topic) (kafka_topic_partitions{job=~\"$job\",instance=~\"$instance\",topic=~\"$topic\"})", + "format": "time_series", + "intervalFactor": 1, + "legendFormat": "{{topic}}", + "refId": "A" + } + ], + "thresholds": [], + "timeFrom": null, + "timeShift": null, + "title": "Partitions per Topic", + "tooltip": { + "shared": false, + "sort": 0, + "value_type": "individual" + }, + "type": "graph", + "xaxis": { + "buckets": null, + "mode": "series", + "name": null, + "show": false, + "values": [ + "current" + ] + }, + "yaxes": [ + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + }, + { + "format": "short", + "label": null, + "logBase": 1, + "max": null, + "min": null, + "show": true + } + ], + "yaxis": { + "align": false, + "alignLevel": null + } + } + ], + "refresh": "10s", + "schemaVersion": 16, + "style": "dark", + "tags": [ + "HAMSTRING", + "Kafka" + ], + "templating": { + "list": [ + { + "allValue": null, + "current": { + "selected": true, + "text": "kafka-exporter", + "value": "kafka-exporter" + }, + "datasource": "prometheus", + "hide": 0, + "includeAll": false, + "label": "Job", + "multi": false, + "name": "job", + "options": [], + "query": "label_values(kafka_topic_partition_current_offset, job)", + "refresh": 1, + "regex": "", + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": null, + "current": {}, + "datasource": "prometheus", + "hide": 0, + "includeAll": false, + "label": "Instance", + "multi": false, + "name": "instance", + "options": [], + "query": "label_values(kafka_topic_partition_current_offset{job=~\"$job\"}, instance)", + "refresh": 1, + "regex": "", + "sort": 0, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "", + "type": "query", + "useTags": false + }, + { + "allValue": ".*", + "current": {}, + "datasource": "prometheus", + "hide": 0, + "includeAll": true, + "label": "Topic", + "multi": true, + "name": "topic", + "options": [], + "query": "label_values(kafka_topic_partition_current_offset{job=~\"$job\",instance=~\"$instance\",topic!=\"__consumer_offsets\"}, topic)", + "refresh": 1, + "regex": "", + "sort": 1, + "tagValuesQuery": "", + "tags": [], + "tagsQuery": "topic", + "type": "query", + "useTags": false + } + ] + }, + "time": { + "from": "now-24h", + "to": "now" + }, + "timepicker": { + "refresh_intervals": [ + "5s", + "10s", + "30s", + "1m", + "5m", + "15m", + "30m", + "1h", + "2h", + "1d" + ], + "time_options": [ + "5m", + "15m", + "1h", + "6h", + "12h", + "24h", + "2d", + "7d", + "30d" + ] + }, + "timezone": "browser", + "title": "HAMSTRING Kafka Exporter", + "uid": "hamstring-kafka-exporter", + "version": 50 +} diff --git a/docker/grafana-provisioning/dashboards/latencies.json b/docker/grafana-provisioning/dashboards/latencies.json index 742c466b..3e686ff8 100644 --- a/docker/grafana-provisioning/dashboards/latencies.json +++ b/docker/grafana-provisioning/dashboards/latencies.json @@ -19,7 +19,22 @@ "fiscalYearStartMonth": 0, "graphTooltip": 0, "id": 2, - "links": [], + "links": [ + { + "asDropdown": true, + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [ + "HAMSTRING" + ], + "targetBlank": false, + "title": "Dashboards", + "tooltip": "Open another HAMSTRING dashboard", + "type": "dashboards", + "url": "" + } + ], "liveNow": false, "panels": [ { @@ -643,13 +658,192 @@ ], "type": "bargauge" }, + { + "datasource": { + "default": false, + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "description": "Average time for alert-producing requests. Uses exact LogServer intake to alerter processed timestamps when mapping data exists; otherwise falls back to historical collector-to-alert latency plus average LogServer latency.", + "fieldConfig": { + "defaults": { + "color": { + "mode": "thresholds" + }, + "mappings": [], + "min": 0, + "noValue": "-", + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 1000000 + }, + { + "color": "red", + "value": 5000000 + } + ] + }, + "unit": "\u00b5s" + }, + "overrides": [] + }, + "gridPos": { + "h": 3, + "w": 19, + "x": 0, + "y": 14 + }, + "id": 81, + "options": { + "colorMode": "value", + "graphMode": "none", + "justifyMode": "auto", + "orientation": "auto", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showPercentChange": false, + "textMode": "auto", + "wideLayout": true, + "percentChangeColorMode": "inverted" + }, + "pluginVersion": "11.2.2+security-01", + "targets": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "editorType": "sql", + "format": 1, + "meta": { + "builderOptions": { + "columns": [], + "database": "", + "limit": 1000, + "mode": "list", + "queryType": "table", + "table": "" + } + }, + "pluginVersion": "4.10.1", + "queryType": "table", + "rawSql": "SELECT if(isFinite(mapped.average_latency_us), mapped.average_latency_us, if(isFinite(fallback.average_latency_us), fallback.average_latency_us, 0)) AS \"Average alerting request roundtrip latency\"\nFROM (\n SELECT avg(latency_us) AS average_latency_us\n FROM (\n SELECT\n sl.message_id AS message_id,\n sltl.logline_id AS logline_id,\n a.suspicious_batch_id AS suspicious_batch_id,\n dateDiff('microsecond', min(sl.timestamp_in), max(ste.timestamp)) AS latency_us\n FROM alerts a\n INNER JOIN suspicious_batches_to_batch sbtb\n ON a.suspicious_batch_id = sbtb.suspicious_batch_id\n INNER JOIN logline_to_batches ltb\n ON sbtb.batch_id = ltb.batch_id\n INNER JOIN loglines ll\n ON ltb.logline_id = ll.logline_id\n AND ll.src_ip = a.src_ip\n INNER JOIN server_log_to_logline sltl\n ON ltb.logline_id = sltl.logline_id\n INNER JOIN server_logs sl\n ON sltl.message_id = sl.message_id\n INNER JOIN server_log_terminal_events ste\n ON sl.message_id = ste.message_id\n AND ste.stage = 'pipeline.alerter'\n AND ste.status = 'processed'\n WHERE toDate(ste.timestamp) >= toDate($__fromTime)\n AND toDate(ste.timestamp) <= toDate($__toTime)\n AND ste.timestamp >= $__fromTime\n AND ste.timestamp <= $__toTime\n AND sl.timestamp_in <= ste.timestamp\n GROUP BY sl.message_id, sltl.logline_id, a.suspicious_batch_id\n )\n WHERE latency_us > 0\n) mapped\nCROSS JOIN (\n SELECT avg(latency_us) AS average_latency_us\n FROM (\n SELECT\n ltb.logline_id AS logline_id,\n a.suspicious_batch_id AS suspicious_batch_id,\n dateDiff('microsecond', min(lt.timestamp), min(a.alert_timestamp)) + any(server_latency_us) AS latency_us\n FROM alerts a\n INNER JOIN suspicious_batches_to_batch sbtb\n ON a.suspicious_batch_id = sbtb.suspicious_batch_id\n INNER JOIN logline_to_batches ltb\n ON sbtb.batch_id = ltb.batch_id\n INNER JOIN loglines ll\n ON ltb.logline_id = ll.logline_id\n AND ll.src_ip = a.src_ip\n INNER JOIN logline_timestamps lt\n ON ltb.logline_id = lt.logline_id\n CROSS JOIN (\n SELECT if(isFinite(avg(latency_us)), avg(latency_us), 0) AS server_latency_us\n FROM server_log_latency_values\n WHERE event_date >= toDate($__fromTime)\n AND event_date <= toDate($__toTime)\n AND end_timestamp >= $__fromTime\n AND end_timestamp <= $__toTime\n ) server_latency\n WHERE lt.stage = 'log_collection.collector'\n AND lt.status = 'in_process'\n AND toDate(a.alert_timestamp) >= toDate($__fromTime)\n AND toDate(a.alert_timestamp) <= toDate($__toTime)\n AND a.alert_timestamp >= $__fromTime\n AND a.alert_timestamp <= $__toTime\n AND lt.timestamp <= a.alert_timestamp\n GROUP BY ltb.logline_id, a.suspicious_batch_id\n )\n WHERE latency_us > 0\n) fallback;\n", + "refId": "alerting_request_roundtrip_latency" + } + ], + "title": "Average alerting request roundtrip latency", + "type": "stat" + }, + { + "datasource": { + "default": false, + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "description": "Average LogServer-to-terminal-stage latency for all completed logevents. Uses exact mapped terminal events when available; otherwise falls back to historical collector-terminal latency plus average LogServer latency.", + "fieldConfig": { + "defaults": { + "color": { + "fixedColor": "text", + "mode": "continuous-GrYlRd" + }, + "fieldMinMax": false, + "mappings": [], + "min": 0, + "noValue": "-", + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "yellow", + "value": 1000000 + }, + { + "color": "red", + "value": 5000000 + } + ] + }, + "unit": "\u00b5s" + }, + "overrides": [] + }, + "gridPos": { + "h": 3, + "w": 19, + "x": 0, + "y": 17 + }, + "id": 82, + "options": { + "displayMode": "gradient", + "maxVizHeight": 300, + "minVizHeight": 16, + "minVizWidth": 8, + "namePlacement": "top", + "orientation": "horizontal", + "reduceOptions": { + "calcs": [ + "lastNotNull" + ], + "fields": "", + "values": false + }, + "showUnfilled": true, + "sizing": "auto", + "valueMode": "text" + }, + "pluginVersion": "11.2.2+security-01", + "targets": [ + { + "datasource": { + "type": "grafana-clickhouse-datasource", + "uid": "PDEE91DDB90597936" + }, + "editorType": "sql", + "format": 1, + "meta": { + "builderOptions": { + "columns": [], + "database": "", + "limit": 1000, + "mode": "list", + "queryType": "table", + "table": "" + } + }, + "pluginVersion": "4.10.1", + "queryType": "table", + "rawSql": "SELECT if(isFinite(mapped.average_latency_us), mapped.average_latency_us, if(isFinite(fallback.average_latency_us), fallback.average_latency_us, 0)) AS \"Average all logevents roundtrip latency\"\nFROM (\n SELECT avg(latency_us) AS average_latency_us\n FROM (\n SELECT\n message_id,\n dateDiff('microsecond', start_timestamp, terminal_timestamp) AS latency_us\n FROM (\n SELECT\n sl.message_id AS message_id,\n min(sl.timestamp_in) AS start_timestamp,\n max(terminal.terminal_timestamp) AS terminal_timestamp\n FROM server_logs sl\n INNER JOIN (\n SELECT\n sltl.message_id AS message_id,\n max(lt.timestamp) AS terminal_timestamp\n FROM server_log_to_logline sltl\n INNER JOIN logline_timestamps lt\n ON sltl.logline_id = lt.logline_id\n WHERE lt.is_active = false\n GROUP BY sltl.message_id\n\n UNION ALL\n\n SELECT\n message_id,\n max(timestamp) AS terminal_timestamp\n FROM server_log_terminal_events\n GROUP BY message_id\n ) terminal\n ON sl.message_id = terminal.message_id\n GROUP BY sl.message_id\n )\n WHERE terminal_timestamp >= start_timestamp\n AND toDate(terminal_timestamp) >= toDate($__fromTime)\n AND toDate(terminal_timestamp) <= toDate($__toTime)\n AND terminal_timestamp >= $__fromTime\n AND terminal_timestamp <= $__toTime\n )\n WHERE latency_us > 0\n) mapped\nCROSS JOIN (\n SELECT avg(latency_us) AS average_latency_us\n FROM (\n SELECT\n dateDiff('microsecond', collector_started_at, terminal_timestamp) + server_latency_us AS latency_us\n FROM (\n SELECT\n lt.logline_id AS logline_id,\n minIf(lt.timestamp, lt.stage = 'log_collection.collector' AND lt.status = 'in_process') AS collector_started_at,\n maxIf(lt.timestamp, lt.is_active = false) AS terminal_timestamp,\n any(server_latency_us) AS server_latency_us\n FROM logline_timestamps lt\n CROSS JOIN (\n SELECT if(isFinite(avg(latency_us)), avg(latency_us), 0) AS server_latency_us\n FROM server_log_latency_values\n WHERE event_date >= toDate($__fromTime)\n AND event_date <= toDate($__toTime)\n AND end_timestamp >= $__fromTime\n AND end_timestamp <= $__toTime\n ) server_latency\n GROUP BY lt.logline_id\n )\n WHERE collector_started_at > toDateTime64(0, 6)\n AND terminal_timestamp >= collector_started_at\n AND toDate(terminal_timestamp) >= toDate($__fromTime)\n AND toDate(terminal_timestamp) <= toDate($__toTime)\n AND terminal_timestamp >= $__fromTime\n AND terminal_timestamp <= $__toTime\n\n UNION ALL\n\n SELECT\n dateDiff('microsecond', timestamp_in, timestamp_failed) + server_latency_us AS latency_us\n FROM failed_loglines\n CROSS JOIN (\n SELECT if(isFinite(avg(latency_us)), avg(latency_us), 0) AS server_latency_us\n FROM server_log_latency_values\n WHERE event_date >= toDate($__fromTime)\n AND event_date <= toDate($__toTime)\n AND end_timestamp >= $__fromTime\n AND end_timestamp <= $__toTime\n ) server_latency\n WHERE timestamp_failed >= timestamp_in\n AND toDate(timestamp_failed) >= toDate($__fromTime)\n AND toDate(timestamp_failed) <= toDate($__toTime)\n AND timestamp_failed >= $__fromTime\n AND timestamp_failed <= $__toTime\n )\n WHERE latency_us > 0\n) fallback;\n", + "refId": "all_logevents_roundtrip_latency" + } + ], + "title": "Average all logevents roundtrip latency", + "type": "bargauge" + }, { "collapsed": false, "gridPos": { "h": 1, "w": 24, "x": 0, - "y": 14 + "y": 20 }, "id": 33, "panels": [], @@ -775,7 +969,7 @@ "h": 6, "w": 6, "x": 0, - "y": 15 + "y": 21 }, "id": 22, "options": { @@ -875,7 +1069,7 @@ "h": 3, "w": 5, "x": 6, - "y": 15 + "y": 21 }, "id": 29, "options": { @@ -1042,7 +1236,7 @@ "h": 6, "w": 6, "x": 12, - "y": 15 + "y": 21 }, "id": 25, "options": { @@ -1142,7 +1336,7 @@ "h": 3, "w": 5, "x": 18, - "y": 15 + "y": 21 }, "id": 35, "options": { @@ -1222,7 +1416,7 @@ "h": 3, "w": 5, "x": 6, - "y": 18 + "y": 24 }, "id": 28, "options": { @@ -1301,7 +1495,7 @@ "h": 3, "w": 5, "x": 18, - "y": 18 + "y": 24 }, "id": 36, "options": { @@ -1468,7 +1662,7 @@ "h": 6, "w": 6, "x": 0, - "y": 21 + "y": 27 }, "id": 23, "options": { @@ -1568,7 +1762,7 @@ "h": 3, "w": 5, "x": 6, - "y": 21 + "y": 27 }, "id": 30, "options": { @@ -1735,7 +1929,7 @@ "h": 6, "w": 6, "x": 12, - "y": 21 + "y": 27 }, "id": 26, "options": { @@ -1835,7 +2029,7 @@ "h": 3, "w": 5, "x": 18, - "y": 21 + "y": 27 }, "id": 37, "options": { @@ -1915,7 +2109,7 @@ "h": 3, "w": 5, "x": 6, - "y": 24 + "y": 30 }, "id": 31, "options": { @@ -1994,7 +2188,7 @@ "h": 3, "w": 5, "x": 18, - "y": 24 + "y": 30 }, "id": 38, "options": { @@ -2160,7 +2354,7 @@ "h": 6, "w": 6, "x": 0, - "y": 27 + "y": 33 }, "id": 24, "options": { @@ -2260,7 +2454,7 @@ "h": 3, "w": 5, "x": 6, - "y": 27 + "y": 33 }, "id": 32, "options": { @@ -2427,7 +2621,7 @@ "h": 6, "w": 6, "x": 12, - "y": 27 + "y": 33 }, "id": 27, "options": { @@ -2527,7 +2721,7 @@ "h": 3, "w": 5, "x": 18, - "y": 27 + "y": 33 }, "id": 39, "options": { @@ -2607,7 +2801,7 @@ "h": 3, "w": 5, "x": 6, - "y": 30 + "y": 36 }, "id": 34, "options": { @@ -2686,7 +2880,7 @@ "h": 3, "w": 5, "x": 18, - "y": 30 + "y": 36 }, "id": 40, "options": { @@ -2740,7 +2934,7 @@ "h": 1, "w": 24, "x": 0, - "y": 33 + "y": 39 }, "id": 71, "panels": [], @@ -2865,7 +3059,7 @@ "h": 6, "w": 8, "x": 0, - "y": 34 + "y": 40 }, "id": 74, "options": { @@ -3052,7 +3246,7 @@ "h": 6, "w": 8, "x": 8, - "y": 34 + "y": 40 }, "id": 75, "options": { @@ -3237,7 +3431,7 @@ "h": 6, "w": 8, "x": 16, - "y": 34 + "y": 40 }, "id": 76, "options": { @@ -3350,7 +3544,7 @@ "h": 3, "w": 8, "x": 0, - "y": 40 + "y": 46 }, "id": 57, "options": { @@ -3443,7 +3637,7 @@ "h": 3, "w": 8, "x": 8, - "y": 40 + "y": 46 }, "id": 48, "options": { @@ -3536,7 +3730,7 @@ "h": 3, "w": 8, "x": 16, - "y": 40 + "y": 46 }, "id": 49, "options": { @@ -3586,7 +3780,9 @@ ], "refresh": "5s", "schemaVersion": 39, - "tags": [], + "tags": [ + "HAMSTRING" + ], "templating": { "list": [ { diff --git a/docker/grafana-provisioning/dashboards/log_volumes.json b/docker/grafana-provisioning/dashboards/log_volumes.json index 4b05ebcf..d73dcf3e 100644 --- a/docker/grafana-provisioning/dashboards/log_volumes.json +++ b/docker/grafana-provisioning/dashboards/log_volumes.json @@ -18,7 +18,22 @@ "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, - "links": [], + "links": [ + { + "asDropdown": true, + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [ + "HAMSTRING" + ], + "targetBlank": false, + "title": "Dashboards", + "tooltip": "Open another HAMSTRING dashboard", + "type": "dashboards", + "url": "" + } + ], "liveNow": false, "panels": [ { @@ -2149,7 +2164,9 @@ ], "refresh": "auto", "schemaVersion": 39, - "tags": [], + "tags": [ + "HAMSTRING" + ], "templating": { "list": [ { diff --git a/docker/grafana-provisioning/dashboards/overview.json b/docker/grafana-provisioning/dashboards/overview.json index d3a809b5..af78d5c6 100644 --- a/docker/grafana-provisioning/dashboards/overview.json +++ b/docker/grafana-provisioning/dashboards/overview.json @@ -18,7 +18,22 @@ "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 0, - "links": [], + "links": [ + { + "asDropdown": true, + "icon": "external link", + "includeVars": true, + "keepTime": true, + "tags": [ + "HAMSTRING" + ], + "targetBlank": false, + "title": "Dashboards", + "tooltip": "Open another HAMSTRING dashboard", + "type": "dashboards", + "url": "" + } + ], "panels": [ { "datasource": { @@ -922,7 +937,9 @@ ], "refresh": "auto", "schemaVersion": 39, - "tags": [], + "tags": [ + "HAMSTRING" + ], "templating": { "list": [ { diff --git a/docker/grafana-provisioning/datasources.yaml b/docker/grafana-provisioning/datasources.yaml index fbfcdc08..57c157fa 100644 --- a/docker/grafana-provisioning/datasources.yaml +++ b/docker/grafana-provisioning/datasources.yaml @@ -15,7 +15,7 @@ datasources: - name: prometheus type: prometheus access: proxy - url: http://prometheus:9088 + url: http://prometheus:9090 isDefault: true jsonData: httpMethod: POST diff --git a/docs/api/base.rst b/docs/api/base.rst new file mode 100644 index 00000000..f494dcfe --- /dev/null +++ b/docs/api/base.rst @@ -0,0 +1,5 @@ +Base +==== + +.. automodule:: src.base + diff --git a/docs/api/detector.rst b/docs/api/detector.rst new file mode 100644 index 00000000..03c31da2 --- /dev/null +++ b/docs/api/detector.rst @@ -0,0 +1,5 @@ +Detector +======== + +.. automodule:: src.detector + diff --git a/docs/api/inspector.rst b/docs/api/inspector.rst new file mode 100644 index 00000000..7ca22f1e --- /dev/null +++ b/docs/api/inspector.rst @@ -0,0 +1,5 @@ +Inspector +========= + +.. automodule:: src.inspector + diff --git a/docs/api/logcollector.rst b/docs/api/logcollector.rst new file mode 100644 index 00000000..f7c3ee09 --- /dev/null +++ b/docs/api/logcollector.rst @@ -0,0 +1,5 @@ +Log Collector +============= + +.. automodule:: src.logcollector + diff --git a/docs/api/logserver.rst b/docs/api/logserver.rst new file mode 100644 index 00000000..a8394239 --- /dev/null +++ b/docs/api/logserver.rst @@ -0,0 +1,5 @@ +Log Server +========== + +.. automodule:: src.logserver + diff --git a/docs/api/monitoring.rst b/docs/api/monitoring.rst new file mode 100644 index 00000000..652c988c --- /dev/null +++ b/docs/api/monitoring.rst @@ -0,0 +1,5 @@ +Monitoring +========== + +.. automodule:: src.monitoring + diff --git a/docs/api/prefilter.rst b/docs/api/prefilter.rst new file mode 100644 index 00000000..f806c639 --- /dev/null +++ b/docs/api/prefilter.rst @@ -0,0 +1,5 @@ +Prefilter +========= + +.. automodule:: src.prefilter + diff --git a/docs/api/train.rst b/docs/api/train.rst new file mode 100644 index 00000000..836ba206 --- /dev/null +++ b/docs/api/train.rst @@ -0,0 +1,5 @@ +Train +===== + +.. automodule:: src.train + diff --git a/docs/api/version.rst b/docs/api/version.rst new file mode 100644 index 00000000..e3c1f69f --- /dev/null +++ b/docs/api/version.rst @@ -0,0 +1,5 @@ +Version +======= + +.. automodule:: src.version + diff --git a/docs/conf.py b/docs/conf.py index 99c11c6f..d2ac102e 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -59,13 +59,13 @@ napoleon_use_rtype = False intersphinx_mapping = { - "python": ("https://docs.python.org/3/", None), - "sphinx": ("https://www.sphinx-doc.org/en/master/", None), + "python": ("https://docs.python.org/3", None), + "sphinx": ("https://www.sphinx-doc.org/en/master", None), } intersphinx_disabled_domains = ["std"] templates_path = ["_templates"] -exclude_patterns = ['_build, "Thumbs.db', ".DS_Store"] +exclude_patterns = ["_build", "output", "Thumbs.db", ".DS_Store"] # -- Options for HTML output html_theme = "sphinx_book_theme" @@ -73,7 +73,7 @@ "use_repository_button": True, "repository_url": "https://github.com/stefanDeveloper/HAMSTRING", } -html_logo = "../assets/hamstring_logo_readthedocs.png" +html_logo = "../assets/hamstring.svg" # -- Options for EPUB output epub_show_urls = "footnote" diff --git a/docs/configuration.rst b/docs/configuration.rst index 8832f516..6d7af4d6 100644 --- a/docs/configuration.rst +++ b/docs/configuration.rst @@ -119,9 +119,21 @@ functionality of the modules. ``pipeline.scaling`` ^^^^^^^^^^^^^^^^^^^^ -Controls the executor used by each pipeline module when it runs blocking work from the asyncio event loop. -Values from ``defaults`` apply to every module and can be overridden per module under ``modules``. Modules -that run several configured instances can also override an individual instance under ``instances``. +Controls how many independent workers each pipeline module starts. Each worker owns its own Kafka +consumer and producer, so workers consuming the same topic join the same Kafka consumer group and can +process different partitions in parallel. Values from ``defaults`` apply to every module and can be +overridden per module under ``modules``. Modules that run several configured instances can also override +an individual instance under ``instances``. + +Scaling is resolved in this order: + +#. ``pipeline.scaling.defaults`` +#. ``pipeline.scaling.modules.`` +#. ``pipeline.scaling.modules..instances.`` + +The instance names are the configured pipeline object names, not Docker service names. For example, +``log_collection.collector.instances.dga_collector`` applies only to the collector whose +``pipeline.log_collection.collectors[].name`` is ``dga_collector``. .. code-block:: yaml @@ -140,10 +152,169 @@ that run several configured instances can also override an individual instance u data_analysis.detector: executor: process processes: 2 + pipeline.alerter: + executor: hybrid + processes: 2 + threads_per_process: 4 + +.. list-table:: Scaling options + :header-rows: 1 + :widths: 25 20 55 + + * - Parameter + - Default + - Description + * - ``executor`` + - ``thread`` + - Worker model. Valid values are ``thread``, ``process``, and ``hybrid``. + * - ``threads`` + - ``1`` + - Number of thread workers for ``executor: thread``. In ``executor: hybrid``, this is accepted as an alias for ``threads_per_process``. + * - ``threads_per_process`` + - ``1`` + - Number of thread workers inside each process for ``executor: hybrid``. + * - ``processes`` + - ``1`` + - Number of worker processes for ``executor: process`` or ``executor: hybrid``. + * - ``max_workers`` + - ``1`` + - Backwards-compatible worker-count alias. For ``thread`` it maps to ``threads``; for pure ``process`` it maps to ``processes``. + * - ``workers`` + - ``1`` + - Alias for ``max_workers``. + * - ``instances`` + - none + - Per-configured-instance overrides. The nested keys must match the instance names listed below. + +``thread`` mode starts ``threads`` independent workers in the service process. ``process`` mode starts +``processes`` worker processes with one worker each. ``hybrid`` mode starts ``processes`` processes with +``threads_per_process`` worker threads inside each process. + +If ``executor`` is omitted, HAMSTRING infers it from the worker-count keys: + +* ``threads`` only: ``thread`` +* ``processes`` only: ``process`` +* ``processes`` and ``threads`` or ``threads_per_process``: ``hybrid`` + +For example, this starts two processes with four Kafka-consuming workers in each process: + +.. code-block:: yaml + + pipeline: + scaling: + modules: + data_analysis.detector: + executor: hybrid + processes: 2 + threads_per_process: 4 -``executor`` may be ``thread`` or ``process``. Worker counts can be configured with ``max_workers`` or -``workers``. The aliases ``threads`` and ``processes`` also set the worker count and infer the executor -type when ``executor`` is omitted. +This is equivalent, because ``threads`` is an alias for ``threads_per_process`` in hybrid mode: + +.. code-block:: yaml + + pipeline: + scaling: + modules: + data_analysis.detector: + processes: 2 + threads: 4 + +Per-instance overrides are useful when one configured stage is more expensive than another. This example +uses hybrid mode for all log collectors, but gives the ``dga_collector`` fewer workers and the +``domainator_collector`` pure process workers: + +.. code-block:: yaml + + pipeline: + scaling: + modules: + log_collection.collector: + executor: hybrid + processes: 2 + threads_per_process: 4 + instances: + dga_collector: + processes: 1 + threads_per_process: 2 + domainator_collector: + executor: process + processes: 3 + +The effective number of Kafka consumers for one configured pipeline instance is: + +.. code-block:: text + + Docker service replicas * processes * threads_per_process + +For ``thread`` mode, ``processes`` is ``1``. For pure ``process`` mode, ``threads_per_process`` is ``1``. +The consumed Kafka topic needs at least that many partitions to keep every worker busy. HAMSTRING requests +at least the local worker count when creating or expanding topics; set ``NUMBER_OF_INSTANCES`` on the +service when Docker Compose replicas are used so topic creation can account for the replica count as well. + +.. list-table:: Module and instance keys + :header-rows: 1 + :widths: 30 30 40 + + * - Module key + - Instance keys + - Example + * - ``log_storage.logserver`` + - Full consumed input topic name. Without an instance override, the module setting applies to every logserver protocol topic. + - ``pipeline-logserver_in-dns`` when the ``logserver_in`` topic prefix is ``pipeline-logserver_in`` and the protocol is ``dns``. + * - ``log_collection.collector`` + - ``pipeline.log_collection.collectors[].name`` + - ``dga_collector``, ``domainator_collector`` + * - ``log_filtering.prefilter`` + - ``pipeline.log_filtering[].name`` + - ``dga_filter``, ``domainator_filter`` + * - ``data_inspection.inspector`` + - ``pipeline.data_inspection[].name`` + - ``dga_inspector``, ``domainator_inspector`` + * - ``data_analysis.detector`` + - ``pipeline.data_analysis[].name`` + - ``RF-dga_detector``, ``domainator`` + * - ``pipeline.alerter`` + - ``generic`` and ``pipeline.alerting.plugins[].name`` + - ``generic``, ``attributor`` + * - ``monitoring.agent`` + - No per-instance key by default. + - Configure the module key directly. + +Docker Compose service replicas are configured separately from ``pipeline.scaling``. Compose replicas add +more containers; ``pipeline.scaling`` adds more workers inside each container. Both forms of scaling use the +same Kafka consumer group for the same stage/topic. + +For local Docker Compose runs, scale services with ``docker compose up --scale``: + +.. code-block:: console + + $ HOST_IP=127.0.0.1 docker compose -f docker/docker-compose.yml --profile prod up --scale logcollector=3 --scale detector=2 + +For the development profile, use the ``-dev`` service names from ``docker/docker-compose.yml``: + +.. code-block:: console + + $ HOST_IP=127.0.0.1 docker compose -f docker/docker-compose.yml --profile dev up --scale logcollector-dev=3 --scale detector-dev=2 + +The compose fragments under ``docker/docker-compose/dev`` and ``docker/docker-compose/prod`` also contain +``deploy.replicas`` fields. Those fields document the intended replica count and are used by orchestrators +that honor Compose ``deploy`` settings. For portable local Compose usage, prefer the explicit ``--scale`` +flag and keep ``NUMBER_OF_INSTANCES`` aligned with the replica count: + +.. code-block:: yaml + + services: + detector: + environment: + - GROUP_ID=data_analysis + - NUMBER_OF_INSTANCES=2 + +.. code-block:: console + + $ HOST_IP=127.0.0.1 docker compose -f docker/docker-compose.yml --profile prod up --scale detector=2 + +With this example and the hybrid detector config shown above, the detector starts +``2 Docker replicas * 2 processes * 4 threads_per_process = 16`` Kafka consumers. ``pipeline.log_storage`` ^^^^^^^^^^^^^^^^^^^^^^^^ @@ -171,11 +342,11 @@ type when ``executor`` is omitted. * - Parameter - Description * - name - - A unique name amongst the ``collectors``configurations top identify the collector instance. + - A unique name amongst the ``collectors`` configurations top identify the collector instance. * - protocol_base - The lowercase protocol name to ingest data from. Currently supported: ``dns`` and ``http``. * - required_log_information - - Defines the expected format for incoming log lines. See the :ref:`Logline format configuration` section for more + - Defines the expected format for incoming log lines. See the :doc:`configuration` page for more details. Each log_collector has a BatchHandler instance. Default confgurations for all Batch handlers are defined in ``pipeline.log_collection.default_batch_handler_config``. @@ -230,6 +401,7 @@ The following list shows the available configuration options. .. list-table:: ``inspector`` Parameters :header-rows: 1 :widths: 30 70 + * - Parameter - Description * - name @@ -364,6 +536,21 @@ The following parameters control the infrastructure of the software. - Not given here - Kafka topic name prefixes given as strings. These prefix name are used to construct the actual topic names based on the instance name (e.g. a collector instance name) that produces for the given stage. (e.g. a prefilter instance name is added as suffix to the prefilter_to_inspector prefix for the inspector to know where to consume.) + * - kafka_consumer.max_poll_interval_ms + - ``1800000`` + - Maximum time in milliseconds between Kafka consumer polls before Kafka removes the consumer from its group. Increase this for long-running detector batches. + * - kafka_topics.replication_factor + - ``3`` + - Replication factor used when creating new Kafka topics. At runtime this is capped to the number of configured Kafka brokers. + * - kafka_topics.auto_expand_partitions + - ``true`` + - If enabled, existing HAMSTRING topics with fewer than the desired partition count are automatically expanded on consumer startup. Kafka does not support shrinking partition counts, so topics that are already larger are left unchanged. + * - kafka_topics.stages + - See ``config.yaml`` + - Per-pipeline-stage topic settings. Keys match ``environment.kafka_topics_prefix.pipeline`` keys. Each stage can set ``partitions`` and ``replication_factor`` for topics whose names use that stage prefix. + * - kafka_topics.topics + - See ``config.yaml`` + - Exact per-topic settings for topics that are not represented by a pipeline prefix, for example external alert topics. Topics without a stage or exact entry use 12 partitions and the default replication factor. * - monitoring.clickhouse_server.hostname - ``clickhouse-server`` - Hostname of the ClickHouse server. Used by Grafana. diff --git a/docs/index.rst b/docs/index.rst index 242f0d30..91d73fd2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,5 +1,5 @@ Welcome to HAMSTRING's documentation! -=================================== +====================================== **HAMSTRING** is a machine learning-based DNS classifier for detecting Domain Generation Algorithms (DGAs), tunneling, and data exfiltration by malicious actors. @@ -26,4 +26,5 @@ Contents training developer_guide api/index + sources references diff --git a/docs/monitoring.rst b/docs/monitoring.rst index 83eac713..356e7aa9 100644 --- a/docs/monitoring.rst +++ b/docs/monitoring.rst @@ -11,7 +11,7 @@ Overview The software includes a monitoring functionality that stores relevant information in a database (`ClickHouse`). The collected data is then visualized using multiple `Grafana` dashboard views. -.. image:: media/monitoring_pipeline.png +.. image:: ../assets/readme_assets/overview.png Setup diff --git a/docs/pipeline.rst b/docs/pipeline.rst index e7bac388..66667d46 100644 --- a/docs/pipeline.rst +++ b/docs/pipeline.rst @@ -29,10 +29,10 @@ Main Classes ------------ .. py:currentmodule:: src.zeek.zeek_config_handler -.. autoclass:: ZeekConfigurationHandler +.. py:class:: ZeekConfigurationHandler .. py:currentmodule:: src.zeek.zeek_analysis_handler -.. autoclass:: ZeekAnalysisHandler +.. py:class:: ZeekAnalysisHandler Usage and configuration ----------------------- @@ -59,6 +59,8 @@ Necessary attributes are: Stage 2: Log Storage ==================== +.. _stage-log-storage: + This stage serves as the central ingestion point for all data. Overview @@ -93,7 +95,7 @@ batches based on subnet IDs, and forwards them to the next pipeline stage for fu Core Functionality ------------------ -The `Log Collection` stage is responsible for retrieving loglines from the :ref:`Log Storage`, +The `Log Collection` stage is responsible for retrieving loglines from the :ref:`Log Storage`, parsing their information fields, and validating the data. Each field is checked to ensure it is of the correct type and format. This stage ensures that all data is accurate, reducing the need for further verification in subsequent stages. @@ -180,7 +182,7 @@ allowing for multiprocessing and threading. As the log information differs for each protocol, there is a default format per protocol. This can be either adapted or a completely new one can be added as well. For more information - please reffer to section :ref:`Logline format configuration`. + please reffer to the :doc:`configuration` page. .. code-block:: @@ -490,7 +492,7 @@ We currently support the following relevance methods: +---------------------------+-------------------------------------------------------------+ | **Name** | **Description** | +===========================+=============================================================+ - | ``no_relevance_check `` | Skip the relevance check of the prefilters entirely. | + | ``no_relevance_check`` | Skip the relevance check of the prefilters entirely. | +---------------------------+-------------------------------------------------------------+ | ``check_dga_relevance`` | Function to filter requests based on LisItems in the | | | logcollector configuration. Using the fourth item in the | @@ -722,6 +724,12 @@ Main Classes .. py:currentmodule:: src.detector.plugins.dga_detector .. autoclass:: DGADetector +.. py:currentmodule:: src.detector.plugins.domainator_detector +.. autoclass:: DomainatorDetector + +.. py:currentmodule:: src.detector.plugins.domainator_attributor +.. autoclass:: DomainatorAttributor + The :class:`DetectorBase` is the primary class for Detectors. It holds common functionalities and is responsible for data ingesting, triggering alerts, logging, etc.. Any Detector is build on top of this class and needs to implement the methods specified by :class:`DetectorAbstractBase`. The class implementations need to go into ``"/src/detector/plugins"`` @@ -765,6 +773,40 @@ Detector instances can be chained by setting ``next_detectors`` on the upstream For intermediary detectors that should only forward suspicious output to another detector, set ``send_to_alerter: false`` or ``produce_topics: []``. +Domainator Detector Models +-------------------------- + +The Domainator detection and attribution pipeline is divided into four model variants. They can be used as a multi-stage chain, where the binary detector forwards suspicious windows to attribution models, or as individual detectors when a deployment already has an upstream selection mechanism. The models use the same feature family and training data, but differ in their ground-truth labels and therefore in the task they solve: detection, tool identification, behavior analysis, or combined identification and behavior analysis. + +The current Domainator model family was trained with a combination of real malware samples from Petrov et al. :cite:p:`petrov_domainator_2025`, tunneling-tool datasets from Chen et al. :cite:p:`chen_dns_lstm_2021` and Gao et al. :cite:p:`gao_graphtunnel_2024`, and benign DNS traffic from Žiža et al. :cite:p:`ziza_dns_exfiltration_2023`. The feature extraction procedure follows Petrov et al. :cite:p:`petrov_domainator_2025`: HAMSTRING calculates features over the subdomain portions of DNS requests sent by the client or victim side of the conversation. + +The shared feature set contains: + +- Levenshtein distance +- Jaro similarity metric +- Jaro-Winkler similarity metric +- Jaro similarity metric on reversed strings +- Jaro-Winkler similarity metric on reversed strings +- Longest common string +- Longest common substring + +The pure detection model, ``1779955108_SPRING-detector``, classifies traffic as legitimate or malicious. The identification model, ``1779955108_SPRING-attributor-identification``, distinguishes between individual DNS tunneling tools or malware families. Unknown tools can therefore be misclassified as one of the known tools or labelled as legitimate. The combined identification and behavior model, ``1779955108_SPRING-attributor-identification-behaviour``, predicts both the tool and observed behavior, such as download, upload, or idle activity. Since this requires behavior-specific training labels, it supports a subset of the identification classes. The generalized behavior model, ``1779955108_SPRING-attributor-behaviour``, omits tool names from the labels and predicts only the behavior, which can help classify behavior from tools not present in the training data. + +.. list-table:: Domainator model labels + :header-rows: 1 + :widths: 35 65 + + * - Model + - Classes + * - ``1779955108_SPRING-detector`` + - ``legitimate``, ``malicious`` + * - ``1779955108_SPRING-attributor-behaviour`` + - ``legitimate``, ``download``, ``idle``, ``upload`` + * - ``1779955108_SPRING-attributor-identification`` + - ``legitimate``, ``cobaltstrike``, ``det``, ``dns2tcp``, ``dnscat``, ``dnsexfiltrator``, ``dnspot``, ``dnsshell``, ``iodine``, ``ozymandns``, ``roguerobin-net``, ``roguerobin-ps``, ``saitama``, ``symbiote``, ``symbiote-dnscat``, ``tcpoverdns``, ``tuns`` + * - ``1779955108_SPRING-attributor-identification-behaviour`` + - ``legitimate_legitimate``, ``cobaltstrike_download``, ``cobaltstrike_upload``, ``det_upload``, ``dns2tcp_download``, ``dns2tcp_upload``, ``dnscat_download``, ``dnscat_idle``, ``dnscat_upload``, ``dnsexfiltrator_upload``, ``dnsshell_download``, ``dnsshell_upload``, ``iodine_download``, ``iodine_idle``, ``iodine_upload``, ``ozymandns_download``, ``ozymandns_upload``, ``roguerobin-net_download``, ``roguerobin-net_idle``, ``roguerobin-net_upload``, ``roguerobin-ps_download``, ``roguerobin-ps_idle``, ``roguerobin-ps_upload``, ``saitama_download``, ``saitama_idle``, ``saitama_upload``, ``symbiote-dnscat_download``, ``symbiote-dnscat_idle``, ``symbiote-dnscat_upload``, ``symbiote_download``, ``symbiote_upload`` + Stage 7: Alerter ================ @@ -796,7 +838,7 @@ Main Classes .. py:currentmodule:: src.alerter.alerter .. autoclass:: AlerterBase -.. py:currentmodule:: src.alerter.alerter +.. py:currentmodule:: src.alerter.plugins.generic_alerter .. autoclass:: GenericAlerter The :class:`AlerterBase` provides the foundation for all alerter instances, handling the base logging and Kafka forwarding logic. Custom plugins should not necessarily inherit from it but are loaded dynamically by the framework. @@ -850,3 +892,7 @@ The :class:`DomainatorDetector` consumes anomalous batches of requests. It identifies potential data exfiltration and command & control on the subdomain level by analyzing characteristics of the subdomains. Messages are grouped by domain into fixed-size windows to allow for sequential anomaly detection. The detector leverages machine learning based on statistical and linguistic features from the domain name including label lengths, character frequencies, entropy measures, and counts of different character types across domain name levels. + +Domainator Attributor +..................... +The :class:`DomainatorAttributor` uses the same subdomain-window feature extraction as the :class:`DomainatorDetector`, but its labels describe the likely tool, malware family, behavior, or tool-behavior combination. It can run downstream of ``DomainatorDetector`` through detector chaining, or independently when its input already contains DNS windows that should be attributed. diff --git a/docs/refs.bib b/docs/refs.bib index 2c72fea3..c1f26a69 100644 --- a/docs/refs.bib +++ b/docs/refs.bib @@ -12,3 +12,47 @@ @inproceedings{schuppen_fanci_2018 year = {2018}, pages = {1165--1181}, } + +@inproceedings{petrov_domainator_2025, + address = {Cham}, + title = {Domainator: Detecting and Identifying DNS-Tunneling Malware Using Metadata Sequences}, + booktitle = {Availability, Reliability and Security}, + series = {Lecture Notes in Computer Science}, + publisher = {Springer Nature Switzerland}, + author = {Petrov, Denis and Ruffing, Pascal and Zillien, Sebastian and Wendzel, Steffen}, + year = {2025}, + pages = {118--140}, + doi = {10.1007/978-3-032-00624-0_6}, + url = {https://doi.org/10.1007/978-3-032-00624-0_6}, +} + +@article{chen_dns_lstm_2021, + title = {DNS covert channel detection method using the LSTM model}, + volume = {104}, + journal = {Computers \& Security}, + author = {Chen, Shaojie and Lang, Bo and Liu, Hongyu and Li, Duokun and Gao, Chuan}, + year = {2021}, + pages = {102095}, + doi = {10.1016/j.cose.2020.102095}, + url = {https://doi.org/10.1016/j.cose.2020.102095}, +} + +@article{gao_graphtunnel_2024, + title = {GraphTunnel: Robust DNS Tunnel Detection Based on DNS Recursive Resolution Graph}, + journal = {IEEE Transactions on Information Forensics and Security}, + author = {Gao, Guangyuan and Niu, Weina and Gong, Jiacheng and Gu, Dujuan and Li, Song and Zhang, Mingxue and Zhang, Xiaosong}, + year = {2024}, + doi = {10.1109/TIFS.2024.3443596}, + url = {https://doi.org/10.1109/TIFS.2024.3443596}, +} + +@article{ziza_dns_exfiltration_2023, + title = {DNS exfiltration detection in the presence of adversarial attacks and modified exfiltrator behaviour}, + volume = {22}, + journal = {International Journal of Information Security}, + author = {Žiža, Kristijan and Tadić, Predrag and Vuletić, Pavle}, + year = {2023}, + pages = {1865--1880}, + doi = {10.1007/s10207-023-00723-w}, + url = {https://doi.org/10.1007/s10207-023-00723-w}, +} diff --git a/docs/sources.rst b/docs/sources.rst new file mode 100644 index 00000000..5790e37c --- /dev/null +++ b/docs/sources.rst @@ -0,0 +1,36 @@ +Sources and Attribution +~~~~~~~~~~~~~~~~~~~~~~~ + +This page tracks scientific sources, dataset origins, and model attribution notes for HAMSTRING detector models. + +Domainator Model Sources +======================== + +The Domainator detector and attributor models use the same subdomain-level feature family and differ by their label space. The current model family combines malware, tunneling-tool, and benign DNS traffic sources: + +.. list-table:: Domainator training sources + :header-rows: 1 + :widths: 25 45 30 + + * - Source + - Contribution + - Citation + * - Domainator malware samples + - Real DNS-tunneling malware samples and the feature processing procedure used for subdomain sequence metadata. + - Petrov et al. :cite:p:`petrov_domainator_2025` + * - LSTM DNS covert-channel dataset + - DNS tunneling/covert-channel tool traffic used as malicious tunneling examples. + - Chen et al. :cite:p:`chen_dns_lstm_2021` + * - GraphTunnel dataset + - DNS tunneling samples used to broaden tunneling-tool coverage. + - Gao et al. :cite:p:`gao_graphtunnel_2024` + * - Benign DNS traffic + - Real DNS traffic used as legitimate traffic for training and evaluation. + - Žiža et al. :cite:p:`ziza_dns_exfiltration_2023` + +Attribution Notes +================= + +- Keep new detector or attributor model descriptions in :ref:`detection_stage`. +- Add scientific publications to ``docs/refs.bib`` and cite them from the relevant model documentation. +- Record dataset provenance here whenever a model release changes training sources, label definitions, or intended attribution semantics. diff --git a/docs/usage.rst b/docs/usage.rst index 4c7968ce..a9c10dde 100644 --- a/docs/usage.rst +++ b/docs/usage.rst @@ -28,6 +28,58 @@ If you want to run containers individually, use: Make sure you set the environment variable ``HOST_IP`` to your host's IP address, so that the services can communicate with each other. +Scaling With Docker Compose +--------------------------- + +HAMSTRING has two scaling axes: + +* Docker Compose replicas start more containers for a service. +* ``pipeline.scaling`` in ``config.yaml`` starts more workers inside each service container. + +Use Docker Compose replicas when you want horizontal service scaling across containers. For the production +profile, scale the production service names: + +.. code-block:: console + + $ HOST_IP=127.0.0.1 docker compose -f docker/docker-compose.yml --profile prod up --scale logcollector=3 --scale detector=2 + +For the development profile, scale the ``-dev`` service names: + +.. code-block:: console + + $ HOST_IP=127.0.0.1 docker compose -f docker/docker-compose.yml --profile dev up --scale logcollector-dev=3 --scale detector-dev=2 + +When using Compose replicas, set ``NUMBER_OF_INSTANCES`` for the scaled service to the same replica count so +Kafka topic creation can request enough partitions for the whole consumer group: + +.. code-block:: yaml + + services: + detector: + environment: + - GROUP_ID=data_analysis + - NUMBER_OF_INSTANCES=2 + +The compose fragments also contain ``deploy.replicas`` fields. Use them for orchestrators that honor Compose +``deploy`` settings; for local ``docker compose up`` runs, the explicit ``--scale`` flag is the clearest option. + +For worker scaling inside a container, configure ``pipeline.scaling``. For example, this starts two detector +processes with four worker threads each in every detector container: + +.. code-block:: yaml + + pipeline: + scaling: + modules: + data_analysis.detector: + executor: hybrid + processes: 2 + threads_per_process: 4 + +With ``--scale detector=2``, that configuration creates ``2 Docker replicas * 2 processes * 4 threads``: +16 Kafka consumers for the detector stage. See :ref:`configuration` for the full scaling option reference and +per-instance override examples. + Installation ------------ diff --git a/requirements/requirements.detector.txt b/requirements/requirements.detector.txt index 5ca06bee..14457ab6 100644 --- a/requirements/requirements.detector.txt +++ b/requirements/requirements.detector.txt @@ -1,5 +1,6 @@ xgboost scikit-learn~=1.5.2 +pandas requests colorlog~=6.8.2 PyYAML~=6.0.1 diff --git a/src/alerter/alerter.py b/src/alerter/alerter.py index 8092ad4a..a04f5fca 100644 --- a/src/alerter/alerter.py +++ b/src/alerter/alerter.py @@ -2,17 +2,26 @@ import os import sys import asyncio +import datetime +import uuid from abc import ABC, abstractmethod import importlib +from pathlib import Path sys.path.append(os.getcwd()) -from confluent_kafka.admin import AdminClient, NewTopic +from confluent_kafka.admin import AdminClient +from src.base.clickhouse_kafka_sender import ClickHouseKafkaSender from src.base.utils import setup_config, ensure_directory -from src.base.execution import create_pipeline_executor +from src.base.execution import ( + create_pipeline_executor, + run_thread_worker_pool, + start_pipeline_worker_replicas, +) from src.base.kafka_handler import ( ExactlyOnceKafkaConsumeHandler, ExactlyOnceKafkaProduceHandler, KafkaMessageFetchException, + ensure_topics, ) from src.base.log_config import get_logger @@ -62,12 +71,21 @@ def __init__(self, alerter_config, consume_topic) -> None: self.key = None self.kafka_consume_handler = ExactlyOnceKafkaConsumeHandler(self.consume_topic) + self.server_log_terminal_events = ClickHouseKafkaSender( + "server_log_terminal_events" + ) # Base actions config self.log_to_file = ALERTING_CONFIG.get("log_to_file", False) self.log_file_path = ALERTING_CONFIG.get( "log_file_path", "/opt/logs/alerts.txt" ) + self.log_rotation_config = ALERTING_CONFIG.get("log_rotation", {}) + self.log_rotation_enabled = self.log_rotation_config.get("enabled", False) + self.log_retention_days = self._parse_log_retention_days( + self.log_rotation_config.get("retention_days", 7) + ) + self._last_log_cleanup_date = None self.log_to_kafka = ALERTING_CONFIG.get("log_to_kafka", False) self.external_kafka_topic = ALERTING_CONFIG.get( "external_kafka_topic", "external_alerts_topic" @@ -94,9 +112,8 @@ def _setup_kafka_output_topics(self): ] ) admin_client = AdminClient({"bootstrap.servers": brokers}) - # Attempt to create topic (will do nothing if it already exists) try: - admin_client.create_topics([NewTopic(self.external_kafka_topic, 1, 1)]) + ensure_topics(admin_client, [self.external_kafka_topic]) except Exception as e: logger.warning( f"Could not auto-create topic {self.external_kafka_topic}: {e}" @@ -104,6 +121,78 @@ def _setup_kafka_output_topics(self): self.kafka_produce_handler = ExactlyOnceKafkaProduceHandler() + @staticmethod + def _parse_log_retention_days(retention_days) -> int | None: + """ + Parse the configured rotated log retention period. + """ + if retention_days is None: + return None + try: + retention_days = int(retention_days) + except (TypeError, ValueError): + logger.warning( + "Invalid alert log retention_days '%s'. Keeping rotated logs for 7 days.", + retention_days, + ) + return 7 + if retention_days < 1: + logger.warning( + "Invalid alert log retention_days '%s'. Keeping rotated logs for 1 day.", + retention_days, + ) + return 1 + return retention_days + + def _get_active_log_file_path( + self, timestamp: datetime.datetime | None = None + ) -> str: + if not self.log_rotation_enabled: + return self.log_file_path + + timestamp = timestamp or datetime.datetime.now() + log_path = Path(self.log_file_path) + rotated_name = f"{log_path.stem}-{timestamp:%Y-%m-%d}{log_path.suffix}" + return str(log_path.with_name(rotated_name)) + + def _cleanup_rotated_logs(self, today: datetime.date | None = None) -> None: + if not self.log_rotation_enabled or self.log_retention_days is None: + return + + today = today or datetime.date.today() + if self._last_log_cleanup_date == today: + return + + log_path = Path(self.log_file_path) + cutoff_date = today - datetime.timedelta(days=self.log_retention_days - 1) + for candidate in log_path.parent.glob(f"{log_path.stem}-*{log_path.suffix}"): + log_date = self._extract_rotated_log_date(candidate) + if log_date is None or log_date >= cutoff_date: + continue + try: + candidate.unlink() + logger.info("%s: Removed expired alert log %s", self.name, candidate) + except OSError as e: + logger.warning( + "%s: Could not remove expired alert log %s: %s", + self.name, + candidate, + e, + ) + + self._last_log_cleanup_date = today + + def _extract_rotated_log_date(self, log_path: Path) -> datetime.date | None: + stem_prefix = f"{Path(self.log_file_path).stem}-" + if not log_path.stem.startswith(stem_prefix): + return None + + date_value = log_path.stem[len(stem_prefix) :] + try: + return datetime.datetime.strptime(date_value, "%Y-%m-%d").date() + except ValueError: + return None + def get_and_fill_data(self) -> None: if self.alert_data: logger.warning( @@ -130,9 +219,13 @@ def _log_to_file_action(self): if not self.log_to_file: return - logger.info(f"{self.name}: Logging alert to file {self.log_file_path}") + active_log_file_path = self._get_active_log_file_path() + ensure_directory(active_log_file_path) + self._cleanup_rotated_logs() + + logger.info(f"{self.name}: Logging alert to file {active_log_file_path}") try: - with open(self.log_file_path, "a+") as f: + with open(active_log_file_path, "a+") as f: json.dump(self.alert_data, f) f.write("\n") except IOError as e: @@ -159,6 +252,57 @@ def _log_to_kafka_action(self): logger.error(f"{self.name}: Error forwarding alert: {e}") raise + def _extract_server_message_ids(self) -> set[uuid.UUID]: + server_message_ids = set() + + def visit(value): + if isinstance(value, dict): + if value.get("server_message_id"): + self._add_server_message_id( + server_message_ids, value["server_message_id"] + ) + if isinstance(value.get("server_message_ids"), list): + for server_message_id in value["server_message_ids"]: + self._add_server_message_id( + server_message_ids, server_message_id + ) + for nested_value in value.values(): + visit(nested_value) + elif isinstance(value, list): + for item in value: + visit(item) + + visit(self.alert_data) + return server_message_ids + + @staticmethod + def _add_server_message_id( + server_message_ids: set[uuid.UUID], + server_message_id, + ) -> None: + try: + server_message_ids.add(uuid.UUID(str(server_message_id))) + except (TypeError, ValueError): + logger.warning( + "Ignoring non-UUID LogServer message id '%s'.", server_message_id + ) + + def _record_alerter_terminal_events( + self, server_message_ids: set[uuid.UUID] + ) -> None: + if not server_message_ids: + return + timestamp = datetime.datetime.now() + for server_message_id in server_message_ids: + self.server_log_terminal_events.insert( + dict( + message_id=server_message_id, + stage=module_name, + status="processed", + timestamp=timestamp, + ) + ) + def bootstrap_alerter_instance(self): """ Main loop for the alerter instance. @@ -169,11 +313,14 @@ def bootstrap_alerter_instance(self): try: self.get_and_fill_data() if self.alert_data: + server_message_ids = self._extract_server_message_ids() # 1. Process specific action self.process_alert() # 2. Executing Base Logging Actions self._log_to_file_action() self._log_to_kafka_action() + self._record_alerter_terminal_events(server_message_ids) + self.kafka_consume_handler.commit() except KafkaMessageFetchException as e: logger.debug(e) @@ -199,36 +346,106 @@ async def start(self): executor.shutdown(wait=False, cancel_futures=True) +def build_alerter_worker(alerter_config, consume_topic, worker_id=None): + class_name = alerter_config.get("alerter_class_name", "GenericAlerter") + alerter_module_name = alerter_config.get( + "alerter_module_name", "generic_alerter" + ) + plugin_module_name = f"{PLUGIN_PATH}.{alerter_module_name}" + plugin_module = importlib.import_module(plugin_module_name) + alerter_class = getattr(plugin_module, class_name) + worker = alerter_class(alerter_config=alerter_config, consume_topic=consume_topic) + worker.worker_id = worker_id + return worker + + +def run_alerter_worker_process( + process_index, + threads_per_process, + alerter_config, + consume_topic, +): + def worker_factory(worker_id): + return build_alerter_worker( + alerter_config=alerter_config, + consume_topic=consume_topic, + worker_id=worker_id, + ) + + run_thread_worker_pool( + worker_factory=worker_factory, + target_name="bootstrap_alerter_instance", + module_name=module_name, + instance_name=alerter_config.get("name", "generic"), + process_index=process_index, + threads_per_process=threads_per_process, + ) + + async def main(): tasks = [] # Setup Generic Alerter Task generic_topic = f"{CONSUME_TOPIC_PREFIX}-generic" logger.info("Initializing Generic Alerter") - class_name = "GenericAlerter" - mod_name = f"{PLUGIN_PATH}.generic_alerter" - module = importlib.import_module(mod_name) - AlerterClass = getattr(module, class_name) - generic_alerter = AlerterClass( - alerter_config={"name": "generic"}, consume_topic=generic_topic + generic_config = {"name": "generic"} + + def generic_worker_factory( + worker_id, + generic_config=generic_config, + generic_topic=generic_topic, + ): + return build_alerter_worker( + alerter_config=generic_config, + consume_topic=generic_topic, + worker_id=worker_id, + ) + + tasks.append( + asyncio.create_task( + start_pipeline_worker_replicas( + config=config, + module_name=module_name, + instance_name="generic", + worker_factory=generic_worker_factory, + target_name="bootstrap_alerter_instance", + process_entrypoint=run_alerter_worker_process, + process_args=(generic_config, generic_topic), + ) + ) ) - tasks.append(asyncio.create_task(generic_alerter.start())) # Setup Specific Custom Alerter Tasks if ALTERTERS: for alerter_config in ALTERTERS: logger.info(f"Initializing Custom Alerter: {alerter_config['name']}") consume_topic = f"{CONSUME_TOPIC_PREFIX}-{alerter_config['name']}" - class_name = alerter_config["alerter_class_name"] - mod_name = f"{PLUGIN_PATH}.{alerter_config['alerter_module_name']}" - module = importlib.import_module(mod_name) - AlerterClass = getattr(module, class_name) - alerter_instance = AlerterClass( - alerter_config=alerter_config, consume_topic=consume_topic + def worker_factory( + worker_id, + alerter_config=alerter_config, + consume_topic=consume_topic, + ): + return build_alerter_worker( + alerter_config=alerter_config, + consume_topic=consume_topic, + worker_id=worker_id, + ) + + tasks.append( + asyncio.create_task( + start_pipeline_worker_replicas( + config=config, + module_name=module_name, + instance_name=alerter_config["name"], + worker_factory=worker_factory, + target_name="bootstrap_alerter_instance", + process_entrypoint=run_alerter_worker_process, + process_args=(alerter_config, consume_topic), + ) + ) ) - tasks.append(asyncio.create_task(alerter_instance.start())) await asyncio.gather(*tasks) diff --git a/src/base/data_classes/clickhouse_connectors.py b/src/base/data_classes/clickhouse_connectors.py index d1fa831a..6c91ed6b 100644 --- a/src/base/data_classes/clickhouse_connectors.py +++ b/src/base/data_classes/clickhouse_connectors.py @@ -34,6 +34,30 @@ class ServerLogsTimestamps: ) +@dataclass +class ServerLogToLogline: + message_id: uuid.UUID = field( + metadata={"marshmallow_field": marshmallow.fields.UUID()} + ) + logline_id: uuid.UUID = field( + metadata={"marshmallow_field": marshmallow.fields.UUID()} + ) + + +@dataclass +class ServerLogTerminalEvents: + message_id: uuid.UUID = field( + metadata={"marshmallow_field": marshmallow.fields.UUID()} + ) + stage: str = field(metadata={"marshmallow_field": marshmallow.fields.String()}) + status: str = field(metadata={"marshmallow_field": marshmallow.fields.String()}) + timestamp: datetime.datetime = field( + metadata={ + "marshmallow_field": marshmallow.fields.DateTime("%Y-%m-%d %H:%M:%S.%f") + } + ) + + @dataclass class FailedLoglines: message_text: str = field( @@ -215,6 +239,8 @@ class FillLevels: TABLE_NAME_TO_TYPE = { "server_logs": ServerLogs, "server_logs_timestamps": ServerLogsTimestamps, + "server_log_to_logline": ServerLogToLogline, + "server_log_terminal_events": ServerLogTerminalEvents, "failed_loglines": FailedLoglines, "logline_to_batches": LoglineToBatches, "loglines": Loglines, diff --git a/src/base/execution.py b/src/base/execution.py index d79f182b..317012b0 100644 --- a/src/base/execution.py +++ b/src/base/execution.py @@ -1,14 +1,31 @@ from __future__ import annotations -from concurrent.futures import Executor, ProcessPoolExecutor, ThreadPoolExecutor +import asyncio +import os +from concurrent.futures import ( + FIRST_EXCEPTION, + Executor, + ProcessPoolExecutor, + ThreadPoolExecutor, + wait, +) from dataclasses import dataclass -from typing import Any +from typing import Any, Callable @dataclass(frozen=True) class PipelineExecutorConfig: executor: str = "thread" - max_workers: int = 1 + processes: int = 1 + threads_per_process: int = 1 + + @property + def max_workers(self) -> int: + return self.total_workers + + @property + def total_workers(self) -> int: + return self.processes * self.threads_per_process def get_pipeline_executor_config( @@ -42,16 +59,135 @@ def create_pipeline_executor( module_name=module_name, instance_name=instance_name, ) - if executor_config.executor == "process": - return ProcessPoolExecutor(max_workers=executor_config.max_workers) + if executor_config.executor in {"process", "hybrid"}: + return ProcessPoolExecutor(max_workers=executor_config.processes) prefix = _thread_name_prefix(module_name, instance_name) return ThreadPoolExecutor( - max_workers=executor_config.max_workers, + max_workers=executor_config.threads_per_process, thread_name_prefix=prefix, ) +async def start_pipeline_worker_replicas( + config: dict[str, Any], + module_name: str, + instance_name: str | None, + worker_factory: Callable[[str], Any], + target_name: str, + process_entrypoint: Callable[..., None] | None = None, + process_args: tuple[Any, ...] = (), +) -> None: + executor_config = get_pipeline_executor_config( + config=config, + module_name=module_name, + instance_name=instance_name, + ) + _set_topic_min_partitions(executor_config.total_workers) + + if executor_config.executor == "thread": + await _start_thread_workers( + worker_factory=worker_factory, + target_name=target_name, + module_name=module_name, + instance_name=instance_name, + process_index=0, + threads_per_process=executor_config.threads_per_process, + ) + return + + if process_entrypoint is None: + raise ValueError("process_entrypoint is required for process and hybrid scaling") + + loop = asyncio.get_running_loop() + executor = ProcessPoolExecutor(max_workers=executor_config.processes) + try: + futures = [ + loop.run_in_executor( + executor, + process_entrypoint, + process_index, + executor_config.threads_per_process, + *process_args, + ) + for process_index in range(executor_config.processes) + ] + await asyncio.gather(*futures) + finally: + executor.shutdown(wait=False, cancel_futures=True) + + +def run_thread_worker_pool( + worker_factory: Callable[[str], Any], + target_name: str, + module_name: str, + instance_name: str | None, + process_index: int, + threads_per_process: int, +) -> None: + executor = ThreadPoolExecutor( + max_workers=threads_per_process, + thread_name_prefix=_thread_name_prefix(module_name, instance_name), + ) + futures = [] + try: + for thread_index in range(threads_per_process): + worker_id = _worker_id(process_index, thread_index) + worker = worker_factory(worker_id) + futures.append(executor.submit(getattr(worker, target_name))) + + done, _ = wait(futures, return_when=FIRST_EXCEPTION) + for future in done: + future.result() + finally: + executor.shutdown(wait=False, cancel_futures=True) + + +async def _start_thread_workers( + worker_factory: Callable[[str], Any], + target_name: str, + module_name: str, + instance_name: str | None, + process_index: int, + threads_per_process: int, +) -> None: + loop = asyncio.get_running_loop() + executor = ThreadPoolExecutor( + max_workers=threads_per_process, + thread_name_prefix=_thread_name_prefix(module_name, instance_name), + ) + try: + futures = [] + for thread_index in range(threads_per_process): + worker_id = _worker_id(process_index, thread_index) + worker = worker_factory(worker_id) + futures.append(loop.run_in_executor(executor, getattr(worker, target_name))) + + await asyncio.gather(*futures) + finally: + executor.shutdown(wait=False, cancel_futures=True) + + +def _worker_id(process_index: int, thread_index: int) -> str: + return f"p{process_index}-t{thread_index}" + + +def _set_topic_min_partitions(total_workers: int) -> None: + try: + service_instances = int(os.getenv("NUMBER_OF_INSTANCES", "1")) + except ValueError: + service_instances = 1 + + requested_partitions = max(1, total_workers * max(1, service_instances)) + try: + current_partitions = int(os.getenv("KAFKA_TOPIC_MIN_PARTITIONS", "1")) + except ValueError: + current_partitions = 1 + os.environ["KAFKA_TOPIC_MIN_PARTITIONS"] = str( + max(current_partitions, requested_partitions) + ) + + def _without_instances(config: dict[str, Any]) -> dict[str, Any]: return {key: value for key, value in config.items() if key != "instances"} @@ -62,9 +198,23 @@ def _parse_executor_config(config: dict[str, Any]) -> PipelineExecutorConfig: ) if executor is None: executor = _infer_executor(config) + elif ( + executor == "process" + and _has_explicit_process_count(config) + and _read_threads_per_process(config, default=1) > 1 + ): + executor = "hybrid" - max_workers = _read_max_workers(config, executor) - return PipelineExecutorConfig(executor=executor, max_workers=max_workers) + processes = _read_process_count(config, executor) + threads_per_process = _read_threads_per_process( + config, + default=1 if executor == "process" else None, + ) + return PipelineExecutorConfig( + executor=executor, + processes=processes, + threads_per_process=threads_per_process, + ) def _normalize_executor_name(value: Any) -> str | None: @@ -76,35 +226,71 @@ def _normalize_executor_name(value: Any) -> str | None: return "thread" if normalized in {"process", "processes", "process-pool", "processpool"}: return "process" + if normalized in {"hybrid", "mixed", "process-thread", "process-thread-pool"}: + return "hybrid" raise ValueError( "Pipeline executor must be one of: thread, threads, thread-pool, " - "process, processes, process-pool" + "process, processes, process-pool, hybrid" ) def _infer_executor(config: dict[str, Any]) -> str: - if "processes" in config: + if _has_explicit_process_count(config) and _has_explicit_thread_count(config): + return "hybrid" + if _has_explicit_process_count(config): return "process" return "thread" -def _read_max_workers(config: dict[str, Any], executor: str) -> int: - executor_specific_key = "processes" if executor == "process" else "threads" - worker_value = config.get(executor_specific_key) - if worker_value is None: +def _read_process_count(config: dict[str, Any], executor: str) -> int: + if executor == "thread": + return 1 + + worker_value = config.get("processes") + if worker_value is None and executor == "process": worker_value = config.get("max_workers", config.get("workers")) if worker_value is None: worker_value = 1 + return _read_positive_int(worker_value) + + +def _read_threads_per_process( + config: dict[str, Any], + default: int | None, +) -> int: + worker_value = config.get("threads_per_process") + if worker_value is None: + worker_value = config.get("threads") + if worker_value is None: + if default is None: + worker_value = config.get("max_workers", config.get("workers")) + else: + worker_value = default + if worker_value is None: + worker_value = 1 + + return _read_positive_int(worker_value) + + +def _read_positive_int(value: Any) -> int: try: - max_workers = int(worker_value) + parsed_value = int(value) except (TypeError, ValueError) as exc: raise ValueError("Pipeline executor worker count must be an integer") from exc - if max_workers < 1: + if parsed_value < 1: raise ValueError("Pipeline executor worker count must be at least 1") - return max_workers + return parsed_value + + +def _has_explicit_process_count(config: dict[str, Any]) -> bool: + return "processes" in config + + +def _has_explicit_thread_count(config: dict[str, Any]) -> bool: + return "threads" in config or "threads_per_process" in config def _thread_name_prefix(module_name: str, instance_name: str | None) -> str: diff --git a/src/base/kafka_handler.py b/src/base/kafka_handler.py index 7a9d4a7f..5ddf3c2c 100644 --- a/src/base/kafka_handler.py +++ b/src/base/kafka_handler.py @@ -20,7 +20,7 @@ KafkaException, Producer, ) -from confluent_kafka.admin import AdminClient, NewTopic +from confluent_kafka.admin import AdminClient, NewPartitions, NewTopic sys.path.append(os.getcwd()) from src.base.data_classes.batch import Batch @@ -36,6 +36,249 @@ config = setup_config() KAFKA_BROKERS = config["environment"]["kafka_brokers"] +KAFKA_CONSUMER_CONFIG = config["environment"].get("kafka_consumer", {}) +KAFKA_CONSUMER_MAX_POLL_INTERVAL_MS = int( + KAFKA_CONSUMER_CONFIG.get("max_poll_interval_ms", 1800000) +) +KAFKA_TOPIC_CONFIG = config["environment"].get("kafka_topics", {}) +KAFKA_TOPIC_DEFAULT_PARTITIONS = int(os.getenv("KAFKA_TOPIC_PARTITIONS", 12)) +KAFKA_TOPIC_REPLICATION_FACTOR = int( + os.getenv( + "KAFKA_TOPIC_REPLICATION_FACTOR", + KAFKA_TOPIC_CONFIG.get("replication_factor", len(KAFKA_BROKERS) or 1), + ) +) +KAFKA_TOPIC_AUTO_EXPAND_PARTITIONS = KAFKA_TOPIC_CONFIG.get( + "auto_expand_partitions", True +) +KAFKA_TOPIC_STAGE_CONFIG = KAFKA_TOPIC_CONFIG.get("stages", {}) +KAFKA_TOPIC_EXACT_CONFIG = KAFKA_TOPIC_CONFIG.get("topics", {}) +KAFKA_PIPELINE_TOPIC_PREFIXES = config["environment"].get("kafka_topics_prefix", {}).get( + "pipeline", {} +) + + +def _normalize_topics(topics: str | list[str]) -> list[str]: + if isinstance(topics, str): + return [topics] + return topics + + +def _as_bool(value) -> bool: + if isinstance(value, bool): + return value + if isinstance(value, str): + return value.lower() in {"1", "true", "yes", "on"} + return bool(value) + + +def _topic_config(topic: str | None) -> dict: + if topic is None: + return {} + + exact_config = KAFKA_TOPIC_EXACT_CONFIG.get(topic) + if exact_config is not None: + return exact_config + + matched_stage = None + matched_prefix_length = -1 + for stage_name, topic_prefix in KAFKA_PIPELINE_TOPIC_PREFIXES.items(): + if not topic_prefix: + continue + if topic == topic_prefix or topic.startswith(f"{topic_prefix}-"): + if len(topic_prefix) > matched_prefix_length: + matched_stage = stage_name + matched_prefix_length = len(topic_prefix) + + if matched_stage is None: + return {} + + return KAFKA_TOPIC_STAGE_CONFIG.get(matched_stage, {}) + + +def _desired_topic_partitions( + topic: str | None = None, override: int | None = None +) -> int: + topic_config = _topic_config(topic) + configured_partitions = override + if configured_partitions is None: + configured_partitions = topic_config.get( + "partitions", KAFKA_TOPIC_DEFAULT_PARTITIONS + ) + return max( + 1, + NUMBER_OF_INSTANCES, + _runtime_min_topic_partitions(), + int(configured_partitions), + ) + + +def _runtime_min_topic_partitions() -> int: + try: + return int(os.getenv("KAFKA_TOPIC_MIN_PARTITIONS", "1")) + except ValueError: + return 1 + + +def _topic_replication_factor( + topic: str | None = None, override: int | None = None +) -> int: + broker_count = max(1, len(KAFKA_BROKERS)) + topic_config = _topic_config(topic) + configured_replication_factor = override + if configured_replication_factor is None: + configured_replication_factor = topic_config.get( + "replication_factor", KAFKA_TOPIC_REPLICATION_FACTOR + ) + configured_replication_factor = max(1, int(configured_replication_factor)) + return min(configured_replication_factor, broker_count) + + +def _topic_partition_count(cluster_metadata, topic: str) -> int | None: + topics_metadata = getattr(cluster_metadata, "topics", {}) + + if isinstance(topics_metadata, dict): + topic_metadata = topics_metadata.get(topic) + if topic_metadata is None: + return None + + partitions = getattr(topic_metadata, "partitions", None) + if partitions is None: + return 1 + return len(partitions) + + if topic in topics_metadata: + return 1 + + return None + + +def _is_topic_already_created(exception: Exception) -> bool: + kafka_error = exception.args[0] if getattr(exception, "args", None) else None + topic_already_exists_code = getattr(KafkaError, "TOPIC_ALREADY_EXISTS", None) + if ( + topic_already_exists_code is not None + and hasattr(kafka_error, "code") + and kafka_error.code() == topic_already_exists_code + ): + return True + + return "already exists" in str(exception).lower() + + +def _is_partition_count_already_satisfied(exception: Exception) -> bool: + message = str(exception).lower() + return "already has" in message or "smaller than current" in message + + +def _wait_for_admin_futures(futures: dict, operation: str) -> None: + for topic, future in futures.items(): + try: + future.result() + except KafkaException as exception: + if operation == "create topic" and _is_topic_already_created(exception): + logger.info("Kafka topic '%s' already exists.", topic) + continue + if ( + operation == "expand partitions" + and _is_partition_count_already_satisfied(exception) + ): + logger.info("Kafka topic '%s' already has enough partitions.", topic) + continue + raise + + +def ensure_topics( + admin_client: AdminClient, + topics: str | list[str], + target_partitions: int | None = None, + replication_factor: int | None = None, + auto_expand_partitions: bool | None = None, +) -> dict[str, int]: + normalized_topics = _normalize_topics(topics) + target_partitions_by_topic = { + topic: _desired_topic_partitions(topic, target_partitions) + for topic in normalized_topics + } + replication_factor_by_topic = { + topic: _topic_replication_factor(topic, replication_factor) + for topic in normalized_topics + } + auto_expand_partitions = ( + _as_bool(KAFKA_TOPIC_AUTO_EXPAND_PARTITIONS) + if auto_expand_partitions is None + else _as_bool(auto_expand_partitions) + ) + + cluster_metadata = admin_client.list_topics(timeout=10) + topics_metadata = getattr(cluster_metadata, "topics", {}) + existing_topics = ( + set(topics_metadata.keys()) + if isinstance(topics_metadata, dict) + else set(topics_metadata) + ) + missing_topics = [ + topic for topic in normalized_topics if topic not in existing_topics + ] + + if missing_topics: + logger.info( + "Creating Kafka topics %s.", + missing_topics, + ) + futures = admin_client.create_topics( + [ + NewTopic( + topic, + target_partitions_by_topic[topic], + replication_factor_by_topic[topic], + ) + for topic in missing_topics + ] + ) + _wait_for_admin_futures(futures, "create topic") + + if not auto_expand_partitions: + return target_partitions_by_topic + + cluster_metadata = admin_client.list_topics(timeout=10) + topics_to_expand = [] + for topic in normalized_topics: + current_partition_count = _topic_partition_count(cluster_metadata, topic) + if current_partition_count is None: + continue + target_partitions = target_partitions_by_topic[topic] + if current_partition_count < target_partitions: + logger.info( + "Expanding Kafka topic '%s' from %d to %d partition(s).", + topic, + current_partition_count, + target_partitions, + ) + topics_to_expand.append(NewPartitions(topic, target_partitions)) + + if topics_to_expand: + futures = admin_client.create_partitions(topics_to_expand) + _wait_for_admin_futures(futures, "expand partitions") + + return target_partitions_by_topic + + +def _sanitize_consumer_group_part(value: str) -> str: + return "".join( + character if character.isalnum() or character in "._-" else "_" + for character in value + ) + + +def build_consumer_group_id(topics: str | list[str]) -> str: + normalized_topics = sorted(_normalize_topics(topics)) + topic_suffix = "__".join( + _sanitize_consumer_group_part(topic) for topic in normalized_topics + ) + if not topic_suffix: + return CONSUMER_GROUP_ID + return f"{CONSUMER_GROUP_ID}.{topic_suffix}" class TooManyFailedAttemptsError(Exception): @@ -325,6 +568,10 @@ def __init__(self, topics: str | list[str]) -> None: KafkaException: If consumer creation or subscription fails. """ super().__init__() + self._last_consumed_message = None + + if isinstance(topics, str): + topics = [topics] # get brokers self.brokers = ",".join( @@ -337,33 +584,35 @@ def __init__(self, topics: str | list[str]) -> None: # create consumer conf = { "bootstrap.servers": self.brokers, - "group.id": f"{CONSUMER_GROUP_ID}", + "group.id": build_consumer_group_id(topics), "enable.auto.commit": False, "auto.offset.reset": "earliest", "enable.partition.eof": True, + "max.poll.interval.ms": KAFKA_CONSUMER_MAX_POLL_INTERVAL_MS, } self.consumer = Consumer(conf) - if isinstance(topics, str): - topics = [topics] - # create topics admin_client = AdminClient( { "bootstrap.servers": self.brokers, } ) - admin_client.create_topics( - [NewTopic(topic, NUMBER_OF_INSTANCES, 1) for topic in topics] - ) + target_partitions_by_topic = ensure_topics(admin_client, topics) # check if topics are created - if not self._all_topics_created(topics): + if not self._all_topics_created(topics, target_partitions_by_topic): raise TooManyFailedAttemptsError("Not all topics were created.") # subscribe to the topics self.consumer.subscribe(topics) + def commit(self) -> None: + """Commit the last message returned by ``consume``.""" + if self.consumer and self._last_consumed_message is not None: + self.consumer.commit(self._last_consumed_message) + self._last_consumed_message = None + @abstractmethod def consume(self, *args, **kwargs): """Abstract method for consuming messages from Kafka topics @@ -410,7 +659,9 @@ def consume_as_json(self) -> tuple[Optional[str], dict]: except Exception: raise ValueError("Unknown data format") - def _all_topics_created(self, topics: list[str]) -> bool: + def _all_topics_created( + self, topics: list[str], min_partitions: int | dict[str, int] = 1 + ) -> bool: """Verify that all specified topics have been created successfully. Polls the Kafka cluster to check if each topic in the provided list @@ -430,7 +681,13 @@ def _all_topics_created(self, topics: list[str]) -> bool: all_topics_created = True for topic in topics: - if topic not in assigned_topics.topics: + partition_count = _topic_partition_count(assigned_topics, topic) + required_partitions = ( + min_partitions.get(topic, 1) + if isinstance(min_partitions, dict) + else min_partitions + ) + if partition_count is None or partition_count < required_partitions: all_topics_created = False if not all_topics_created: @@ -456,6 +713,23 @@ def __del__(self) -> None: def _is_dicts(obj): return isinstance(obj, list) and all(isinstance(item, dict) for item in obj) + @staticmethod + def _decode_batch_data(data): + if data is None: + return [] + if not isinstance(data, list): + raise ValueError("Batch data must be a list.") + + decoded_data = [] + for item in data: + if isinstance(item, str): + decoded_data.append(json.loads(item)) + elif isinstance(item, (dict, list)): + decoded_data.append(item) + else: + raise ValueError("Batch data contains unsupported item type.") + return decoded_data + def consume_as_object(self) -> tuple[None | str, Batch]: """ Consumes available messages on the specified topic. Decodes the data and converts it to a Batch @@ -472,10 +746,7 @@ def consume_as_object(self) -> tuple[None | str, Batch]: # TODO: Change return value to fit the type, maybe switch to raise return None, {} eval_data: dict = json.loads(value) - if self._is_dicts(eval_data.get("data")): - eval_data["data"] = eval_data.get("data") - else: - eval_data["data"] = [json.loads(item) for item in eval_data.get("data")] + eval_data["data"] = self._decode_batch_data(eval_data.get("data")) batch_schema = marshmallow_dataclass.class_schema(Batch)() eval_data: Batch = batch_schema.load(eval_data) if isinstance(eval_data, Batch): @@ -543,6 +814,7 @@ def consume(self) -> tuple[Optional[str], Optional[str], Optional[str]]: key = msg.key().decode("utf-8") if msg.key() else None value = msg.value().decode("utf-8") if msg.value() else None topic = msg.topic() if msg.topic() else None + self._last_consumed_message = msg return key, value, topic except KeyboardInterrupt: logger.info("Stopping KafkaConsumeHandler...") @@ -608,8 +880,7 @@ def consume(self) -> tuple[Optional[str], Optional[str], Optional[str]]: key = msg.key().decode("utf-8") if msg.key() else None value = msg.value().decode("utf-8") if msg.value() else None topic = msg.topic() if msg.topic() else None - - self.consumer.commit(msg) + self._last_consumed_message = msg return key, value, topic except KeyboardInterrupt: diff --git a/src/detector/detector.py b/src/detector/detector.py index c865b0b9..011bec23 100644 --- a/src/detector/detector.py +++ b/src/detector/detector.py @@ -12,6 +12,7 @@ from numpy import median from abc import ABC, abstractmethod import importlib +from typing import Any sys.path.append(os.getcwd()) from src.base.clickhouse_kafka_sender import ClickHouseKafkaSender @@ -23,7 +24,11 @@ KafkaMessageFetchException, ) from src.base.log_config import get_logger -from src.base.execution import create_pipeline_executor +from src.base.execution import ( + create_pipeline_executor, + run_thread_worker_pool, + start_pipeline_worker_replicas, +) module_name = "data_analysis.detector" logger = get_logger(module_name) @@ -183,6 +188,7 @@ def __init__( self.model = self.model_name self.checksum = detector_config["checksum"] self.threshold = detector_config["threshold"] + self.use_scaler = detector_config["use_scaler"] if "use_scaler" in detector_config.keys() else False self.consume_topic = consume_topic if produce_topics is None: @@ -332,6 +338,43 @@ def _sha256sum(self, file_path: str) -> str: return h.hexdigest() + + def get_model_download_url(self): + """ + Generate the complete URL for downloading the Domainator detection model. + + Constructs the URL using the base URL from configuration and appends the + specific model filename with checksum for verification. + + Returns: + str: Fully qualified URL where the model can be downloaded. + """ + self.model_base_url = ( + self.model_base_url[:-1] + if self.model_base_url[-1] == "/" + else self.model_base_url + ) + return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1" + + def get_scaler_download_url(self): + """ + Generate the complete URL for downloading the Domainator detection models scaler. + + Constructs the URL using the base URL from configuration and appends the + specific model filename with checksum for verification. + + Returns: + str: Fully qualified URL where the model can be downloaded. + """ + self.model_base_url = ( + self.model_base_url[:-1] + if self.model_base_url[-1] == "/" + else self.model_base_url + ) + return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2Fscaler.pickle&dl=1" + + + def _get_model(self): """ Download and validate the detection model. @@ -353,7 +396,7 @@ def _get_model(self): requests.HTTPError: If there's an error downloading the model. """ logger.info(f"Get model: {self.model_name} with checksum {self.checksum}") - scaler_download_url = self.get_scaler_download_url() + scaler_download_url = self.get_scaler_download_url() if self.use_scaler else None if not os.path.isfile(self.model_path): model_download_url = self.get_model_download_url() @@ -396,16 +439,14 @@ def detect(self) -> None: """ Process messages to detect malicious requests. - This method applies the detection model to each message in the current batch, - identifies potential threats based on the model's predictions, and collects - warnings for further processing. - - The detection uses a threshold to determine if a prediction indicates - malicious activity, and only warnings exceeding this threshold are retained. + This method has to be overwritten in the child classes for detectors. + The implementation below is tried to be generic. If a detector needs a different approach (e.g. predict on more than one message, buffer messages, etc.) + then just overwrite the method and append to `self.warnings` a warning if a given messag is to be regarded malicious. Note: - This method relies on the implementation of ``predict``of the rspective subclass + """ + logger.info("general") logger.info("Start detecting malicious requests.") for message in self.messages: y_pred = self.predict(message) @@ -418,13 +459,11 @@ def detect(self) -> None: warning = { "request": message, "probability": float(y_pred[0][1]), - # TODO: what is the use of this? not even json serializabel ? - # "model": self.model, "name": self.name, "sha256": self.checksum, } self.warnings.append(warning) - + def clear_data(self): """Clears the data in the internal data structures.""" self.messages = [] @@ -445,10 +484,10 @@ def send_warning(self) -> None: The method updates multiple database tables to maintain the pipeline's state tracking and provides detailed information about detected threats. """ - logger.info("Store alert.") row_id = generate_collisions_resistant_uuid() downstream_messages = [] if len(self.warnings) > 0: + logger.info("Begin warning computation...") overall_score = median( [warning["probability"] for warning in self.warnings] ) @@ -463,7 +502,8 @@ def send_warning(self) -> None: } if self.produce_topics: - logger.info(f"Producing alert to Kafka: {alert}") + kafka_alert = self._build_kafka_alert(alert) + logger.debug(f"Producing compact alert to Kafka: {kafka_alert}") if self.kafka_produce_handler is None: self.kafka_produce_handler = ExactlyOnceKafkaProduceHandler() @@ -471,7 +511,7 @@ def send_warning(self) -> None: for topic in self.produce_topics: self.kafka_produce_handler.produce( topic=topic, - data=json.dumps(alert), + data=json.dumps(kafka_alert), key=self.key, ) else: @@ -484,7 +524,7 @@ def send_warning(self) -> None: suspicious_batch_id=self.suspicious_batch_id, overall_score=overall_score, domain_names=json.dumps(self._get_warning_requests()), - result=json.dumps(self.warnings), + result=json.dumps(self._build_persisted_warnings()), ) ) @@ -501,11 +541,7 @@ def send_warning(self) -> None: ) ) - logline_ids = set() - for message in self.messages: - logline_ids.add(message["logline_id"]) - - for logline_id in logline_ids: + for logline_id in self._get_message_logline_ids(): self.logline_timestamps.insert( dict( logline_id=logline_id, @@ -531,11 +567,7 @@ def send_warning(self) -> None: ) ) - logline_ids = set() - for message in self.messages: - logline_ids.add(message["logline_id"]) - - for logline_id in logline_ids: + for logline_id in self._get_message_logline_ids(): self.logline_timestamps.insert( dict( logline_id=logline_id, @@ -574,7 +606,7 @@ def _send_detector_batch(self, parent_row_id, messages) -> None: if not self.downstream_detector_topics: return - logger.info( + logger.debug( f"Producing detector output to Kafka topics: {self.downstream_detector_topics}" ) data_to_send = { @@ -596,12 +628,181 @@ def _send_detector_batch(self, parent_row_id, messages) -> None: key=self.key, ) - def _get_warning_requests(self) -> list: + def _build_persisted_warnings(self) -> list[dict]: return [ - warning.get("request", warning.get("request_domain", warning)) + self._normalize_warning_for_storage(warning) for warning in self.warnings ] + def _normalize_warning_for_storage(self, warning: dict) -> dict: + raw_detector_output = { + key: value for key, value in warning.items() if key != "request" + } + request = warning.get("request") + request_messages = ( + self._flatten_messages(request) if request is not None else [] + ) + detector_name = ( + warning.get("detector_name") + or warning.get("name") + or self.name + ) + score = warning.get("score", warning.get("probability")) + domains = self._extract_warning_domains(warning) + logline_ids = self._extract_message_values(request_messages, "logline_id") + server_message_ids = self._extract_message_values( + request_messages, "server_message_id" + ) + + normalized_warning = { + "detector_name": detector_name, + "name": detector_name, + "score": score, + "probability": score, + "predicted_class": warning.get("predicted_class", ""), + "attributes": warning.get("attributes", []), + "domains": domains, + "domain_names": domains, + "logline_ids": logline_ids, + "server_message_ids": server_message_ids, + "request_count": len(request_messages), + "raw_detector_output": raw_detector_output, + "request": request, + } + + for optional_field in ("model", "sha256", "class_probabilities"): + if optional_field in warning: + normalized_warning[optional_field] = warning[optional_field] + + return normalized_warning + + @staticmethod + def _extract_message_values(messages: list, field_name: str) -> list[str]: + return sorted( + { + str(message[field_name]) + for message in messages + if isinstance(message, dict) and field_name in message + } + ) + + def _get_warning_requests(self) -> list[str]: + domains = [] + for warning in self.warnings: + domains.extend(self._extract_warning_domains(warning)) + return sorted(set(domains)) + + def _extract_warning_domains(self, warning) -> list[str]: + warnings = self.warnings + + if warning is not None: + warnings = warning + + if isinstance(warnings, str): + try: + warnings = json.loads(warnings) + except json.JSONDecodeError: + return [warnings] + + domains: list[str] = [] + stack: list[Any] = [warnings] + + while stack: + item = stack.pop() + + if isinstance(item, str): + try: + stack.append(json.loads(item)) + except json.JSONDecodeError: + domains.append(item) + + elif isinstance(item, list): + stack.extend(reversed(item)) + + elif isinstance(item, dict): + # If request/request_domain contains nested objects, descend into it + if "request" in item: + stack.append(item["request"]) + continue + + if "request_domain" in item: + stack.append(item["request_domain"]) + continue + + # Actual DNS warning object + if "domain_name" in item: + domains.append(item["domain_name"]) + + for field_name in ("domains", "domain_names"): + field_value = item.get(field_name) + if isinstance(field_value, list): + domains.extend( + domain + for domain in field_value + if isinstance(domain, str) + ) + + return sorted(set(domains)) + + def _build_kafka_alert(self, alert: dict) -> dict: + compact_alert = dict(alert) + compact_alert["result"] = [ + self._compact_warning_for_kafka(warning) + for warning in alert.get("result", []) + ] + return compact_alert + + def _compact_warning_for_kafka(self, warning: dict) -> dict: + compact_warning = { + key: value for key, value in warning.items() if key != "request" + } + request = warning.get("request") + if request is None: + return compact_warning + + request_messages = self._flatten_messages(request) + compact_warning["request_count"] = len(request_messages) + compact_warning["domain_names"] = sorted( + { + message["domain_name"] + for message in request_messages + if isinstance(message, dict) and "domain_name" in message + } + )[:50] + compact_warning["logline_ids"] = sorted( + { + message["logline_id"] + for message in request_messages + if isinstance(message, dict) and "logline_id" in message + } + )[:100] + compact_warning["server_message_ids"] = sorted( + { + message["server_message_id"] + for message in request_messages + if isinstance(message, dict) and "server_message_id" in message + } + )[:100] + return compact_warning + + def _flatten_messages(self, messages) -> list: + flattened_messages = [] + pending_messages = [messages] + while pending_messages: + message = pending_messages.pop() + if isinstance(message, list): + pending_messages.extend(message) + else: + flattened_messages.append(message) + return flattened_messages + + def _get_message_logline_ids(self) -> set: + return { + message["logline_id"] + for message in self._flatten_messages(self.messages) + if isinstance(message, dict) and "logline_id" in message + } + def _get_downstream_messages(self) -> list: messages = [ warning["request"] for warning in self.warnings if "request" in warning @@ -635,6 +836,7 @@ def bootstrap_detector_instance(self): self.detect() logger.debug("Send warnings") self.send_warning() + self.kafka_consume_handler.commit() except KafkaMessageFetchException as e: # pragma: no cover logger.debug(e) except IOError as e: @@ -663,6 +865,54 @@ async def start(self): # pragma: no cover executor.shutdown(wait=False, cancel_futures=True) +def build_detector_worker( + detector_config, + consume_topic, + produce_topics, + downstream_detector_topics, + worker_id=None, +): + class_name = detector_config["detector_class_name"] + plugin_module_name = f"{PLUGIN_PATH}.{detector_config['detector_module_name']}" + plugin_module = importlib.import_module(plugin_module_name) + detector_class = getattr(plugin_module, class_name) + worker = detector_class( + detector_config=detector_config, + consume_topic=consume_topic, + produce_topics=produce_topics, + downstream_detector_topics=downstream_detector_topics, + ) + worker.worker_id = worker_id + return worker + + +def run_detector_worker_process( + process_index, + threads_per_process, + detector_config, + consume_topic, + produce_topics, + downstream_detector_topics, +): + def worker_factory(worker_id): + return build_detector_worker( + detector_config=detector_config, + consume_topic=consume_topic, + produce_topics=produce_topics, + downstream_detector_topics=downstream_detector_topics, + worker_id=worker_id, + ) + + run_thread_worker_pool( + worker_factory=worker_factory, + target_name="bootstrap_detector_instance", + module_name=module_name, + instance_name=detector_config["name"], + process_index=process_index, + threads_per_process=threads_per_process, + ) + + async def main(): # pragma: no cover """ Initialize and start all detector instances defined in the configuration. @@ -682,23 +932,46 @@ async def main(): # pragma: no cover produce_topics = build_alerter_topics(detector_config) downstream_detector_topics = build_downstream_detector_topics(detector_config) logger.info( - "Detector %s configured with alerter topics %s and downstream detector topics %s", + "Detector %s configured with consume topic %s, alerter topics %s and downstream detector topics %s", detector_config["name"], + consume_topic, produce_topics, downstream_detector_topics, ) - class_name = detector_config["detector_class_name"] - module_name = f"{PLUGIN_PATH}.{detector_config['detector_module_name']}" - module = importlib.import_module(module_name) - DetectorClass = getattr(module, class_name) - detector = DetectorClass( + def worker_factory( + worker_id, detector_config=detector_config, consume_topic=consume_topic, produce_topics=produce_topics, downstream_detector_topics=downstream_detector_topics, + ): + return build_detector_worker( + detector_config=detector_config, + consume_topic=consume_topic, + produce_topics=produce_topics, + downstream_detector_topics=downstream_detector_topics, + worker_id=worker_id, + ) + + tasks.append( + asyncio.create_task( + start_pipeline_worker_replicas( + config=config, + module_name=module_name, + instance_name=detector_config["name"], + worker_factory=worker_factory, + target_name="bootstrap_detector_instance", + process_entrypoint=run_detector_worker_process, + process_args=( + detector_config, + consume_topic, + produce_topics, + downstream_detector_topics, + ), + ) + ) ) - tasks.append(asyncio.create_task(detector.start())) await asyncio.gather(*tasks) diff --git a/src/detector/plugins/dga_detector.py b/src/detector/plugins/dga_detector.py index cbb3e076..c3cebfcd 100644 --- a/src/detector/plugins/dga_detector.py +++ b/src/detector/plugins/dga_detector.py @@ -42,35 +42,6 @@ def __init__( detector_config, consume_topic, produce_topics, downstream_detector_topics ) - def get_model_download_url(self): - """ - Generate the complete URL for downloading the DGA detection model. - - Constructs the URL using the base URL from configuration and appends the - specific model filename with checksum for verification. - - Returns: - str: Fully qualified URL where the model can be downloaded. - """ - self.model_base_url = ( - self.model_base_url[:-1] - if self.model_base_url[-1] == "/" - else self.model_base_url - ) - return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1" - - def get_scaler_download_url(self): - """ - Generate the complete URL for downloading the DGA detection models scaler. - - Constructs the URL using the base URL from configuration and appends the - specific model filename with checksum for verification. - - Returns: - str: Fully qualified URL where the model can be downloaded. - """ - return None - def predict(self, message): """ Process a message and predict if the domain is likely generated by a DGA. @@ -180,3 +151,31 @@ def calculate_entropy(s: str) -> float: logger.debug("Finished data transformation") return all_features.reshape(1, -1) + + + def detect(self) -> None: + """ + Process messages to detect malicious requests. + + This method applies the detection model to each message in the current batch, + identifies potential threats based on the model's predictions, and collects + warnings for further processing. + + The detection uses a threshold to determine if a prediction indicates + malicious activity, and only warnings exceeding this threshold are retained. + + Note: + This method relies on the implementation of ``predict``of the rspective subclass + """ + for message in self.messages: + y_pred = self.predict(message) + logger.info(f"Prediction: {y_pred}") + if np.argmax(y_pred, axis=1) == 1 and y_pred[0][1] > self.threshold: + logger.debug("Append malicious request to warning.") + warning = { + "request": message, + "probability": float(y_pred[0][1]), + "name": self.name, + "sha256": self.checksum, + } + self.warnings.append(warning) \ No newline at end of file diff --git a/src/detector/plugins/domainator_attributor.py b/src/detector/plugins/domainator_attributor.py index 08ef6277..87f097fb 100644 --- a/src/detector/plugins/domainator_attributor.py +++ b/src/detector/plugins/domainator_attributor.py @@ -8,12 +8,14 @@ from src.base.log_config import get_logger from src.detector.plugins.domainator_utils import ( strip_domain, - get_domainator_features + get_domainator_features, ) module_name = "data_analysis.detector" logger = get_logger(module_name) +LEGITIMATE_ATTRIBUTE_LABELS = {"benign", "legit", "legitimate", "legitimate_legitimate"} + class DomainatorAttributor(DetectorBase): """ @@ -50,46 +52,12 @@ def __init__( """ self.model_base_url = detector_config["base_url"] self.message_queues = defaultdict(list) - + super().__init__( detector_config, consume_topic, produce_topics, downstream_detector_topics ) - - self.labels = self.model.classes_ - - def get_model_download_url(self): - """ - Generate the complete URL for downloading the Domainator detection model. - - Constructs the URL using the base URL from configuration and appends the - specific model filename with checksum for verification. - - Returns: - str: Fully qualified URL where the model can be downloaded. - """ - self.model_base_url = ( - self.model_base_url[:-1] - if self.model_base_url[-1] == "/" - else self.model_base_url - ) - return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1" - - def get_scaler_download_url(self): - """ - Generate the complete URL for downloading the Domainator identification models scaler. - - Constructs the URL using the base URL from configuration and appends the - specific model filename with checksum for verification. - Returns: - str: Fully qualified URL where the model can be downloaded. - """ - self.model_base_url = ( - self.model_base_url[:-1] - if self.model_base_url[-1] == "/" - else self.model_base_url - ) - return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2Fscaler.pickle&dl=1" + self.labels = self.model.classes_ def predict(self, messages): """ @@ -109,7 +77,7 @@ def predict(self, messages): """ queries = [message["domain_name"] for message in messages] - y_pred = self.model.predict_proba(get_domainator_features(queries)) + y_pred = self.model.predict_proba(get_domainator_features(queries)) return y_pred def detect(self): @@ -131,18 +99,24 @@ def detect(self): y_pred = self.predict(self.message_queues[message_domain]) logger.info(f"Prediction: {y_pred}") + winning_index = int(np.argmax(y_pred, axis=1)[0]) + winning_label = self.labels[winning_index] + winning_probability = float(y_pred[0][winning_index]) y_pred_labelled = [ {"attribute": label, "probability": float(score)} for label, score in zip(self.labels, y_pred[0]) if score >= self.threshold ] - logger.info(f"Prediction with labels: {y_pred_labelled}") + logger.debug(f"Prediction with labels: {y_pred_labelled}") - if np.argmax(y_pred, axis=1) == 1 and len(y_pred_labelled) > 0: - logger.info("Append malicious request domain to warning.") + is_legitimate = winning_label in LEGITIMATE_ATTRIBUTE_LABELS + if not is_legitimate and winning_probability >= self.threshold: + logger.debug("Append malicious request domain to warning.") warning = { "request": self.message_queues[message_domain], - "probability": y_pred_labelled, + "probability": winning_probability, + "predicted_class": winning_label, + "attributes": y_pred_labelled, "name": self.name, "sha256": self.checksum, } @@ -150,4 +124,3 @@ def detect(self): if len(self.message_queues[message_domain]) >= 10: del self.message_queues[message_domain][0] - diff --git a/src/detector/plugins/domainator_detector.py b/src/detector/plugins/domainator_detector.py index 6810049b..b9c4d829 100644 --- a/src/detector/plugins/domainator_detector.py +++ b/src/detector/plugins/domainator_detector.py @@ -8,7 +8,7 @@ from src.base.log_config import get_logger from src.detector.plugins.domainator_utils import ( strip_domain, - get_domainator_features + get_domainator_features, ) module_name = "data_analysis.detector" @@ -53,40 +53,6 @@ def __init__( detector_config, consume_topic, produce_topics, downstream_detector_topics ) - def get_model_download_url(self): - """ - Generate the complete URL for downloading the Domainator detection model. - - Constructs the URL using the base URL from configuration and appends the - specific model filename with checksum for verification. - - Returns: - str: Fully qualified URL where the model can be downloaded. - """ - self.model_base_url = ( - self.model_base_url[:-1] - if self.model_base_url[-1] == "/" - else self.model_base_url - ) - return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2F{self.model_name}.pickle&dl=1" - - def get_scaler_download_url(self): - """ - Generate the complete URL for downloading the Domainator detection models scaler. - - Constructs the URL using the base URL from configuration and appends the - specific model filename with checksum for verification. - - Returns: - str: Fully qualified URL where the model can be downloaded. - """ - self.model_base_url = ( - self.model_base_url[:-1] - if self.model_base_url[-1] == "/" - else self.model_base_url - ) - return f"{self.model_base_url}/files/?p=%2F{self.model_name}%2F{self.checksum}%2Fscaler.pickle&dl=1" - def predict(self, messages): """ Process a window of messages and predict if the domain is likely to be used @@ -129,4 +95,3 @@ def detect(self): if len(self.message_queues[message_domain]) >= 10: del self.message_queues[message_domain][0] - diff --git a/src/detector/plugins/domainator_utils.py b/src/detector/plugins/domainator_utils.py index ba51d5f9..80671e70 100644 --- a/src/detector/plugins/domainator_utils.py +++ b/src/detector/plugins/domainator_utils.py @@ -1,8 +1,24 @@ import numpy as np import itertools +import pandas as pd import pylcs import Levenshtein +from src.base.log_config import get_logger + +module_name = "data_analysis.detector" +logger = get_logger(module_name) + +DOMAINATOR_FEATURE_COLUMNS = [ + "levenshtein", + "jaro", + "jaro_reversed", + "jaro_winkler", + "jaro_winkler_reversed", + "lcs_sequence", + "lcs_string", +] + def strip_domain(query: str): """Extract the domain name from the message for the window grouping @@ -24,7 +40,7 @@ def strip_domain(query: str): return domain -def get_domainator_features(queries: list) -> np.ndarray: +def get_domainator_features(queries: list) -> pd.DataFrame: """Extracts feature vector from domain name for ML model inference. Computes various statistical and linguistic features from the domain name @@ -35,23 +51,13 @@ def get_domainator_features(queries: list) -> np.ndarray: queries (list): List of query strings to extract features from. Returns: - numpy.ndarray: Feature vector ready for ML model prediction. + pandas.DataFrame: Feature vector ready for ML model prediction. """ queries = [query.strip(".") for query in queries] subdomains = [".".join(domain.split(".")[:-2]) for domain in queries] - # Values can be put directly into an array, as the return converts them anyway, - # but this slightly improves readability - metrics = { - "levenshtein": [], - "jaro": [], - "rev_jaro": [], - "jaro_winkler": [], - "rev_jaro_wink": [], - "lcs_seq": [], - "lcs_str": [], - } + metrics = {column: [] for column in DOMAINATOR_FEATURE_COLUMNS} # if subdomains: cartesian = list(itertools.combinations(subdomains, 2)) @@ -62,19 +68,19 @@ def get_domainator_features(queries: list) -> np.ndarray: metrics["jaro"] = np.mean( [Levenshtein.jaro(product[0], product[1]) for product in cartesian] ) - metrics["jaro_winkler"] = np.mean( + metrics["jaro_reversed"] = np.mean( [ - Levenshtein.jaro_winkler(product[0], product[1], prefix_weight=0.2) + Levenshtein.jaro(product[0][::-1], product[1][::-1]) for product in cartesian ] ) - metrics["rev_jaro"] = np.mean( + metrics["jaro_winkler"] = np.mean( [ - Levenshtein.jaro(product[0][::-1], product[1][::-1]) + Levenshtein.jaro_winkler(product[0], product[1], prefix_weight=0.2) for product in cartesian ] ) - metrics["rev_jaro_wink"] = np.mean( + metrics["jaro_winkler_reversed"] = np.mean( [ Levenshtein.jaro_winkler( product[0][::-1], product[1][::-1], prefix_weight=0.2 @@ -83,7 +89,7 @@ def get_domainator_features(queries: list) -> np.ndarray: ] ) - metrics["lcs_seq"] = np.mean( + metrics["lcs_sequence"] = np.mean( [ ( pylcs.lcs_sequence_length(product[0], product[1]) @@ -94,7 +100,7 @@ def get_domainator_features(queries: list) -> np.ndarray: for product in cartesian ] ) - metrics["lcs_str"] = np.mean( + metrics["lcs_string"] = np.mean( [ ( pylcs.lcs_string_length(product[0], product[1]) @@ -106,4 +112,8 @@ def get_domainator_features(queries: list) -> np.ndarray: ] ) - return np.fromiter(metrics.values(), dtype=float).reshape(1, -1) \ No newline at end of file + return pd.DataFrame( + [[metrics[column] for column in DOMAINATOR_FEATURE_COLUMNS]], + columns=DOMAINATOR_FEATURE_COLUMNS, + ) + diff --git a/src/inspector/inspector.py b/src/inspector/inspector.py index 9386c75e..c4d95523 100644 --- a/src/inspector/inspector.py +++ b/src/inspector/inspector.py @@ -24,7 +24,11 @@ KafkaMessageFetchException, ) from src.base.log_config import get_logger -from src.base.execution import create_pipeline_executor +from src.base.execution import ( + create_pipeline_executor, + run_thread_worker_pool, + start_pipeline_worker_replicas, +) module_name = "data_inspection.inspector" logger = get_logger(module_name) @@ -232,6 +236,7 @@ def send_data(self): self.suspicious_batches_to_batch.insert( dict( + timestamp=datetime.now(), suspicious_batch_id=suspicious_batch_id, batch_id=self.batch_id, ) @@ -365,6 +370,7 @@ def bootstrap_inspection_process(self): self.get_and_fill_data() self.inspect() self.send_data() + self.kafka_consume_handler.commit() except KafkaMessageFetchException as e: # pragma: no cover logger.debug(e) except IOError as e: @@ -394,6 +400,50 @@ async def start(self): # pragma: no cover executor.shutdown(wait=False, cancel_futures=True) +def build_inspector_worker( + inspector_config, + consume_topic, + produce_topics, + worker_id=None, +): + class_name = inspector_config["inspector_class_name"] + plugin_module_name = f"{PLUGIN_PATH}.{inspector_config['inspector_module_name']}" + plugin_module = importlib.import_module(plugin_module_name) + inspector_class = getattr(plugin_module, class_name) + worker = inspector_class( + consume_topic=consume_topic, + produce_topics=produce_topics, + config=inspector_config, + ) + worker.worker_id = worker_id + return worker + + +def run_inspector_worker_process( + process_index, + threads_per_process, + inspector_config, + consume_topic, + produce_topics, +): + def worker_factory(worker_id): + return build_inspector_worker( + inspector_config=inspector_config, + consume_topic=consume_topic, + produce_topics=produce_topics, + worker_id=worker_id, + ) + + run_thread_worker_pool( + worker_factory=worker_factory, + target_name="bootstrap_inspection_process", + module_name=module_name, + instance_name=inspector_config["name"], + process_index=process_index, + threads_per_process=threads_per_process, + ) + + async def main(): """ Entry point for the Inspector module. @@ -420,14 +470,35 @@ async def main(): and str(detector.get("consume_from", "")).strip().lower() != "detector" ] class_name = inspector["inspector_class_name"] - module_name = f"{PLUGIN_PATH}.{inspector['inspector_module_name']}" - module = importlib.import_module(module_name) - InspectorClass = getattr(module, class_name) - logger.info(f"using {class_name} and {module_name}") - inspector_instance = InspectorClass( - consume_topic=consume_topic, produce_topics=produce_topics, config=inspector + plugin_module_name = f"{PLUGIN_PATH}.{inspector['inspector_module_name']}" + logger.info(f"using {class_name} and {plugin_module_name}") + + def worker_factory( + worker_id, + inspector=inspector, + consume_topic=consume_topic, + produce_topics=produce_topics, + ): + return build_inspector_worker( + inspector_config=inspector, + consume_topic=consume_topic, + produce_topics=produce_topics, + worker_id=worker_id, + ) + + tasks.append( + asyncio.create_task( + start_pipeline_worker_replicas( + config=config, + module_name=module_name, + instance_name=inspector["name"], + worker_factory=worker_factory, + target_name="bootstrap_inspection_process", + process_entrypoint=run_inspector_worker_process, + process_args=(inspector, consume_topic, produce_topics), + ) + ) ) - tasks.append(asyncio.create_task(inspector_instance.start())) await asyncio.gather(*tasks) diff --git a/src/logcollector/batch_handler.py b/src/logcollector/batch_handler.py index 46931207..2ffdf63c 100644 --- a/src/logcollector/batch_handler.py +++ b/src/logcollector/batch_handler.py @@ -80,6 +80,7 @@ def add_message(self, key: str, logline_id: uuid.UUID, message: str) -> None: batch_id = self.batch_id.get(key) self.logline_to_batches.insert( dict( + timestamp=datetime.datetime.now(), logline_id=logline_id, batch_id=batch_id, ) @@ -105,6 +106,7 @@ def add_message(self, key: str, logline_id: uuid.UUID, message: str) -> None: self.logline_to_batches.insert( dict( + timestamp=datetime.datetime.now(), logline_id=logline_id, batch_id=new_batch_id, ) diff --git a/src/logcollector/collector.py b/src/logcollector/collector.py index 51f4e135..7dd1d738 100644 --- a/src/logcollector/collector.py +++ b/src/logcollector/collector.py @@ -11,7 +11,11 @@ from src.base.kafka_handler import ExactlyOnceKafkaConsumeHandler from src.base.logline_handler import LoglineHandler from src.base import utils -from src.base.execution import create_pipeline_executor +from src.base.execution import ( + create_pipeline_executor, + run_thread_worker_pool, + start_pipeline_worker_replicas, +) from src.logcollector.batch_handler import BufferedBatchSender from src.base.log_config import get_logger from collections import defaultdict @@ -74,6 +78,10 @@ def __init__( self.failed_protocol_loglines = ClickHouseKafkaSender("failed_loglines") self.protocol_loglines = ClickHouseKafkaSender("loglines") self.logline_timestamps = ClickHouseKafkaSender("logline_timestamps") + self.server_log_to_logline = ClickHouseKafkaSender("server_log_to_logline") + self.server_log_terminal_events = ClickHouseKafkaSender( + "server_log_terminal_events" + ) self.fill_levels = ClickHouseKafkaSender("fill_levels") self.fill_levels.insert( @@ -121,9 +129,15 @@ def fetch(self) -> None: while True: key, value, topic = self.kafka_consume_handler.consume() logger.debug(f"From Kafka: '{value}'") - self.send(datetime.datetime.now(), value) - - def send(self, timestamp_in: datetime.datetime, message: str) -> None: + self.send(datetime.datetime.now(), value, server_message_id=key) + self.kafka_consume_handler.commit() + + def send( + self, + timestamp_in: datetime.datetime, + message: str, + server_message_id: str | uuid.UUID | None = None, + ) -> None: """Processes and sends a log line to the batch handler after validation. This method: @@ -135,20 +149,33 @@ def send(self, timestamp_in: datetime.datetime, message: str) -> None: Args: timestamp_in (datetime.datetime): Timestamp when the log line entered the pipeline message (str): Raw log line message in JSON format + server_message_id (str | uuid.UUID | None): Optional LogServer message id + received as the Kafka key. """ + server_message_uuid = self._parse_server_message_id(server_message_id) try: fields = self.logline_handler.validate_logline_and_get_fields_as_json( message ) except ValueError: + timestamp_failed = datetime.datetime.now() self.failed_protocol_loglines.insert( dict( message_text=message, timestamp_in=timestamp_in, - timestamp_failed=datetime.datetime.now(), + timestamp_failed=timestamp_failed, reason_for_failure=None, # TODO: Add actual reason ) ) + if server_message_uuid: + self.server_log_terminal_events.insert( + dict( + message_id=server_message_uuid, + stage=module_name, + status="failed", + timestamp=timestamp_failed, + ) + ) return additional_fields = fields.copy() for field in REQUIRED_FIELDS: @@ -164,6 +191,14 @@ def send(self, timestamp_in: datetime.datetime, message: str) -> None: additional_fields=json.dumps(additional_fields), ) ) + if server_message_uuid: + self.server_log_to_logline.insert( + dict( + timestamp=datetime.datetime.now(), + message_id=server_message_uuid, + logline_id=logline_id, + ) + ) self.logline_timestamps.insert( dict( logline_id=logline_id, @@ -175,6 +210,8 @@ def send(self, timestamp_in: datetime.datetime, message: str) -> None: ) message_fields = fields.copy() message_fields["logline_id"] = str(logline_id) + if server_message_uuid: + message_fields["server_message_id"] = str(server_message_uuid) self.logline_timestamps.insert( dict( @@ -188,6 +225,22 @@ def send(self, timestamp_in: datetime.datetime, message: str) -> None: self.batch_handler.add_message(subnet_id, json.dumps(message_fields)) logger.debug(f"Sent: {message}") + @staticmethod + def _parse_server_message_id( + server_message_id: str | uuid.UUID | None, + ) -> uuid.UUID | None: + if not server_message_id: + return None + if isinstance(server_message_id, uuid.UUID): + return server_message_id + try: + return uuid.UUID(str(server_message_id)) + except (TypeError, ValueError): + logger.warning( + "Ignoring non-UUID LogServer message id '%s'.", server_message_id + ) + return None + def _get_subnet_id( self, address: ipaddress.IPv4Address | ipaddress.IPv6Address ) -> str: @@ -221,6 +274,54 @@ def _get_subnet_id( return f"{normalized_ip_address}_{prefix_length}" +def build_logcollector_worker( + collector_name, + protocol, + consume_topic, + produce_topics, + validation_config, + worker_id=None, +): + worker = LogCollector( + collector_name=collector_name, + protocol=protocol, + consume_topic=consume_topic, + produce_topics=produce_topics, + validation_config=validation_config, + ) + worker.worker_id = worker_id + return worker + + +def run_logcollector_worker_process( + process_index, + threads_per_process, + collector_name, + protocol, + consume_topic, + produce_topics, + validation_config, +): + def worker_factory(worker_id): + return build_logcollector_worker( + collector_name=collector_name, + protocol=protocol, + consume_topic=consume_topic, + produce_topics=produce_topics, + validation_config=validation_config, + worker_id=worker_id, + ) + + run_thread_worker_pool( + worker_factory=worker_factory, + target_name="fetch", + module_name=module_name, + instance_name=collector_name, + process_index=process_index, + threads_per_process=threads_per_process, + ) + + async def main() -> None: """Creates and starts all configured LogCollector instances. @@ -242,14 +343,43 @@ async def main() -> None: if collector["name"] == prefilter["collector_name"] ] validation_config = collector["required_log_information"] - collector_instance = LogCollector( - collector_name=collector["name"], + + def worker_factory( + worker_id, + collector=collector, protocol=protocol, consume_topic=consume_topic, produce_topics=produce_topics, validation_config=validation_config, + ): + return build_logcollector_worker( + collector_name=collector["name"], + protocol=protocol, + consume_topic=consume_topic, + produce_topics=produce_topics, + validation_config=validation_config, + worker_id=worker_id, + ) + + tasks.append( + asyncio.create_task( + start_pipeline_worker_replicas( + config=config, + module_name=module_name, + instance_name=collector["name"], + worker_factory=worker_factory, + target_name="fetch", + process_entrypoint=run_logcollector_worker_process, + process_args=( + collector["name"], + protocol, + consume_topic, + produce_topics, + validation_config, + ), + ) + ) ) - tasks.append(asyncio.create_task(collector_instance.start())) await asyncio.gather(*tasks) diff --git a/src/logserver/server.py b/src/logserver/server.py index 01bb73d9..8a48a2f6 100644 --- a/src/logserver/server.py +++ b/src/logserver/server.py @@ -13,7 +13,11 @@ ) from src.base.clickhouse_kafka_sender import ClickHouseKafkaSender from src.base.utils import setup_config, get_zeek_sensor_topic_base_names -from src.base.execution import create_pipeline_executor +from src.base.execution import ( + create_pipeline_executor, + run_thread_worker_pool, + start_pipeline_worker_replicas, +) from src.base.log_config import get_logger module_name = "log_storage.logserver" @@ -82,7 +86,11 @@ def send(self, message_id: uuid.UUID, message: str) -> None: message (str): Message to be sent. """ for topic in self.produce_topics: - self.kafka_produce_handler.produce(topic=topic, data=message) + self.kafka_produce_handler.produce( + topic=topic, + data=message, + key=str(message_id), + ) logger.debug(f"Sent: '{message}' to topic {topic}") self.server_logs_timestamps.insert( @@ -114,6 +122,33 @@ def fetch_from_kafka(self) -> None: ) self.send(message_id, value) + self.kafka_consume_handler.commit() + + +def build_logserver_worker(consume_topic, produce_topics, worker_id=None): + worker = LogServer(consume_topic=consume_topic, produce_topics=produce_topics) + worker.worker_id = worker_id + return worker + + +def run_logserver_worker_process( + process_index, threads_per_process, consume_topic, produce_topics +): + def worker_factory(worker_id): + return build_logserver_worker( + consume_topic=consume_topic, + produce_topics=produce_topics, + worker_id=worker_id, + ) + + run_thread_worker_pool( + worker_factory=worker_factory, + target_name="fetch_from_kafka", + module_name=module_name, + instance_name=consume_topic, + process_index=process_index, + threads_per_process=threads_per_process, + ) async def main() -> None: @@ -128,10 +163,31 @@ async def main() -> None: for collector in COLLECTORS if collector["protocol_base"] == protocol ] - server_instance = LogServer( - consume_topic=consume_topic, produce_topics=produce_topics + + def worker_factory( + worker_id, + consume_topic=consume_topic, + produce_topics=produce_topics, + ): + return build_logserver_worker( + consume_topic=consume_topic, + produce_topics=produce_topics, + worker_id=worker_id, + ) + + tasks.append( + asyncio.create_task( + start_pipeline_worker_replicas( + config=config, + module_name=module_name, + instance_name=consume_topic, + worker_factory=worker_factory, + target_name="fetch_from_kafka", + process_entrypoint=run_logserver_worker_process, + process_args=(consume_topic, produce_topics), + ) + ) ) - tasks.append(asyncio.create_task(server_instance.start())) await asyncio.gather(*tasks) diff --git a/src/monitoring/clickhouse_batch_sender.py b/src/monitoring/clickhouse_batch_sender.py index 027cbc7e..4df8c85f 100644 --- a/src/monitoring/clickhouse_batch_sender.py +++ b/src/monitoring/clickhouse_batch_sender.py @@ -98,6 +98,22 @@ def __init__(self): "event_timestamp": datetime.datetime, }, ), + "server_log_to_logline": Table( + "server_log_to_logline", + { + "message_id": uuid.UUID, + "logline_id": uuid.UUID, + }, + ), + "server_log_terminal_events": Table( + "server_log_terminal_events", + { + "message_id": uuid.UUID, + "stage": str, + "status": str, + "timestamp": datetime.datetime, + }, + ), "failed_loglines": Table( "failed_loglines", { diff --git a/src/monitoring/monitoring_agent.py b/src/monitoring/monitoring_agent.py index e2281b2d..948a9e31 100644 --- a/src/monitoring/monitoring_agent.py +++ b/src/monitoring/monitoring_agent.py @@ -73,6 +73,8 @@ def __init__(self): self.table_names = [ "server_logs", "server_logs_timestamps", + "server_log_to_logline", + "server_log_terminal_events", "failed_loglines", "logline_to_batches", "loglines", diff --git a/src/prefilter/prefilter.py b/src/prefilter/prefilter.py index 60ebc28f..4abc2fd1 100644 --- a/src/prefilter/prefilter.py +++ b/src/prefilter/prefilter.py @@ -16,7 +16,11 @@ KafkaMessageFetchException, ) from src.base.log_config import get_logger -from src.base.execution import create_pipeline_executor +from src.base.execution import ( + create_pipeline_executor, + run_thread_worker_pool, + start_pipeline_worker_replicas, +) from src.base.utils import ( setup_config, get_zeek_sensor_topic_base_names, @@ -306,6 +310,7 @@ def bootstrap_prefiltering_process(self): self.get_and_fill_data() self.check_data_relevance_using_rules() self.send_filtered_data() + self.kafka_consume_handler.commit() async def start(self): # pragma: no cover """Starts the ``Prefilter`` processing loop. @@ -333,6 +338,54 @@ async def start(self): # pragma: no cover self.clear_data() +def build_prefilter_worker( + validation_config, + consume_topic, + produce_topics, + relevance_function_name, + prefilter_name, + worker_id=None, +): + worker = Prefilter( + validation_config=validation_config, + consume_topic=consume_topic, + produce_topics=produce_topics, + relevance_function_name=relevance_function_name, + ) + worker.name = prefilter_name + worker.worker_id = worker_id + return worker + + +def run_prefilter_worker_process( + process_index, + threads_per_process, + validation_config, + consume_topic, + produce_topics, + relevance_function_name, + prefilter_name, +): + def worker_factory(worker_id): + return build_prefilter_worker( + validation_config=validation_config, + consume_topic=consume_topic, + produce_topics=produce_topics, + relevance_function_name=relevance_function_name, + prefilter_name=prefilter_name, + worker_id=worker_id, + ) + + run_thread_worker_pool( + worker_factory=worker_factory, + target_name="bootstrap_prefiltering_process", + module_name=module_name, + instance_name=prefilter_name, + process_index=process_index, + threads_per_process=threads_per_process, + ) + + async def main() -> None: """Creates and starts all configured Prefilter instances. @@ -361,14 +414,43 @@ async def main() -> None: for inspector in INSPECTORS if prefilter["name"] == inspector["prefilter_name"] ] - prefilter_instance = Prefilter( + + def worker_factory( + worker_id, validation_config=validation_config, consume_topic=consume_topic, produce_topics=produce_topics, relevance_function_name=relevance_function_name, + prefilter_name=prefilter["name"], + ): + return build_prefilter_worker( + validation_config=validation_config, + consume_topic=consume_topic, + produce_topics=produce_topics, + relevance_function_name=relevance_function_name, + prefilter_name=prefilter_name, + worker_id=worker_id, + ) + + tasks.append( + asyncio.create_task( + start_pipeline_worker_replicas( + config=config, + module_name=module_name, + instance_name=prefilter["name"], + worker_factory=worker_factory, + target_name="bootstrap_prefiltering_process", + process_entrypoint=run_prefilter_worker_process, + process_args=( + validation_config, + consume_topic, + produce_topics, + relevance_function_name, + prefilter["name"], + ), + ) + ) ) - prefilter_instance.name = prefilter["name"] - tasks.append(asyncio.create_task(prefilter_instance.start())) await asyncio.gather(*tasks) diff --git a/tests/alerter/__init__.py b/tests/alerter/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/tests/alerter/__init__.py @@ -0,0 +1 @@ + diff --git a/tests/alerter/test_alerter.py b/tests/alerter/test_alerter.py new file mode 100644 index 00000000..01bfd01d --- /dev/null +++ b/tests/alerter/test_alerter.py @@ -0,0 +1,84 @@ +import datetime +import tempfile +import unittest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from src.alerter.alerter import AlerterBase + + +class TestAlerter(AlerterBase): + def process_alert(self): + pass + + +class TestAlerterLogRotation(unittest.TestCase): + def _build_alerter(self, log_file_path: str, retention_days=7): + alerting_config = { + "log_to_file": True, + "log_file_path": log_file_path, + "log_rotation": { + "enabled": True, + "retention_days": retention_days, + }, + "log_to_kafka": False, + } + patches = [ + patch("src.alerter.alerter.ALERTING_CONFIG", alerting_config), + patch("src.alerter.alerter.ExactlyOnceKafkaConsumeHandler"), + patch("src.alerter.alerter.ClickHouseKafkaSender"), + ] + for active_patch in patches: + active_patch.start() + self.addCleanup(active_patch.stop) + + return TestAlerter({"name": "test"}, "alerts") + + def test_active_log_file_path_uses_current_day(self): + alerter = self._build_alerter("/tmp/alerts.txt") + + active_path = alerter._get_active_log_file_path( + datetime.datetime(2026, 6, 25, 12, 30) + ) + + self.assertEqual("/tmp/alerts-2026-06-25.txt", active_path) + + def test_cleanup_rotated_logs_keeps_configured_number_of_days(self): + with tempfile.TemporaryDirectory() as temp_dir: + base_log_path = Path(temp_dir) / "alerts.txt" + for filename in ( + "alerts-2026-06-22.txt", + "alerts-2026-06-23.txt", + "alerts-2026-06-24.txt", + "alerts-2026-06-25.txt", + "alerts-other.txt", + ): + (Path(temp_dir) / filename).write_text("{}\n") + + alerter = self._build_alerter(str(base_log_path), retention_days=3) + alerter.name = "test" + + alerter._cleanup_rotated_logs(today=datetime.date(2026, 6, 25)) + + self.assertFalse((Path(temp_dir) / "alerts-2026-06-22.txt").exists()) + self.assertTrue((Path(temp_dir) / "alerts-2026-06-23.txt").exists()) + self.assertTrue((Path(temp_dir) / "alerts-2026-06-24.txt").exists()) + self.assertTrue((Path(temp_dir) / "alerts-2026-06-25.txt").exists()) + self.assertTrue((Path(temp_dir) / "alerts-other.txt").exists()) + + def test_log_to_file_writes_to_rotated_file(self): + with tempfile.TemporaryDirectory() as temp_dir: + alerter = self._build_alerter(str(Path(temp_dir) / "alerts.txt")) + alerter.alert_data = {"src_ip": "192.0.2.1"} + alerter._cleanup_rotated_logs = MagicMock() + + with patch( + "src.alerter.alerter.datetime.datetime", + wraps=datetime.datetime, + ) as mock_datetime: + mock_datetime.now.return_value = datetime.datetime(2026, 6, 25, 12, 30) + alerter._log_to_file_action() + + log_file = Path(temp_dir) / "alerts-2026-06-25.txt" + self.assertEqual('{"src_ip": "192.0.2.1"}\n', log_file.read_text()) + diff --git a/tests/detector/test_detector.py b/tests/detector/test_detector.py index 94f8f89b..5eb7acc5 100644 --- a/tests/detector/test_detector.py +++ b/tests/detector/test_detector.py @@ -379,6 +379,128 @@ def test_save_warning( self.assertEqual("test-detector", alert["detector_name"]) self.assertEqual("192.168.1.1", alert["src_ip"]) self.assertEqual(2, len(alert["result"])) + self.assertNotIn("request", alert["result"][0]) + + @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") + @patch("src.detector.detector.ClickHouseKafkaSender") + @patch("src.detector.detector.DetectorBase._get_model") + def test_save_warning_with_nested_message_windows( + self, mock_get_model, mock_clickhouse, mock_kafka_consume_handler + ): + mock_get_model.return_value = (MagicMock(), MagicMock()) + mock_kafka_consume_handler.return_value = MagicMock() + + sut = TestDetector( + consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG + ) + request_window = [ + {"logline_id": "test_id_1", "domain_name": "one.example"}, + {"logline_id": "test_id_2", "domain_name": "two.example"}, + ] + sut.warnings = [ + { + "request": request_window, + "probability": 0.8765, + "model": "rf", + "sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", + } + ] + sut.parent_row_id = f"{uuid.uuid4()}-{uuid.uuid4()}" + sut.key = "192.168.1.1" + sut.suspicious_batch_id = uuid.uuid4() + sut.messages = [request_window] + sut.kafka_produce_handler = MagicMock() + + sut.send_warning() + + sut.kafka_produce_handler.produce.assert_called_once() + alert = json.loads(sut.kafka_produce_handler.produce.call_args.kwargs["data"]) + self.assertEqual(2, alert["result"][0]["request_count"]) + self.assertEqual( + ["one.example", "two.example"], + alert["result"][0]["domain_names"], + ) + self.assertNotIn("request", alert["result"][0]) + + @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") + @patch("src.detector.detector.ClickHouseKafkaSender") + @patch("src.detector.detector.DetectorBase._get_model") + def test_normalize_warning_for_storage_keeps_detector_output_separate( + self, mock_get_model, mock_clickhouse, mock_kafka_consume_handler + ): + mock_get_model.return_value = (MagicMock(), MagicMock()) + mock_kafka_consume_handler.return_value = MagicMock() + + sut = TestDetector( + consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG + ) + warning = { + "request": [ + { + "logline_id": "logline-1", + "server_message_id": "server-message-1", + "domain_name": "one.example", + }, + { + "logline_id": "logline-2", + "server_message_id": "server-message-2", + "domain_name": "two.example", + }, + ], + "probability": 0.8765, + "predicted_class": "cobaltstrike", + "attributes": [ + {"attribute": "cobaltstrike", "probability": 0.8765}, + ], + "name": "domainator_attributor", + "sha256": "021af76b2385ddbc76f6e3ad10feb0bb081f9cf05cff2e52333e31040bbf36cc", + } + + normalized = sut._normalize_warning_for_storage(warning) + + self.assertEqual("domainator_attributor", normalized["detector_name"]) + self.assertEqual(0.8765, normalized["score"]) + self.assertEqual("cobaltstrike", normalized["predicted_class"]) + self.assertEqual(["one.example", "two.example"], normalized["domains"]) + self.assertEqual(["logline-1", "logline-2"], normalized["logline_ids"]) + self.assertEqual( + ["server-message-1", "server-message-2"], + normalized["server_message_ids"], + ) + self.assertEqual(warning["request"], normalized["request"]) + self.assertNotIn("request", normalized["raw_detector_output"]) + self.assertEqual( + "cobaltstrike", + normalized["raw_detector_output"]["predicted_class"], + ) + + @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") + @patch("src.detector.detector.ClickHouseKafkaSender") + @patch("src.detector.detector.DetectorBase._get_model") + def test_normalize_warning_for_storage_handles_sparse_detector_output( + self, mock_get_model, mock_clickhouse, mock_kafka_consume_handler + ): + mock_get_model.return_value = (MagicMock(), MagicMock()) + mock_kafka_consume_handler.return_value = MagicMock() + + sut = TestDetector( + consume_topic="test_topic", detector_config=MINIMAL_DETECTOR_CONFIG + ) + normalized = sut._normalize_warning_for_storage( + { + "request_domain": "malicious.example", + "probability": 0.7, + "name": "sparse_detector", + } + ) + + self.assertEqual("sparse_detector", normalized["detector_name"]) + self.assertEqual(0.7, normalized["score"]) + self.assertEqual("", normalized["predicted_class"]) + self.assertEqual([], normalized["attributes"]) + self.assertEqual(["malicious.example"], normalized["domains"]) + self.assertEqual([], normalized["logline_ids"]) + self.assertIn("request_domain", normalized["raw_detector_output"]) @patch("src.detector.detector.ExactlyOnceKafkaConsumeHandler") @patch("src.detector.detector.ClickHouseKafkaSender") diff --git a/tests/detector/test_domainator_attributor.py b/tests/detector/test_domainator_attributor.py index da81dc29..2084f01f 100644 --- a/tests/detector/test_domainator_attributor.py +++ b/tests/detector/test_domainator_attributor.py @@ -2,6 +2,7 @@ import numpy as np import unittest from unittest.mock import MagicMock, patch, call +from pandas.testing import assert_frame_equal import os import sys @@ -11,6 +12,7 @@ from src.detector.plugins.domainator_attributor import DomainatorAttributor from src.base.data_classes.batch import Batch from src.detector.plugins.domainator_utils import ( + DOMAINATOR_FEATURE_COLUMNS, get_domainator_features ) @@ -91,6 +93,45 @@ def test_detect(self): sut.detect() self.assertNotEqual([], sut.warnings) + def test_detect_emits_for_attribution_class_other_than_index_one(self): + mock_kafka = MagicMock() + mock_ch = MagicMock() + sut = self._create_detector(mock_kafka, mock_ch) + sut.threshold = 0.5 + sut.labels = ["benign", "tool-A", "tool-B"] + for _ in range(0, 4, 1): + sut.messages.append(DEFAULT_DATA) + + with patch( + "src.detector.plugins.domainator_attributor.DomainatorAttributor.predict", + return_value=np.array([[0.01, 0.02, 0.97]]), + ): + sut.detect() + + self.assertNotEqual([], sut.warnings) + self.assertEqual(0.97, sut.warnings[0]["probability"]) + self.assertEqual( + [{"attribute": "tool-B", "probability": 0.97}], + sut.warnings[0]["attributes"], + ) + + def test_detect_warning_probability_is_numeric(self): + mock_kafka = MagicMock() + mock_ch = MagicMock() + sut = self._create_detector(mock_kafka, mock_ch) + sut.threshold = 0.5 + sut.labels = ["tool-A", "tool-B"] + for _ in range(0, 4, 1): + sut.messages.append(DEFAULT_DATA) + + with patch( + "src.detector.plugins.domainator_attributor.DomainatorAttributor.predict", + return_value=np.array([[0.01, 0.99]]), + ): + sut.detect() + + self.assertIsInstance(sut.warnings[0]["probability"], float) + def test_detect_message_list(self): mock_kafka = MagicMock() mock_ch = MagicMock() @@ -125,11 +166,37 @@ def test_predict_calls_model(self): # Verify the argument was correct called_features = detector.model.predict_proba.call_args[0][0] expected_features = get_domainator_features(["google.com", "google.com"]) - np.testing.assert_array_equal(called_features, expected_features) + assert_frame_equal(called_features, expected_features) # Verify prediction result np.testing.assert_array_equal(result, mock_prediction) + def test_predict_aligns_features_to_model_order(self): + """Test that predict reorders features to match sklearn fit-time order.""" + mock_kafka = MagicMock() + mock_ch = MagicMock() + detector = self._create_detector(mock_kafka, mock_ch) + + mock_prediction = np.array([[0.2, 0.8]]) + detector.model.predict_proba.return_value = mock_prediction + detector.model.feature_names_in_ = np.array( + list(reversed(DOMAINATOR_FEATURE_COLUMNS)) + ) + + message = [ + {"domain_name": "a.example.com"}, + {"domain_name": "b.example.com"}, + {"domain_name": "c.example.com"}, + ] + result = detector.predict(message) + + called_features = detector.model.predict_proba.call_args[0][0] + self.assertEqual( + called_features.columns.tolist(), + list(reversed(DOMAINATOR_FEATURE_COLUMNS)), + ) + np.testing.assert_array_equal(result, mock_prediction) + def test_get_features_basic_attributes(self): """Test basic label features calculation.""" mock_kafka = MagicMock() @@ -142,9 +209,9 @@ def test_get_features_basic_attributes(self): ) # Basic features: label_length, label_max, label_average - leven_dist = features[0][0] # Levenshtein distance - jaro_dist = features[0][1] # Jaro distance - lcs = features[0][6] # Longest common string + leven_dist = features.iloc[0, 0] # Levenshtein distance + jaro_dist = features.iloc[0, 1] # Jaro distance + lcs = features.iloc[0, 6] # Longest common string self.assertEqual(leven_dist, 0.75) self.assertAlmostEqual(jaro_dist, 0.833, 3) # Rounded to 3 decimal places @@ -160,23 +227,23 @@ def test_get_features_empty_domains(self): # Basic features self.assertEqual( - features[0][0], 1.0 + features.iloc[0, 0], 1.0 ) # Levenshtein distance of empty strings is 1 - self.assertEqual(features[0][1], 1.0) # Jaro distance of empty strings is 1 + self.assertEqual(features.iloc[0, 1], 1.0) # Jaro distance of empty strings is 1 self.assertEqual( - features[0][2], 1.0 + features.iloc[0, 2], 1.0 ) # Jaro distance on the reverse empty strings is 1 self.assertEqual( - features[0][3], 1.0 + features.iloc[0, 3], 1.0 ) # Jaro-Winkler distance of empty strings is 1 self.assertEqual( - features[0][4], 1.0 + features.iloc[0, 4], 1.0 ) # Jaro-Winkler distance on the reverse empty strings is 1 self.assertEqual( - features[0][5], 0.0 + features.iloc[0, 5], 0.0 ) # Longest common sequence of empty strings is 0 self.assertEqual( - features[0][6], 0.0 + features.iloc[0, 6], 0.0 ) # Longest common string of empty strings is 0 def test_get_features_single_same_character(self): @@ -189,23 +256,23 @@ def test_get_features_single_same_character(self): # Basic features self.assertEqual( - features[0][0], 1.0 + features.iloc[0, 0], 1.0 ) # Levenshtein distance of same strings is 1 - self.assertEqual(features[0][1], 1.0) # Jaro distance of same strings is 1 + self.assertEqual(features.iloc[0, 1], 1.0) # Jaro distance of same strings is 1 self.assertEqual( - features[0][2], 1.0 + features.iloc[0, 2], 1.0 ) # Jaro distance on the reverse same strings is 1 self.assertEqual( - features[0][3], 1.0 + features.iloc[0, 3], 1.0 ) # Jaro-Winkler distance of same strings is 1 self.assertEqual( - features[0][4], 1.0 + features.iloc[0, 4], 1.0 ) # Jaro-Winkler distance on the reverse same strings is 1 self.assertEqual( - features[0][5], 0.0 + features.iloc[0, 5], 0.0 ) # Longest common sequence of same strings is 0 self.assertEqual( - features[0][6], 0.0 + features.iloc[0, 6], 0.0 ) # Longest common string of same strings is 0 def test_get_features_feature_vector_shape(self): @@ -221,6 +288,7 @@ def test_get_features_feature_vector_shape(self): expected_entropy = 7 self.assertEqual(features.shape, (1, expected_entropy)) + self.assertEqual(features.columns.tolist(), DOMAINATOR_FEATURE_COLUMNS) def test_get_features_case_insensitivity(self): """Test that the statistical comparison is case-insensitive.""" @@ -237,7 +305,7 @@ def test_get_features_case_insensitivity(self): # The comparison features should be identical regardless of case np.testing.assert_array_almost_equal( - features_upper[0][0:], - features_lower[0][0:], + features_upper.to_numpy()[0], + features_lower.to_numpy()[0], decimal=5, - ) \ No newline at end of file + ) diff --git a/tests/detector/test_domainator_detector.py b/tests/detector/test_domainator_detector.py index ce92a317..4bbdffac 100644 --- a/tests/detector/test_domainator_detector.py +++ b/tests/detector/test_domainator_detector.py @@ -2,6 +2,7 @@ import numpy as np import unittest from unittest.mock import MagicMock, patch, call +from pandas.testing import assert_frame_equal import os import sys @@ -11,6 +12,7 @@ from src.detector.plugins.domainator_detector import DomainatorDetector from src.base.data_classes.batch import Batch from src.detector.plugins.domainator_utils import ( + DOMAINATOR_FEATURE_COLUMNS, get_domainator_features ) @@ -110,11 +112,37 @@ def test_predict_calls_model(self): # Verify the argument was correct called_features = detector.model.predict_proba.call_args[0][0] expected_features = get_domainator_features(["google.com", "google.com"]) - np.testing.assert_array_equal(called_features, expected_features) + assert_frame_equal(called_features, expected_features) # Verify prediction result np.testing.assert_array_equal(result, mock_prediction) + def test_predict_aligns_features_to_model_order(self): + """Test that predict reorders features to match sklearn fit-time order.""" + mock_kafka = MagicMock() + mock_ch = MagicMock() + detector = self._create_detector(mock_kafka, mock_ch) + + mock_prediction = np.array([[0.2, 0.8]]) + detector.model.predict_proba.return_value = mock_prediction + detector.model.feature_names_in_ = np.array( + list(reversed(DOMAINATOR_FEATURE_COLUMNS)) + ) + + message = [ + {"domain_name": "a.example.com"}, + {"domain_name": "b.example.com"}, + {"domain_name": "c.example.com"}, + ] + result = detector.predict(message) + + called_features = detector.model.predict_proba.call_args[0][0] + self.assertEqual( + called_features.columns.tolist(), + list(reversed(DOMAINATOR_FEATURE_COLUMNS)), + ) + np.testing.assert_array_equal(result, mock_prediction) + def test_get_features_basic_attributes(self): """Test basic label features calculation.""" mock_kafka = MagicMock() @@ -127,9 +155,9 @@ def test_get_features_basic_attributes(self): ) # Basic features: label_length, label_max, label_average - leven_dist = features[0][0] # Levenshtein distance - jaro_dist = features[0][1] # Jaro distance - lcs = features[0][6] # Longest common string + leven_dist = features.iloc[0, 0] # Levenshtein distance + jaro_dist = features.iloc[0, 1] # Jaro distance + lcs = features.iloc[0, 6] # Longest common string self.assertEqual(leven_dist, 0.75) self.assertAlmostEqual(jaro_dist, 0.833, 3) # Rounded to 3 decimal places @@ -145,23 +173,23 @@ def test_get_features_empty_domains(self): # Basic features self.assertEqual( - features[0][0], 1.0 + features.iloc[0, 0], 1.0 ) # Levenshtein distance of empty strings is 1 - self.assertEqual(features[0][1], 1.0) # Jaro distance of empty strings is 1 + self.assertEqual(features.iloc[0, 1], 1.0) # Jaro distance of empty strings is 1 self.assertEqual( - features[0][2], 1.0 + features.iloc[0, 2], 1.0 ) # Jaro distance on the reverse empty strings is 1 self.assertEqual( - features[0][3], 1.0 + features.iloc[0, 3], 1.0 ) # Jaro-Winkler distance of empty strings is 1 self.assertEqual( - features[0][4], 1.0 + features.iloc[0, 4], 1.0 ) # Jaro-Winkler distance on the reverse empty strings is 1 self.assertEqual( - features[0][5], 0.0 + features.iloc[0, 5], 0.0 ) # Longest common sequence of empty strings is 0 self.assertEqual( - features[0][6], 0.0 + features.iloc[0, 6], 0.0 ) # Longest common string of empty strings is 0 def test_get_features_single_same_character(self): @@ -174,23 +202,23 @@ def test_get_features_single_same_character(self): # Basic features self.assertEqual( - features[0][0], 1.0 + features.iloc[0, 0], 1.0 ) # Levenshtein distance of same strings is 1 - self.assertEqual(features[0][1], 1.0) # Jaro distance of same strings is 1 + self.assertEqual(features.iloc[0, 1], 1.0) # Jaro distance of same strings is 1 self.assertEqual( - features[0][2], 1.0 + features.iloc[0, 2], 1.0 ) # Jaro distance on the reverse same strings is 1 self.assertEqual( - features[0][3], 1.0 + features.iloc[0, 3], 1.0 ) # Jaro-Winkler distance of same strings is 1 self.assertEqual( - features[0][4], 1.0 + features.iloc[0, 4], 1.0 ) # Jaro-Winkler distance on the reverse same strings is 1 self.assertEqual( - features[0][5], 0.0 + features.iloc[0, 5], 0.0 ) # Longest common sequence of same strings is 0 self.assertEqual( - features[0][6], 0.0 + features.iloc[0, 6], 0.0 ) # Longest common string of same strings is 0 def test_get_features_feature_vector_shape(self): @@ -206,6 +234,7 @@ def test_get_features_feature_vector_shape(self): expected_entropy = 7 self.assertEqual(features.shape, (1, expected_entropy)) + self.assertEqual(features.columns.tolist(), DOMAINATOR_FEATURE_COLUMNS) def test_get_features_case_insensitivity(self): """Test that the statistical comparison is case-insensitive.""" @@ -222,7 +251,7 @@ def test_get_features_case_insensitivity(self): # The comparison features should be identical regardless of case np.testing.assert_array_almost_equal( - features_upper[0][0:], - features_lower[0][0:], + features_upper.to_numpy()[0], + features_lower.to_numpy()[0], decimal=5, - ) \ No newline at end of file + ) diff --git a/tests/inspector/test_inspector.py b/tests/inspector/test_inspector.py index 2a65097c..e5f288ff 100644 --- a/tests/inspector/test_inspector.py +++ b/tests/inspector/test_inspector.py @@ -643,16 +643,20 @@ def setUp(self): ] @patch("src.inspector.inspector.logger") - @patch("src.inspector.plugins.no_inspector.NoInspector") + @patch( + "src.inspector.inspector.start_pipeline_worker_replicas", + new_callable=AsyncMock, + ) @patch("asyncio.create_task") @patch("asyncio.run") async def test_main_succesful_start( - self, mock_asyncio_run, mock_asyncio_create_task, mock_inspector, mock_logger + self, + mock_asyncio_run, + mock_asyncio_create_task, + mock_start_workers, + mock_logger, ): # Arrange - mock_inspector_instance = MagicMock() - mock_inspector_instance.start = AsyncMock() - mock_inspector.return_value = mock_inspector_instance mock_asyncio_create_task.side_effect = lambda coro: coro # Act @@ -660,7 +664,7 @@ async def test_main_succesful_start( await main() # Assert - mock_inspector_instance.start.assert_called_once() + mock_start_workers.assert_awaited_once() if __name__ == "__main__": diff --git a/tests/kafka/test_exactly_once_kafka_consume_handler.py b/tests/kafka/test_exactly_once_kafka_consume_handler.py index 440ee32b..55a5c431 100644 --- a/tests/kafka/test_exactly_once_kafka_consume_handler.py +++ b/tests/kafka/test_exactly_once_kafka_consume_handler.py @@ -10,7 +10,7 @@ from src.base.data_classes.batch import Batch from src.base.kafka_handler import ExactlyOnceKafkaConsumeHandler -CONSUMER_GROUP_ID = "default_gid" +CONSUMER_GROUP_ID = "default_gid.test_topic" class TestInit(unittest.TestCase): @@ -47,6 +47,7 @@ def test_init(self, mock_consumer, mock_admin_client, mock_all_topics_created): "enable.auto.commit": False, "auto.offset.reset": "earliest", "enable.partition.eof": True, + "max.poll.interval.ms": 1800000, } sut = ExactlyOnceKafkaConsumeHandler(topics="test_topic") @@ -89,6 +90,7 @@ def test_init_fail(self, mock_consumer, mock_admin_client, mock_all_topics_creat "enable.auto.commit": False, "auto.offset.reset": "earliest", "enable.partition.eof": True, + "max.poll.interval.ms": 1800000, } with patch.object( @@ -192,9 +194,13 @@ def test_message_processing(self): except StopIteration: pass - self.sut.consumer.commit.assert_called_once() + self.sut.consumer.commit.assert_not_called() self.assertEqual((key, value, topic), result) + self.sut.commit() + + self.sut.consumer.commit.assert_called_once_with(msg) + def test_consumer_raises_keyboard_interrupt(self): self.sut.consumer.poll.side_effect = [KeyboardInterrupt] @@ -407,6 +413,36 @@ def test_consume_as_object_valid_data_with_inner_strings(self): self.assertEqual(result[0], key) self.assertIsInstance(result[1], Batch) + def test_consume_as_object_valid_data_with_nested_request_windows(self): + key = "valid_key" + batch_schema = marshmallow_dataclass.class_schema(Batch)() + value = batch_schema.dumps( + { + "batch_id": uuid.uuid4(), + "batch_tree_row_id": uuid.uuid4(), + "begin_timestamp": datetime.datetime.now(), + "end_timestamp": datetime.datetime.now(), + "data": [ + [ + {"domain_name": "one.example.org"}, + {"domain_name": "two.example.org"}, + ] + ], + } + ) + topic = "test_topic" + + with patch( + "src.base.kafka_handler.ExactlyOnceKafkaConsumeHandler.consume" + ) as mock_consume: + mock_consume.return_value = [key, value, topic] + + result = self.sut.consume_as_object() + + self.assertEqual(result[0], key) + self.assertIsInstance(result[1], Batch) + self.assertEqual("one.example.org", result[1].data[0][0]["domain_name"]) + def test_consume_as_object_invalid_data(self): key = "invalid_key" value = json.dumps( diff --git a/tests/kafka/test_kafka_consume_handler.py b/tests/kafka/test_kafka_consume_handler.py index cc36e66b..0f258ebd 100644 --- a/tests/kafka/test_kafka_consume_handler.py +++ b/tests/kafka/test_kafka_consume_handler.py @@ -3,12 +3,171 @@ from unittest.mock import patch, MagicMock from src.base.kafka_handler import ( + build_consumer_group_id, + ensure_topics, KafkaConsumeHandler, KafkaMessageFetchException, TooManyFailedAttemptsError, + _desired_topic_partitions, + _topic_replication_factor, + _topic_config, ) +def _metadata(partitions_by_topic: dict[str, int]): + metadata = MagicMock() + metadata.topics = {} + for topic, partition_count in partitions_by_topic.items(): + topic_metadata = MagicMock() + topic_metadata.partitions = { + partition: MagicMock() for partition in range(partition_count) + } + metadata.topics[topic] = topic_metadata + return metadata + + +class TestConsumerGroupId(unittest.TestCase): + @patch("src.base.kafka_handler.CONSUMER_GROUP_ID", "test_group_id") + def test_build_consumer_group_id_for_single_topic(self): + self.assertEqual( + "test_group_id.test_topic", + build_consumer_group_id("test_topic"), + ) + + @patch("src.base.kafka_handler.CONSUMER_GROUP_ID", "test_group_id") + def test_build_consumer_group_id_for_multiple_topics(self): + self.assertEqual( + "test_group_id.test_topic_1__test_topic_2", + build_consumer_group_id(["test_topic_2", "test_topic_1"]), + ) + + +class TestTopicReconciliation(unittest.TestCase): + @patch("src.base.kafka_handler.NewTopic") + def test_missing_topic_is_created_with_target_partitions(self, mock_new_topic): + admin_client = MagicMock() + admin_client.list_topics.side_effect = [ + _metadata({}), + _metadata({"test_topic": 4}), + ] + admin_client.create_topics.return_value = {"test_topic": MagicMock()} + mock_new_topic.side_effect = ( + lambda topic, partitions, replication_factor: ( + topic, + partitions, + replication_factor, + ) + ) + + target_partitions_by_topic = ensure_topics( + admin_client, + ["test_topic"], + target_partitions=4, + replication_factor=2, + ) + + self.assertEqual({"test_topic": 4}, target_partitions_by_topic) + admin_client.create_topics.assert_called_once_with([("test_topic", 4, 2)]) + admin_client.create_partitions.assert_not_called() + + @patch("src.base.kafka_handler.NewPartitions") + def test_existing_topic_with_too_few_partitions_is_expanded( + self, mock_new_partitions + ): + admin_client = MagicMock() + admin_client.list_topics.side_effect = [ + _metadata({"test_topic": 2}), + _metadata({"test_topic": 2}), + ] + admin_client.create_partitions.return_value = {"test_topic": MagicMock()} + mock_new_partitions.side_effect = ( + lambda topic, total_count: (topic, total_count) + ) + + ensure_topics( + admin_client, + ["test_topic"], + target_partitions=4, + replication_factor=2, + ) + + admin_client.create_topics.assert_not_called() + admin_client.create_partitions.assert_called_once_with([("test_topic", 4)]) + + def test_existing_topic_with_enough_partitions_is_left_unchanged(self): + admin_client = MagicMock() + admin_client.list_topics.side_effect = [ + _metadata({"test_topic": 8}), + _metadata({"test_topic": 8}), + ] + + ensure_topics( + admin_client, + ["test_topic"], + target_partitions=4, + replication_factor=2, + ) + + admin_client.create_topics.assert_not_called() + admin_client.create_partitions.assert_not_called() + + def test_auto_expand_can_be_disabled(self): + admin_client = MagicMock() + admin_client.list_topics.return_value = _metadata({"test_topic": 2}) + + ensure_topics( + admin_client, + ["test_topic"], + target_partitions=4, + replication_factor=2, + auto_expand_partitions=False, + ) + + admin_client.create_topics.assert_not_called() + admin_client.create_partitions.assert_not_called() + + @patch("src.base.kafka_handler.NUMBER_OF_INSTANCES", 6) + @patch("src.base.kafka_handler.KAFKA_TOPIC_DEFAULT_PARTITIONS", 3) + def test_desired_partitions_uses_highest_scale_value(self): + self.assertEqual(6, _desired_topic_partitions()) + + @patch("src.base.kafka_handler.KAFKA_TOPIC_REPLICATION_FACTOR", 3) + @patch( + "src.base.kafka_handler.KAFKA_BROKERS", + [{"hostname": "127.0.0.1", "internal_port": 9999}], + ) + def test_replication_factor_is_capped_to_configured_brokers(self): + self.assertEqual(1, _topic_replication_factor()) + + @patch( + "src.base.kafka_handler.KAFKA_PIPELINE_TOPIC_PREFIXES", + {"inspector_to_detector": "pipeline-inspector_to_detector"}, + ) + @patch( + "src.base.kafka_handler.KAFKA_TOPIC_STAGE_CONFIG", + {"inspector_to_detector": {"partitions": 7, "replication_factor": 2}}, + ) + def test_stage_topic_config_is_resolved_from_prefix(self): + topic = "pipeline-inspector_to_detector-domainator" + self.assertEqual( + {"partitions": 7, "replication_factor": 2}, _topic_config(topic) + ) + self.assertEqual(7, _desired_topic_partitions(topic)) + self.assertEqual(2, _topic_replication_factor(topic)) + + @patch( + "src.base.kafka_handler.KAFKA_TOPIC_EXACT_CONFIG", + {"hamstring_alerts": {"partitions": 5, "replication_factor": 2}}, + ) + def test_exact_topic_config_wins(self): + self.assertEqual( + {"partitions": 5, "replication_factor": 2}, + _topic_config("hamstring_alerts"), + ) + self.assertEqual(5, _desired_topic_partitions("hamstring_alerts")) + self.assertEqual(2, _topic_replication_factor("hamstring_alerts")) + + class TestInit(unittest.TestCase): @patch("src.base.kafka_handler.CONSUMER_GROUP_ID", "test_group_id") @patch( @@ -43,10 +202,11 @@ def test_init_successful( expected_conf = { "bootstrap.servers": "127.0.0.1:9999,127.0.0.2:9998,127.0.0.3:9997", - "group.id": "test_group_id", + "group.id": "test_group_id.test_topic", "enable.auto.commit": False, "auto.offset.reset": "earliest", "enable.partition.eof": True, + "max.poll.interval.ms": 1800000, } # Act @@ -91,10 +251,11 @@ def test_init_unsuccessful( expected_conf = { "bootstrap.servers": "127.0.0.1:9999,127.0.0.2:9998,127.0.0.3:9997", - "group.id": "test_group_id", + "group.id": "test_group_id.test_topic", "enable.auto.commit": False, "auto.offset.reset": "earliest", "enable.partition.eof": True, + "max.poll.interval.ms": 1800000, } # Act @@ -138,10 +299,11 @@ def test_init_successful_with_list( expected_conf = { "bootstrap.servers": "127.0.0.1:9999,127.0.0.2:9998,127.0.0.3:9997", - "group.id": "test_group_id", + "group.id": "test_group_id.test_topic_1__test_topic_2", "enable.auto.commit": False, "auto.offset.reset": "earliest", "enable.partition.eof": True, + "max.poll.interval.ms": 1800000, } # Act @@ -324,15 +486,16 @@ def setUp(self, mock_consumer, mock_admin_client, mock_all_topics_created): @patch("src.base.kafka_handler.Consumer") def test_with_all_created(self, mock_consumer): # Arrange - mock_topics = MagicMock() - mock_topics.topics = ["test_topic", "another_topic"] - self.sut.consumer = MagicMock() - self.sut.consumer.list_topics.return_value = mock_topics + self.sut.consumer.list_topics.return_value = _metadata( + {"test_topic": 3, "another_topic": 3} + ) # Act and Assert self.assertTrue( - self.sut._all_topics_created(topics=["test_topic", "another_topic"]) + self.sut._all_topics_created( + topics=["test_topic", "another_topic"], min_partitions=3 + ) ) @patch("src.base.kafka_handler.time.sleep") @@ -350,6 +513,22 @@ def test_with_none_created(self, mock_consumer, mock_sleep): self.sut._all_topics_created(topics=["test_topic", "another_topic"]) ) + @patch("src.base.kafka_handler.time.sleep") + @patch("src.base.kafka_handler.Consumer") + def test_with_too_few_partitions(self, mock_consumer, mock_sleep): + # Arrange + self.sut.consumer = MagicMock() + self.sut.consumer.list_topics.return_value = _metadata( + {"test_topic": 1, "another_topic": 3} + ) + + # Act and Assert + self.assertFalse( + self.sut._all_topics_created( + topics=["test_topic", "another_topic"], min_partitions=3 + ) + ) + if __name__ == "__main__": unittest.main() diff --git a/tests/kafka/test_simple_kafka_consume_handler.py b/tests/kafka/test_simple_kafka_consume_handler.py index 41900480..eb454e9b 100644 --- a/tests/kafka/test_simple_kafka_consume_handler.py +++ b/tests/kafka/test_simple_kafka_consume_handler.py @@ -38,10 +38,11 @@ def test_init_successful( expected_conf = { "bootstrap.servers": "127.0.0.1:9999,127.0.0.2:9998,127.0.0.3:9997", - "group.id": "test_group_id", + "group.id": "test_group_id.test_topic", "enable.auto.commit": False, "auto.offset.reset": "earliest", "enable.partition.eof": True, + "max.poll.interval.ms": 1800000, } # Act diff --git a/tests/logcollector/test_collector.py b/tests/logcollector/test_collector.py index c5a47b84..967f9d53 100644 --- a/tests/logcollector/test_collector.py +++ b/tests/logcollector/test_collector.py @@ -138,6 +138,7 @@ def fetch_wrapper(*args, **kwargs): self.sut.fetch() mock_send.assert_called_once() + mock_consume_handler.commit.assert_called_once() class TestSend(unittest.TestCase): @@ -174,6 +175,29 @@ def test_valid_logline(self): # Assert self.sut.batch_handler.add_message.assert_not_called() + def test_invalid_logline_records_failed_terminal_event_for_logserver_message(self): + timestamp = datetime.datetime(2026, 2, 14, 16, 38, 6, 184006) + message = "test_message" + server_message_id = uuid.UUID("bd72ccb4-0ef2-4100-aa22-e787122d6875") + + self.sut.failed_protocol_loglines = MagicMock() + self.sut.server_log_terminal_events = MagicMock() + self.sut.logline_handler.validate_logline_and_get_fields_as_json.side_effect = [ + ValueError + ] + + self.sut.send( + timestamp_in=timestamp, + message=message, + server_message_id=str(server_message_id), + ) + + self.sut.server_log_terminal_events.insert.assert_called_once() + terminal_event = self.sut.server_log_terminal_events.insert.call_args.args[0] + self.assertEqual(server_message_id, terminal_event["message_id"]) + self.assertEqual("log_collection.collector", terminal_event["stage"]) + self.assertEqual("failed", terminal_event["status"]) + def test_invalid_logline(self): timestamp = datetime.datetime(2026, 2, 14, 16, 38, 6, 184006) message = "test_message" @@ -181,6 +205,7 @@ def test_invalid_logline(self): # Arrange mock_logline_handler = Mock() self.sut.logline_handler = mock_logline_handler.return_value + self.sut.server_log_to_logline = MagicMock() self.sut.logline_handler.validate_logline_and_get_fields_as_json.return_value = { "ts": str(timestamp), "status_code": "test_status", @@ -195,12 +220,22 @@ def test_invalid_logline(self): return_value=uuid.UUID("da3aec7f-b355-4a2c-a2f4-2066d49431a5"), ), ): - self.sut.send(timestamp_in=timestamp, message=message) + self.sut.send( + timestamp_in=timestamp, + message=message, + server_message_id="bd72ccb4-0ef2-4100-aa22-e787122d6875", + ) # Assert self.sut.batch_handler.add_message.assert_called_once_with( "192.168.3.0_24", - '{"ts": "2026-02-14 16:38:06.184006", "status_code": "test_status", "src_ip": "192.168.3.141", "record_type": "test_record_type", "logline_id": "da3aec7f-b355-4a2c-a2f4-2066d49431a5"}', + '{"ts": "2026-02-14 16:38:06.184006", "status_code": "test_status", "src_ip": "192.168.3.141", "record_type": "test_record_type", "logline_id": "da3aec7f-b355-4a2c-a2f4-2066d49431a5", "server_message_id": "bd72ccb4-0ef2-4100-aa22-e787122d6875"}', + ) + self.sut.server_log_to_logline.insert.assert_called_once_with( + { + "message_id": uuid.UUID("bd72ccb4-0ef2-4100-aa22-e787122d6875"), + "logline_id": uuid.UUID("da3aec7f-b355-4a2c-a2f4-2066d49431a5"), + } ) @@ -474,23 +509,27 @@ def setUp(self): ] @patch("src.logcollector.collector.logger") - @patch("src.logcollector.collector.LogCollector") + @patch( + "src.logcollector.collector.start_pipeline_worker_replicas", + new_callable=AsyncMock, + ) @patch("asyncio.create_task") @patch("asyncio.run") async def test_main( - self, mock_asyncio_run, mock_asyncio_create_task, mock_instance, mock_logger + self, + mock_asyncio_run, + mock_asyncio_create_task, + mock_start_workers, + mock_logger, ): # Arrange - mock_instance_obj = MagicMock() - mock_instance.return_value = mock_instance_obj - mock_instance_obj.start = AsyncMock() mock_asyncio_create_task.side_effect = lambda coro: coro with patch("src.logcollector.collector.COLLECTORS", self.cs): await main() - mock_instance_obj.start.assert_called_once() + mock_start_workers.assert_awaited_once() args, kwargs = mock_asyncio_create_task.call_args_list[0] expected_call = args[0] mock_asyncio_create_task.assert_called_once_with(expected_call) diff --git a/tests/logserver/test_server.py b/tests/logserver/test_server.py index fffab2ae..c59fc44f 100644 --- a/tests/logserver/test_server.py +++ b/tests/logserver/test_server.py @@ -96,13 +96,16 @@ def test_send( message = "test_message" sut = LogServer(consume_topic="consume_topic1", produce_topics=["test_topic"]) + message_id = uuid.UUID("bd72ccb4-0ef2-4100-aa22-e787122d6875") + # Act - sut.send(uuid.uuid4(), message) + sut.send(message_id, message) # Assert mock_kafka_produce_handler_instance.produce.assert_called_once_with( topic="test_topic", data=message, + key=str(message_id), ) @@ -156,6 +159,7 @@ def fetch_wrapper(*args, **kwargs): mock_send.assert_called_once_with( UUID("bd72ccb4-0ef2-4100-aa22-e787122d6875"), "value1" ) + mock_consume_handler.commit.assert_called_once() # class TestFetchFromFile(unittest.IsolatedAsyncioTestCase): @@ -208,7 +212,7 @@ def fetch_wrapper(*args, **kwargs): class TestMain(unittest.IsolatedAsyncioTestCase): @patch("src.logserver.server.logger") - @patch("src.logserver.server.LogServer") + @patch("src.logserver.server.start_pipeline_worker_replicas", new_callable=AsyncMock) @patch("asyncio.create_task") @patch("asyncio.run") @patch("src.logserver.server.SENSOR_PROTOCOLS", ["dns"]) @@ -219,35 +223,37 @@ class TestMain(unittest.IsolatedAsyncioTestCase): [{"name": "test-collector", "protocol_base": "dns"}], ) async def test_main( - self, mock_asyncio_run, mock_asyncio_create_task, mock_instance, mock_logger + self, + mock_asyncio_run, + mock_asyncio_create_task, + mock_start_workers, + mock_logger, ): # Arrange - mock_instance_obj = MagicMock() - mock_instance.return_value = mock_instance_obj - mock_instance_obj.start = AsyncMock() mock_asyncio_create_task.side_effect = lambda coro: coro # Act await main() # Assert - mock_instance_obj.start.assert_called_once() + mock_start_workers.assert_awaited_once() args, kwargs = mock_asyncio_create_task.call_args_list[0] expected_call = args[0] mock_asyncio_create_task.assert_called_once_with(expected_call) @patch("src.logserver.server.SENSOR_PROTOCOLS", ["dns", "http"]) @patch("src.logserver.server.logger") - @patch("src.logserver.server.LogServer") + @patch("src.logserver.server.start_pipeline_worker_replicas", new_callable=AsyncMock) @patch("asyncio.create_task") @patch("asyncio.run") async def test_main_multiple_protocols( - self, mock_asyncio_run, mock_asyncio_create_task, mock_instance, mock_logger + self, + mock_asyncio_run, + mock_asyncio_create_task, + mock_start_workers, + mock_logger, ): # Arrange - mock_instance_obj = MagicMock() - mock_instance.return_value = mock_instance_obj - mock_instance_obj.start = AsyncMock() mock_asyncio_create_task.side_effect = lambda coro: coro mock_asyncio_run.side_effect = RuntimeError("simulated failure") @@ -258,6 +264,7 @@ async def test_main_multiple_protocols( args, kwargs = mock_asyncio_create_task.call_args_list[0] expected_call = args[0] assert mock_asyncio_create_task.call_count == 2 + assert mock_start_workers.await_count == 2 if __name__ == "__main__": diff --git a/tests/miscellaneous/test_execution.py b/tests/miscellaneous/test_execution.py index c4ccedd1..dc495908 100644 --- a/tests/miscellaneous/test_execution.py +++ b/tests/miscellaneous/test_execution.py @@ -1,9 +1,13 @@ import unittest +import asyncio +import os from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from unittest.mock import patch from src.base.execution import ( create_pipeline_executor, get_pipeline_executor_config, + start_pipeline_worker_replicas, ) @@ -24,6 +28,8 @@ def test_defaults_are_used_when_module_is_missing(self): self.assertEqual("thread", result.executor) self.assertEqual(3, result.max_workers) + self.assertEqual(1, result.processes) + self.assertEqual(3, result.threads_per_process) def test_module_overrides_defaults(self): config = { @@ -44,6 +50,8 @@ def test_module_overrides_defaults(self): self.assertEqual("process", result.executor) self.assertEqual(4, result.max_workers) + self.assertEqual(4, result.processes) + self.assertEqual(1, result.threads_per_process) def test_instance_overrides_module(self): config = { @@ -70,6 +78,50 @@ def test_instance_overrides_module(self): self.assertEqual("thread", result.executor) self.assertEqual(5, result.max_workers) + self.assertEqual(1, result.processes) + self.assertEqual(5, result.threads_per_process) + + def test_hybrid_executor_uses_processes_and_threads(self): + config = { + "pipeline": { + "scaling": { + "modules": { + "data_analysis.detector": { + "executor": "hybrid", + "processes": 2, + "threads_per_process": 4, + } + } + } + } + } + + result = get_pipeline_executor_config(config, "data_analysis.detector") + + self.assertEqual("hybrid", result.executor) + self.assertEqual(2, result.processes) + self.assertEqual(4, result.threads_per_process) + self.assertEqual(8, result.max_workers) + + def test_processes_and_threads_infer_hybrid(self): + config = { + "pipeline": { + "scaling": { + "modules": { + "data_analysis.detector": { + "processes": 2, + "threads": 4, + } + } + } + } + } + + result = get_pipeline_executor_config(config, "data_analysis.detector") + + self.assertEqual("hybrid", result.executor) + self.assertEqual(2, result.processes) + self.assertEqual(4, result.threads_per_process) def test_process_executor_is_created(self): config = { @@ -111,3 +163,41 @@ def test_invalid_worker_count_raises(self): with self.assertRaises(ValueError): get_pipeline_executor_config(config, "log_collection.collector") + + def test_thread_worker_replicas_create_one_worker_per_thread(self): + config = { + "pipeline": { + "scaling": { + "modules": { + "log_filtering.prefilter": { + "executor": "thread", + "threads": 3, + } + } + } + } + } + worker_ids = [] + + class Worker: + def __init__(self, worker_id): + self.worker_id = worker_id + + def run_once(self): + worker_ids.append(self.worker_id) + + async def run_workers(): + await start_pipeline_worker_replicas( + config=config, + module_name="log_filtering.prefilter", + instance_name=None, + worker_factory=lambda worker_id: Worker(worker_id), + target_name="run_once", + ) + + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("KAFKA_TOPIC_MIN_PARTITIONS", None) + asyncio.run(run_workers()) + self.assertEqual("3", os.environ["KAFKA_TOPIC_MIN_PARTITIONS"]) + + self.assertEqual(["p0-t0", "p0-t1", "p0-t2"], sorted(worker_ids)) diff --git a/tests/prefilter/test_prefilter.py b/tests/prefilter/test_prefilter.py index a407a3dd..2b648de9 100644 --- a/tests/prefilter/test_prefilter.py +++ b/tests/prefilter/test_prefilter.py @@ -857,26 +857,26 @@ def setUp(self): # mock_prefilter_instance.clear_data.assert_called() @patch("src.prefilter.prefilter.logger") - @patch("src.prefilter.prefilter.Prefilter") + @patch( + "src.prefilter.prefilter.start_pipeline_worker_replicas", + new_callable=AsyncMock, + ) @patch("asyncio.create_task") @patch("asyncio.run") async def test_main_normal_flow( self, mock_asyncio_run, mock_asyncio_create_task, - mock_prefilter_cls, + mock_start_workers, mock_logger, ): # Arrange - mock_prefilter_instance = MagicMock() - mock_prefilter_instance.start = AsyncMock() - mock_prefilter_cls.return_value = mock_prefilter_instance mock_asyncio_create_task.side_effect = lambda coro: coro with patch("src.prefilter.prefilter.PREFILTERS", self.pf): await main() - mock_prefilter_instance.start.assert_called_once() + mock_start_workers.assert_awaited_once() # @patch("src.prefilter.prefilter.logger") # @patch("src.prefilter.prefilter.Prefilter")