Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 33 additions & 3 deletions pathwaysutils/experimental/shared_pathways_service/isc_pathways.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,38 @@ class ProxyOptions:
Attributes:
use_insecure_credentials: Whether to use insecure gRPC credentials for the
proxy server.
xla_flags: A list of XLA flags to pass to the proxy server.
"""
use_insecure_credentials: bool = False
xla_flags: list[str] = dataclasses.field(default_factory=list)

@classmethod
def from_list(cls, options: Iterable[str] | None) -> "ProxyOptions":
"""Creates a ProxyOptions object from a list of 'key:value' strings."""
use_insecure = False
xla_flags = []
for option in options or []:
if ":" in option:
key, value = option.split(":", 1)
if key.strip().lower() == "use_insecure_credentials":
key_strip = key.strip().lower()
if key_strip == "use_insecure_credentials":
use_insecure = value.strip().lower() == "true"
return cls(use_insecure_credentials=use_insecure)
elif key_strip == "xla_flags":
val_strip = value.strip()
if (
val_strip
and val_strip.startswith(('"', "'"))
and val_strip.endswith(val_strip[0])
):
val_to_split = val_strip[1:-1]
else:
val_to_split = val_strip
xla_flags = val_to_split.split()

if xla_flags:
validators.validate_xla_flags(xla_flags)

return cls(use_insecure_credentials=use_insecure, xla_flags=xla_flags)


def _deploy_pathways_proxy_server(
Expand Down Expand Up @@ -108,6 +127,13 @@ def _deploy_pathways_proxy_server(
' value: "true"\n'
)

proxy_args_str = ""
if proxy_options.xla_flags:
proxy_args_str = "\n".join(
f" - {flag}" for flag in proxy_options.xla_flags
)
proxy_args_str = "\n" + proxy_args_str

template = string.Template(yaml_template)
substituted_yaml = template.substitute(
PROXY_JOB_NAME=proxy_job_name,
Expand All @@ -118,6 +144,7 @@ def _deploy_pathways_proxy_server(
GCS_SCRATCH_LOCATION=gcs_scratch_location,
PROXY_SERVER_IMAGE=proxy_server_image,
PROXY_ENV=proxy_env_str,
PROXY_ARGS=proxy_args_str,
)

_logger.info("Deploying Pathways proxy: %s", proxy_job_name)
Expand Down Expand Up @@ -423,6 +450,7 @@ def connect(
validators.validate_pathways_service(pathways_service)
validators.validate_tpu_instances(expected_tpu_instances)
validators.validate_proxy_server_image(proxy_server_image)
validators.validate_proxy_options(proxy_options)
_logger.info("Validation complete.")
gke_utils.fetch_cluster_credentials(
cluster_name=cluster, project_id=project, location=region
Expand All @@ -433,6 +461,8 @@ def connect(
)}"
)

proxy_options_obj = ProxyOptions.from_list(proxy_options)

_logger.info("Starting ISCPathways context.")
with _ISCPathways(
cluster=cluster,
Expand All @@ -443,7 +473,7 @@ def connect(
expected_tpu_instances=expected_tpu_instances,
proxy_job_name=proxy_job_name,
proxy_server_image=proxy_server_image,
proxy_options=proxy_options,
proxy_options=proxy_options_obj,
collect_service_metrics=collect_service_metrics,
) as t:
if t.proxy_pod_name:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@
"proxy_options",
None,
"Configuration options for the Pathways proxy. Specify entries in the form"
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
' "key:value". For example: --proxy_options=use_insecure_credentials:true'
' or --proxy_options=xla_flags:"--xla_flag1 --xla_flag2"',
)

flags.DEFINE_bool(
"collect_service_metrics",
False,
Expand All @@ -60,8 +62,6 @@ def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")

proxy_options = isc_pathways.ProxyOptions.from_list(FLAGS.proxy_options)

with isc_pathways.connect(
cluster=FLAGS.cluster,
project=FLAGS.project,
Expand All @@ -72,7 +72,7 @@ def main(argv: Sequence[str]) -> None:
proxy_job_name=FLAGS.proxy_job_name,
proxy_server_image=FLAGS.proxy_server_image
or isc_pathways.DEFAULT_PROXY_IMAGE,
proxy_options=proxy_options,
proxy_options=FLAGS.proxy_options,
collect_service_metrics=FLAGS.collect_service_metrics,
):
orig_matrix = jnp.zeros(5)
Expand Down
20 changes: 4 additions & 16 deletions pathwaysutils/experimental/shared_pathways_service/run_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from absl import logging
from pathwaysutils.experimental.shared_pathways_service import isc_pathways


_CLUSTER = flags.DEFINE_string(
"cluster", None, "The name of the GKE cluster.", required=True
)
Expand Down Expand Up @@ -62,7 +61,8 @@
"proxy_options",
[],
"Configuration options for the Pathways proxy. Specify entries in the form"
' "key:value". For example: --proxy_options=use_insecure_credentials:true',
' "key:value". For example: --proxy_options=use_insecure_credentials:true'
' or --proxy_options=xla_flags:"--xla_flag1 --xla_flag2"',
)
_COMMAND = flags.DEFINE_string(
"command", None, "The command to run on TPUs.", required=True
Expand All @@ -76,17 +76,7 @@
" stored in Cloud Monitoring.",
)

flags.register_validator(
"proxy_options",
lambda value: all(
":" in item
and len(item.split(":")) > 1
and item.split(":", 1)[0]
and item.split(":", 1)[1]
for item in value
),
message='--proxy_options must be in the format "key:value".',
)



def run_command(
Expand Down Expand Up @@ -125,8 +115,6 @@ def run_command(
Raises:
subprocess.CalledProcessError: If the workload command fails.
"""
parsed_proxy_options = isc_pathways.ProxyOptions.from_list(proxy_options)

logging.info("Connecting to Shared Pathways Service...")
with connect_fn(
cluster=cluster,
Expand All @@ -140,7 +128,7 @@ def run_command(
if proxy_server_image
else isc_pathways.DEFAULT_PROXY_IMAGE
),
proxy_options=parsed_proxy_options,
proxy_options=proxy_options,
collect_service_metrics=collect_service_metrics,
):
logging.info("Connection established. Running command: %r", command)
Expand Down
30 changes: 29 additions & 1 deletion pathwaysutils/experimental/shared_pathways_service/validators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,30 @@
"""Validation functions for Shared Pathways Service."""

from collections.abc import Mapping
from collections.abc import Iterable, Mapping
import logging
import re
from typing import Any
from absl import flags

_logger = logging.getLogger(__name__)


def validate_proxy_options(proxy_options: Iterable[str] | None) -> None:
"""Validates that proxy options are in the format 'key:value'."""
if not proxy_options:
return
for item in proxy_options:
if (
":" not in item
or len(item.split(":")) <= 1
or not item.split(":", 1)[0]
or not item.split(":", 1)[1]
):
raise flags.ValidationError(
f'--proxy_options must be in the format "key:value". Got: {item}'
)


def validate_pathways_service(pathways_service: str) -> None:
"""Validates the Pathways service name and port."""
if not pathways_service:
Expand Down Expand Up @@ -105,3 +122,14 @@ def validate_proxy_server_image(proxy_server_image: str) -> None:
f"Proxy server image '{proxy_server_image}' must contain a tag with ':'"
" or a digest with '@'."
)


def validate_xla_flags(xla_flags: Iterable[str] | None) -> None:
"""Validates that all XLA flags start with '--xla_'."""
if not xla_flags:
return
for flag in xla_flags:
if not flag.startswith("--xla_"):
raise flags.ValidationError(
f"XLA flag '{flag}' must start with '--xla_'."
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ spec:
- --server_port=${PROXY_SERVER_PORT}
- --resource_manager_address=${PATHWAYS_HEAD_HOSTNAME}:${PATHWAYS_HEAD_PORT}
- --gcs_scratch_location=${GCS_SCRATCH_LOCATION}
- --virtual_slices=${EXPECTED_INSTANCES}
- --virtual_slices=${EXPECTED_INSTANCES}${PROXY_ARGS}
env:
${PROXY_ENV}
ports:
Expand Down
Loading