Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,183 @@ def execute_resumable(self, context):
from airflow.providers.common.compat.sdk import Context


class _SparkSubmitDeploymentBackend:
"""Base interface for private Spark submit deployment backends."""

def __init__(self, operator: SparkSubmitOperator, hook: SparkSubmitHook) -> None:
self.operator = operator
self.hook = hook

def submit_job(self, context: Context) -> str | None:
raise NotImplementedError()

def get_job_status(self, external_id: str, context: Context) -> str:
raise NotImplementedError()

def is_job_active(self, status: str) -> bool:
raise NotImplementedError()

def is_job_succeeded(self, status: str) -> bool:
raise NotImplementedError()

def poll_until_complete(self, external_id: str, context: Context) -> None:
raise NotImplementedError()

def on_kill(self) -> None:
raise NotImplementedError()


class _KubernetesSparkSubmitBackend(_SparkSubmitDeploymentBackend):
"""Logic for tracking Spark driver pods in Kubernetes."""

def submit_job(self, context: Context) -> str | None:
self.hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false"
self.hook.submit(self.operator.application)
pod_name = self.hook._kubernetes_driver_pod
namespace = self.hook._connection["namespace"]
if not pod_name:
raise RuntimeError("spark-submit did not capture a K8s driver pod name")
external_id = f"{namespace}:{pod_name}"
self.operator.log.info("Spark K8s driver pod submitted: %s", external_id)
return external_id

def get_job_status(self, external_id: str, context: Context) -> str:
if (task_state_store := context.get("task_state_store")) is not None:
if (cached := task_state_store.get(self.operator._K8S_DRIVER_STATUS_KEY)) is not None:
if not isinstance(cached, str):
raise ValueError(f"Cached K8s driver status is not a string: {cached!r}")
return cached
if kube_client is None:
raise RuntimeError("apache-airflow-providers-cncf-kubernetes is required to query K8s pod status")
namespace, pod_name = self.operator._parse_k8s_external_id(external_id)
try:
client = kube_client.get_kube_client()
pod = client.read_namespaced_pod(pod_name, namespace)
return pod.status.phase or "Pending"
except kube_client.ApiException as e:
if e.status == 404:
return "NotFound"
raise

def is_job_active(self, status: str) -> bool:
return status.upper() in ("PENDING", "RUNNING")

def is_job_succeeded(self, status: str) -> bool:
return status.upper() == "SUCCEEDED"

def poll_until_complete(self, external_id: str, context: Context) -> None:
if external_id is not None:
_, pod_name = self.operator._parse_k8s_external_id(external_id)
self.hook._kubernetes_driver_pod = pod_name
terminal_phase = self.hook._poll_k8s_driver_via_api()
# Cache only when the pod actually reached Succeeded, the 404/vanished path
# returns None for cases like: pod deleted by on_kill or garbage collected after failure)
# and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission.
if terminal_phase == "Succeeded" and self.operator.reconnect_on_retry:
if (task_state_store := context.get("task_state_store")) is not None:
task_state_store.set(self.operator._K8S_DRIVER_STATUS_KEY, "Succeeded")

def on_kill(self) -> None:
self.hook.on_kill()


class _YarnSparkSubmitBackend(_SparkSubmitDeploymentBackend):
"""Logic for tracking Spark applications in YARN cluster mode."""

def submit_job(self, context: Context) -> str | None:
if self.hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true":
raise ValueError(
"spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts"
"with the need to exit spark-submit immediately to persist the application ID for tracking. "
"Either remove the explicit conf or set reconnect_on_retry=False."
)
self.hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
self.hook.submit(self.operator.application)
app_id = self.hook._yarn_application_id
if not app_id:
raise RuntimeError("spark-submit did not produce a YARN application ID")
self.operator.log.info("YARN application submitted: %s", app_id)
return app_id

def get_job_status(self, external_id: str, context: Context) -> str:
return self.hook.query_yarn_application_status(external_id)

def is_job_active(self, status: str) -> bool:
# https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
return status.upper() in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"}

def is_job_succeeded(self, status: str) -> bool:
return status.upper() == "SUCCEEDED"

def poll_until_complete(self, external_id: str, context: Context) -> None:
try:
self.hook._start_yarn_application_status_tracking(external_id)
finally:
self.hook._run_post_submit_commands()

def on_kill(self) -> None:
if self.hook._yarn_application_id:
# spark-submit has already exited (waitAppCompletion=false), so the hook's
# CLI-based kill has nothing to terminate. Kill the YARN app via REST API instead.
self.hook._kill_yarn_application(self.hook._yarn_application_id)
else:
self.hook.on_kill()


class _StandaloneSparkSubmitBackend(_SparkSubmitDeploymentBackend):
"""Logic for tracking Spark driver status in Spark standalone mode."""

def submit_job(self, context: Context) -> str | None:
driver_id = self.hook.submit(self.operator.application)
if not driver_id:
raise RuntimeError("spark-submit did not return a driver ID")
self.operator.log.info("Spark driver submitted: %s", driver_id)
return driver_id

def get_job_status(self, external_id: str, context: Context) -> str:
scheme = self.hook._connection.get("rest_scheme", "http")
rest_port = self.hook._connection.get("rest_port", 6066)
# HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order.
# The master URL port (e.g. 7077) is the RPC port — not the REST API port.
# Use rest-port connection extra to override spark.master.rest.port (default 6066).
master_urls = self.hook._connection["master"].replace("spark://", "").split(",")
last_exc: Exception = RuntimeError("No Spark masters to query")
for m in master_urls:
host = m.strip().split(":")[0]
url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
try:
status = self.operator._fetch_driver_status(url, external_id)
return status
except Exception as e:
self.operator.log.warning("Could not reach Spark master %s: %s", host, e)
last_exc = e
raise last_exc

def is_job_active(self, status: str) -> bool:
# RELAUNCHING: driver is being restarted after a failure, still alive.
# UNKNOWN: master is in failure recovery, state is temporarily unavailable.
# https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
return status.upper() in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")

def is_job_succeeded(self, status: str) -> bool:
# standalone and YARN both use FINISHED
return status.upper() == "FINISHED"

def poll_until_complete(self, external_id: str, context: Context) -> None:
self.operator.log.info("Polling driver %s until completion", external_id)
self.hook._driver_id = external_id
try:
self.hook._start_driver_status_tracking()
if self.hook._driver_status != "FINISHED":
raise RuntimeError(f"Driver {external_id} exited with status {self.hook._driver_status}")
finally:
# post-submit commands must fire whether the job succeeded or failed.
self.hook._run_post_submit_commands()

def on_kill(self) -> None:
self.hook.on_kill()


class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
"""
Wrap the spark-submit binary to kick off a spark-submit job; requires "spark-submit" binary in the PATH.
Expand Down Expand Up @@ -301,82 +478,26 @@ def execute(self, context: Context) -> None:
return self.get_job_result(driver_id, context)
hook.submit(self.application)

def submit_job(self, context: Context) -> str | None:
@property
def _backend(self) -> _SparkSubmitDeploymentBackend:
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_kubernetes:
self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false"
self._hook.submit(self.application)
pod_name = self._hook._kubernetes_driver_pod
namespace = self._hook._connection["namespace"]
if not pod_name:
raise RuntimeError("spark-submit did not capture a K8s driver pod name")
external_id = f"{namespace}:{pod_name}"
self.log.info("Spark K8s driver pod submitted: %s", external_id)
return external_id
if self._hook._is_yarn_cluster_mode:
if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", "").strip().lower() == "true":
raise ValueError(
"spark.yarn.submit.waitAppCompletion=true cannot be set for cluster mode as it conflicts"
"with the need to exit spark-submit immediately to persist the application ID for tracking. "
"Either remove the explicit conf or set reconnect_on_retry=False."
)
self._hook._conf["spark.yarn.submit.waitAppCompletion"] = "false"
self._hook.submit(self.application)
app_id = self._hook._yarn_application_id
if not app_id:
raise RuntimeError("spark-submit did not produce a YARN application ID")
self.log.info("YARN application submitted: %s", app_id)
return app_id
driver_id = self._hook.submit(self.application)
if not driver_id:
raise RuntimeError("spark-submit did not return a driver ID")
self.log.info("Spark driver submitted: %s", driver_id)
return driver_id
if not hasattr(self, "__backend") or self.__backend is None or self.__backend.hook is not self._hook:
if self._hook._is_yarn_cluster_mode:
self.__backend = _YarnSparkSubmitBackend(self, self._hook)
elif self._hook._is_kubernetes:
self.__backend = _KubernetesSparkSubmitBackend(self, self._hook)
else:
self.__backend = _StandaloneSparkSubmitBackend(self, self._hook)
return self.__backend

def submit_job(self, context: Context) -> str | None:
return self._backend.submit_job(context)

def get_job_status(self, external_id: JsonValue, context: Context) -> str:
# called from submit_job which always returns a str (Spark driver IDs are strings)
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_yarn_cluster_mode:
return self._hook.query_yarn_application_status(external_id)
if self._hook._is_kubernetes:
if (task_state_store := context.get("task_state_store")) is not None:
if (cached := task_state_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None:
if not isinstance(cached, str):
raise ValueError(f"Cached K8s driver status is not a string: {cached!r}")
return cached
if kube_client is None:
raise RuntimeError(
"apache-airflow-providers-cncf-kubernetes is required to query K8s pod status"
)
namespace, pod_name = self._parse_k8s_external_id(external_id)
try:
client = kube_client.get_kube_client()
pod = client.read_namespaced_pod(pod_name, namespace)
return pod.status.phase or "Pending"
except kube_client.ApiException as e:
if e.status == 404:
return "NotFound"
raise
scheme = self._hook._connection.get("rest_scheme", "http")
rest_port = self._hook._connection.get("rest_port", 6066)
# HA master URLs can look like spark://m1:7077,m2:7077 — try each host in order.
# The master URL port (e.g. 7077) is the RPC port — not the REST API port.
# Use rest-port connection extra to override spark.master.rest.port (default 6066).
master_urls = self._hook._connection["master"].replace("spark://", "").split(",")
last_exc: Exception = RuntimeError("No Spark masters to query")
for m in master_urls:
host = m.strip().split(":")[0]
url = f"{scheme}://{host}:{rest_port}/v1/submissions/status/{external_id}"
try:
status = self._fetch_driver_status(url, external_id)
return status
except Exception as e:
self.log.warning("Could not reach Spark master %s: %s", host, e)
last_exc = e
raise last_exc
return self._backend.get_job_status(external_id, context)

@staticmethod
def _parse_k8s_external_id(external_id: str) -> tuple[str, str]:
Expand All @@ -402,76 +523,21 @@ def _fetch_driver_status(self, url: str, external_id: str) -> str:
return status

def is_job_active(self, status: str) -> bool:
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
if self._hook._is_yarn_cluster_mode:
# https://hadoop.apache.org/docs/stable/hadoop-yarn/hadoop-yarn-site/ResourceManagerRest.html
return status in {"NEW", "NEW_SAVING", "SUBMITTED", "ACCEPTED", "RUNNING"}
if self._hook._is_kubernetes:
return status in ("PENDING", "RUNNING")
# RELAUNCHING: driver is being restarted after a failure, still alive.
# UNKNOWN: master is in failure recovery, state is temporarily unavailable.
# https://github.com/apache/spark/blob/master/core/src/main/scala/org/apache/spark/deploy/master/DriverState.scala
return status in ("SUBMITTED", "RUNNING", "RELAUNCHING", "UNKNOWN")
return self._backend.is_job_active(status)

def is_job_succeeded(self, status: str) -> bool:
if self._hook is None:
self._hook = self._get_hook()
status = status.upper()
if self._hook._is_yarn_cluster_mode:
return status == "SUCCEEDED"
if self._hook._is_kubernetes:
return status == "SUCCEEDED"
# standalone and YARN both use FINISHED
return status == "FINISHED"
return self._backend.is_job_succeeded(status)

def poll_until_complete(self, external_id: JsonValue, context: Context) -> None:
# called from submit_job which always returns a str (Spark driver IDs are strings)
external_id = cast("str", external_id)
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_yarn_cluster_mode:
try:
self._hook._start_yarn_application_status_tracking(external_id)
finally:
self._hook._run_post_submit_commands()
return
if self._hook._is_kubernetes:
if external_id is not None:
_, pod_name = self._parse_k8s_external_id(external_id)
self._hook._kubernetes_driver_pod = pod_name
terminal_phase = self._hook._poll_k8s_driver_via_api()
# Cache only when the pod actually reached Succeeded, the 404/vanished path
# returns None for cases like: pod deleted by on_kill or garbage collected after failure)
# and must not be cached, otherwise a retry would see "Succeeded" and skip resubmission.
if terminal_phase == "Succeeded" and self.reconnect_on_retry:
if (task_state_store := context.get("task_state_store")) is not None:
task_state_store.set(self._K8S_DRIVER_STATUS_KEY, "Succeeded")
return

self.log.info("Polling driver %s until completion", external_id)
self._hook._driver_id = external_id
try:
self._hook._start_driver_status_tracking()
if self._hook._driver_status != "FINISHED":
raise RuntimeError(f"Driver {external_id} exited with status {self._hook._driver_status}")
finally:
# post-submit commands must fire whether the job succeeded or failed.
self._hook._run_post_submit_commands()
self._backend.poll_until_complete(external_id, context)

def get_job_result(self, external_id: JsonValue, context: Context) -> None:
return None

def on_kill(self) -> None:
if self._hook is None:
self._hook = self._get_hook()
if self._hook._is_yarn_cluster_mode and self._hook._yarn_application_id:
# spark-submit has already exited (waitAppCompletion=false), so the hook's
# CLI-based kill has nothing to terminate. Kill the YARN app via REST API instead.
self._hook._kill_yarn_application(self._hook._yarn_application_id)
else:
self._hook.on_kill()
self._backend.on_kill()

def _get_hook(self) -> SparkSubmitHook:
return SparkSubmitHook(
Expand Down
Loading