diff --git a/modules/reporting/gcs.py b/modules/reporting/gcs.py index 1af389d0a12..301213fe4a3 100644 --- a/modules/reporting/gcs.py +++ b/modules/reporting/gcs.py @@ -23,6 +23,21 @@ class GCSUploader: """Helper class to upload files to GCS.""" + @staticmethod + def parse_custom_string(custom_str): + if not custom_str: + return {} + + if custom_str.endswith("..."): + custom_str = custom_str[:-3] + parts = custom_str.split(",") + data = {} + for part in parts: + if ":" in part: + key, value = part.split(":", 1) + data[key] = value + return data + def __init__(self, bucket_name=None, auth_by=None, credentials_path=None, exclude_dirs=None, exclude_files=None, mode=None): if not HAVE_GCS: raise ImportError("google-cloud-storage library is missing") @@ -31,9 +46,9 @@ def __init__(self, bucket_name=None, auth_by=None, credentials_path=None, exclud if not bucket_name: cfg = Config("reporting") if not cfg.gcs.enabled: - # If we are initializing purely for manual usage but config is disabled, we might want to allow it if params are passed. - # But if params are missing AND config is disabled/missing, we can't proceed. - pass + # If we are initializing purely for manual usage but config is disabled, we might want to allow it if params are passed. + # But if params are missing AND config is disabled/missing, we can't proceed. + pass bucket_name = cfg.gcs.bucket_name auth_by = cfg.gcs.auth_by @@ -44,7 +59,7 @@ def __init__(self, bucket_name=None, auth_by=None, credentials_path=None, exclud exclude_dirs_str = cfg.gcs.get("exclude_dirs", "") exclude_files_str = cfg.gcs.get("exclude_files", "") - mode = cfg.gcs.get("mode", "file") + mode = cfg.gcs.get("mode", "zip") # Parse exclusion sets self.exclude_dirs = {item.strip() for item in exclude_dirs_str.split(",") if item.strip()} @@ -56,7 +71,7 @@ def __init__(self, bucket_name=None, auth_by=None, credentials_path=None, exclud self.mode = mode if not bucket_name: - raise ValueError("GCS bucket_name is not configured.") + raise ValueError("GCS bucket_name is not configured.") if auth_by == "vm": self.storage_client = storage.Client() @@ -90,13 +105,13 @@ def _iter_files_to_upload(self, source_directory): relative_path = os.path.relpath(local_path, source_directory) yield local_path, relative_path - def upload(self, source_directory, analysis_id, tlp=None): + def upload(self, source_directory, analysis_id, tlp=None, metadata=None): if self.mode == "zip": - self.upload_zip_archive(analysis_id, source_directory, tlp=tlp) + self.upload_zip_archive(analysis_id, source_directory, tlp=tlp, metadata=metadata) else: - self.upload_files_individually(analysis_id, source_directory, tlp=tlp) + self.upload_files_individually(analysis_id, source_directory, tlp=tlp, metadata=metadata) - def upload_zip_archive(self, analysis_id, source_directory, tlp=None): + def upload_zip_archive(self, analysis_id, source_directory, tlp=None, metadata=None): log.debug("Compressing and uploading files for analysis ID %s to GCS", analysis_id) blob_name = f"{analysis_id}_tlp_{tlp}.zip" if tlp else f"{analysis_id}.zip" @@ -104,16 +119,18 @@ def upload_zip_archive(self, analysis_id, source_directory, tlp=None): tmp_zip_file_name = tmp_zip_file.name with zipfile.ZipFile(tmp_zip_file, "w", zipfile.ZIP_DEFLATED) as archive: for local_path, relative_path in self._iter_files_to_upload(source_directory): - archive.write(local_path, relative_path) + archive.write(local_path, os.path.join(str(analysis_id), relative_path)) try: log.debug("Uploading '%s' to '%s'", tmp_zip_file_name, blob_name) blob = self.bucket.blob(blob_name) + if metadata: + blob.metadata = metadata blob.upload_from_filename(tmp_zip_file_name) finally: os.unlink(tmp_zip_file_name) log.info("Successfully uploaded archive for analysis %s to GCS.", analysis_id) - def upload_files_individually(self, analysis_id, source_directory, tlp=None): + def upload_files_individually(self, analysis_id, source_directory, tlp=None, metadata=None): log.debug("Uploading files for analysis ID %s to GCS", analysis_id) folder_name = f"{analysis_id}_tlp_{tlp}" if tlp else str(analysis_id) @@ -121,6 +138,8 @@ def upload_files_individually(self, analysis_id, source_directory, tlp=None): blob_name = f"{folder_name}/{relative_path}" # log.debug("Uploading '%s' to '%s'", local_path, blob_name) blob = self.bucket.blob(blob_name) + if metadata: + blob.metadata = metadata blob.upload_from_filename(local_path) log.info("Successfully uploaded files for analysis %s to GCS.", analysis_id) @@ -154,6 +173,7 @@ def run(self, results): tlp = results.get("info", {}).get("tlp") analysis_id = results.get("info", {}).get("id") + custom = results.get("info", {}).get("custom") # We can now just use the Uploader. # But for backward compatibility with overrides in self.options (e.g. per-module config overrides in Cuckoo), @@ -172,7 +192,7 @@ def run(self, results): credentials_path_str = self.options.get("credentials_path") credentials_path = None if credentials_path_str: - credentials_path = os.path.join(CUCKOO_ROOT, credentials_path_str) + credentials_path = os.path.join(CUCKOO_ROOT, credentials_path_str) mode = self.options.get("mode", "file") try: @@ -182,8 +202,9 @@ def run(self, results): raise CuckooReportError("Could not get analysis ID from results.") source_directory = self.analysis_path + metadata = GCSUploader.parse_custom_string(custom) - uploader.upload(source_directory, analysis_id, tlp) + uploader.upload(source_directory, analysis_id, tlp, metadata=metadata) except Exception as e: raise CuckooReportError(f"Failed to upload report to GCS: {e}") from e diff --git a/utils/dist.py b/utils/dist.py index c71b9e373b6..d6f958d9667 100644 --- a/utils/dist.py +++ b/utils/dist.py @@ -17,17 +17,15 @@ import threading import time import timeit -import zipfile -from contextlib import suppress +from concurrent.futures import ThreadPoolExecutor, as_completed + from datetime import datetime, timedelta -from io import BytesIO from itertools import combinations from logging import handlers -from urllib.parse import urlparse +from urllib.parse import urlparse, urljoin from sqlalchemy import and_, or_, select, func, delete, case from sqlalchemy.exc import OperationalError, SQLAlchemyError -import pyzipper import requests requests.packages.urllib3.disable_warnings() @@ -53,7 +51,6 @@ ) from lib.cuckoo.core.database import ( Database, - Guest, _Database, init_database, ) @@ -85,6 +82,15 @@ logging.getLogger("requests").setLevel(logging.WARNING) logging.getLogger("urllib3").setLevel(logging.WARNING) + +class NoStatusLogFilter(logging.Filter): + def filter(self, record): + return "GET /status" not in record.getMessage() + + +logging.getLogger("werkzeug").addFilter(NoStatusLogFilter()) +logging.getLogger("_internal").addFilter(NoStatusLogFilter()) + dist_ignore_patterns = shutil.ignore_patterns(*[pattern.strip() for pattern in dist_conf.distributed.ignore_patterns.split(",")]) STATUSES = {} ID2NAME = {} @@ -166,11 +172,11 @@ def node_status(url: str, name: str, apikey: str) -> dict: """ try: r = requests.get( - os.path.join(url, "cuckoo", "status/"), headers={"Authorization": f"Token {apikey}"}, verify=False, timeout=300 + os.path.join(url, "cuckoo", "status/"), headers={"Authorization": f"Token {apikey}"}, verify=False, timeout=5 ) return r.json().get("data", {}) except Exception as e: - log.critical("Possible invalid CAPE node (%s): %s", name, e) + log.warning("Possible invalid CAPE node (%s): %s", name, e) return {} @@ -189,7 +195,7 @@ def node_fetch_tasks(status, url, apikey, action="fetch", since=0): list: A list of tasks fetched from the remote server. Returns an empty list if an error occurs. """ try: - url = os.path.join(url, "tasks", "list/") + url = urljoin(url, "tasks/list/") params = dict(status=status, ids=True) if action == "fetch": params["completed_after"] = since @@ -200,7 +206,7 @@ def node_fetch_tasks(status, url, apikey, action="fetch", since=0): return [] return r.json().get("data", []) except Exception as e: - log.critical("Error listing completed tasks (node %s): %s", url, e) + log.warning("Error listing completed tasks (node %s): %s", url, e) return [] @@ -220,7 +226,7 @@ def node_list_machines(url, apikey): HTTPException: If the request to the CAPE node fails or returns an error. """ try: - r = requests.get(os.path.join(url, "machines", "list/"), headers={"Authorization": f"Token {apikey}"}, verify=False) + r = requests.get(urljoin(url, "machines/list/"), headers={"Authorization": f"Token {apikey}"}, verify=False) for machine in r.json()["data"]: yield Machine(name=machine["name"], platform=machine["platform"], tags=machine["tags"]) except Exception as e: @@ -242,7 +248,7 @@ def node_list_exitnodes(url, apikey): HTTPException: If the request fails or the response is invalid. """ try: - r = requests.get(os.path.join(url, "exitnodes/"), headers={"Authorization": f"Token {apikey}"}, verify=False) + r = requests.get(urljoin(url, "exitnodes/"), headers={"Authorization": f"Token {apikey}"}, verify=False) for exitnode in r.json()["data"]: yield exitnode except Exception as e: @@ -342,9 +348,10 @@ def _delete_many(node, ids, nodes, db): if nodes[node].name == main_server_name: return try: - url = os.path.join(nodes[node].url, "tasks", "delete_many/") + url = urljoin(nodes[node].url, "tasks/delete_many/") apikey = nodes[node].apikey - log.debug("Removing task id(s): %s - from node: %s", ids, nodes[node].name) + # log.debug("Removing task id(s): %s - from node: %s", ids, nodes[node].name) + log.info("[REMOVE] %-15s ==> Task(s) %s", nodes[node].name, ids) res = requests.post( url, headers={"Authorization": f"Token {apikey}"}, @@ -356,11 +363,11 @@ def _delete_many(node, ids, nodes, db): db.rollback() except Exception as e: - log.critical("Error deleting task (tasks #%s, node %s): %s", ids, nodes[node].name, e) + log.warning("Error deleting task (tasks #%s, node %s): %s", ids, nodes[node].name, e) db.rollback() -def node_submit_task(task_id, node_id, main_task_id): +def node_submit_task(task_id, node_id, main_task_id, db=None): """ Submits a task to a specified node for processing. @@ -368,6 +375,7 @@ def node_submit_task(task_id, node_id, main_task_id): task_id (int): The ID of the task to be submitted. node_id (int): The ID of the node to which the task will be submitted. main_task_id (int): The ID of the main task associated with this task. + db (Session, optional): The database session to use. Returns: bool: True if the task was successfully submitted, False otherwise. @@ -384,12 +392,21 @@ def node_submit_task(task_id, node_id, main_task_id): 6. Updates the task status in the database based on the submission result. 7. Logs relevant information and errors during the process. """ - db = session() - node = db.scalar(select(Node).where(Node.id == node_id)) - task = db.get(Task, task_id) - check = False + # Use existing session if provided, otherwise create new + if db: + session_managed = False + else: + db = session() + session_managed = True + try: + node = db.scalar(select(Node).where(Node.id == node_id)) + task = db.get(Task, task_id) + check = False + if node.name == main_server_name: + if session_managed: + db.close() return # Remove the earlier appended comma @@ -413,11 +430,15 @@ def node_submit_task(task_id, node_id, main_task_id): tlp=task.tlp, ) + r = None + # Add timeout to requests + request_timeout = 300 # 5 minutes for large files + if task.category in ("file", "pcap"): if task.category == "pcap": - data = {"pcap": 1} + data["pcap"] = 1 - url = os.path.join(node.url, "tasks", "create", "file/") + url = urljoin(node.url, "tasks/create/file/") # If the file does not exist anymore, ignore it and move on # to the next file. if not path_exists(task.path): @@ -429,11 +450,22 @@ def node_submit_task(task_id, node_id, main_task_id): except Exception as e: log.exception(e) db.rollback() + if session_managed: + db.close() return try: - files = dict(file=open(task.path, "rb")) - r = requests.post(url, data=data, files=files, headers={"Authorization": f"Token {apikey}"}, verify=False) - except OSError: + # Use context manager for file + with open(task.path, "rb") as f: + r = requests.post( + url, + data=data, + files={"file": f}, + headers={"Authorization": f"Token {apikey}"}, + verify=False, + timeout=request_timeout, + ) + except (OSError, requests.RequestException) as e: + log.error("Error submitting file task (Main: %d): %s", task.main_task_id, e) task.finished = True task.retrieved = True main_db.set_status(task.main_task_id, TASK_FAILED_REPORTING) @@ -442,19 +474,45 @@ def node_submit_task(task_id, node_id, main_task_id): except Exception as e: log.exception(e) db.rollback() + if session_managed: + db.close() return elif task.category == "url": - url = os.path.join(node.url, "tasks", "create", "url/") - r = requests.post( - url, data={"url": task.path, "options": task.options}, headers={"Authorization": f"Token {apikey}"}, verify=False - ) + url = urljoin(node.url, "tasks/create/url/") + try: + r = requests.post( + url, + data={"url": task.path, "options": task.options}, + headers={"Authorization": f"Token {apikey}"}, + verify=False, + timeout=request_timeout, + ) + except requests.RequestException as e: + log.error("Error submitting url task (Main: %d): %s", task.main_task_id, e) + if session_managed: + db.close() + return elif task.category == "static": - url = os.path.join(node.url, "tasks", "create", "static/") - files = dict(file=open(task.path, "rb")) - r = requests.post(url, data=data, files=files, headers={"Authorization": f"Token {apikey}"}, verify=False) + url = urljoin(node.url, "tasks/create/static/") + try: + with open(task.path, "rb") as f: + r = requests.post( + url, + data=data, + files={"file": f}, + headers={"Authorization": f"Token {apikey}"}, + verify=False, + timeout=request_timeout, + ) + except (OSError, requests.RequestException) as e: + log.error("Error submitting static task (Main: %d): %s", task.main_task_id, e) + if session_managed: + db.close() + return else: log.debug("Target category is: %s", task.category) - db.close() + if session_managed: + db.close() return # encoding problem @@ -486,18 +544,21 @@ def node_submit_task(task_id, node_id, main_task_id): if b"File too big, enable" in r.content: main_db.set_status(task.main_task_id, TASK_BANNED) if task.task_id: - log.debug("Submitted task to worker: %s - %d - %d", node.name, task.task_id, task.main_task_id) - - elif r.status_code == 500: + # log.debug("Submitted task to worker: %s - %d - %d", node.name, task.task_id, task.main_task_id) + log.info("[SUBMIT] %-15s <== Task %-6d (Main: %d)", node.name, task.task_id, task.main_task_id) + elif r and r.status_code == 500: log.info("Saving error to /tmp/dist_error.html") _ = path_write_file("/tmp/dist_error.html", r.content) log.info((r.status_code, r.text[:200])) - elif r.status_code == 429: + elif r and r.status_code == 429: log.info((r.status_code, "see api auth for more details")) else: - log.info("Node: %d - Task submit to worker failed: %d - %s", node.id, r.status_code, r.text) + if r: + log.info("Node: %d - Task submit to worker failed: %d - %s", node.id, r.status_code, r.text) + else: + log.info("Node: %d - Task submit to worker failed: No response", node.id) if check: task.node_id = node.id @@ -515,10 +576,11 @@ def node_submit_task(task_id, node_id, main_task_id): except Exception as e: log.exception(e) - log.critical("Error submitting task (task #%d, node %s): %s", task.id, node.name, e) + log.critical("Error submitting task (task #%d, node %s): %s", task.id if task else -1, node.name if node else "Unknown", e) - db.commit() - db.close() + if session_managed: + db.commit() + db.close() return check @@ -534,10 +596,6 @@ class Retriever(threading.Thread): free_space_mon(): Monitors free disk space and logs an error if space is insufficient. - - notification_loop(): - Sends notifications for completed tasks to configured callback URLs. - failed_cleaner(): Cleans up failed tasks from nodes and updates their status in the database. @@ -550,9 +608,6 @@ class Retriever(threading.Thread): fetch_latest_reports_nfs(): Fetches the latest reports from nodes using NFS and processes them. - fetch_latest_reports(): - Fetches the latest reports from nodes using REST API and processes them. - remove_from_worker(): Removes tasks from worker nodes and updates their status in the database. """ @@ -641,52 +696,6 @@ def free_space_mon(self): free_space_monitor(dir_path, analysis=True) time.sleep(600) - def notification_loop(self): - """ - Continuously checks for completed tasks that have not been notified and sends notifications to specified URLs. - - This method runs an infinite loop that: - 1. Queries the database for tasks that are finished, retrieved, but not yet notified. - 2. For each task, updates the main task status to `TASK_REPORTED`. - 3. Sends a POST request to each URL specified in the configuration with the task ID in the payload. - 4. Marks the task as notified if the POST request is successful. - 5. Logs the status of each notification attempt. - - The loop sleeps for 20 seconds before repeating the process. - - Raises: - requests.exceptions.ConnectionError: If there is a connection error while sending the POST request. - Exception: For any other exceptions that occur during the notification process. - """ - urls = reporting_conf.callback.url.split(",") - headers = {"x-api-key": reporting_conf.callback.key} - - with session() as db: - while True: - stmt = ( - select(Task) - .where(Task.finished.is_(True), Task.retrieved.is_(True), Task.notificated.is_(False)) - .order_by(Task.id.desc()) - ) - - for task in db.scalars(stmt): - with main_db.session.begin(): - main_db.set_status(task.main_task_id, TASK_REPORTED) - log.debug("reporting main_task_id: %d", task.main_task_id) - for url in urls: - try: - res = requests.post(url, headers=headers, data=json.dumps({"task_id": int(task.main_task_id)})) - if res and res.ok: - task.notificated = True - else: - log.info("failed to report: %d - %d", task.main_task_id, res.status_code) - except requests.exceptions.ConnectionError: - log.info("Can't report to callback") - except Exception as e: - log.info("failed to report: %d - %s", task.main_task_id, str(e)) - db.commit() - time.sleep(20) - def failed_cleaner(self): """ Periodically checks for failed tasks on enabled nodes and cleans them up. @@ -700,46 +709,54 @@ def failed_cleaner(self): The method runs indefinitely, sleeping for 600 seconds between each iteration. Attributes: - self.cleaner_queue (Queue): A queue to hold tasks that need to be cleaned. + self.cleaner_queue (queue.Queue): A queue to hold tasks that need to be cleaned. Notes: - - This method acquires and releases a lock (`lock_retriever`) to ensure - thread-safe operations when adding tasks to the cleaner queue. + - This method acquires and releases a lock (`lock_retriever`) to ensure thread-safe operations when adding tasks to the cleaner queue. - The method commits changes to the database after processing each node. - - The method closes the database session before exiting. Raises: Any exceptions raised during database operations or task processing are not explicitly handled within this method. """ - db = session() while True: - nodes = db.execute(select(Node.id, Node.name, Node.url, Node.apikey).where(Node.enabled.is_(True))) - for node in nodes: - log.info("Checking for failed tasks on: %s", node.name) - for task in node_fetch_tasks("failed_analysis|failed_processing", node.url, node.apikey, action="delete"): - task_stmt = select(Task).where(Task.task_id == task["id"], Task.node_id == node.id).order_by(Task.id.desc()) - t = db.scalar(task_stmt) - if t is not None: - log.info("Cleaning failed for id: %d, node: %s: main_task_id: %d", t.id, t.node_id, t.main_task_id) - with main_db.session.begin(): - main_db.set_status(t.main_task_id, TASK_FAILED_REPORTING) - t.finished = True - t.retrieved = True - t.notificated = True - lock_retriever.acquire() - if (t.node_id, t.task_id) not in self.cleaner_queue.queue: - self.cleaner_queue.put((t.node_id, t.task_id)) - lock_retriever.release() - else: - log.debug("failed_cleaner t is None for: %s - node_id: %d", str(task["id"]), node.id) - lock_retriever.acquire() - if (node.id, task["id"]) not in self.cleaner_queue.queue: - self.cleaner_queue.put((node.id, task["id"])) - lock_retriever.release() - db.commit() + with session() as db: + try: + nodes = db.execute(select(Node).where(Node.enabled.is_(True))) + for node in nodes or []: + failed_task_ids = [ + task["id"] + for task in node_fetch_tasks( + "failed_analysis|failed_processing", node.url, node.apikey, action="delete" + ) + ] + + if not failed_task_ids: + continue + + log.info("Found %d failed tasks on node: %s", len(failed_task_ids), node.name) + + # Fetch all relevant tasks from the database in one query + stmt = select(Task).where(Task.node_id == node.id, Task.task_id.in_(failed_task_ids)) + tasks_in_db = {t.task_id: t for t in db.scalars(stmt)} + + for task_id in failed_task_ids: + t = tasks_in_db.get(task_id) + if t: + log.info("Cleaning failed task: main_task_id=%d, node_id=%d", t.main_task_id, t.node_id) + with main_db.session.begin(): + main_db.set_status(t.main_task_id, TASK_FAILED_REPORTING) + t.finished = True + t.retrieved = True + t.notificated = True + + # Always queue for cleaning on the worker, even if not in our DB + self.cleaner_queue.put((node.id, task_id)) + db.commit() + except SQLAlchemyError as e: + log.error("Database error in failed_cleaner: %s", e) + db.rollback() time.sleep(600) - db.close() def fetcher(self): """ @@ -762,68 +779,56 @@ def fetcher(self): Exception: If an error occurs during task processing, it is logged and the status count is incremented """ last_checks = {} - # to not exit till cleaner works - with session() as db: - while True: - if self.stop_dist.is_set(): - time.sleep(60) + # self.thread_local.requests_session = requests.Session() + + while True: + if self.stop_dist.is_set(): + time.sleep(60) + continue + with session() as db: + try: + nodes = db.scalars(select(Node).where(Node.enabled.is_(True))) + except SQLAlchemyError as e: + log.error("Database error in fetcher: %s", e) continue - # .with_entities(Node.id, Node.name, Node.url, Node.apikey, Node.last_check) - nodes = db.scalars(select(Node).where(Node.enabled.is_(True))) for node in nodes: self.status_count.setdefault(node.name, 0) last_checks.setdefault(node.name, 0) last_checks[node.name] += 1 - # reset it every 10 calls + if hasattr(node, "last_check") and node.last_check: - last_check = int(node.last_check.strftime("%s")) + last_check = int(node.last_check.timestamp()) else: last_check = 0 + + # reset it every 10 calls if last_checks[node.name] == 3: last_check = 0 last_checks[node.name] = 0 - limit = 0 - task_ids = [] - for task in node_fetch_tasks("reported", node.url, node.apikey, "fetch", last_check): - task_ids.append(task["id"]) - - if True: - stmt = ( - select(Task) - .where( - Task.finished.is_(False), - Task.retrieved.is_(False), - Task.node_id == node.id, - Task.deleted.is_(False), - Task.task_id.in_(task_ids), - ) - .order_by(Task.id.desc()) + + task_ids = [task["id"] for task in node_fetch_tasks("reported", node.url, node.apikey, "fetch", last_check)] + + if task_ids: + stmt = select(Task.task_id).where( + Task.finished.is_(False), + Task.retrieved.is_(False), + Task.node_id == node.id, + Task.deleted.is_(False), + Task.task_id.in_(task_ids), ) - tasker = db.scalars(stmt) + tasks_to_fetch = db.scalars(stmt).all() - if tasker is None: - # log.debug(f"Node ID: {node.id} - Task ID: {task['id']} - adding to cleaner") - self.cleaner_queue.put((node.id, task["id"])) - continue + processed_task_ids = set(self.current_queue.get(node.id, [])) + try: + queue_task_ids = {t[0]["id"] for t in list(self.fetcher_queue.queue) if t[1] == node.id} + except TypeError: + print("queue_task_ids", self.fetcher_queue.queue) + queue_task_ids = set() - for task in tasker: + for task_id in tasks_to_fetch: try: - if ( - task.task_id not in self.current_queue.get(node.id, []) - and (task.task_id, node.id) not in self.fetcher_queue.queue - ): - limit += 1 - self.fetcher_queue.put(({"id": task.task_id}, node.id)) - # log.debug("%s - %d", task, node.id) - """ - completed_on = datetime.strptime(task["completed_on"], "%Y-%m-%d %H:%M:%S") - if node.last_check is None or completed_on > node.last_check: - node.last_check = completed_on - db.commit() - db.refresh(node) - #if limit == 50: - # break - """ + if task_id not in processed_task_ids and task_id not in queue_task_ids: + self.fetcher_queue.put(({"id": task_id}, node.id)) except Exception as e: self.status_count[node.name] += 1 log.exception(e) @@ -864,46 +869,6 @@ def delete_target_file(self, task_id: int, sample_sha256: str, target: str): if not sample_still_used: path_delete(copy_path) - def inject_guest_info(self, main_task_id: int, report_path: str): - """ - Inject guest information from report.json into the main database. - - Args: - main_task_id (int): The ID of the main task. - report_path (str): The path to the analysis folder. - """ - report_json_path = os.path.join(report_path, "reports", "report.json") - if not path_exists(report_json_path): - return - - try: - with open(report_json_path, "r") as f: - report_data = json.load(f) - machine = report_data.get("info", {}).get("machine", {}) - if machine and isinstance(machine, dict): - with main_db.session.begin(): - # Check if guest already exists - stmt = select(Guest).where(Guest.task_id == main_task_id) - if not main_db.session.scalar(stmt): - guest = Guest( - name=machine.get("name"), - label=machine.get("label"), - platform=machine.get("platform"), - manager=machine.get("manager"), - task_id=main_task_id, - ) - # Set optional fields if they exist - if "started_on" in machine: - with suppress(Exception): - guest.started_on = datetime.strptime(machine["started_on"], "%Y-%m-%d %H:%M:%S") - if "shutdown_on" in machine: - with suppress(Exception): - guest.shutdown_on = datetime.strptime(machine["shutdown_on"], "%Y-%m-%d %H:%M:%S") - - main_db.session.add(guest) - except Exception as e: - log.error("Failed to inject guest info for task %d: %s", main_task_id, e) - # This should be executed as external thread as it generates bottle neck def fetch_latest_reports_nfs(self): """ @@ -930,12 +895,10 @@ def fetch_latest_reports_nfs(self): Raises: Exception: If any error occurs during the processing of tasks. - """ - # db = session() - with session() as db: - # to not exit till cleaner works - while True: + while True: + with session() as db: + # to not exit till cleaner works if self.stop_dist.is_set(): time.sleep(60) continue @@ -959,6 +922,7 @@ def fetch_latest_reports_nfs(self): ) t = db.scalar(stmt) if t is None: + # print(type(self.t_is_none.get(node_id))) self.t_is_none.setdefault(node_id, []).append(task["id"]) # sometime it not deletes tasks in workers of some fails or something @@ -967,7 +931,7 @@ def fetch_latest_reports_nfs(self): if (node_id, task.get("id")) not in self.cleaner_queue.queue: self.cleaner_queue.put((node_id, task.get("id"))) continue - + """ log.debug( "Fetching dist report for: id: %d, task_id: %d, main_task_id: %d from node: %s", t.id, @@ -975,6 +939,9 @@ def fetch_latest_reports_nfs(self): t.main_task_id, ID2NAME[t.node_id] if t.node_id in ID2NAME else t.node_id, ) + """ + dbg_line = f"[FETCH ] {ID2NAME[t.node_id] if t.node_id in ID2NAME else t.node_id:<15} <== Task {t.id:<6} - (Main ID: {t.main_task_id})" + log.debug(dbg_line) with main_db.session.begin(): # set completed_on time main_db.set_status(t.main_task_id, TASK_DISTRIBUTED_COMPLETED) @@ -994,6 +961,7 @@ def fetch_latest_reports_nfs(self): continue timediff = timeit.default_timer() - start_copy + """ log.info( "It took %s seconds to copy report %d from node: %s for task: %d", f"{timediff:.2f}", @@ -1001,17 +969,23 @@ def fetch_latest_reports_nfs(self): node.name, t.main_task_id, ) - - self.inject_guest_info(t.main_task_id, report_path) - + """ + dbg_line = f"[PERF ] Copied Task {t.task_id} in {timediff:.2f}s" + log.debug(dbg_line) + self.cleaner_queue.put((node_id, task.get("id"))) # this doesn't exist for some reason + sample_parent = None + hashes = {} if path_exists(t.path): sample_sha256 = None - sample_parent = None with main_db.session.begin(): samples = main_db.find_sample(task_id=t.main_task_id) if samples: sample_sha256 = samples[0].sample.sha256 + hashes["sha256"] = samples[0].sample.sha256 + hashes["md5"] = samples[0].sample.md5 + hashes["sha1"] = samples[0].sample.sha1 + if hasattr(samples[0].sample, "parent_links"): for parent in samples[0].sample.parent_links: if parent.task_id == t.main_task_id: @@ -1043,8 +1017,9 @@ def fetch_latest_reports_nfs(self): if sample_parent: try: report = load_iocs(t.main_task_id, detail=True) - report["info"].update({"parent_sample": sample_parent}) - dump_iocs(report, t.main_task_id) + if report: + report["info"].update({"parent_sample": sample_parent}) + dump_iocs(report, t.main_task_id) # ToDo insert into mongo mongo_update_one( "analysis", {"info.id": int(t.main_task_id)}, {"$set": {"info.parent_sample": sample_parent}} @@ -1058,7 +1033,10 @@ def fetch_latest_reports_nfs(self): # TLP is not readily available in 't' object without loading report.json or task options. # We can try to get TLP from task options if available, or just pass None. tlp = t.tlp - gcs_uploader.upload(report_path, t.main_task_id, tlp=tlp) + metadata = GCSUploader.parse_custom_string(t.custom) + metadata.update(hashes) + metadata["task_id"] = t.main_task_id + gcs_uploader.upload(report_path, t.main_task_id, tlp=tlp, metadata=metadata) if GCS_DELETE_AFTER_UPLOAD: try: @@ -1079,164 +1057,6 @@ def fetch_latest_reports_nfs(self): self.current_queue[node_id].remove(task["id"]) db.commit() - # This should be executed as external thread as it generates bottle neck - def fetch_latest_reports(self): - """ - Continuously fetches the latest reports from distributed nodes and processes them. - - This method runs in an infinite loop until `self.stop_dist` is set. It retrieves tasks from the `fetcher_queue`, - fetches the corresponding reports from the nodes, and processes them. The reports are saved to the local storage - and the task status is updated in the database. - - The method handles various scenarios such as: - - Task not found or already processed. - - Report retrieval failures. - - Report extraction and saving. - - Handling of sample binaries associated with the tasks. - - The method also manages a cleaner queue to handle tasks that need to be cleaned up. - - Raises: - Exception: If any unexpected error occurs during the report fetching and processing. - """ - db = session() - # to not exit till cleaner works - while True: - if self.stop_dist.is_set(): - time.sleep(60) - continue - task, node_id = self.fetcher_queue.get() - - self.current_queue.setdefault(node_id, []).append(task["id"]) - - try: - # In the case that a Cuckoo node has been reset over time it"s - # possible that there are multiple combinations of - # node-id/task-id, in this case we take the last one available. - # (This makes it possible to re-setup a Cuckoo node). - stmt = ( - select(Task) - .where( - Task.node_id == node_id, - Task.task_id == task["id"], - Task.retrieved.is_(False), - Task.finished.is_(False), - ) - .order_by(Task.id.desc()) - ) - t = db.scalar(stmt) - if t is None: - self.t_is_none.setdefault(node_id, []).append(task["id"]) - - # sometime it not deletes tasks in workers of some fails or something - # this will do the trick - # log.debug("tf else,") - if (node_id, task.get("id")) not in self.cleaner_queue.queue: - self.cleaner_queue.put((node_id, task.get("id"))) - continue - - log.debug( - "Fetching dist report for: id: %d, task_id: %d, main_task_id: %d from node: %s", - t.id, - t.task_id, - t.main_task_id, - ID2NAME[t.node_id] if t.node_id in ID2NAME else t.node_id, - ) - with main_db.session.begin(): - # set completed_on time - main_db.set_status(t.main_task_id, TASK_DISTRIBUTED_COMPLETED) - # set reported time - main_db.set_status(t.main_task_id, TASK_REPORTED) - - # Fetch each requested report. - node = db.scalar(select(Node).where(Node.id == node_id)) - report = node_get_report(t.task_id, "dist/", node.url, node.apikey, stream=True) - - if report is None: - log.info("dist report retrieve failed NONE: task_id: %d from node: %d", t.task_id, node_id) - continue - - if report.status_code != 200: - log.info( - "dist report retrieve failed - status_code %d: task_id: %d from node: %s", - report.status_code, - t.task_id, - node_id, - ) - if report.status_code == 400 and (node_id, task.get("id")) not in self.cleaner_queue.queue: - self.cleaner_queue.put((node_id, task.get("id"))) - log.info("Status code: %d - MSG: %s", report.status_code, report.text) - continue - - log.info( - "Report size for task %s is: %s MB", - t.task_id, - f"{int(report.headers.get('Content-length', 1)) / int(1 << 20):,.0f}", - ) - - report_path = os.path.join(CUCKOO_ROOT, "storage", "analyses", str(t.main_task_id)) - if not path_exists(report_path): - path_mkdir(report_path, mode=0o755) - try: - if report.content: - # with pyzipper.AESZipFile(BytesIO(report.content)) as zf: - # zf.setpassword(zip_pwd) - with zipfile.ZipFile(BytesIO(report.content)) as zf: - try: - zf.extractall(report_path) - self.inject_guest_info(t.main_task_id, report_path) - if (node_id, task.get("id")) not in self.cleaner_queue.queue: - self.cleaner_queue.put((node_id, task.get("id"))) - except OSError: - log.error("Permission denied: %s", report_path) - - if path_exists(t.path): - sample_sha256 = None - with main_db.session.begin(): - samples = main_db.find_sample(task_id=t.main_task_id) - if samples: - sample_sha256 = samples[0].sample.sha256 - if sample_sha256 is None: - # keep fallback for now - with open(t.path, "rb") as f: - sample = f.read() - sample_sha256 = hashlib.sha256(sample).hexdigest() - - destination = os.path.join(CUCKOO_ROOT, "storage", "binaries") - if not path_exists(destination): - path_mkdir(destination, mode=0o755) - - destination = os.path.join(destination, sample_sha256) - if not path_exists(destination) and path_exists(t.path): - shutil.move(t.path, destination) - - # creating link to analysis folder - if path_exists(t.path): - with suppress(Exception): - os.symlink(destination, os.path.join(report_path, "binary")) - - self.delete_target_file(t.main_task_id, sample_sha256, t.path) - - else: - log.debug("%s doesn't exist", t.path) - - t.retrieved = True - t.finished = True - db.commit() - - else: - log.error("Zip file is empty") - except pyzipper.zipfile.BadZipFile: - log.error("File is not a zip file") - except Exception as e: - log.exception("Exception: %s", str(e)) - if path_exists(os.path.join(report_path, "reports", "report.json")): - path_delete(os.path.join(report_path, "reports", "report.json")) - except Exception as e: - log.exception(e) - self.current_queue[node_id].remove(task["id"]) - db.commit() - db.close() def remove_from_worker(self): """ @@ -1254,7 +1074,7 @@ def remove_from_worker(self): 4. Removes tasks from the `t_is_none` dictionary if present. 5. Sends a request to delete tasks from the worker node. 6. Commits the changes to the database. - 7. Sleeps for 20 seconds before processing the next batch of tasks. + 7. Sleeps for 20 seconds before processing the next tasks. Note: The method runs indefinitely until manually stopped. @@ -1267,29 +1087,26 @@ def remove_from_worker(self): with session() as db: for node in db.scalars(select(Node)): nodes.setdefault(node.id, node) - + self.t_is_none.setdefault(node.id, set()) while True: details = {} # print("cleaner size is ", self.cleaner_queue.qsize()) for _ in range(self.cleaner_queue.qsize()): node_id, task_id = self.cleaner_queue.get() - details.setdefault(node_id, []).append(str(task_id)) if task_id in self.t_is_none.get(node_id, []): self.t_is_none[node_id].remove(task_id) - - if len(self.t_is_none[node_id]) > 50: - break - - # ToDo Do we need to do something here? - + details.setdefault(node_id, set()).add(str(task_id)) + if len(self.t_is_none.get(node_id)) > 50: + break + db = session() for node_id in details: node = nodes[node_id] - if node and details[node_id]: + if node and details.get(node_id): ids = ",".join(list(set(details[node_id]))) _delete_many(node_id, ids, nodes, db) - - db.commit() - time.sleep(20) + db.commit() + db.close() + time.sleep(20) class StatusThread(threading.Thread): @@ -1363,192 +1180,226 @@ def submit_tasks(self, node_name, pend_tasks_num, options_like=False, force_push ) if not main_db_tasks: return True - if main_db_tasks: - for t in main_db_tasks: - options = get_options(t.options) - # Check if file exist, if no wipe from db and continue, rare cases - if t.category in ("file", "pcap", "static"): - if not path_exists(t.target): - log.info("Task id: %d - File doesn't exist: %s", t.id, t.target) - main_db.set_status(t.id, TASK_BANNED) - continue - if not web_conf.general.allow_ignore_size and "ignore_size_check" not in options: - # We can't upload size bigger than X to our workers. In case we extract archive that contains bigger file. - file_size = path_get_size(t.target) - if file_size > web_conf.general.max_sample_size: - log.warning( - "File size: %d is bigger than allowed: %d", file_size, web_conf.general.max_sample_size - ) - main_db.set_status(t.id, TASK_BANNED) - continue - force_push = False - try: - # check if node exist and its correct - if options.get("node"): - requested_node = options.get("node") - if requested_node not in STATUSES: - # if the requested node is not available - force_push = True - elif requested_node != node.name: - # otherwise keep looping - continue - if "timeout=" in t.options: - t.timeout = options.get("timeout", 0) - except Exception as e: - log.exception(e) - # wtf are you doing in pendings? - tasks = db.scalars(select(Task).where(Task.main_task_id == t.id)).all() - if tasks: - for task in tasks: - log.info("Deleting incorrectly uploaded file from dist db, main_task_id: %s", t.id) - if node.name == main_server_name: - main_db.set_status(t.id, TASK_RUNNING) - else: - main_db.set_status(t.id, TASK_DISTRIBUTED) - # db.delete(task) - db.commit() + # Prefetch existing distributed tasks for these main tasks to avoid N+1 query + main_task_ids = [t.id for t in main_db_tasks] + existing_dist_tasks = db.scalars(select(Task).where(Task.main_task_id.in_(main_task_ids))).all() + existing_dist_tasks_map = {} + for task in existing_dist_tasks: + existing_dist_tasks_map.setdefault(task.main_task_id, []).append(task) + + tasks_to_push = [] + for t in main_db_tasks: + # somtime big files breaks cape + # print(t.category, t.target) + if t.category in ("file", "pcap", "static"): + if not path_exists(t.target): + log.info("Task id: %d - File doesn't exist: %s", t.id, t.target) + main_db.set_status(t.id, TASK_BANNED) continue - # Convert array of tags into comma separated list - tags = ",".join([tag.name for tag in t.tags]) - # Append a comma, to make LIKE searches more precise - if tags: - tags += "," - - # sanity check - if "x86" in tags and "x64" in tags: - tags = tags.replace("x86,", "") - - if "msoffice-crypt-tmp" in t.target and "password=" in t.options: - # t.options = t.options.replace(f"password={options['password']}", "") - options["password"] - # if options.get("node"): - # t.options = t.options.replace(f"node={options['node']}", "") + if not web_conf.general.allow_ignore_size and "ignore_size_check" not in t.options: + # We can't upload size bigger than X to our workers. In case we extract archive that contains bigger file. + file_size = path_get_size(t.target) + if file_size > web_conf.general.max_sample_size: + log.debug( + "File size: %d is bigger than allowed: %d", file_size, web_conf.general.max_sample_size + ) + main_db.set_status(t.id, TASK_BANNED) + continue + options = get_options(t.options) + # Check if file exist, if no wipe from db and continue, rare cases + force_push = False + try: + # check if node exist and its correct if options.get("node"): - del options["node"] - t.options = ",".join([f"{k}={v}" for k, v in options.items()]) - if t.options: - t.options += "," - t.options += f"main_task_id={t.id}" - args = dict( - package=t.package, - category=t.category, - timeout=t.timeout, - priority=t.priority, - options=t.options, - machine=t.machine, - platform=t.platform, - tags=tags, - custom=t.custom, - memory=t.memory, - clock=t.clock, - enforce_timeout=t.enforce_timeout, - main_task_id=t.id, - route=t.route, - tlp=t.tlp, - ) - task = Task(path=t.target, **args) - db.add(task) - try: - db.commit() - except Exception as e: - log.exception(e) - log.info("TASK_FAILED_REPORTING") - db.rollback() - log.info(e) - continue - if force_push or force_push_push: - # Submit appropriate tasks to node - submitted = node_submit_task(task.id, node.id, t.id) - if submitted: - if node.name == main_server_name: - main_db.set_status(t.id, TASK_RUNNING) - else: - main_db.set_status(t.id, TASK_DISTRIBUTED) - limit += 1 - if limit in (pend_tasks_num, len(main_db_tasks)): - db.commit() - log.info("Pushed all tasks") - return True - - # ToDo not finished - # Only get tasks that have not been pushed yet. - """ - q = db.query(Task).filter(or_(Task.node_id.is_(None), Task.task_id.is_(None)), Task.finished.is_(False)) - if q is None: - db.commit() - return True - - # Order by task priority and task id. - q = q.order_by(-Task.priority, Task.main_task_id) - # if we have node set in options push - - if dist_conf.distributed.enable_tags: - # Create filter query from tasks in ta - tags = [getattr(Task, "tags") == ""] - for tg in SERVER_TAGS[node.name]: - if len(tg.split(",")) == 1: - tags.append(getattr(Task, "tags") == (tg + ",")) + requested_node = options.get("node") + if requested_node not in STATUSES: + # if the requested node is not available + force_push = True + elif requested_node != node.name: + # otherwise keep looping + continue + if "timeout=" in t.options: + t.timeout = options.get("timeout", 0) + except Exception as e: + log.exception(e) + # wtf are you doing in pendings? + tasks = existing_dist_tasks_map.get(t.id, []) + if tasks: + for task in tasks: + log.info("Deleting incorrectly uploaded file from dist db, main_task_id: %s", t.id) + if node.name == main_server_name: + main_db.set_status(t.id, TASK_RUNNING) else: - tg = tg.split(",") - # ie. LIKE "%,%,%," - t_combined = [getattr(Task, "tags").like("%s" % ("%," * len(tg)))] - for tag in tg: - t_combined.append(getattr(Task, "tags").like("%%%s%%" % (tag + ","))) - tags.append(and_(*t_combined)) - # Filter by available tags - q = q.filter(or_(*tags)) - - to_upload = q.limit(pend_tasks_num).all() - """ - # 1. Start with a select() statement and initial filters. - stmt = ( - select(Task) - .where(or_(Task.node_id.is_(None), Task.task_id.is_(None)), Task.finished.is_(False)) - .order_by(Task.priority.desc(), Task.main_task_id) + main_db.set_status(t.id, TASK_DISTRIBUTED) + # db.delete(task) + db.commit() + continue + # Convert array of tags into comma separated list + tags = ",".join([tag.name for tag in t.tags]) + # Append a comma, to make LIKE searches more precise + if tags: + tags += "," + # sanity check + if "x86" in tags and "x64" in tags: + tags = tags.replace("x86,", "") + if "msoffice-crypt-tmp" in t.target and "password=" in t.options: + # t.options = t.options.replace(f"password={options['password']}", "") + options["password"] + # if options.get("node"): + # t.options = t.options.replace(f"node={options['node']}", "") + if options.get("node"): + del options["node"] + t.options = ",".join([f"{k}={v}" for k, v in options.items()]) + if t.options: + t.options += "," + t.options += f"main_task_id={t.id}" + args = dict( + package=t.package, + category=t.category, + timeout=t.timeout, + priority=t.priority, + options=t.options, + machine=t.machine, + platform=t.platform, + tags=tags, + custom=t.custom, + memory=t.memory, + clock=t.clock, + enforce_timeout=t.enforce_timeout, + main_task_id=t.id, + route=t.route, + tlp=t.tlp, ) - # print(stmt, "stmt") - # ToDo broken - """ - # 3. Apply the dynamic tag filter. - if dist_conf.distributed.enable_tags: - tags_conditions = [Task.tags == ""] - for tg in SERVER_TAGS[node.name]: - tags_list = tg.split(",") - if len(tags_list) == 1: - tags_conditions.append(Task.tags == f"{tg},") - else: - # The pattern of building a list of conditions for `and_` or `or_` - # works the same way with the modern .where() clause. - t_combined = [Task.tags.like(f"%{tag},%") for tag in tags_list] - tags_conditions.append(and_(*t_combined)) + task = Task(path=t.target, **args) + db.add(task) + db.commit() + try: + if force_push or force_push_push: + # Add to list for parallel submission + tasks_to_push.append((task, t.id)) + limit += 1 + if limit == pend_tasks_num: + # We will process collected tasks below + break + except Exception as e: + log.exception(e) + log.info("TASK_FAILED_REPORTING") + db.rollback() + log.info(e) + continue - stmt = stmt.where(or_(*tags_conditions)) - """ - # 4. Apply the limit and execute the query. - to_upload = db.scalars(stmt.limit(pend_tasks_num)).all() + # Process collected force_push tasks in parallel + if tasks_to_push: + max_workers = int(dist_conf.distributed.dist_threads) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + # future -> (task, main_task_id) + future_to_info = { + executor.submit(node_submit_task, task.id, node.id, main_task_id, db=None): (task, main_task_id) + for task, main_task_id in tasks_to_push + } + for future in as_completed(future_to_info): + task, main_task_id = future_to_info[future] + try: + submitted = future.result() + if submitted: + if node.name == main_server_name: + main_db.set_status(main_task_id, TASK_RUNNING) + else: + main_db.set_status(main_task_id, TASK_DISTRIBUTED) + except Exception as e: + log.exception("Exception during parallel submission (force_push) for task %d: %s", task.id, e) - if not to_upload: - db.commit() - log.info("nothing to upload? How? o_O") - return False - # Submit appropriate tasks to node - log.debug("going to upload %d tasks to node %s", pend_tasks_num, node.name) - for task in to_upload: - submitted = node_submit_task(task.id, node.id, task.main_task_id) - if submitted: - if node.name == main_server_name: - main_db.set_status(task.main_task_id, TASK_RUNNING) - else: - main_db.set_status(task.main_task_id, TASK_DISTRIBUTED) + db.commit() + log.info("Pushed all tasks") + return True + # ToDo not finished + # Only get tasks that have not been pushed yet. + """ + q = db.query(Task).filter(or_(Task.node_id.is_(None), Task.task_id.is_(None)), Task.finished.is_(False)) + if q is None: + db.commit() + return True + # Order by task priority and task id. + q = q.order_by(-Task.priority, Task.main_task_id) + # if we have node set in options push + if dist_conf.distributed.enable_tags: + # Create filter query from tasks in ta + tags = [getattr(Task, "tags") == ""] + for tg in SERVER_TAGS[node.name]: + if len(tg.split(",")) == 1: + tags.append(getattr(Task, "tags") == (tg + ",")) + else: + tg = tg.split(",") + # ie. LIKE "%,%,%," + t_combined = [getattr(Task, "tags").like("%s" % ("%," * len(tg)))] + for tag in tg: + t_combined.append(getattr(Task, "tags").like("%%%s%%" % (tag + ","))) + tags.append(and_(*t_combined)) + # Filter by available tags + q = q.filter(or_(*tags)) + to_upload = q.limit(pend_tasks_num).all() + """ + # 1. Start with a select() statement and initial filters. + stmt = ( + select(Task) + .where(or_(Task.node_id.is_(None), Task.task_id.is_(None)), Task.finished.is_(False)) + .order_by(Task.priority.desc(), Task.main_task_id) + ) + # print(stmt, "stmt") + # ToDo broken + """ + # 3. Apply the dynamic tag filter. + if dist_conf.distributed.enable_tags: + tags_conditions = [Task.tags == ""] + for tg in SERVER_TAGS[node.name]: + tags_list = tg.split(",") + if len(tags_list) == 1: + tags_conditions.append(Task.tags == f"{tg},") else: - log.info("something is wrong with submission of task: %d", task.id) - db.delete(task) - db.commit() - limit += 1 - if limit == pend_tasks_num: - db.commit() - return True + # The pattern of building a list of conditions for `and_` or `or_` + # works the same way with the modern .where() clause. + t_combined = [Task.tags.like(f"%{tag},%") for tag in tags_list] + tags_conditions.append(and_(*t_combined)) + stmt = stmt.where(or_(*tags_conditions)) + """ + # 4. Apply the limit and execute the query. + to_upload = db.scalars(stmt.limit(pend_tasks_num)).all() + # print(to_upload, node.name, pend_tasks_num) + if not to_upload: + db.commit() + log.info("nothing to upload? How? o_O") + return False + # Parallel execution + max_workers = int(dist_conf.distributed.dist_threads) + with ThreadPoolExecutor(max_workers=max_workers) as executor: + future_to_task = { + executor.submit(node_submit_task, task.id, node.id, task.main_task_id, db=None): task + for task in to_upload + } + count_submitted = 0 + for future in as_completed(future_to_task): + task = future_to_task[future] + try: + submitted = future.result() + if submitted: + if node.name == main_server_name: + main_db.set_status(task.main_task_id, TASK_RUNNING) + else: + main_db.set_status(task.main_task_id, TASK_DISTRIBUTED) + else: + log.info("something is wrong with submission of task: %d", task.id) + # Since task object is bound to 'db' session which is active in this thread, + # we can use it. + main_db.set_status(task.main_task_id, TASK_BANNED) + db.delete(task) + db.commit() + except Exception as e: + log.exception("Exception during parallel submission for task %d: %s", task.id, e) + count_submitted += 1 + # Original logic had early return if limit reached, but we process batch. + db.commit() + return True db.commit() return True @@ -1587,8 +1438,8 @@ def run(self): db = session() master_storage_only = False if not dist_conf.distributed.master_storage_only: - stmt1 = select(Node.id, Node.name, Node.url, Node.apikey).where(Node.name == main_server_name) - master = db.stelar(stmt1) + stmt1 = select(Node).where(Node.name == main_server_name) + master = db.scalar(stmt1) if master is None: master_storage_only = True elif db.scalar(select(func.count(Machine.id)).where(Machine.node_id == master.id)) == 0: @@ -1597,102 +1448,105 @@ def run(self): master_storage_only = True db.close() - # MINIMUMQUEUE but per Node depending of number vms - nodes = db.scalars(select(Node).where(Node.enabled.is_(True))) - for node in nodes: - MINIMUMQUEUE[node.name] = db.scalar(select(func.count(Machine.id)).where(Machine.node_id == node.id)) - ID2NAME[node.id] = node.name - self.load_vm_tags(db, node.id, node.name) + with session() as db: + # MINIMUMQUEUE but per Node depending of number vms + nodes = db.scalars(select(Node).where(Node.enabled.is_(True))) + for node in nodes: + MINIMUMQUEUE[node.name] = db.scalar(select(func.count(Machine.id)).where(Machine.node_id == node.id)) + ID2NAME[node.id] = node.name + self.load_vm_tags(db, node.id, node.name) - db.commit() statuses = {} while True: # HACK: This exception handling here is a big hack as well as db should check if the # there is any issue with the current session (expired or database is down.). try: - # Remove disabled nodes - nodes = db.scalars(select(Node).where(Node.enabled.is_(False))) + with session() as db: + # Remove disabled nodes + nodes = db.scalars(select(Node).where(Node.enabled.is_(False))) + + for node in nodes or []: + if node.name in STATUSES: + STATUSES.pop(node.name) + + # Request a status update on all CAPE nodes. + nodes = db.scalars(select(Node).where(Node.enabled.is_(True))) + for node in nodes: # noqa + status = node_status(node.url, node.name, node.apikey) + if not status: + log.error("No data from node: %s", node.name) + failed_count.setdefault(node.name, 0) + failed_count[node.name] += 1 + # This will declare worker as dead after X failed connections checks + if failed_count[node.name] == dead_count: + log.info("[-] %s dead", node.name) + # node.enabled = False + db.commit() + if node.name in STATUSES: + STATUSES.pop(node.name) - for node in nodes or []: - if node.name in STATUSES: - STATUSES.pop(node.name) + continue + failed_count[node.name] = 0 + # log.info("Status.. %s -> %s", node.name, status["tasks"]) - # Request a status update on all CAPE nodes. - nodes = db.scalars(select(Node).where(Node.enabled.is_(True))) - for node in nodes: - status = node_status(node.url, node.name, node.apikey) - if not status: - failed_count.setdefault(node.name, 0) - failed_count[node.name] += 1 - # This will declare worker as dead after X failed connections checks - if failed_count[node.name] == dead_count: - log.info("[-] %s dead", node.name) - # node.enabled = False - db.commit() - if node.name in STATUSES: - STATUSES.pop(node.name) - continue - failed_count[node.name] = 0 - log.info("Status.. %s -> %s", node.name, status["tasks"]) - statuses[node.name] = status - statuses[node.name]["enabled"] = True - STATUSES = statuses - try: - # first submit tasks with specified node - res = self.submit_tasks( - node.name, - MINIMUMQUEUE[node.name], - options_like=f"node={node.name}", - force_push_push=True, - db=db, - ) - # We return False if nothing uploaded to cicle the nodes in case we have tags related tasks - if not res: + log_line = f"STATUS | {node.name:<15} | Pend: {status['tasks']['pending']:>3} | Run: {status['tasks']['running']:>3} | Done: {status['tasks']['completed']:>3} | Rep: {status['tasks']['reported']:>4}" + log.info(log_line) + + statuses[node.name] = status + statuses[node.name]["enabled"] = True + STATUSES = statuses + try: + # first submit tasks with specified node + res = self.submit_tasks( + node.name, + MINIMUMQUEUE[node.name], + options_like=f"node={node.name}", + force_push_push=True, + db=db, + ) + # We return False if nothing uploaded to cicle the nodes in case we have tags related tasks + if not res: + continue + # Balance the tasks, works fine if no tags are set + node_name = min( + STATUSES, + key=lambda k: STATUSES[k]["tasks"]["completed"] + + STATUSES[k]["tasks"]["pending"] + + STATUSES[k]["tasks"]["running"], + ) + if node_name != node.name: + node = db.scalar(select(Node).where(Node.name == node_name)) + pend_tasks_num = MINIMUMQUEUE[node.name] - ( + STATUSES[node.name]["tasks"]["pending"] + STATUSES[node.name]["tasks"]["running"] + ) + except KeyError: + # servers hotplug + MINIMUMQUEUE[node.name] = db.scalar(select(func.count(Machine.id)).where(Machine.node_id == node.id)) continue - # Balance the tasks, works fine if no tags are set - node_name = min( - STATUSES, - key=lambda k: STATUSES[k]["tasks"]["completed"] - + STATUSES[k]["tasks"]["pending"] - + STATUSES[k]["tasks"]["running"], - ) - if node_name != node.name: - node = db.scalar(select(Node).where(Node.name == node_name)) - pend_tasks_num = MINIMUMQUEUE[node.name] - ( - STATUSES[node.name]["tasks"]["pending"] + STATUSES[node.name]["tasks"]["running"] - ) - except KeyError: - # servers hotplug - MINIMUMQUEUE[node.name] = db.scalar(select(func.count(Machine.id)).where(Machine.node_id == node.id)) - continue - if pend_tasks_num <= 0: - continue - # If - master only used for storage, not check master queue - # elif - master also analyze samples, check master queue - # send tasks to workers if master queue has extra tasks(pending) - if master_storage_only: - res = self.submit_tasks(node.name, pend_tasks_num, db=db) - if not res: + if pend_tasks_num <= 0: continue + # If - master only used for storage, not check master queue + # elif - master also analyze samples, check master queue + # send tasks to workers if master queue has extra tasks(pending) + if master_storage_only: + res = self.submit_tasks(node.name, pend_tasks_num, db=db) + if not res: + continue - elif ( - statuses.get(main_server_name, {}).get("tasks", {}).get("pending", 0) - > MINIMUMQUEUE.get(main_server_name, 0) - and status["tasks"]["pending"] < MINIMUMQUEUE[node.name] - ): - res = self.submit_tasks(node.name, pend_tasks_num, db=db) - if not res: - continue - db.commit() + elif ( + statuses.get(main_server_name, {}).get("tasks", {}).get("pending", 0) + > MINIMUMQUEUE.get(main_server_name, 0) + and status["tasks"]["pending"] < MINIMUMQUEUE[node.name] + ): + res = self.submit_tasks(node.name, pend_tasks_num, db=db) + if not res: + continue + db.commit() except Exception as e: log.exception("Got an exception when trying to check nodes status and submit tasks: %s.", str(e)) - # ToDo hard test this rollback, this normally only happens on db restart and similar - db.rollback() time.sleep(INTERVAL) - db.close() - def output_json(data, code, headers=None): """ @@ -1861,7 +1715,7 @@ def get(self, main_task_id): db = session() task_db = db.scalar(select(Task).where(Task.main_task_id == main_task_id)) if task_db and task_db.node_id: - node_stmt = select(Node.id, Node.name, Node.url, Node.apikey, Node.enabled).where(Node.id == task_db.node_id) + node_stmt = select(Node).where(Node.id == task_db.node_id) node = db.scalar(node_stmt) response = {"status": 1, "task_id": task_db.task_id, "url": node.url, "name": node.name} else: