From da4041ba7b8f03e06d132d6b16634033acb128c8 Mon Sep 17 00:00:00 2001 From: Arnav Date: Wed, 17 Jun 2026 20:37:47 +0530 Subject: [PATCH] Decouple SparkSubmitOperator resumable deployment backends for better maintainability The ResumableJobMixin implementation for SparkSubmitOperator previously had YARN, Kubernetes, and Standalone backend logics interleaved directly inside each mixin method. This scattered per-backend logic across multiple methods. Decoupling these by introducing specialized strategy classes per backend isolates the deployment-specific details, making the operator easier to maintain and extend. --- .../apache/spark/operators/spark_submit.py | 324 +++++++++++------- 1 file changed, 195 insertions(+), 129 deletions(-) diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 7ceb95b387a5d..47f07a4cea1f7 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -59,6 +59,183 @@ def execute_resumable(self, context): from airflow.providers.common.compat.sdk import Context +class _SparkSubmitDeploymentBackend: + """Base interface for private Spark submit deployment backends.""" + + def __init__(self, operator: SparkSubmitOperator, hook: SparkSubmitHook) -> None: + self.operator = operator + self.hook = hook + + def submit_job(self, context: Context) -> str | None: + raise NotImplementedError() + + def get_job_status(self, external_id: str, context: Context) -> str: + raise NotImplementedError() + + def is_job_active(self, status: str) -> bool: + raise NotImplementedError() + + def is_job_succeeded(self, status: str) -> bool: + raise NotImplementedError() + + def poll_until_complete(self, external_id: str, context: Context) -> None: + raise NotImplementedError() + + def on_kill(self) -> None: + raise NotImplementedError() + + +class _KubernetesSparkSubmitBackend(_SparkSubmitDeploymentBackend): + """Logic for tracking Spark driver pods in Kubernetes.""" + + def submit_job(self, context: Context) -> str | None: + self.hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false" + self.hook.submit(self.operator.application) + pod_name = self.hook._kubernetes_driver_pod + namespace = self.hook._connection["namespace"] + if not pod_name: + raise RuntimeError("spark-submit did not capture a K8s driver pod name") + external_id = f"{namespace}:{pod_name}" + self.operator.log.info("Spark K8s driver pod submitted: %s", external_id) + return external_id + + def get_job_status(self, external_id: str, context: Context) -> str: + if (task_state_store := context.get("task_state_store")) is not None: + if (cached := task_state_store.get(self.operator._K8S_DRIVER_STATUS_KEY)) is not None: + if not isinstance(cached, str): + raise ValueError(f"Cached K8s driver status is not a string: {cached!r}") + return cached + if kube_client is None: + raise RuntimeError("apache-airflow-providers-cncf-kubernetes is required to query K8s pod status") + namespace, pod_name = self.operator._parse_k8s_external_id(external_id) + try: + client = kube_client.get_kube_client() + pod = client.read_namespaced_pod(pod_name, namespace) + return pod.status.phase or "Pending" + except kube_client.ApiException as e: + if e.status == 404: + return "NotFound" + raise + + def is_job_active(self, status: str) -> bool: + return status.upper() in ("PENDING", "RUNNING") + + def is_job_succeeded(self, status: str) -> bool: + return status.upper() == "SUCCEEDED" + + def poll_until_complete(self, external_id: str, context: Context) -> None: + if external_id is not None: + _, pod_name = self.operator._parse_k8s_external_id(external_id) + self.hook._kubernetes_driver_pod = pod_name + terminal_phase = self.hook._poll_k8s_driver_via_api() + # Cache only when the pod actually reached Succeeded, the 404/vanished path + # returns None for cases like: pod deleted by on_kill or garbage collected after failure) + # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. + if terminal_phase == "Succeeded" and self.operator.reconnect_on_retry: + if (task_state_store := context.get("task_state_store")) is not None: + task_state_store.set(self.operator._K8S_DRIVER_STATUS_KEY, "Succeeded") + + def on_kill(self) -> None: + self.hook.on_kill() + + +class _YarnSparkSubmitBackend(_SparkSubmitDeploymentBackend): + """Logic for tracking Spark applications in YARN cluster mode.""" + + def submit_job(self, context: Context) -> str | None: + if self.hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true": + raise ValueError( + "spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts" + "with the need to exit spark-submit immediately to persist the application ID for tracking. " + "Either remove the explicit conf or set reconnect_on_retry=False." + ) + self.hook._conf["spark.yarn.submit.waitAppCompletion"] = "false" + self.hook.submit(self.operator.application) + app_id = self.hook._yarn_application_id + if not app_id: + raise RuntimeError("spark-submit did not produce a YARN application ID") + self.operator.log.info("YARN application submitted: %s", app_id) + return app_id + + def get_job_status(self, external_id: str, context: Context) -> str: + return self.hook.query_yarn_application_status(external_id) + + def is_job_active(self, status: str) -> bool: + # https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html + return status.upper() in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"} + + def is_job_succeeded(self, status: str) -> bool: + return status.upper() == "SUCCEEDED" + + def poll_until_complete(self, external_id: str, context: Context) -> None: + try: + self.hook._start_yarn_application_status_tracking(external_id) + finally: + self.hook._run_post_submit_commands() + + def on_kill(self) -> None: + if self.hook._yarn_application_id: + # spark-submit has already exited (waitAppCompletion=false), so the hook's + # CLI-based kill has nothing to terminate. Kill the YARN app via REST API instead. + self.hook._kill_yarn_application(self.hook._yarn_application_id) + else: + self.hook.on_kill() + + +class _StandaloneSparkSubmitBackend(_SparkSubmitDeploymentBackend): + """Logic for tracking Spark driver status in Spark standalone mode.""" + + def submit_job(self, context: Context) -> str | None: + driver_id = self.hook.submit(self.operator.application) + if not driver_id: + raise RuntimeError("spark-submit did not return a driver ID") + self.operator.log.info("Spark driver submitted: %s", driver_id) + return driver_id + + def get_job_status(self, external_id: str, context: Context) -> str: + scheme = self.hook._connection.get("rest_scheme", "http") + rest_port = self.hook._connection.get("rest_port", 6066) + # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. + # The master URL port (e.g. 7077) is the RPC port — not the REST API port. + # Use rest-port connection extra to override spark.master.rest.port (default 6066). + master_urls = self.hook._connection["master"].replace("spark://", "").split(",") + last_exc: Exception = RuntimeError("No Spark masters to query") + for m in master_urls: + host = m.strip().split(":")[0] + url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}" + try: + status = self.operator._fetch_driver_status(url, external_id) + return status + except Exception as e: + self.operator.log.warning("Could not reach Spark master %s: %s", host, e) + last_exc = e + raise last_exc + + def is_job_active(self, status: str) -> bool: + # RELAUNCHING: driver is being restarted after a failure, still alive. + # UNKNOWN: master is in failure recovery, state is temporarily unavailable. + # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala + return status.upper() in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN") + + def is_job_succeeded(self, status: str) -> bool: + # standalone and YARN both use FINISHED + return status.upper() == "FINISHED" + + def poll_until_complete(self, external_id: str, context: Context) -> None: + self.operator.log.info("Polling driver %s until completion", external_id) + self.hook._driver_id = external_id + try: + self.hook._start_driver_status_tracking() + if self.hook._driver_status != "FINISHED": + raise RuntimeError(f"Driver {external_id} exited with status {self.hook._driver_status}") + finally: + # post-submit commands must fire whether the job succeeded or failed. + self.hook._run_post_submit_commands() + + def on_kill(self) -> None: + self.hook.on_kill() + + class SparkSubmitOperator(ResumableJobMixin, BaseOperator): """ Wrap the spark-submit binary to kick off a spark-submit job; requires "spark-submit" binary in the PATH. @@ -301,82 +478,26 @@ def execute(self, context: Context) -> None: return self.get_job_result(driver_id, context) hook.submit(self.application) - def submit_job(self, context: Context) -> str | None: + @property + def _backend(self) -> _SparkSubmitDeploymentBackend: if self._hook is None: self._hook = self._get_hook() - if self._hook._is_kubernetes: - self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false" - self._hook.submit(self.application) - pod_name = self._hook._kubernetes_driver_pod - namespace = self._hook._connection["namespace"] - if not pod_name: - raise RuntimeError("spark-submit did not capture a K8s driver pod name") - external_id = f"{namespace}:{pod_name}" - self.log.info("Spark K8s driver pod submitted: %s", external_id) - return external_id - if self._hook._is_yarn_cluster_mode: - if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true": - raise ValueError( - "spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts" - "with the need to exit spark-submit immediately to persist the application ID for tracking. " - "Either remove the explicit conf or set reconnect_on_retry=False." - ) - self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false" - self._hook.submit(self.application) - app_id = self._hook._yarn_application_id - if not app_id: - raise RuntimeError("spark-submit did not produce a YARN application ID") - self.log.info("YARN application submitted: %s", app_id) - return app_id - driver_id = self._hook.submit(self.application) - if not driver_id: - raise RuntimeError("spark-submit did not return a driver ID") - self.log.info("Spark driver submitted: %s", driver_id) - return driver_id + if not hasattr(self, "__backend") or self.__backend is None or self.__backend.hook is not self._hook: + if self._hook._is_yarn_cluster_mode: + self.__backend = _YarnSparkSubmitBackend(self, self._hook) + elif self._hook._is_kubernetes: + self.__backend = _KubernetesSparkSubmitBackend(self, self._hook) + else: + self.__backend = _StandaloneSparkSubmitBackend(self, self._hook) + return self.__backend + + def submit_job(self, context: Context) -> str | None: + return self._backend.submit_job(context) def get_job_status(self, external_id: JsonValue, context: Context) -> str: # called from submit_job which always returns a str (Spark driver IDs are strings) external_id = cast("str", external_id) - if self._hook is None: - self._hook = self._get_hook() - if self._hook._is_yarn_cluster_mode: - return self._hook.query_yarn_application_status(external_id) - if self._hook._is_kubernetes: - if (task_state_store := context.get("task_state_store")) is not None: - if (cached := task_state_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None: - if not isinstance(cached, str): - raise ValueError(f"Cached K8s driver status is not a string: {cached!r}") - return cached - if kube_client is None: - raise RuntimeError( - "apache-airflow-providers-cncf-kubernetes is required to query K8s pod status" - ) - namespace, pod_name = self._parse_k8s_external_id(external_id) - try: - client = kube_client.get_kube_client() - pod = client.read_namespaced_pod(pod_name, namespace) - return pod.status.phase or "Pending" - except kube_client.ApiException as e: - if e.status == 404: - return "NotFound" - raise - scheme = self._hook._connection.get("rest_scheme", "http") - rest_port = self._hook._connection.get("rest_port", 6066) - # HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order. - # The master URL port (e.g. 7077) is the RPC port — not the REST API port. - # Use rest-port connection extra to override spark.master.rest.port (default 6066). - master_urls = self._hook._connection["master"].replace("spark://", "").split(",") - last_exc: Exception = RuntimeError("No Spark masters to query") - for m in master_urls: - host = m.strip().split(":")[0] - url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}" - try: - status = self._fetch_driver_status(url, external_id) - return status - except Exception as e: - self.log.warning("Could not reach Spark master %s: %s", host, e) - last_exc = e - raise last_exc + return self._backend.get_job_status(external_id, context) @staticmethod def _parse_k8s_external_id(external_id: str) -> tuple[str, str]: @@ -402,76 +523,21 @@ def _fetch_driver_status(self, url: str, external_id: str) -> str: return status def is_job_active(self, status: str) -> bool: - if self._hook is None: - self._hook = self._get_hook() - status = status.upper() - if self._hook._is_yarn_cluster_mode: - # https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html - return status in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"} - if self._hook._is_kubernetes: - return status in ("PENDING", "RUNNING") - # RELAUNCHING: driver is being restarted after a failure, still alive. - # UNKNOWN: master is in failure recovery, state is temporarily unavailable. - # https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala - return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN") + return self._backend.is_job_active(status) def is_job_succeeded(self, status: str) -> bool: - if self._hook is None: - self._hook = self._get_hook() - status = status.upper() - if self._hook._is_yarn_cluster_mode: - return status == "SUCCEEDED" - if self._hook._is_kubernetes: - return status == "SUCCEEDED" - # standalone and YARN both use FINISHED - return status == "FINISHED" + return self._backend.is_job_succeeded(status) def poll_until_complete(self, external_id: JsonValue, context: Context) -> None: # called from submit_job which always returns a str (Spark driver IDs are strings) external_id = cast("str", external_id) - if self._hook is None: - self._hook = self._get_hook() - if self._hook._is_yarn_cluster_mode: - try: - self._hook._start_yarn_application_status_tracking(external_id) - finally: - self._hook._run_post_submit_commands() - return - if self._hook._is_kubernetes: - if external_id is not None: - _, pod_name = self._parse_k8s_external_id(external_id) - self._hook._kubernetes_driver_pod = pod_name - terminal_phase = self._hook._poll_k8s_driver_via_api() - # Cache only when the pod actually reached Succeeded, the 404/vanished path - # returns None for cases like: pod deleted by on_kill or garbage collected after failure) - # and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission. - if terminal_phase == "Succeeded" and self.reconnect_on_retry: - if (task_state_store := context.get("task_state_store")) is not None: - task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded") - return - - self.log.info("Polling driver %s until completion", external_id) - self._hook._driver_id = external_id - try: - self._hook._start_driver_status_tracking() - if self._hook._driver_status != "FINISHED": - raise RuntimeError(f"Driver {external_id} exited with status {self._hook._driver_status}") - finally: - # post-submit commands must fire whether the job succeeded or failed. - self._hook._run_post_submit_commands() + self._backend.poll_until_complete(external_id, context) def get_job_result(self, external_id: JsonValue, context: Context) -> None: return None def on_kill(self) -> None: - if self._hook is None: - self._hook = self._get_hook() - if self._hook._is_yarn_cluster_mode and self._hook._yarn_application_id: - # spark-submit has already exited (waitAppCompletion=false), so the hook's - # CLI-based kill has nothing to terminate. Kill the YARN app via REST API instead. - self._hook._kill_yarn_application(self._hook._yarn_application_id) - else: - self._hook.on_kill() + self._backend.on_kill() def _get_hook(self) -> SparkSubmitHook: return SparkSubmitHook(