From 69aad932813d7f6ecc85b0823f3d81f40cf865a4 Mon Sep 17 00:00:00 2001 From: Abhinav Singh Date: Thu, 18 Jun 2026 13:41:08 -0700 Subject: [PATCH] Add support for vscode and gcs for sps environments. PiperOrigin-RevId: 934528376 --- .../scripts/jupyter_setup.sh | 11 + .../scripts/vscode_setup.sh | 71 ++++ .../start_vscode_on_cpu_np.py | 350 ++++++++++++++++-- 3 files changed, 408 insertions(+), 24 deletions(-) create mode 100644 pathwaysutils/experimental/shared_pathways_service/scripts/jupyter_setup.sh create mode 100644 pathwaysutils/experimental/shared_pathways_service/scripts/vscode_setup.sh diff --git a/pathwaysutils/experimental/shared_pathways_service/scripts/jupyter_setup.sh b/pathwaysutils/experimental/shared_pathways_service/scripts/jupyter_setup.sh new file mode 100644 index 0000000..5ce12ad --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/scripts/jupyter_setup.sh @@ -0,0 +1,11 @@ +#!/bin/bash +# jupyter_setup.sh + +# Update and install dependencies +sudo apt update > /dev/null +pip3 install jupyterlab > /dev/null + +# Launch Jupyter Lab +# We use {PORT} as a placeholder to be replaced by Python +echo "Starting Jupyter Lab on port {PORT}..." +jupyter lab --allow-root --ip=127.0.0.1 --port={PORT} diff --git a/pathwaysutils/experimental/shared_pathways_service/scripts/vscode_setup.sh b/pathwaysutils/experimental/shared_pathways_service/scripts/vscode_setup.sh new file mode 100644 index 0000000..35ec1a7 --- /dev/null +++ b/pathwaysutils/experimental/shared_pathways_service/scripts/vscode_setup.sh @@ -0,0 +1,71 @@ +#!/bin/bash +# Variables +PORT="{PORT}" +BUCKET_NAME="{BUCKET}" +POD_PATTERN="{WORKLOAD}" + +# Paths +WORKSPACE_DIR="$HOME/vscode_workspace" +mkdir -p "$WORKSPACE_DIR" + +# 1. Install Dependencies +echo "Step 1: Installing dependencies..." +sudo apt-get update -qq >/dev/null +sudo apt-get install -y -qq inotify-tools curl >/dev/null + +if ! command -v code-server &> /dev/null; then + echo "Installing code-server..." + curl -fsSL https://code-server.dev/install.sh | sh >/dev/null 2>&1 +fi + +if ! command -v gsutil &> /dev/null; then + echo "Installing Google Cloud SDK..." + sudo apt-get install -y -qq apt-transport-https ca-certificates gnupg >/dev/null + echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | sudo tee -a /etc/apt/sources.list.d/google-cloud-sdk.list >/dev/null + curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - >/dev/null 2>&1 + sudo apt-get update -qq >/dev/null && sudo apt-get install -y -qq google-cloud-cli >/dev/null +fi + +# 2. Bucket Logic for Syncing Workspace +if [ -n "$BUCKET_NAME" ]; then + GCS_PATH="gs://$BUCKET_NAME/vscode_workspaces/$POD_PATTERN" + echo "Step 2: Syncing with: $GCS_PATH" + + # --- Initial Recovery or Init --- + echo " -> Checking GCS for existing files..." + if gsutil ls "$GCS_PATH" >/dev/null 2>&1; then + echo " ✅ Found files! Downloading..." + gsutil -m rsync -r "$GCS_PATH" "$WORKSPACE_DIR" + else + echo " ⚠️ No existing workspace found. initializing..." + # Create a placeholder and upload immediately so directory appears in GCS + touch "$WORKSPACE_DIR/.workspace_init" + gsutil cp "$WORKSPACE_DIR/.workspace_init" "$GCS_PATH/.workspace_init" + echo " ✅ Initialized GCS directory with placeholder." + fi + # --- Background Watcher Process --- + ( + echo "Watcher started at $(date)" > $HOME/sync.log + while true; do + inotifywait -r -e modify,create,delete,move "$WORKSPACE_DIR" >> $HOME/sync.log 2>&1 + echo "[AutoSync] Change detected. Waiting 5s..." >> $HOME/sync.log + sleep 5 + + echo "[AutoSync] Syncing (Mirroring)..." >> $HOME/sync.log + # Explicitly use -r (recursive) and -d (delete) + gsutil -m rsync -r -d "$WORKSPACE_DIR" "$GCS_PATH" >> $HOME/sync.log 2>&1 + + echo "[AutoSync] Done at $(date)" >> $HOME/sync.log + done + ) & +else + echo "Step 2: No bucket provided. Skipping sync." +fi + +# 3. Launch VS Code +# Kill any existing remote process on this port +kill -9 $(lsof -t -i:$PORT) 2>/dev/null +echo "Step 3: Launching VS Code..." +echo " -> Local URL: http://127.0.0.1:$PORT" + +code-server --bind-addr 127.0.0.1:$PORT --auth none "$WORKSPACE_DIR" diff --git a/pathwaysutils/experimental/shared_pathways_service/start_vscode_on_cpu_np.py b/pathwaysutils/experimental/shared_pathways_service/start_vscode_on_cpu_np.py index 872261b..8d64800 100644 --- a/pathwaysutils/experimental/shared_pathways_service/start_vscode_on_cpu_np.py +++ b/pathwaysutils/experimental/shared_pathways_service/start_vscode_on_cpu_np.py @@ -1,13 +1,22 @@ -"""Deploys VSCode on a GKE CPU node pool and sets up port forwarding.""" +"""Deploys VSCode on a GKE CPU node pool, or connects to an existing running Pathways pod.""" +import multiprocessing import os import random +import re +import select +import signal +import socket import string +import threading import time from absl import app from absl import flags from absl import logging +from kubernetes import client +from kubernetes import config +from kubernetes import stream from pathwaysutils.experimental.shared_pathways_service import gke_utils FLAGS = flags.FLAGS @@ -37,12 +46,209 @@ False, "If true, only print the generated YAML without deploying.", ) +_WORKLOAD = flags.DEFINE_string( + "workload", + None, + "Regex pattern to match the running Pod name. If provided, connects to an " + "existing running pod instead of deploying a fresh one.", +) +_MODE = flags.DEFINE_enum( + "mode", + "vscode", + ["vscode", "jupyter"], + "IDE mode to launch (vscode or jupyter).", +) +_PORT = flags.DEFINE_integer( + "port", + 8888, + "Port to forward for the IDE when connecting to an existing workload.", +) +_BUCKET = flags.DEFINE_string( + "bucket", + "", + "GCS Bucket for syncing state (VS Code only).", +) +_CHECK_ACTIVE_SESSION = flags.DEFINE_boolean( + "check_active_session", + False, + "Check if session exists. If running, skip setup and just tunnel.", +) +_NON_PATHWAYS = flags.DEFINE_boolean( + "non_pathways", + False, + "If true, use workload name directly as search pattern instead of " + "pathways-head pattern.", +) _TEMPLATE_FILE = os.path.join( os.path.dirname(__file__), "yamls/code-server.yaml" ) +def _load_k8s_config() -> None: + try: + config.load_kube_config() + except Exception: # pylint: disable=broad-exception-caught + config.load_incluster_config() + + +def _find_pod(pattern: str) -> str: + _load_k8s_config() + v1 = client.CoreV1Api() + pods = v1.list_namespaced_pod(_NAMESPACE.value) + regex = re.compile(pattern) + + for pod in pods.items: + if regex.search(pod.metadata.name) and pod.status.phase == "Running": + return pod.metadata.name + + raise RuntimeError(f"No running pod found matching pattern: {pattern}") + + +def _is_port_active(pod_name: str, port: int, container_name: str) -> bool: + """Executes a small python snippet inside the pod to check if the port is bound.""" + _load_k8s_config() + v1 = client.CoreV1Api() + check_cmd = [ + "python3", + "-c", + "import socket; s = socket.socket(socket.AF_INET, socket.SOCK_STREAM); " + f"res = s.connect_ex(('127.0.0.1', {port})); " + "print('OPEN' if res == 0 else 'CLOSED'); s.close()", + ] + try: + resp = stream.stream( + v1.connect_get_namespaced_pod_exec, + pod_name, + _NAMESPACE.value, + command=check_cmd, + container=container_name, + stderr=True, + stdin=False, + stdout=True, + tty=False, + _preload_content=True, + ) + return "OPEN" in resp + except Exception as e: # pylint: disable=broad-exception-caught + logging.warning( + "Could not check port status inside pod %s: %s. Assuming closed.", + pod_name, + e, + ) + return False + + +def _load_script( + filename: str, port: int, bucket: str, workload: str +) -> list[str]: + """Reads a bash script from disk and injects variables.""" + try: + with open(filename, "r") as f: + script_content = f.read() + + script_content = script_content.replace("{PORT}", str(port)) + script_content = script_content.replace("{BUCKET}", bucket) + script_content = script_content.replace("{WORKLOAD}", workload) + + return ["/bin/bash", "-c", script_content] + except FileNotFoundError as e: + raise ValueError(f"Could not find script file '{filename}'") from e + + +class PortForwarderServer: + """Custom port forwarder server using Kubernetes API stream.""" + + def __init__( + self, + pod_name: str, + local_port: int, + remote_port: int, + namespace: str = "default", + ): + self.pod_name = pod_name + self.local_port = local_port + self.remote_port = remote_port + self.namespace = namespace + self.running = True + + def run(self): + _load_k8s_config() + v1 = client.CoreV1Api() + server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + try: + server_socket.bind(("127.0.0.1", self.local_port)) + server_socket.listen(5) + logging.info( + "[Tunnel] Forwarding 127.0.0.1:%d -> %s:%d", + self.local_port, + self.pod_name, + self.remote_port, + ) + except OSError as e: + logging.error( + "[Tunnel Error] Cannot bind port %d: %s", self.local_port, e + ) + return + + while self.running: + try: + local_conn, _ = server_socket.accept() + t = threading.Thread(target=self._handle_client, args=(local_conn, v1)) + t.daemon = True + t.start() + except KeyboardInterrupt: + break + except Exception: # pylint: disable=broad-exception-caught + pass + + def _handle_client(self, local_conn, v1): + k8s_socket = None + try: + pf_stream = stream.portforward( + v1.connect_get_namespaced_pod_portforward, + self.pod_name, + self.namespace, + ports=str(self.remote_port), + ) + k8s_socket = pf_stream.socket(self.remote_port) + self._bridge_sockets(local_conn, k8s_socket) + except Exception: # pylint: disable=broad-exception-caught + pass + finally: + local_conn.close() + if k8s_socket: + k8s_socket.close() + + def _bridge_sockets(self, sock1, sock2): + sockets = [sock1, sock2] + buffer_size = 32768 + while True: + r, _, _ = select.select(sockets, [], []) + if sock1 in r: + data = sock1.recv(buffer_size) + if not data: + break + sock2.sendall(data) + if sock2 in r: + data = sock2.recv(buffer_size) + if not data: + break + sock1.sendall(data) + + +def _run_tunnel_process( + pod_name: str, local_port: int, remote_port: int, namespace: str +) -> None: + signal.signal(signal.SIGINT, signal.SIG_IGN) + server = PortForwarderServer( + pod_name, local_port, remote_port, namespace=namespace + ) + server.run() + + def _prepare_deployment_yaml(service_name: str, remote_port: int) -> str: """Prepares the deployment YAML for VS Code.""" context = { @@ -128,34 +334,130 @@ def main(argv): if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") - service_name = "{}".format( - _NAME.value - + f"-{os.environ.get('USER', 'user')}-" - + "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) - ) - logging.info("Service name: %s", service_name) + if _WORKLOAD.value: + logging.info("Fetching cluster credentials...") + gke_utils.fetch_cluster_credentials( + cluster_name=_CLUSTER.value, + project_id=_PROJECT.value, + location=_REGION.value, + ) - remote_port = 8080 + search_pattern = _WORKLOAD.value + if _WORKLOAD.value == os.environ.get("USER", "user"): + search_pattern = f"{_WORKLOAD.value}-pathways-head" + if _NON_PATHWAYS.value: + search_pattern = _WORKLOAD.value - deployment_yaml = _prepare_deployment_yaml(service_name, remote_port) + pod_name = _find_pod(search_pattern) + logging.info("Target Pod: %s", pod_name) - if _DRY_RUN.value: - logging.info( - "Dry run: Would deploy the following YAML:\n%s", deployment_yaml + tunnel_proc = multiprocessing.Process( + target=_run_tunnel_process, + args=(pod_name, _PORT.value, _PORT.value, _NAMESPACE.value), ) - return + tunnel_proc.start() + time.sleep(1) - logging.info("Fetching cluster credentials...") - gke_utils.fetch_cluster_credentials( - cluster_name=_CLUSTER.value, - project_id=_PROJECT.value, - location=_REGION.value, - ) - try: - _deploy_vscode(service_name, deployment_yaml) - _start_port_forwarding(service_name, remote_port) - finally: - _cleanup_gke_resources(service_name, _NAMESPACE.value) + skip_setup = False + if _CHECK_ACTIVE_SESSION.value: + logging.info( + "Checking for existing %s session on port %d...", + _MODE.value, + _PORT.value, + ) + if _is_port_active(pod_name, _PORT.value, "jax-tpu"): + logging.info("Active session detected! Skipping setup script.") + skip_setup = True + else: + logging.info("No active session found. Proceeding with installation.") + + try: + if skip_setup: + logging.info( + "Session ready (Port Forwarding Only). Access at" + " http://127.0.0.1:%d", + _PORT.value, + ) + logging.info("Press Ctrl+C to stop.") + while True: + time.sleep(1) + else: + script_dir = os.path.join(os.path.dirname(__file__), "scripts") + if _MODE.value == "jupyter": + script_file = os.path.join(script_dir, "jupyter_setup.sh") + else: + script_file = os.path.join(script_dir, "vscode_setup.sh") + + cmd = _load_script( + script_file, + _PORT.value, + _BUCKET.value, + _WORKLOAD.value, + ) + + _load_k8s_config() + v1 = client.CoreV1Api() + logging.info( + "Session ready. Access at http://127.0.0.1:%d", _PORT.value + ) + logging.info("Press Ctrl+C to stop.") + resp = stream.stream( + v1.connect_get_namespaced_pod_exec, + pod_name, + _NAMESPACE.value, + command=cmd, + container="jax-tpu", + stderr=True, + stdin=False, + stdout=True, + tty=False, + _preload_content=False, + ) + while resp.is_open(): + resp.update(timeout=1) + if resp.peek_stdout(): + print(resp.read_stdout(), end="") + if resp.peek_stderr(): + print(resp.read_stderr(), end="") + except KeyboardInterrupt: + logging.info("Stopping session...") + except Exception as e: # pylint: disable=broad-exception-caught + logging.error("Execution Error: %s", e) + finally: + if tunnel_proc.is_alive(): + tunnel_proc.terminate() + tunnel_proc.join() + logging.info("Tunnel closed.") + + else: + service_name = "{}".format( + _NAME.value + + f"-{os.environ.get('USER', 'user')}-" + + "".join(random.choices(string.ascii_lowercase + string.digits, k=4)) + ) + logging.info("Service name: %s", service_name) + + remote_port = 8080 + + deployment_yaml = _prepare_deployment_yaml(service_name, remote_port) + + if _DRY_RUN.value: + logging.info( + "Dry run: Would deploy the following YAML:\n%s", deployment_yaml + ) + return + + logging.info("Fetching cluster credentials...") + gke_utils.fetch_cluster_credentials( + cluster_name=_CLUSTER.value, + project_id=_PROJECT.value, + location=_REGION.value, + ) + try: + _deploy_vscode(service_name, deployment_yaml) + _start_port_forwarding(service_name, remote_port) + finally: + _cleanup_gke_resources(service_name, _NAMESPACE.value) if __name__ == "__main__":