From 37a870b45c8d2ffae81a4c4b94524bbd195ad813 Mon Sep 17 00:00:00 2001 From: Akanksha Gupta Date: Wed, 29 Apr 2026 16:15:43 -0700 Subject: [PATCH] Add support for passing XLA flags to the Pathways proxy This change introduces an `xla_flags` option within `ProxyOptions`, allowing users to specify XLA flags via the `--proxy_options` command-line argument. The flags are parsed, validated to ensure they start with "--xla_", and then injected into the Kubernetes YAML for the proxy server. New validator functions are added to handle the parsing and validation logic. PiperOrigin-RevId: 907831230 --- .../shared_pathways_service/isc_pathways.py | 36 +++++++++++++++++-- .../run_connect_example.py | 8 ++--- .../shared_pathways_service/run_workload.py | 20 +++-------- .../shared_pathways_service/validators.py | 30 +++++++++++++++- .../yamls/pw-proxy.yaml | 2 +- 5 files changed, 71 insertions(+), 25 deletions(-) diff --git a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py index ed3a4cc..9af27e4 100644 --- a/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py +++ b/pathwaysutils/experimental/shared_pathways_service/isc_pathways.py @@ -47,19 +47,38 @@ class ProxyOptions: Attributes: use_insecure_credentials: Whether to use insecure gRPC credentials for the proxy server. + xla_flags: A list of XLA flags to pass to the proxy server. """ use_insecure_credentials: bool = False + xla_flags: list[str] = dataclasses.field(default_factory=list) @classmethod def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions": """Creates a ProxyOptions object from a list of 'key:value' strings.""" use_insecure = False + xla_flags = [] for option in options or []: if ":" in option: key, value = option.split(":", 1) - if key.strip().lower() == "use_insecure_credentials": + key_strip = key.strip().lower() + if key_strip == "use_insecure_credentials": use_insecure = value.strip().lower() == "true" - return cls(use_insecure_credentials=use_insecure) + elif key_strip == "xla_flags": + val_strip = value.strip() + if ( + val_strip + and val_strip.startswith(('"', "'")) + and val_strip.endswith(val_strip[0]) + ): + val_to_split = val_strip[1:-1] + else: + val_to_split = val_strip + xla_flags = val_to_split.split() + + if xla_flags: + validators.validate_xla_flags(xla_flags) + + return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags) def _deploy_pathways_proxy_server( @@ -108,6 +127,13 @@ def _deploy_pathways_proxy_server( ' value: "true"\n' ) + proxy_args_str = "" + if proxy_options.xla_flags: + proxy_args_str = "\n".join( + f" - {flag}" for flag in proxy_options.xla_flags + ) + proxy_args_str = "\n" + proxy_args_str + template = string.Template(yaml_template) substituted_yaml = template.substitute( PROXY_JOB_NAME=proxy_job_name, @@ -118,6 +144,7 @@ def _deploy_pathways_proxy_server( GCS_SCRATCH_LOCATION=gcs_scratch_location, PROXY_SERVER_IMAGE=proxy_server_image, PROXY_ENV=proxy_env_str, + PROXY_ARGS=proxy_args_str, ) _logger.info("Deploying Pathways proxy: %s", proxy_job_name) @@ -423,6 +450,7 @@ def connect( validators.validate_pathways_service(pathways_service) validators.validate_tpu_instances(expected_tpu_instances) validators.validate_proxy_server_image(proxy_server_image) + validators.validate_proxy_options(proxy_options) _logger.info("Validation complete.") gke_utils.fetch_cluster_credentials( cluster_name=cluster, project_id=project, location=region @@ -433,6 +461,8 @@ def connect( )}" ) + proxy_options_obj = ProxyOptions.from_list(proxy_options) + _logger.info("Starting ISCPathways context.") with _ISCPathways( cluster=cluster, @@ -443,7 +473,7 @@ def connect( expected_tpu_instances=expected_tpu_instances, proxy_job_name=proxy_job_name, proxy_server_image=proxy_server_image, - proxy_options=proxy_options, + proxy_options=proxy_options_obj, collect_service_metrics=collect_service_metrics, ) as t: if t.proxy_pod_name: diff --git a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py index f07220e..f14df73 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_connect_example.py @@ -39,8 +39,10 @@ "proxy_options", None, "Configuration options for the Pathways proxy. Specify entries in the form" - ' "key:value". For example: --proxy_options=use_insecure_credentials:true', + ' "key:value". For example: --proxy_options=use_insecure_credentials:true' + ' or --proxy_options=xla_flags:"--xla_flag1 --xla_flag2"', ) + flags.DEFINE_bool( "collect_service_metrics", False, @@ -60,8 +62,6 @@ def main(argv: Sequence[str]) -> None: if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") - proxy_options = isc_pathways.ProxyOptions.from_list(FLAGS.proxy_options) - with isc_pathways.connect( cluster=FLAGS.cluster, project=FLAGS.project, @@ -72,7 +72,7 @@ def main(argv: Sequence[str]) -> None: proxy_job_name=FLAGS.proxy_job_name, proxy_server_image=FLAGS.proxy_server_image or isc_pathways.DEFAULT_PROXY_IMAGE, - proxy_options=proxy_options, + proxy_options=FLAGS.proxy_options, collect_service_metrics=FLAGS.collect_service_metrics, ): orig_matrix = jnp.zeros(5) diff --git a/pathwaysutils/experimental/shared_pathways_service/run_workload.py b/pathwaysutils/experimental/shared_pathways_service/run_workload.py index c666f52..be2e4bf 100644 --- a/pathwaysutils/experimental/shared_pathways_service/run_workload.py +++ b/pathwaysutils/experimental/shared_pathways_service/run_workload.py @@ -28,7 +28,6 @@ from absl import logging from pathwaysutils.experimental.shared_pathways_service import isc_pathways - _CLUSTER = flags.DEFINE_string( "cluster", None, "The name of the GKE cluster.", required=True ) @@ -62,7 +61,8 @@ "proxy_options", [], "Configuration options for the Pathways proxy. Specify entries in the form" - ' "key:value". For example: --proxy_options=use_insecure_credentials:true', + ' "key:value". For example: --proxy_options=use_insecure_credentials:true' + ' or --proxy_options=xla_flags:"--xla_flag1 --xla_flag2"', ) _COMMAND = flags.DEFINE_string( "command", None, "The command to run on TPUs.", required=True @@ -76,17 +76,7 @@ " stored in Cloud Monitoring.", ) -flags.register_validator( - "proxy_options", - lambda value: all( - ":" in item - and len(item.split(":")) > 1 - and item.split(":", 1)[0] - and item.split(":", 1)[1] - for item in value - ), - message='--proxy_options must be in the format "key:value".', -) + def run_command( @@ -125,8 +115,6 @@ def run_command( Raises: subprocess.CalledProcessError: If the workload command fails. """ - parsed_proxy_options = isc_pathways.ProxyOptions.from_list(proxy_options) - logging.info("Connecting to Shared Pathways Service...") with connect_fn( cluster=cluster, @@ -140,7 +128,7 @@ def run_command( if proxy_server_image else isc_pathways.DEFAULT_PROXY_IMAGE ), - proxy_options=parsed_proxy_options, + proxy_options=proxy_options, collect_service_metrics=collect_service_metrics, ): logging.info("Connection established. Running command: %r", command) diff --git a/pathwaysutils/experimental/shared_pathways_service/validators.py b/pathwaysutils/experimental/shared_pathways_service/validators.py index f5a6216..818aaba 100644 --- a/pathwaysutils/experimental/shared_pathways_service/validators.py +++ b/pathwaysutils/experimental/shared_pathways_service/validators.py @@ -1,13 +1,30 @@ """Validation functions for Shared Pathways Service.""" -from collections.abc import Mapping +from collections.abc import Iterable, Mapping import logging import re from typing import Any +from absl import flags _logger = logging.getLogger(__name__) +def validate_proxy_options(proxy_options: Iterable[str] | None) -> None: + """Validates that proxy options are in the format 'key:value'.""" + if not proxy_options: + return + for item in proxy_options: + if ( + ":" not in item + or len(item.split(":")) <= 1 + or not item.split(":", 1)[0] + or not item.split(":", 1)[1] + ): + raise flags.ValidationError( + f'--proxy_options must be in the format "key:value". Got: {item}' + ) + + def validate_pathways_service(pathways_service: str) -> None: """Validates the Pathways service name and port.""" if not pathways_service: @@ -105,3 +122,14 @@ def validate_proxy_server_image(proxy_server_image: str) -> None: f"Proxy server image '{proxy_server_image}' must contain a tag with ':'" " or a digest with '@'." ) + + +def validate_xla_flags(xla_flags: Iterable[str] | None) -> None: + """Validates that all XLA flags start with '--xla_'.""" + if not xla_flags: + return + for flag in xla_flags: + if not flag.startswith("--xla_"): + raise flags.ValidationError( + f"XLA flag '{flag}' must start with '--xla_'." + ) diff --git a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml index 91d6aa4..e0ee244 100644 --- a/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml +++ b/pathwaysutils/experimental/shared_pathways_service/yamls/pw-proxy.yaml @@ -20,7 +20,7 @@ spec: - --server_port=${PROXY_SERVER_PORT} - --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT} - --gcs_scratch_location=${GCS_SCRATCH_LOCATION} - - --virtual_slices=${EXPECTED_INSTANCES} + - --virtual_slices=${EXPECTED_INSTANCES}${PROXY_ARGS} env: ${PROXY_ENV} ports: