From 25e442395fa317cd7b3392084ba1b78d5a8a07f7 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Wed, 15 Apr 2026 20:28:43 -0700 Subject: [PATCH 1/3] feat(server,sandbox): move SSH connect and exec onto supervisor session relay Introduce a persistent supervisor-to-gateway session (ConnectSupervisor bidirectional gRPC RPC) and migrate /connect/ssh and ExecSandbox onto relay channels coordinated through it. Architecture: - gRPC control plane: carries session lifecycle (hello, heartbeat) and relay lifecycle (RelayOpen, RelayOpenResult, RelayClose) - HTTP data plane: for each relay, the supervisor opens a reverse HTTP CONNECT to /relay/{channel_id} on the gateway; the gateway bridges the client stream with the supervisor stream - The supervisor is a dumb byte bridge with no SSH/NSSH1 awareness; the gateway sends the NSSH1 preface through the relay Key changes: - Add ConnectSupervisor RPC and session/relay proto messages - Add gateway session registry (SupervisorSessionRegistry) with pending-relay map for channel correlation - Add /relay/{channel_id} HTTP CONNECT endpoint - Rewire /connect/ssh: session lookup + RelayOpen instead of direct TCP dial to sandbox:2222 - Rewire ExecSandbox: relay-based proxy instead of direct sandbox dial - Add supervisor session client with reconnect and relay bridge - Remove ResolveSandboxEndpoint from proto, gateway, and K8s driver Closes OS-86 --- Cargo.lock | 4 + .../tests/ensure_providers_integration.rs | 24 +- .../openshell-cli/tests/mtls_integration.rs | 10 + .../tests/provider_commands_integration.rs | 24 +- .../sandbox_create_lifecycle_integration.rs | 25 +- .../sandbox_name_fallback_integration.rs | 22 +- .../openshell-driver-kubernetes/src/driver.rs | 69 +-- .../openshell-driver-kubernetes/src/grpc.rs | 18 +- crates/openshell-sandbox/Cargo.toml | 7 + crates/openshell-sandbox/src/grpc_client.rs | 5 + crates/openshell-sandbox/src/lib.rs | 16 + .../src/supervisor_session.rs | 351 ++++++++++++++ crates/openshell-server/src/compute/mod.rs | 98 +--- crates/openshell-server/src/grpc/mod.rs | 25 +- crates/openshell-server/src/grpc/sandbox.rs | 441 ++++++------------ crates/openshell-server/src/http.rs | 1 + crates/openshell-server/src/lib.rs | 6 + crates/openshell-server/src/relay.rs | 67 +++ crates/openshell-server/src/ssh_tunnel.rs | 233 ++++----- .../src/supervisor_session.rs | 440 +++++++++++++++++ .../tests/auth_endpoint_integration.rs | 10 + .../tests/edge_tunnel_auth.rs | 23 +- .../tests/multiplex_integration.rs | 23 +- .../tests/multiplex_tls_integration.rs | 23 +- .../tests/ws_tunnel_integration.rs | 23 +- proto/compute_driver.proto | 25 - proto/openshell.proto | 89 ++++ 27 files changed, 1408 insertions(+), 694 deletions(-) create mode 100644 crates/openshell-sandbox/src/supervisor_session.rs create mode 100644 crates/openshell-server/src/relay.rs create mode 100644 crates/openshell-server/src/supervisor_session.rs diff --git a/Cargo.lock b/Cargo.lock index e4057f75c..31144fc41 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3162,6 +3162,10 @@ dependencies = [ "futures", "hex", "hmac", + "http", + "http-body-util", + "hyper", + "hyper-util", "ipnet", "landlock", "libc", diff --git a/crates/openshell-cli/tests/ensure_providers_integration.rs b/crates/openshell-cli/tests/ensure_providers_integration.rs index 2cd362023..d5e813931 100644 --- a/crates/openshell-cli/tests/ensure_providers_integration.rs +++ b/crates/openshell-cli/tests/ensure_providers_integration.rs @@ -11,13 +11,14 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, Provider, ProviderResponse, - RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, - ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -298,6 +299,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -423,6 +426,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // ── TLS helpers ────────────────────────────────────────────────────── diff --git a/crates/openshell-cli/tests/mtls_integration.rs b/crates/openshell-cli/tests/mtls_integration.rs index 5d04239bf..c98b7eae4 100644 --- a/crates/openshell-cli/tests/mtls_integration.rs +++ b/crates/openshell-cli/tests/mtls_integration.rs @@ -200,6 +200,9 @@ impl OpenShell for TestOpenShell { >; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; async fn watch_sandbox( &self, @@ -325,6 +328,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn build_ca() -> (Certificate, KeyPair) { diff --git a/crates/openshell-cli/tests/provider_commands_integration.rs b/crates/openshell-cli/tests/provider_commands_integration.rs index c5476afee..1d1323371 100644 --- a/crates/openshell-cli/tests/provider_commands_integration.rs +++ b/crates/openshell-cli/tests/provider_commands_integration.rs @@ -7,13 +7,14 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, Provider, ProviderResponse, - RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, - ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + Provider, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -252,6 +253,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -377,6 +380,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { diff --git a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs index d5d39f082..e4c658b7b 100644 --- a/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs +++ b/crates/openshell-cli/tests/sandbox_create_lifecycle_integration.rs @@ -8,14 +8,14 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, PlatformEvent, ProviderResponse, - RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, SandboxPhase, SandboxResponse, - SandboxStreamEvent, ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, - sandbox_stream_event, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + PlatformEvent, ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, Sandbox, + SandboxPhase, SandboxResponse, SandboxStreamEvent, ServiceStatus, SupervisorMessage, + UpdateProviderRequest, WatchSandboxRequest, sandbox_stream_event, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -242,6 +242,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -403,6 +405,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } fn install_rustls_provider() { diff --git a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs index fbadec4c3..7824d141a 100644 --- a/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs +++ b/crates/openshell-cli/tests/sandbox_name_fallback_integration.rs @@ -8,12 +8,13 @@ use openshell_core::proto::open_shell_server::{OpenShell, OpenShellServer}; use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, Sandbox, SandboxResponse, - SandboxStreamEvent, ServiceStatus, UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, Sandbox, SandboxResponse, SandboxStreamEvent, ServiceStatus, + SupervisorMessage, UpdateProviderRequest, WatchSandboxRequest, }; use rcgen::{ BasicConstraints, Certificate, CertificateParams, ExtendedKeyUsagePurpose, IsCa, KeyPair, @@ -210,6 +211,8 @@ impl OpenShell for TestOpenShell { tokio_stream::wrappers::ReceiverStream>; type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream>; + type ConnectSupervisorStream = + tokio_stream::wrappers::ReceiverStream>; async fn watch_sandbox( &self, @@ -335,6 +338,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // ── helpers ─────────────────────────────────────────────────────────── diff --git a/crates/openshell-driver-kubernetes/src/driver.rs b/crates/openshell-driver-kubernetes/src/driver.rs index 440703af5..3e0240d0f 100644 --- a/crates/openshell-driver-kubernetes/src/driver.rs +++ b/crates/openshell-driver-kubernetes/src/driver.rs @@ -5,7 +5,7 @@ use crate::config::KubernetesComputeConfig; use futures::{Stream, StreamExt, TryStreamExt}; -use k8s_openapi::api::core::v1::{Event as KubeEventObj, Node, Pod}; +use k8s_openapi::api::core::v1::{Event as KubeEventObj, Node}; use kube::api::{Api, ApiResource, DeleteParams, ListParams, PostParams}; use kube::core::gvk::GroupVersionKind; use kube::core::{DynamicObject, ObjectMeta}; @@ -15,12 +15,10 @@ use openshell_core::proto::compute::v1::{ DriverCondition as SandboxCondition, DriverPlatformEvent as PlatformEvent, DriverSandbox as Sandbox, DriverSandboxSpec as SandboxSpec, DriverSandboxStatus as SandboxStatus, DriverSandboxTemplate as SandboxTemplate, - GetCapabilitiesResponse, ResolveSandboxEndpointResponse, SandboxEndpoint, - WatchSandboxesDeletedEvent, WatchSandboxesEvent, WatchSandboxesPlatformEvent, - WatchSandboxesSandboxEvent, sandbox_endpoint, watch_sandboxes_event, + GetCapabilitiesResponse, WatchSandboxesDeletedEvent, WatchSandboxesEvent, + WatchSandboxesPlatformEvent, WatchSandboxesSandboxEvent, watch_sandboxes_event, }; use std::collections::BTreeMap; -use std::net::IpAddr; use std::pin::Pin; use std::time::Duration; use tokio::sync::mpsc; @@ -271,21 +269,6 @@ impl KubernetesComputeDriver { &self.config.ssh_handshake_secret } - async fn agent_pod_ip(&self, pod_name: &str) -> Result, KubeError> { - let api: Api = Api::namespaced(self.client.clone(), &self.config.namespace); - match api.get(pod_name).await { - Ok(pod) => { - let ip = pod - .status - .and_then(|status| status.pod_ip) - .and_then(|ip| ip.parse().ok()); - Ok(ip) - } - Err(KubeError::Api(err)) if err.code == 404 => Ok(None), - Err(err) => Err(err), - } - } - pub async fn create_sandbox(&self, sandbox: &Sandbox) -> Result<(), KubernetesDriverError> { let name = sandbox.name.as_str(); info!( @@ -407,52 +390,6 @@ impl KubernetesComputeDriver { } } - pub async fn resolve_sandbox_endpoint( - &self, - sandbox: &Sandbox, - ) -> Result { - if let Some(status) = sandbox.status.as_ref() - && !status.instance_id.is_empty() - { - match self.agent_pod_ip(&status.instance_id).await { - Ok(Some(ip)) => { - return Ok(ResolveSandboxEndpointResponse { - endpoint: Some(SandboxEndpoint { - target: Some(sandbox_endpoint::Target::Ip(ip.to_string())), - port: u32::from(self.config.ssh_port), - }), - }); - } - Ok(None) => { - return Err(KubernetesDriverError::Precondition( - "sandbox agent pod IP is not available".to_string(), - )); - } - Err(err) => { - return Err(KubernetesDriverError::Message(format!( - "failed to resolve agent pod IP: {err}" - ))); - } - } - } - - if sandbox.name.is_empty() { - return Err(KubernetesDriverError::Precondition( - "sandbox has no name".to_string(), - )); - } - - Ok(ResolveSandboxEndpointResponse { - endpoint: Some(SandboxEndpoint { - target: Some(sandbox_endpoint::Target::Host(format!( - "{}.{}.svc.cluster.local", - sandbox.name, self.config.namespace - ))), - port: u32::from(self.config.ssh_port), - }), - }) - } - pub async fn watch_sandboxes(&self) -> Result { let namespace = self.config.namespace.clone(); let sandbox_api = self.api(); diff --git a/crates/openshell-driver-kubernetes/src/grpc.rs b/crates/openshell-driver-kubernetes/src/grpc.rs index 2c5a94467..75e131d41 100644 --- a/crates/openshell-driver-kubernetes/src/grpc.rs +++ b/crates/openshell-driver-kubernetes/src/grpc.rs @@ -5,8 +5,7 @@ use futures::{Stream, StreamExt}; use openshell_core::proto::compute::v1::{ CreateSandboxRequest, CreateSandboxResponse, DeleteSandboxRequest, DeleteSandboxResponse, GetCapabilitiesRequest, GetCapabilitiesResponse, GetSandboxRequest, GetSandboxResponse, - ListSandboxesRequest, ListSandboxesResponse, ResolveSandboxEndpointRequest, - ResolveSandboxEndpointResponse, StopSandboxRequest, StopSandboxResponse, + ListSandboxesRequest, ListSandboxesResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateRequest, ValidateSandboxCreateResponse, WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_server::ComputeDriver, }; @@ -128,21 +127,6 @@ impl ComputeDriver for ComputeDriverService { Ok(Response::new(DeleteSandboxResponse { deleted })) } - async fn resolve_sandbox_endpoint( - &self, - request: Request, - ) -> Result, Status> { - let sandbox = request - .into_inner() - .sandbox - .ok_or_else(|| Status::invalid_argument("sandbox is required"))?; - self.driver - .resolve_sandbox_endpoint(&sandbox) - .await - .map(Response::new) - .map_err(status_from_driver_error) - } - type WatchSandboxesStream = Pin> + Send + 'static>>; diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 541784ee6..b21b1948f 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -51,8 +51,15 @@ rcgen = { workspace = true } webpki-roots = { workspace = true } # HTTP +hyper = { workspace = true } +hyper-util = { workspace = true } +http = "1" +http-body-util = "0.1" bytes = { workspace = true } +# UUID +uuid = { workspace = true } + # Encoding base64 = { workspace = true } diff --git a/crates/openshell-sandbox/src/grpc_client.rs b/crates/openshell-sandbox/src/grpc_client.rs index 5503637ee..09e7b607d 100644 --- a/crates/openshell-sandbox/src/grpc_client.rs +++ b/crates/openshell-sandbox/src/grpc_client.rs @@ -74,6 +74,11 @@ async fn connect_channel(endpoint: &str) -> Result { .wrap_err("failed to connect to OpenShell server") } +/// Create a channel to the OpenShell server (public for use by supervisor_session). +pub async fn connect_channel_pub(endpoint: &str) -> Result { + connect_channel(endpoint).await +} + /// Connect to the OpenShell server (mTLS or plaintext based on endpoint scheme). async fn connect(endpoint: &str) -> Result> { let channel = connect_channel(endpoint).await?; diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index b81dd4a6c..76da6bb3f 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -21,6 +21,7 @@ pub mod proxy; mod sandbox; mod secrets; mod ssh; +mod supervisor_session; use miette::{IntoDiagnostic, Result}; #[cfg(target_os = "linux")] @@ -676,6 +677,21 @@ pub async fn run_sandbox( } } + // Spawn the persistent supervisor session if we have a gateway endpoint + // and sandbox identity. The session provides relay channels for SSH + // connect and ExecSandbox through the gateway. + if let (Some(endpoint), Some(id)) = (openshell_endpoint.as_ref(), sandbox_id.as_ref()) { + // The SSH listen address was consumed above, so we use the configured + // SSH port (default 2222) for loopback connections from the relay. + let ssh_port = std::env::var("OPENSHELL_SSH_PORT") + .ok() + .and_then(|p| p.parse::().ok()) + .unwrap_or(2222); + + supervisor_session::spawn(endpoint.clone(), id.clone(), ssh_port); + info!("supervisor session task spawned"); + } + #[cfg(target_os = "linux")] let mut handle = ProcessHandle::spawn( program, diff --git a/crates/openshell-sandbox/src/supervisor_session.rs b/crates/openshell-sandbox/src/supervisor_session.rs new file mode 100644 index 000000000..2b571df08 --- /dev/null +++ b/crates/openshell-sandbox/src/supervisor_session.rs @@ -0,0 +1,351 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! Persistent supervisor-to-gateway session. +//! +//! Maintains a long-lived `ConnectSupervisor` bidirectional gRPC stream to the +//! gateway. When the gateway sends `RelayOpen`, the supervisor opens a reverse +//! HTTP CONNECT tunnel back to the gateway and bridges it to the local SSH +//! daemon. The supervisor is a dumb byte bridge — it has no protocol awareness +//! of the SSH or NSSH1 bytes flowing through the tunnel. + +use std::time::Duration; + +use openshell_core::proto::open_shell_client::OpenShellClient; +use openshell_core::proto::{ + GatewayMessage, SupervisorHeartbeat, SupervisorHello, SupervisorMessage, gateway_message, + supervisor_message, +}; +use tokio::sync::mpsc; +use tonic::transport::Channel; +use tracing::{info, warn}; + +use crate::grpc_client; + +const INITIAL_BACKOFF: Duration = Duration::from_secs(1); +const MAX_BACKOFF: Duration = Duration::from_secs(30); + +/// Spawn the supervisor session task. +/// +/// The task runs for the lifetime of the sandbox process, reconnecting with +/// exponential backoff on failures. +pub fn spawn( + endpoint: String, + sandbox_id: String, + ssh_listen_port: u16, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(run_session_loop(endpoint, sandbox_id, ssh_listen_port)) +} + +async fn run_session_loop(endpoint: String, sandbox_id: String, ssh_listen_port: u16) { + let mut backoff = INITIAL_BACKOFF; + let mut attempt: u64 = 0; + + loop { + attempt += 1; + + match run_single_session(&endpoint, &sandbox_id, ssh_listen_port).await { + Ok(()) => { + info!(sandbox_id = %sandbox_id, "supervisor session ended cleanly"); + break; + } + Err(e) => { + warn!( + sandbox_id = %sandbox_id, + attempt = attempt, + backoff_ms = backoff.as_millis() as u64, + error = %e, + "supervisor session failed, reconnecting" + ); + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(MAX_BACKOFF); + } + } + } +} + +async fn run_single_session( + endpoint: &str, + sandbox_id: &str, + ssh_listen_port: u16, +) -> Result<(), Box> { + // Connect to the gateway. + let channel = grpc_client::connect_channel_pub(endpoint) + .await + .map_err(|e| format!("connect failed: {e}"))?; + let mut client = OpenShellClient::new(channel.clone()); + + // Create the outbound message stream. + let (tx, rx) = mpsc::channel::(64); + let outbound = tokio_stream::wrappers::ReceiverStream::new(rx); + + // Send hello as the first message. + let instance_id = uuid::Uuid::new_v4().to_string(); + tx.send(SupervisorMessage { + payload: Some(supervisor_message::Payload::Hello(SupervisorHello { + sandbox_id: sandbox_id.to_string(), + instance_id: instance_id.clone(), + })), + }) + .await + .map_err(|_| "failed to queue hello")?; + + // Open the bidirectional stream. + let response = client + .connect_supervisor(outbound) + .await + .map_err(|e| format!("connect_supervisor RPC failed: {e}"))?; + let mut inbound = response.into_inner(); + + // Wait for SessionAccepted. + let accepted = match inbound.message().await? { + Some(msg) => match msg.payload { + Some(gateway_message::Payload::SessionAccepted(a)) => a, + Some(gateway_message::Payload::SessionRejected(r)) => { + return Err(format!("session rejected: {}", r.reason).into()); + } + _ => return Err("expected SessionAccepted or SessionRejected".into()), + }, + None => return Err("stream closed before session accepted".into()), + }; + + let heartbeat_secs = accepted.heartbeat_interval_secs.max(5); + info!( + sandbox_id = %sandbox_id, + session_id = %accepted.session_id, + instance_id = %instance_id, + heartbeat_secs = heartbeat_secs, + "supervisor session established" + ); + + // Main loop: receive gateway messages + send heartbeats. + let mut heartbeat_interval = + tokio::time::interval(Duration::from_secs(u64::from(heartbeat_secs))); + heartbeat_interval.tick().await; // skip immediate tick + + loop { + tokio::select! { + msg = inbound.message() => { + match msg { + Ok(Some(msg)) => { + handle_gateway_message( + &msg, + sandbox_id, + &endpoint, + ssh_listen_port, + &channel, + ).await; + } + Ok(None) => { + info!(sandbox_id = %sandbox_id, "supervisor session: gateway closed stream"); + return Ok(()); + } + Err(e) => { + return Err(format!("stream error: {e}").into()); + } + } + } + _ = heartbeat_interval.tick() => { + let hb = SupervisorMessage { + payload: Some(supervisor_message::Payload::Heartbeat( + SupervisorHeartbeat {}, + )), + }; + if tx.send(hb).await.is_err() { + return Err("outbound channel closed".into()); + } + } + } + } +} + +async fn handle_gateway_message( + msg: &GatewayMessage, + sandbox_id: &str, + endpoint: &str, + ssh_listen_port: u16, + _channel: &Channel, +) { + match &msg.payload { + Some(gateway_message::Payload::Heartbeat(_)) => { + // Gateway heartbeat — nothing to do. + } + Some(gateway_message::Payload::RelayOpen(open)) => { + let channel_id = open.channel_id.clone(); + let endpoint = endpoint.to_string(); + let sandbox_id = sandbox_id.to_string(); + + info!( + sandbox_id = %sandbox_id, + channel_id = %channel_id, + "supervisor session: relay open request, spawning bridge" + ); + + tokio::spawn(async move { + if let Err(e) = handle_relay_open(&channel_id, &endpoint, ssh_listen_port).await { + warn!( + sandbox_id = %sandbox_id, + channel_id = %channel_id, + error = %e, + "supervisor session: relay bridge failed" + ); + } + }); + } + Some(gateway_message::Payload::RelayClose(close)) => { + info!( + sandbox_id = %sandbox_id, + channel_id = %close.channel_id, + reason = %close.reason, + "supervisor session: relay close from gateway" + ); + } + _ => { + warn!(sandbox_id = %sandbox_id, "supervisor session: unexpected gateway message"); + } + } +} + +/// Handle a RelayOpen by opening a reverse HTTP CONNECT to the gateway and +/// bridging it to the local SSH daemon. +async fn handle_relay_open( + channel_id: &str, + endpoint: &str, + ssh_listen_port: u16, +) -> Result<(), Box> { + // Build the relay URL from the gateway endpoint. + // The endpoint is like "https://gateway:8080" or "http://gateway:8080". + let relay_url = format!("{endpoint}/relay/{channel_id}"); + + // Open a reverse HTTP CONNECT to the gateway's relay endpoint. + let mut relay_stream = open_reverse_connect(&relay_url).await?; + + // Connect to the local SSH daemon on loopback. + let mut ssh_conn = tokio::net::TcpStream::connect(("127.0.0.1", ssh_listen_port)).await?; + + info!(channel_id = %channel_id, "relay bridge: connected to local SSH daemon, bridging"); + + // Bridge the relay stream to the local SSH connection. + // The gateway sends NSSH1 preface + SSH bytes through the relay. + // The SSH daemon receives them as if the gateway connected directly. + let _ = tokio::io::copy_bidirectional(&mut relay_stream, &mut ssh_conn).await; + + Ok(()) +} + +/// Open an HTTP CONNECT tunnel to the given URL and return the upgraded stream. +/// +/// This uses a raw hyper HTTP/1.1 client to send a CONNECT request and upgrade +/// the connection to a raw byte stream. +async fn open_reverse_connect( + url: &str, +) -> Result< + hyper_util::rt::TokioIo, + Box, +> { + let uri: http::Uri = url.parse()?; + let host = uri.host().ok_or("missing host")?; + let port = uri + .port_u16() + .unwrap_or(if uri.scheme_str() == Some("https") { + 443 + } else { + 80 + }); + let authority = format!("{host}:{port}"); + let path = uri.path().to_string(); + let use_tls = uri.scheme_str() == Some("https"); + + // Connect TCP. + let tcp = tokio::net::TcpStream::connect(&authority).await?; + tcp.set_nodelay(true)?; + + if use_tls { + // Build TLS connector using the same env-var certs as the gRPC client. + let tls_stream = connect_tls(tcp, host).await?; + send_connect_request(tls_stream, &authority, &path).await + } else { + send_connect_request(tcp, &authority, &path).await + } +} + +async fn send_connect_request( + io: IO, + authority: &str, + path: &str, +) -> Result< + hyper_util::rt::TokioIo, + Box, +> +where + IO: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static, +{ + use http::Method; + + let (mut sender, conn) = + hyper::client::conn::http1::handshake(hyper_util::rt::TokioIo::new(io)).await?; + + // Spawn the connection driver. + tokio::spawn(async move { + if let Err(e) = conn.with_upgrades().await { + warn!(error = %e, "relay CONNECT connection driver error"); + } + }); + + let req = http::Request::builder() + .method(Method::CONNECT) + .uri(path) + .header(http::header::HOST, authority) + .body(http_body_util::Empty::::new())?; + + let resp = sender.send_request(req).await?; + + if resp.status() != http::StatusCode::OK + && resp.status() != http::StatusCode::SWITCHING_PROTOCOLS + { + return Err(format!("relay CONNECT failed: {}", resp.status()).into()); + } + + let upgraded = hyper::upgrade::on(resp).await?; + Ok(hyper_util::rt::TokioIo::new(upgraded)) +} + +/// Connect TLS using the same cert env vars as the gRPC client. +async fn connect_tls( + tcp: tokio::net::TcpStream, + host: &str, +) -> Result< + tokio_rustls::client::TlsStream, + Box, +> { + use rustls::pki_types::ServerName; + use std::sync::Arc; + + let ca_path = std::env::var("OPENSHELL_TLS_CA")?; + let cert_path = std::env::var("OPENSHELL_TLS_CERT")?; + let key_path = std::env::var("OPENSHELL_TLS_KEY")?; + + let ca_pem = std::fs::read(&ca_path)?; + let cert_pem = std::fs::read(&cert_path)?; + let key_pem = std::fs::read(&key_path)?; + + let mut root_store = rustls::RootCertStore::empty(); + for cert in rustls_pemfile::certs(&mut ca_pem.as_slice()) { + root_store.add(cert?)?; + } + + let certs: Vec<_> = + rustls_pemfile::certs(&mut cert_pem.as_slice()).collect::>()?; + let key = + rustls_pemfile::private_key(&mut key_pem.as_slice())?.ok_or("no private key found")?; + + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_client_auth_cert(certs, key)?; + + let connector = tokio_rustls::TlsConnector::from(Arc::new(config)); + let server_name = ServerName::try_from(host.to_string())?; + let tls_stream = connector.connect(server_name, tcp).await?; + + Ok(tls_stream) +} diff --git a/crates/openshell-server/src/compute/mod.rs b/crates/openshell-server/src/compute/mod.rs index 846782c65..181a5a819 100644 --- a/crates/openshell-server/src/compute/mod.rs +++ b/crates/openshell-server/src/compute/mod.rs @@ -13,9 +13,8 @@ use openshell_core::proto::compute::v1::{ CreateSandboxRequest, DeleteSandboxRequest, DriverCondition, DriverPlatformEvent, DriverResourceRequirements, DriverSandbox, DriverSandboxSpec, DriverSandboxStatus, DriverSandboxTemplate, GetCapabilitiesRequest, GetSandboxRequest, ListSandboxesRequest, - ResolveSandboxEndpointRequest, ResolveSandboxEndpointResponse, ValidateSandboxCreateRequest, - WatchSandboxesEvent, WatchSandboxesRequest, compute_driver_server::ComputeDriver, - sandbox_endpoint, watch_sandboxes_event, + ValidateSandboxCreateRequest, WatchSandboxesEvent, WatchSandboxesRequest, + compute_driver_server::ComputeDriver, watch_sandboxes_event, }; use openshell_core::proto::{ PlatformEvent, Sandbox, SandboxCondition, SandboxPhase, SandboxSpec, SandboxStatus, @@ -26,7 +25,6 @@ use openshell_driver_kubernetes::{ }; use prost::Message; use std::fmt; -use std::net::IpAddr; use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -55,12 +53,6 @@ pub enum ComputeError { Message(String), } -#[derive(Debug)] -pub enum ResolvedEndpoint { - Ip(IpAddr, u16), - Host(String, u16), -} - #[derive(Clone)] pub struct ComputeRuntime { driver: SharedComputeDriver, @@ -243,29 +235,6 @@ impl ComputeRuntime { Ok(deleted) } - pub async fn resolve_sandbox_endpoint( - &self, - sandbox: &Sandbox, - ) -> Result { - let driver_sandbox = driver_sandbox_from_public(sandbox); - self.driver - .resolve_sandbox_endpoint(Request::new(ResolveSandboxEndpointRequest { - sandbox: Some(driver_sandbox), - })) - .await - .map(|response| response.into_inner()) - .map_err(|status| match status.code() { - Code::FailedPrecondition => { - Status::failed_precondition(status.message().to_string()) - } - _ => Status::internal(status.message().to_string()), - }) - .and_then(|response| { - resolved_endpoint_from_response(&response) - .map_err(|err| Status::internal(err.to_string())) - }) - } - pub fn spawn_watchers(&self) { let runtime = Arc::new(self.clone()); let watch_runtime = runtime.clone(); @@ -813,30 +782,6 @@ fn decode_sandbox_record(record: &ObjectRecord) -> Result { Sandbox::decode(record.payload.as_slice()).map_err(|e| e.to_string()) } -fn resolved_endpoint_from_response( - response: &ResolveSandboxEndpointResponse, -) -> Result { - let endpoint = response - .endpoint - .as_ref() - .ok_or_else(|| ComputeError::Message("compute driver returned no endpoint".to_string()))?; - let port = u16::try_from(endpoint.port) - .map_err(|_| ComputeError::Message("compute driver returned invalid port".to_string()))?; - - match endpoint.target.as_ref() { - Some(sandbox_endpoint::Target::Ip(ip)) => ip - .parse() - .map(|ip| ResolvedEndpoint::Ip(ip, port)) - .map_err(|e| ComputeError::Message(format!("invalid endpoint IP: {e}"))), - Some(sandbox_endpoint::Target::Host(host)) => { - Ok(ResolvedEndpoint::Host(host.clone(), port)) - } - None => Err(ComputeError::Message( - "compute driver returned endpoint without target".to_string(), - )), - } -} - fn public_status_from_driver(status: &DriverSandboxStatus) -> SandboxStatus { SandboxStatus { sandbox_name: status.sandbox_name.clone(), @@ -929,8 +874,7 @@ mod tests { use futures::stream; use openshell_core::proto::compute::v1::{ CreateSandboxResponse, DeleteSandboxResponse, GetCapabilitiesResponse, GetSandboxRequest, - GetSandboxResponse, ResolveSandboxEndpointResponse, SandboxEndpoint, StopSandboxRequest, - StopSandboxResponse, ValidateSandboxCreateResponse, sandbox_endpoint, + GetSandboxResponse, StopSandboxRequest, StopSandboxResponse, ValidateSandboxCreateResponse, }; use std::sync::Arc; @@ -938,7 +882,6 @@ mod tests { struct TestDriver { listed_sandboxes: Vec, current_sandboxes: Vec, - resolve_precondition: Option, } #[tonic::async_trait] @@ -1031,24 +974,6 @@ mod tests { })) } - async fn resolve_sandbox_endpoint( - &self, - _request: Request, - ) -> Result, Status> { - if let Some(message) = &self.resolve_precondition { - return Err(Status::failed_precondition(message.clone())); - } - - Ok(tonic::Response::new(ResolveSandboxEndpointResponse { - endpoint: Some(SandboxEndpoint { - target: Some(sandbox_endpoint::Target::Host( - "sandbox.default.svc.cluster.local".to_string(), - )), - port: 2222, - }), - })) - } - async fn watch_sandboxes( &self, _request: Request, @@ -1322,23 +1247,6 @@ mod tests { ); } - #[tokio::test] - async fn resolve_sandbox_endpoint_preserves_precondition_errors() { - let runtime = test_runtime(Arc::new(TestDriver { - resolve_precondition: Some("sandbox agent pod IP is not available".to_string()), - ..Default::default() - })) - .await; - - let err = runtime - .resolve_sandbox_endpoint(&sandbox_record("sb-1", "sandbox-a", SandboxPhase::Ready)) - .await - .expect_err("endpoint resolution should preserve failed-precondition errors"); - - assert_eq!(err.code(), Code::FailedPrecondition); - assert_eq!(err.message(), "sandbox agent pod IP is not available"); - } - #[tokio::test] async fn reconcile_store_with_backend_applies_driver_snapshot() { let runtime = test_runtime(Arc::new(TestDriver { diff --git a/crates/openshell-server/src/grpc/mod.rs b/crates/openshell-server/src/grpc/mod.rs index af60897d1..8a5516c6b 100644 --- a/crates/openshell-server/src/grpc/mod.rs +++ b/crates/openshell-server/src/grpc/mod.rs @@ -14,10 +14,10 @@ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, EditDraftChunkRequest, EditDraftChunkResponse, ExecSandboxEvent, ExecSandboxRequest, - GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, GetDraftPolicyResponse, - GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, - GetSandboxConfigResponse, GetSandboxLogsRequest, GetSandboxLogsResponse, - GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, + GatewayMessage, GetDraftHistoryRequest, GetDraftHistoryResponse, GetDraftPolicyRequest, + GetDraftPolicyResponse, GetGatewayConfigRequest, GetGatewayConfigResponse, GetProviderRequest, + GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxLogsRequest, + GetSandboxLogsResponse, GetSandboxPolicyStatusRequest, GetSandboxPolicyStatusResponse, GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, ListSandboxPoliciesRequest, ListSandboxPoliciesResponse, ListSandboxesRequest, @@ -25,11 +25,12 @@ use openshell_core::proto::{ RejectDraftChunkRequest, RejectDraftChunkResponse, ReportPolicyStatusRequest, ReportPolicyStatusResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, - UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, UpdateConfigResponse, - UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, + SupervisorMessage, UndoDraftChunkRequest, UndoDraftChunkResponse, UpdateConfigRequest, + UpdateConfigResponse, UpdateProviderRequest, WatchSandboxRequest, open_shell_server::OpenShell, }; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; +use std::pin::Pin; use std::sync::Arc; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; @@ -383,6 +384,18 @@ impl OpenShell for OpenShellService { ) -> Result, Status> { policy::handle_get_draft_history(&self.state, request).await } + + // --- Supervisor session --- + + type ConnectSupervisorStream = + Pin> + Send + 'static>>; + + async fn connect_supervisor( + &self, + request: Request>, + ) -> Result, Status> { + crate::supervisor_session::handle_connect_supervisor(&self.state, request).await + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 8e5930826..cdc7b51dd 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -22,13 +22,11 @@ use openshell_core::proto::{ use openshell_core::proto::{Sandbox, SandboxPhase, SandboxTemplate, SshSession}; use prost::Message; use std::sync::Arc; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tonic::{Request, Response, Status}; -use tracing::{debug, info, warn}; +use tracing::{info, warn}; use russh::ChannelMsg; use russh::client::AuthResult; @@ -438,7 +436,13 @@ pub(super) async fn handle_exec_sandbox( return Err(Status::failed_precondition("sandbox is not ready")); } - let (target_host, target_port) = resolve_sandbox_exec_target(state, &sandbox).await?; + // Open a relay channel through the supervisor session. + let (channel_id, relay_rx) = state + .supervisor_sessions + .open_relay(&sandbox.id) + .await + .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; + let command_str = build_remote_exec_command(&req) .map_err(|e| Status::invalid_argument(format!("command construction failed: {e}")))?; let stdin_payload = req.stdin; @@ -449,11 +453,32 @@ pub(super) async fn handle_exec_sandbox( let (tx, rx) = mpsc::channel::>(256); tokio::spawn(async move { - if let Err(err) = stream_exec_over_ssh( + // Wait for the supervisor's reverse CONNECT to deliver the relay stream. + let relay_stream = match tokio::time::timeout(std::time::Duration::from_secs(10), relay_rx) + .await + { + Ok(Ok(stream)) => stream, + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay channel dropped"); + let _ = tx + .send(Err(Status::unavailable("relay channel dropped"))) + .await; + return; + } + Err(_) => { + warn!(sandbox_id = %sandbox_id, channel_id = %channel_id, "ExecSandbox: relay open timed out"); + let _ = tx + .send(Err(Status::deadline_exceeded("relay open timed out"))) + .await; + return; + } + }; + + if let Err(err) = stream_exec_over_relay( tx.clone(), &sandbox_id, - &target_host, - target_port, + &channel_id, + relay_stream, &command_str, stdin_payload, timeout_seconds, @@ -584,16 +609,6 @@ fn resolve_gateway(config: &openshell_core::Config) -> (String, u16) { (host, port) } -async fn resolve_sandbox_exec_target( - state: &ServerState, - sandbox: &Sandbox, -) -> Result<(String, u16), Status> { - match state.compute.resolve_sandbox_endpoint(sandbox).await? { - crate::compute::ResolvedEndpoint::Ip(ip, port) => Ok((ip.to_string(), port)), - crate::compute::ResolvedEndpoint::Host(host, port) => Ok((host, port)), - } -} - /// Shell-escape a value for embedding in a POSIX shell command. /// /// Wraps unsafe values in single quotes with the standard `'\''` idiom for @@ -646,34 +661,18 @@ fn build_remote_exec_command(req: &ExecSandboxRequest) -> Result Ok(result) } -/// Maximum number of attempts when establishing the SSH transport to a sandbox. -const SSH_CONNECT_MAX_ATTEMPTS: u32 = 6; - -/// Initial backoff duration between SSH connection retries. -const SSH_CONNECT_INITIAL_BACKOFF: std::time::Duration = std::time::Duration::from_millis(250); - -/// Maximum backoff duration between SSH connection retries. -const SSH_CONNECT_MAX_BACKOFF: std::time::Duration = std::time::Duration::from_secs(2); - -/// Returns `true` if the gRPC status represents a transient SSH connection error. -fn is_retryable_ssh_error(status: &Status) -> bool { - if status.code() != tonic::Code::Internal { - return false; - } - let msg = status.message(); - msg.contains("Connection reset by peer") - || msg.contains("Connection refused") - || msg.contains("failed to establish ssh transport") - || msg.contains("failed to connect to ssh proxy") - || msg.contains("failed to start ssh proxy") -} - +/// Execute a command over an SSH transport relayed through a supervisor session. +/// +/// This is the relay equivalent of `stream_exec_over_ssh`. Instead of dialing a +/// sandbox endpoint directly, the SSH transport runs over a `DuplexStream` that +/// is bridged to the supervisor's local SSH daemon via a reverse HTTP CONNECT +/// tunnel. #[allow(clippy::too_many_arguments)] -async fn stream_exec_over_ssh( +async fn stream_exec_over_relay( tx: mpsc::Sender>, sandbox_id: &str, - target_host: &str, - target_port: u16, + channel_id: &str, + relay_stream: tokio::io::DuplexStream, command: &str, stdin_payload: Vec, timeout_seconds: u32, @@ -683,96 +682,53 @@ async fn stream_exec_over_ssh( let command_preview: String = command.chars().take(120).collect(); info!( sandbox_id = %sandbox_id, - target_host = %target_host, - target_port, + channel_id = %channel_id, command_len = command.len(), stdin_len = stdin_payload.len(), command_preview = %command_preview, - "ExecSandbox command started" + "ExecSandbox (relay): command started" ); - let (exit_code, proxy_task) = { - let mut last_err: Option = None; - - let mut result = None; - for attempt in 0..SSH_CONNECT_MAX_ATTEMPTS { - if attempt > 0 { - let backoff = (SSH_CONNECT_INITIAL_BACKOFF * 2u32.pow(attempt - 1)) - .min(SSH_CONNECT_MAX_BACKOFF); - warn!( - sandbox_id = %sandbox_id, - attempt = attempt + 1, - backoff_ms = %backoff.as_millis(), - error = %last_err.as_ref().unwrap(), - "Retrying SSH transport establishment" - ); - tokio::time::sleep(backoff).await; - } - - let (local_proxy_port, proxy_task) = match start_single_use_ssh_proxy( - target_host, - target_port, - handshake_secret, - ) + let (local_proxy_port, proxy_task) = + start_single_use_ssh_proxy_over_relay(relay_stream, handshake_secret) .await - { - Ok(v) => v, - Err(e) => { - last_err = Some(Status::internal(format!("failed to start ssh proxy: {e}"))); - continue; - } - }; - - let exec = run_exec_with_russh( - local_proxy_port, - command, - stdin_payload.clone(), - request_tty, - tx.clone(), - ); + .map_err(|e| Status::internal(format!("failed to start relay proxy: {e}")))?; + + let exec = run_exec_with_russh( + local_proxy_port, + command, + stdin_payload, + request_tty, + tx.clone(), + ); - let exec_result = if timeout_seconds == 0 { - exec.await - } else if let Ok(r) = tokio::time::timeout( - std::time::Duration::from_secs(u64::from(timeout_seconds)), - exec, - ) - .await - { - r - } else { - let _ = tx - .send(Ok(ExecSandboxEvent { - payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Exit( - ExecSandboxExit { exit_code: 124 }, - )), - })) - .await; - let _ = proxy_task.await; - return Ok(()); - }; - - match exec_result { - Ok(exit_code) => { - result = Some((exit_code, proxy_task)); - break; - } - Err(status) => { - let _ = proxy_task.await; - if is_retryable_ssh_error(&status) && attempt + 1 < SSH_CONNECT_MAX_ATTEMPTS { - last_err = Some(status); - continue; - } - return Err(status); - } - } - } + let exec_result = if timeout_seconds == 0 { + exec.await + } else if let Ok(r) = tokio::time::timeout( + std::time::Duration::from_secs(u64::from(timeout_seconds)), + exec, + ) + .await + { + r + } else { + let _ = tx + .send(Ok(ExecSandboxEvent { + payload: Some(openshell_core::proto::exec_sandbox_event::Payload::Exit( + ExecSandboxExit { exit_code: 124 }, + )), + })) + .await; + let _ = proxy_task.await; + return Ok(()); + }; - result.ok_or_else(|| { - last_err.unwrap_or_else(|| { - Status::internal("ssh connection failed after exhausting retries") - }) - })? + let exit_code = match exec_result { + Ok(code) => code, + Err(status) => { + let _ = proxy_task.await; + return Err(status); + } }; let _ = proxy_task.await; @@ -788,6 +744,75 @@ async fn stream_exec_over_ssh( Ok(()) } +/// Create a localhost SSH proxy that bridges to a relay DuplexStream. +/// +/// The proxy sends the NSSH1 handshake preface through the relay (which flows +/// to the supervisor and on to the embedded SSH daemon), waits for "OK", then +/// bridges the russh client connection with the relay stream. +async fn start_single_use_ssh_proxy_over_relay( + relay_stream: tokio::io::DuplexStream, + handshake_secret: &str, +) -> Result<(u16, tokio::task::JoinHandle<()>), Box> { + let listener = TcpListener::bind(("127.0.0.1", 0)).await?; + let port = listener.local_addr()?.port(); + let handshake_secret = handshake_secret.to_string(); + + let task = tokio::spawn(async move { + let Ok((mut client_conn, _)) = listener.accept().await else { + warn!("SSH relay proxy: failed to accept local connection"); + return; + }; + + let (mut relay_read, mut relay_write) = tokio::io::split(relay_stream); + + // Send NSSH1 handshake through the relay to the SSH daemon. + let Ok(preface) = build_preface(&uuid::Uuid::new_v4().to_string(), &handshake_secret) + else { + warn!("SSH relay proxy: failed to build handshake preface"); + return; + }; + if let Err(e) = + tokio::io::AsyncWriteExt::write_all(&mut relay_write, preface.as_bytes()).await + { + warn!(error = %e, "SSH relay proxy: failed to send handshake preface"); + return; + } + + // Read handshake response from the relay. + let mut response_buf = Vec::new(); + loop { + let mut byte = [0u8; 1]; + match tokio::io::AsyncReadExt::read(&mut relay_read, &mut byte).await { + Ok(0) => break, + Ok(_) => { + if byte[0] == b'\n' { + break; + } + response_buf.push(byte[0]); + if response_buf.len() > 1024 { + break; + } + } + Err(e) => { + warn!(error = %e, "SSH relay proxy: failed to read handshake response"); + return; + } + } + } + let response = String::from_utf8_lossy(&response_buf); + if response.trim() != "OK" { + warn!(response = %response.trim(), "SSH relay proxy: handshake rejected"); + return; + } + + // Reunite the split halves for copy_bidirectional. + let mut relay = relay_read.unsplit(relay_write); + let _ = tokio::io::copy_bidirectional(&mut client_conn, &mut relay).await; + }); + + Ok((port, task)) +} + #[derive(Debug, Clone, Copy)] struct SandboxSshClientHandler; @@ -914,98 +939,6 @@ async fn run_exec_with_russh( Ok(exit_code.unwrap_or(1)) } -/// Check whether an IP address is safe to use as an SSH proxy target. -fn is_safe_ssh_proxy_target(ip: std::net::IpAddr) -> bool { - match ip { - std::net::IpAddr::V4(v4) => !v4.is_loopback() && !v4.is_link_local(), - std::net::IpAddr::V6(v6) => { - if v6.is_loopback() { - return false; - } - if let Some(v4) = v6.to_ipv4_mapped() { - return !v4.is_loopback() && !v4.is_link_local(); - } - true - } - } -} - -async fn start_single_use_ssh_proxy( - target_host: &str, - target_port: u16, - handshake_secret: &str, -) -> Result<(u16, tokio::task::JoinHandle<()>), Box> { - let listener = TcpListener::bind(("127.0.0.1", 0)).await?; - let port = listener.local_addr()?.port(); - let target_host = target_host.to_string(); - let handshake_secret = handshake_secret.to_string(); - - let task = tokio::spawn(async move { - let Ok((mut client_conn, _)) = listener.accept().await else { - warn!("SSH proxy: failed to accept local connection"); - return; - }; - - let addr_str = format!("{target_host}:{target_port}"); - let resolved = match tokio::net::lookup_host(&addr_str).await { - Ok(mut addrs) => { - if let Some(addr) = addrs.next() { - addr - } else { - warn!(target_host = %target_host, "SSH proxy: DNS resolution returned no addresses"); - return; - } - } - Err(e) => { - warn!(target_host = %target_host, error = %e, "SSH proxy: DNS resolution failed"); - return; - } - }; - - if !is_safe_ssh_proxy_target(resolved.ip()) { - warn!( - target_host = %target_host, - resolved_ip = %resolved.ip(), - "SSH proxy: target resolved to blocked IP range (loopback or link-local)" - ); - return; - } - - debug!( - target_host = %target_host, - resolved_ip = %resolved.ip(), - target_port, - "SSH proxy: connecting to validated target" - ); - - let Ok(mut sandbox_conn) = TcpStream::connect(resolved).await else { - warn!(target_host = %target_host, resolved_ip = %resolved.ip(), target_port, "SSH proxy: failed to connect to sandbox"); - return; - }; - let Ok(preface) = build_preface(&uuid::Uuid::new_v4().to_string(), &handshake_secret) - else { - warn!("SSH proxy: failed to build handshake preface"); - return; - }; - if let Err(e) = sandbox_conn.write_all(preface.as_bytes()).await { - warn!(error = %e, "SSH proxy: failed to send handshake preface"); - return; - } - let mut response = String::new(); - if let Err(e) = read_line(&mut sandbox_conn, &mut response).await { - warn!(error = %e, "SSH proxy: failed to read handshake response"); - return; - } - if response.trim() != "OK" { - warn!(response = %response.trim(), "SSH proxy: handshake rejected by sandbox"); - return; - } - let _ = tokio::io::copy_bidirectional(&mut client_conn, &mut sandbox_conn).await; - }); - - Ok((port, task)) -} - fn build_preface( token: &str, secret: &str, @@ -1023,29 +956,6 @@ fn build_preface( Ok(format!("NSSH1 {token} {timestamp} {nonce} {signature}\n")) } -async fn read_line( - stream: &mut TcpStream, - buf: &mut String, -) -> Result<(), Box> { - let mut bytes = Vec::new(); - loop { - let mut byte = [0_u8; 1]; - let n = stream.read(&mut byte).await?; - if n == 0 { - break; - } - if byte[0] == b'\n' { - break; - } - bytes.push(byte[0]); - if bytes.len() > 1024 { - break; - } - } - *buf = String::from_utf8_lossy(&bytes).to_string(); - Ok(()) -} - fn hmac_sha256(key: &[u8], data: &[u8]) -> String { use hmac::{Hmac, Mac}; use sha2::Sha256; @@ -1161,59 +1071,6 @@ mod tests { assert!(build_remote_exec_command(&req).is_err()); } - // ---- is_safe_ssh_proxy_target ---- - - #[test] - fn ssh_proxy_target_allows_pod_network_ips() { - use std::net::{IpAddr, Ipv4Addr}; - assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 10, 0, 0, 5 - )))); - assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 172, 16, 0, 1 - )))); - assert!(is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 192, 168, 1, 100 - )))); - } - - #[test] - fn ssh_proxy_target_blocks_loopback() { - use std::net::{IpAddr, Ipv4Addr, Ipv6Addr}; - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 127, 0, 0, 1 - )))); - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 127, 0, 0, 2 - )))); - assert!(!is_safe_ssh_proxy_target(IpAddr::V6(Ipv6Addr::LOCALHOST))); - } - - #[test] - fn ssh_proxy_target_blocks_link_local() { - use std::net::{IpAddr, Ipv4Addr}; - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 169, 254, 169, 254 - )))); - assert!(!is_safe_ssh_proxy_target(IpAddr::V4(Ipv4Addr::new( - 169, 254, 0, 1 - )))); - } - - #[test] - fn ssh_proxy_target_blocks_ipv4_mapped_ipv6_loopback() { - use std::net::IpAddr; - let ip: IpAddr = "::ffff:127.0.0.1".parse().unwrap(); - assert!(!is_safe_ssh_proxy_target(ip)); - } - - #[test] - fn ssh_proxy_target_blocks_ipv4_mapped_ipv6_link_local() { - use std::net::IpAddr; - let ip: IpAddr = "::ffff:169.254.169.254".parse().unwrap(); - assert!(!is_safe_ssh_proxy_target(ip)); - } - // ---- petname / generate_name ---- #[test] diff --git a/crates/openshell-server/src/http.rs b/crates/openshell-server/src/http.rs index afe7edc1b..aefe4181b 100644 --- a/crates/openshell-server/src/http.rs +++ b/crates/openshell-server/src/http.rs @@ -49,6 +49,7 @@ pub fn health_router() -> Router { pub fn http_router(state: Arc) -> Router { health_router() .merge(crate::ssh_tunnel::router(state.clone())) + .merge(crate::relay::router(state.clone())) .merge(crate::ws_tunnel::router(state.clone())) .merge(crate::auth::router(state)) } diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index a8d820b4d..346aaa172 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -16,9 +16,11 @@ mod http; mod inference; mod multiplex; mod persistence; +mod relay; mod sandbox_index; mod sandbox_watch; mod ssh_tunnel; +pub(crate) mod supervisor_session; mod tls; pub mod tracing_bus; mod ws_tunnel; @@ -73,6 +75,9 @@ pub struct ServerState { /// set/delete operation, including the precedence check on sandbox /// mutations that reads global state. pub settings_mutex: tokio::sync::Mutex<()>, + + /// Registry of active supervisor sessions and pending relay channels. + pub supervisor_sessions: supervisor_session::SupervisorSessionRegistry, } fn is_benign_tls_handshake_failure(error: &std::io::Error) -> bool { @@ -103,6 +108,7 @@ impl ServerState { ssh_connections_by_token: Mutex::new(HashMap::new()), ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), + supervisor_sessions: supervisor_session::SupervisorSessionRegistry::new(), } } } diff --git a/crates/openshell-server/src/relay.rs b/crates/openshell-server/src/relay.rs new file mode 100644 index 000000000..662fe4d99 --- /dev/null +++ b/crates/openshell-server/src/relay.rs @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! HTTP CONNECT relay endpoint for supervisor-initiated reverse tunnels. +//! +//! When the gateway sends a `RelayOpen` message over the supervisor's gRPC +//! session, the supervisor opens `CONNECT /relay/{channel_id}` back to this +//! endpoint. The gateway then bridges the supervisor's upgraded stream with +//! the client's SSH tunnel or exec proxy. + +use axum::{ + Router, extract::Path, extract::State, http::Method, response::IntoResponse, routing::any, +}; +use http::StatusCode; +use hyper::upgrade::OnUpgrade; +use hyper_util::rt::TokioIo; +use std::sync::Arc; +use tokio::io::AsyncWriteExt; +use tracing::{info, warn}; + +use crate::ServerState; + +pub fn router(state: Arc) -> Router { + Router::new() + .route("/relay/{channel_id}", any(relay_connect)) + .with_state(state) +} + +async fn relay_connect( + State(state): State>, + Path(channel_id): Path, + req: hyper::Request, +) -> impl IntoResponse { + if req.method() != Method::CONNECT { + return StatusCode::METHOD_NOT_ALLOWED.into_response(); + } + + // Claim the pending relay. This consumes the entry — it cannot be reused. + let supervisor_stream = match state.supervisor_sessions.claim_relay(&channel_id) { + Ok(stream) => stream, + Err(_) => { + warn!(channel_id = %channel_id, "relay: unknown or expired channel"); + return StatusCode::NOT_FOUND.into_response(); + } + }; + + info!(channel_id = %channel_id, "relay: supervisor connected, upgrading"); + + // Upgrade the HTTP connection to a raw byte stream and bridge it to + // the DuplexStream that connects to the gateway-side waiter. + let on_upgrade: OnUpgrade = hyper::upgrade::on(req); + tokio::spawn(async move { + match on_upgrade.await { + Ok(upgraded) => { + let mut upgraded = TokioIo::new(upgraded); + let mut supervisor = supervisor_stream; + let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut supervisor).await; + let _ = AsyncWriteExt::shutdown(&mut upgraded).await; + } + Err(e) => { + warn!(channel_id = %channel_id, error = %e, "relay: upgrade failed"); + } + } + }); + + StatusCode::SWITCHING_PROTOCOLS.into_response() +} diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs index 536513ccd..de14976ac 100644 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ b/crates/openshell-server/src/ssh_tunnel.rs @@ -6,15 +6,12 @@ use axum::{Router, extract::State, http::Method, response::IntoResponse, routing::any}; use http::StatusCode; use hyper::Request; -use hyper::upgrade::Upgraded; use hyper_util::rt::TokioIo; use openshell_core::proto::{Sandbox, SandboxPhase, SshSession}; use prost::Message; -use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpStream; use tracing::{info, warn}; use uuid::Uuid; @@ -23,7 +20,6 @@ use crate::persistence::{ObjectId, ObjectName, ObjectType, Store}; const HEADER_SANDBOX_ID: &str = "x-sandbox-id"; const HEADER_TOKEN: &str = "x-sandbox-token"; -const PREFACE_MAGIC: &str = "NSSH1"; /// Maximum concurrent SSH tunnel connections per session token. const MAX_CONNECTIONS_PER_TOKEN: u32 = 3; @@ -100,19 +96,15 @@ async fn ssh_connect( return StatusCode::PRECONDITION_FAILED.into_response(); } - let connect_target = match state.compute.resolve_sandbox_endpoint(&sandbox).await { - Ok(crate::compute::ResolvedEndpoint::Ip(ip, port)) => { - ConnectTarget::Ip(SocketAddr::new(ip, port)) - } - Ok(crate::compute::ResolvedEndpoint::Host(host, port)) => ConnectTarget::Host(host, port), - Err(status) if status.code() == tonic::Code::FailedPrecondition => { - return StatusCode::PRECONDITION_FAILED.into_response(); - } - Err(err) => { - warn!(error = %err, "Failed to resolve sandbox endpoint"); + // Open a relay channel through the supervisor session. + let (channel_id, relay_rx) = match state.supervisor_sessions.open_relay(&sandbox_id).await { + Ok(pair) => pair, + Err(status) => { + warn!(sandbox_id = %sandbox_id, error = %status.message(), "SSH tunnel: supervisor session not available"); return StatusCode::BAD_GATEWAY.into_response(); } }; + // Enforce per-token concurrent connection limit. { let mut counts = state.ssh_connections_by_token.lock().unwrap(); @@ -150,20 +142,97 @@ async fn ssh_connect( let upgrade = hyper::upgrade::on(req); tokio::spawn(async move { - match upgrade.await { - Ok(mut upgraded) => { - if let Err(err) = handle_tunnel( - &mut upgraded, - connect_target, - &token_clone, - &handshake_secret, + // Wait for the supervisor's reverse CONNECT to arrive and claim the relay. + let relay_stream = match tokio::time::timeout(Duration::from_secs(10), relay_rx).await { + Ok(Ok(stream)) => stream, + Ok(Err(_)) => { + warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay channel dropped"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, &sandbox_id_clone, - ) - .await - { - warn!(error = %err, "SSH tunnel failure"); + ); + return; + } + Err(_) => { + warn!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: relay open timed out"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, + &sandbox_id_clone, + ); + return; + } + }; + + // Send NSSH1 handshake through the relay to the SSH daemon before + // bridging the client's SSH bytes. The relay carries bytes to the + // supervisor which bridges them to the local SSH daemon on loopback. + let (mut relay_read, mut relay_write) = tokio::io::split(relay_stream); + let preface = match build_preface(&token_clone, &handshake_secret) { + Ok(p) => p, + Err(e) => { + warn!(error = %e, "SSH tunnel: failed to build NSSH1 preface"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, + &sandbox_id_clone, + ); + return; + } + }; + if let Err(e) = relay_write.write_all(preface.as_bytes()).await { + warn!(error = %e, "SSH tunnel: failed to send NSSH1 preface through relay"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count(&state_clone.ssh_connections_by_sandbox, &sandbox_id_clone); + return; + } + + // Read handshake response from the SSH daemon through the relay. + let mut response_buf = Vec::new(); + loop { + let mut byte = [0u8; 1]; + match relay_read.read(&mut byte).await { + Ok(0) => break, + Ok(_) => { + if byte[0] == b'\n' { + break; + } + response_buf.push(byte[0]); + if response_buf.len() > 1024 { + break; + } + } + Err(e) => { + warn!(error = %e, "SSH tunnel: failed to read NSSH1 response from relay"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count( + &state_clone.ssh_connections_by_sandbox, + &sandbox_id_clone, + ); + return; } } + } + let response = String::from_utf8_lossy(&response_buf); + if response.trim() != "OK" { + warn!(response = %response.trim(), "SSH tunnel: NSSH1 handshake rejected by sandbox"); + decrement_connection_count(&state_clone.ssh_connections_by_token, &token_clone); + decrement_connection_count(&state_clone.ssh_connections_by_sandbox, &sandbox_id_clone); + return; + } + + info!(sandbox_id = %sandbox_id_clone, channel_id = %channel_id, "SSH tunnel: NSSH1 handshake OK, bridging client"); + + // Reunite the split relay halves and bridge with the client's upgraded stream. + let mut relay = relay_read.unsplit(relay_write); + + match upgrade.await { + Ok(upgraded) => { + let mut upgraded = TokioIo::new(upgraded); + let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut relay).await; + let _ = AsyncWriteExt::shutdown(&mut upgraded).await; + } Err(err) => { warn!(error = %err, "SSH upgrade failed"); } @@ -177,90 +246,6 @@ async fn ssh_connect( StatusCode::OK.into_response() } -async fn handle_tunnel( - upgraded: &mut Upgraded, - target: ConnectTarget, - token: &str, - secret: &str, - sandbox_id: &str, -) -> Result<(), Box> { - // The sandbox pod may not be network-reachable immediately after the CRD - // reports Ready (DNS propagation, pod IP assignment, SSH server startup). - // Retry the TCP connection with exponential backoff. - let mut upstream = None; - let mut last_err = None; - let delays = [ - Duration::from_millis(100), - Duration::from_millis(250), - Duration::from_millis(500), - Duration::from_secs(1), - Duration::from_secs(2), - Duration::from_secs(5), - Duration::from_secs(10), - Duration::from_secs(15), - ]; - let target_desc = match &target { - ConnectTarget::Ip(addr) => format!("{addr}"), - ConnectTarget::Host(host, port) => format!("{host}:{port}"), - }; - info!(sandbox_id = %sandbox_id, target = %target_desc, "SSH tunnel: connecting to sandbox"); - for (attempt, delay) in std::iter::once(&Duration::ZERO) - .chain(delays.iter()) - .enumerate() - { - if !delay.is_zero() { - info!(sandbox_id = %sandbox_id, attempt = attempt + 1, delay_ms = delay.as_millis() as u64, "SSH tunnel: retrying TCP connect"); - tokio::time::sleep(*delay).await; - } - let result = match &target { - ConnectTarget::Ip(addr) => TcpStream::connect(addr).await, - ConnectTarget::Host(host, port) => TcpStream::connect((host.as_str(), *port)).await, - }; - match result { - Ok(stream) => { - info!( - sandbox_id = %sandbox_id, - attempts = attempt + 1, - "SSH tunnel: TCP connected to sandbox" - ); - upstream = Some(stream); - break; - } - Err(err) => { - info!(sandbox_id = %sandbox_id, attempt = attempt + 1, error = %err, "SSH tunnel: TCP connect failed"); - last_err = Some(err); - } - } - } - let mut upstream = upstream.ok_or_else(|| { - let err = last_err.unwrap(); - format!("failed to connect to sandbox after retries: {err}") - })?; - upstream.set_nodelay(true)?; - info!(sandbox_id = %sandbox_id, "SSH tunnel: sending NSSH1 handshake preface"); - let preface = build_preface(token, secret)?; - upstream.write_all(preface.as_bytes()).await?; - - info!(sandbox_id = %sandbox_id, "SSH tunnel: waiting for handshake response"); - let mut response = String::new(); - read_line(&mut upstream, &mut response).await?; - info!(sandbox_id = %sandbox_id, response = %response.trim(), "SSH tunnel: handshake response received"); - if response.trim() != "OK" { - return Err("sandbox handshake rejected".into()); - } - - info!(sandbox_id = %sandbox_id, "SSH tunnel established"); - let mut upgraded = TokioIo::new(upgraded); - // Discard the result entirely – connection-close errors are expected when - // the SSH session ends and do not represent a failure worth propagating. - let _ = tokio::io::copy_bidirectional(&mut upgraded, &mut upstream).await; - // Gracefully shut down the write-half of the upgraded connection so the - // client receives a clean EOF instead of a TCP RST. This gives SSH time - // to read any remaining protocol data (e.g. exit-status) from its buffer. - let _ = AsyncWriteExt::shutdown(&mut upgraded).await; - Ok(()) -} - fn header_value(headers: &http::HeaderMap, name: &str) -> Result { let value = headers .get(name) @@ -275,6 +260,8 @@ fn header_value(headers: &http::HeaderMap, name: &str) -> Result Result<(), Box> { - let mut bytes = Vec::new(); - loop { - let mut byte = [0u8; 1]; - let n = stream.read(&mut byte).await?; - if n == 0 { - break; - } - if byte[0] == b'\n' { - break; - } - bytes.push(byte[0]); - if bytes.len() > 1024 { - break; - } - } - *buf = String::from_utf8_lossy(&bytes).to_string(); - Ok(()) -} - fn hmac_sha256(key: &[u8], data: &[u8]) -> String { use hmac::{Hmac, Mac}; use sha2::Sha256; @@ -345,11 +309,6 @@ impl ObjectName for SshSession { } } -enum ConnectTarget { - Ip(SocketAddr), - Host(String, u16), -} - /// Decrement a connection count entry, removing it if it reaches zero. fn decrement_connection_count( counts: &std::sync::Mutex>, diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs new file mode 100644 index 000000000..ed33f8e15 --- /dev/null +++ b/crates/openshell-server/src/supervisor_session.rs @@ -0,0 +1,440 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use tokio::sync::{mpsc, oneshot}; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status}; +use tracing::{info, warn}; +use uuid::Uuid; + +use openshell_core::proto::{ + GatewayMessage, RelayOpen, SessionAccepted, SupervisorMessage, gateway_message, + supervisor_message, +}; + +use crate::ServerState; + +const HEARTBEAT_INTERVAL_SECS: u32 = 15; +const RELAY_PENDING_TIMEOUT: Duration = Duration::from_secs(10); + +// --------------------------------------------------------------------------- +// Session registry +// --------------------------------------------------------------------------- + +/// A live supervisor session handle. +struct LiveSession { + #[allow(dead_code)] + sandbox_id: String, + tx: mpsc::Sender, + #[allow(dead_code)] + connected_at: Instant, +} + +/// Holds a oneshot sender that will deliver the upgraded relay stream. +type RelayStreamSender = oneshot::Sender; + +/// Registry of active supervisor sessions and pending relay channels. +#[derive(Default)] +pub struct SupervisorSessionRegistry { + /// sandbox_id -> live session handle. + sessions: Mutex>, + /// channel_id -> oneshot sender for the reverse CONNECT stream. + pending_relays: Mutex>, +} + +struct PendingRelay { + sender: RelayStreamSender, + created_at: Instant, +} + +impl std::fmt::Debug for SupervisorSessionRegistry { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let session_count = self.sessions.lock().unwrap().len(); + let pending_count = self.pending_relays.lock().unwrap().len(); + f.debug_struct("SupervisorSessionRegistry") + .field("sessions", &session_count) + .field("pending_relays", &pending_count) + .finish() + } +} + +impl SupervisorSessionRegistry { + pub fn new() -> Self { + Self::default() + } + + /// Register a live supervisor session for the given sandbox. + /// + /// Returns the previous session's sender (if any) so the caller can close it. + fn register( + &self, + sandbox_id: String, + tx: mpsc::Sender, + ) -> Option> { + let mut sessions = self.sessions.lock().unwrap(); + let previous = sessions.remove(&sandbox_id).map(|s| s.tx); + sessions.insert( + sandbox_id.clone(), + LiveSession { + sandbox_id, + tx, + connected_at: Instant::now(), + }, + ); + previous + } + + /// Remove the session for a sandbox. + fn remove(&self, sandbox_id: &str) { + self.sessions.lock().unwrap().remove(sandbox_id); + } + + /// Open a relay channel: sends RelayOpen to the supervisor and returns a + /// stream that will be connected once the supervisor's reverse HTTP CONNECT + /// arrives. + /// + /// Returns `(channel_id, receiver_for_relay_stream)`. + pub async fn open_relay( + &self, + sandbox_id: &str, + ) -> Result<(String, oneshot::Receiver), Status> { + let channel_id = Uuid::new_v4().to_string(); + + // Look up the session and send RelayOpen. + let tx = { + let sessions = self.sessions.lock().unwrap(); + let session = sessions + .get(sandbox_id) + .ok_or_else(|| Status::unavailable("supervisor session not connected"))?; + session.tx.clone() + }; + + // Register the pending relay before sending RelayOpen to avoid a race. + let (relay_tx, relay_rx) = oneshot::channel(); + { + let mut pending = self.pending_relays.lock().unwrap(); + pending.insert( + channel_id.clone(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + } + + let msg = GatewayMessage { + payload: Some(gateway_message::Payload::RelayOpen(RelayOpen { + channel_id: channel_id.clone(), + })), + }; + + if tx.send(msg).await.is_err() { + // Session dropped between our lookup and send. + self.pending_relays.lock().unwrap().remove(&channel_id); + return Err(Status::unavailable("supervisor session disconnected")); + } + + Ok((channel_id, relay_rx)) + } + + /// Claim a pending relay channel. Called by the /relay/{channel_id} HTTP handler + /// when the supervisor's reverse CONNECT arrives. + /// + /// Returns the DuplexStream half that the supervisor side should read/write. + pub fn claim_relay(&self, channel_id: &str) -> Result { + let pending = { + let mut map = self.pending_relays.lock().unwrap(); + map.remove(channel_id) + .ok_or_else(|| Status::not_found("unknown or expired relay channel"))? + }; + + if pending.created_at.elapsed() > RELAY_PENDING_TIMEOUT { + return Err(Status::deadline_exceeded("relay channel timed out")); + } + + // Create a duplex stream pair: one end for the gateway bridge, one for + // the supervisor HTTP CONNECT handler. + let (gateway_stream, supervisor_stream) = tokio::io::duplex(64 * 1024); + + // Send the gateway-side stream to the waiter (ssh_tunnel or exec handler). + if pending.sender.send(gateway_stream).is_err() { + return Err(Status::internal("relay requester dropped")); + } + + Ok(supervisor_stream) + } + + /// Remove all pending relays that have exceeded the timeout. + pub fn reap_expired_relays(&self) { + let mut map = self.pending_relays.lock().unwrap(); + map.retain(|_, pending| pending.created_at.elapsed() <= RELAY_PENDING_TIMEOUT); + } + + /// Clean up all state for a sandbox (session + pending relays). + pub fn cleanup_sandbox(&self, sandbox_id: &str) { + self.remove(sandbox_id); + } +} + +// --------------------------------------------------------------------------- +// ConnectSupervisor gRPC handler +// --------------------------------------------------------------------------- + +pub async fn handle_connect_supervisor( + state: &Arc, + request: Request>, +) -> Result< + Response< + Pin> + Send + 'static>>, + >, + Status, +> { + let mut inbound = request.into_inner(); + + // Step 1: Wait for SupervisorHello. + let hello = match inbound.message().await? { + Some(msg) => match msg.payload { + Some(supervisor_message::Payload::Hello(hello)) => hello, + _ => return Err(Status::invalid_argument("expected SupervisorHello")), + }, + None => return Err(Status::invalid_argument("stream closed before hello")), + }; + + let sandbox_id = hello.sandbox_id.clone(); + if sandbox_id.is_empty() { + return Err(Status::invalid_argument("sandbox_id is required")); + } + + let session_id = Uuid::new_v4().to_string(); + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + instance_id = %hello.instance_id, + "supervisor session: accepted" + ); + + // Step 2: Create the outbound channel and register the session. + let (tx, rx) = mpsc::channel::(64); + if let Some(_previous_tx) = state + .supervisor_sessions + .register(sandbox_id.clone(), tx.clone()) + { + info!(sandbox_id = %sandbox_id, "supervisor session: superseded previous session"); + } + + // Step 3: Send SessionAccepted. + let accepted = GatewayMessage { + payload: Some(gateway_message::Payload::SessionAccepted(SessionAccepted { + session_id: session_id.clone(), + heartbeat_interval_secs: HEARTBEAT_INTERVAL_SECS, + })), + }; + if tx.send(accepted).await.is_err() { + state.supervisor_sessions.remove(&sandbox_id); + return Err(Status::internal("failed to send session accepted")); + } + + // Step 4: Spawn the session loop that reads inbound messages. + let state_clone = Arc::clone(state); + let sandbox_id_clone = sandbox_id.clone(); + tokio::spawn(async move { + run_session_loop( + &state_clone, + &sandbox_id_clone, + &session_id, + &tx, + &mut inbound, + ) + .await; + state_clone.supervisor_sessions.remove(&sandbox_id_clone); + info!(sandbox_id = %sandbox_id_clone, session_id = %session_id, "supervisor session: ended"); + }); + + // Return the outbound stream. + let stream = ReceiverStream::new(rx); + let stream: Pin< + Box> + Send + 'static>, + > = Box::pin(tokio_stream::StreamExt::map(stream, Ok)); + + Ok(Response::new(stream)) +} + +async fn run_session_loop( + _state: &Arc, + sandbox_id: &str, + session_id: &str, + tx: &mpsc::Sender, + inbound: &mut tonic::Streaming, +) { + let heartbeat_interval = Duration::from_secs(u64::from(HEARTBEAT_INTERVAL_SECS)); + let mut heartbeat_timer = tokio::time::interval(heartbeat_interval); + // Skip the first immediate tick. + heartbeat_timer.tick().await; + + loop { + tokio::select! { + msg = inbound.message() => { + match msg { + Ok(Some(msg)) => { + handle_supervisor_message(sandbox_id, session_id, msg); + } + Ok(None) => { + info!(sandbox_id = %sandbox_id, session_id = %session_id, "supervisor session: stream closed by supervisor"); + break; + } + Err(e) => { + warn!(sandbox_id = %sandbox_id, session_id = %session_id, error = %e, "supervisor session: stream error"); + break; + } + } + } + _ = heartbeat_timer.tick() => { + let hb = GatewayMessage { + payload: Some(gateway_message::Payload::Heartbeat( + openshell_core::proto::GatewayHeartbeat {}, + )), + }; + if tx.send(hb).await.is_err() { + info!(sandbox_id = %sandbox_id, session_id = %session_id, "supervisor session: outbound channel closed"); + break; + } + } + } + } +} + +fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: SupervisorMessage) { + match msg.payload { + Some(supervisor_message::Payload::Heartbeat(_)) => { + // Heartbeat received — nothing to do for now. + } + Some(supervisor_message::Payload::RelayOpenResult(result)) => { + if result.success { + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + channel_id = %result.channel_id, + "supervisor session: relay opened successfully" + ); + } else { + warn!( + sandbox_id = %sandbox_id, + session_id = %session_id, + channel_id = %result.channel_id, + error = %result.error, + "supervisor session: relay open failed" + ); + } + } + Some(supervisor_message::Payload::RelayClose(close)) => { + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + channel_id = %close.channel_id, + reason = %close.reason, + "supervisor session: relay closed by supervisor" + ); + } + _ => { + warn!( + sandbox_id = %sandbox_id, + session_id = %session_id, + "supervisor session: unexpected message type" + ); + } + } +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn registry_register_and_lookup() { + let registry = SupervisorSessionRegistry::new(); + let (tx, _rx) = mpsc::channel(1); + + assert!(registry.register("sandbox-1".to_string(), tx).is_none()); + + // Should find the session. + let sessions = registry.sessions.lock().unwrap(); + assert!(sessions.contains_key("sandbox-1")); + } + + #[test] + fn registry_supersedes_previous_session() { + let registry = SupervisorSessionRegistry::new(); + let (tx1, _rx1) = mpsc::channel(1); + let (tx2, _rx2) = mpsc::channel(1); + + assert!(registry.register("sandbox-1".to_string(), tx1).is_none()); + assert!(registry.register("sandbox-1".to_string(), tx2).is_some()); + } + + #[test] + fn registry_remove() { + let registry = SupervisorSessionRegistry::new(); + let (tx, _rx) = mpsc::channel(1); + registry.register("sandbox-1".to_string(), tx); + + registry.remove("sandbox-1"); + let sessions = registry.sessions.lock().unwrap(); + assert!(!sessions.contains_key("sandbox-1")); + } + + #[test] + fn claim_relay_unknown_channel() { + let registry = SupervisorSessionRegistry::new(); + let result = registry.claim_relay("nonexistent"); + assert!(result.is_err()); + } + + #[test] + fn claim_relay_success() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-1".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + let result = registry.claim_relay("ch-1"); + assert!(result.is_ok()); + // Should be consumed. + assert!(!registry.pending_relays.lock().unwrap().contains_key("ch-1")); + } + + #[test] + fn reap_expired_relays() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-old".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now() - Duration::from_secs(60), + }, + ); + + registry.reap_expired_relays(); + assert!( + !registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-old") + ); + } +} diff --git a/crates/openshell-server/tests/auth_endpoint_integration.rs b/crates/openshell-server/tests/auth_endpoint_integration.rs index 7c6545873..cd2abe157 100644 --- a/crates/openshell-server/tests/auth_endpoint_integration.rs +++ b/crates/openshell-server/tests/auth_endpoint_integration.rs @@ -528,6 +528,9 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { type ExecSandboxStream = tokio_stream::wrappers::ReceiverStream< Result, >; + type ConnectSupervisorStream = tokio_stream::wrappers::ReceiverStream< + Result, + >; async fn watch_sandbox( &self, @@ -663,6 +666,13 @@ impl openshell_core::proto::open_shell_server::OpenShell for TestOpenShell { { Err(tonic::Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, tonic::Status> { + Err(tonic::Status::unimplemented("not implemented in test")) + } } /// Test 7: Plaintext server (no TLS) accepts both gRPC and HTTP. diff --git a/crates/openshell-server/tests/edge_tunnel_auth.rs b/crates/openshell-server/tests/edge_tunnel_auth.rs index 22f08434d..a5d6a88e9 100644 --- a/crates/openshell-server/tests/edge_tunnel_auth.rs +++ b/crates/openshell-server/tests/edge_tunnel_auth.rs @@ -37,13 +37,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -186,6 +187,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -307,6 +309,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/crates/openshell-server/tests/multiplex_integration.rs b/crates/openshell-server/tests/multiplex_integration.rs index 1957c5b87..8b93b0989 100644 --- a/crates/openshell-server/tests/multiplex_integration.rs +++ b/crates/openshell-server/tests/multiplex_integration.rs @@ -11,13 +11,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -154,6 +155,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -275,6 +277,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } #[tokio::test] diff --git a/crates/openshell-server/tests/multiplex_tls_integration.rs b/crates/openshell-server/tests/multiplex_tls_integration.rs index 98d5d6256..4d77e8cae 100644 --- a/crates/openshell-server/tests/multiplex_tls_integration.rs +++ b/crates/openshell-server/tests/multiplex_tls_integration.rs @@ -13,13 +13,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -167,6 +168,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -288,6 +290,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } /// PKI bundle: CA cert, server cert+key, client cert+key. diff --git a/crates/openshell-server/tests/ws_tunnel_integration.rs b/crates/openshell-server/tests/ws_tunnel_integration.rs index 54a7354c8..705e9de49 100644 --- a/crates/openshell-server/tests/ws_tunnel_integration.rs +++ b/crates/openshell-server/tests/ws_tunnel_integration.rs @@ -40,13 +40,14 @@ use hyper_util::{ use openshell_core::proto::{ CreateProviderRequest, CreateSandboxRequest, CreateSshSessionRequest, CreateSshSessionResponse, DeleteProviderRequest, DeleteProviderResponse, DeleteSandboxRequest, DeleteSandboxResponse, - ExecSandboxEvent, ExecSandboxRequest, GetGatewayConfigRequest, GetGatewayConfigResponse, - GetProviderRequest, GetSandboxConfigRequest, GetSandboxConfigResponse, - GetSandboxProviderEnvironmentRequest, GetSandboxProviderEnvironmentResponse, GetSandboxRequest, - HealthRequest, HealthResponse, ListProvidersRequest, ListProvidersResponse, - ListSandboxesRequest, ListSandboxesResponse, ProviderResponse, RevokeSshSessionRequest, - RevokeSshSessionResponse, SandboxResponse, SandboxStreamEvent, ServiceStatus, - UpdateProviderRequest, WatchSandboxRequest, + ExecSandboxEvent, ExecSandboxRequest, GatewayMessage, GetGatewayConfigRequest, + GetGatewayConfigResponse, GetProviderRequest, GetSandboxConfigRequest, + GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + GetSandboxProviderEnvironmentResponse, GetSandboxRequest, HealthRequest, HealthResponse, + ListProvidersRequest, ListProvidersResponse, ListSandboxesRequest, ListSandboxesResponse, + ProviderResponse, RevokeSshSessionRequest, RevokeSshSessionResponse, SandboxResponse, + SandboxStreamEvent, ServiceStatus, SupervisorMessage, UpdateProviderRequest, + WatchSandboxRequest, open_shell_client::OpenShellClient, open_shell_server::{OpenShell, OpenShellServer}, }; @@ -180,6 +181,7 @@ impl OpenShell for TestOpenShell { type WatchSandboxStream = ReceiverStream>; type ExecSandboxStream = ReceiverStream>; + type ConnectSupervisorStream = ReceiverStream>; async fn watch_sandbox( &self, @@ -301,6 +303,13 @@ impl OpenShell for TestOpenShell { ) -> Result, Status> { Err(Status::unimplemented("not implemented in test")) } + + async fn connect_supervisor( + &self, + _request: tonic::Request>, + ) -> Result, Status> { + Err(Status::unimplemented("not implemented in test")) + } } // --------------------------------------------------------------------------- diff --git a/proto/compute_driver.proto b/proto/compute_driver.proto index 53b0ac27d..68af695e5 100644 --- a/proto/compute_driver.proto +++ b/proto/compute_driver.proto @@ -38,10 +38,6 @@ service ComputeDriver { // Tear down platform resources for a sandbox. rpc DeleteSandbox(DeleteSandboxRequest) returns (DeleteSandboxResponse); - // Resolve the current endpoint for sandbox exec/SSH transport. - rpc ResolveSandboxEndpoint(ResolveSandboxEndpointRequest) - returns (ResolveSandboxEndpointResponse); - // Stream sandbox observations from the platform. rpc WatchSandboxes(WatchSandboxesRequest) returns (stream WatchSandboxesEvent); } @@ -238,27 +234,6 @@ message DeleteSandboxResponse { bool deleted = 1; } -message ResolveSandboxEndpointRequest { - // Sandbox to resolve for exec or SSH connectivity. - DriverSandbox sandbox = 1; -} - -message SandboxEndpoint { - oneof target { - // Direct IP address for the sandbox endpoint. - string ip = 1; - // DNS host name for the sandbox endpoint. - string host = 2; - } - // TCP port for the sandbox endpoint. - uint32 port = 3; -} - -message ResolveSandboxEndpointResponse { - // Current endpoint the gateway should use to reach the sandbox. - SandboxEndpoint endpoint = 1; -} - message WatchSandboxesRequest {} message WatchSandboxesSandboxEvent { diff --git a/proto/openshell.proto b/proto/openshell.proto index 0ee1e8904..53812c977 100644 --- a/proto/openshell.proto +++ b/proto/openshell.proto @@ -91,6 +91,14 @@ service OpenShell { // Push sandbox supervisor logs to the server (client-streaming). rpc PushSandboxLogs(stream PushSandboxLogsRequest) returns (PushSandboxLogsResponse); + // Persistent supervisor-to-gateway session (bidirectional streaming). + // + // The supervisor opens this stream at startup and keeps it alive for the + // sandbox lifetime. The gateway uses it to coordinate relay channels for + // SSH connect and ExecSandbox. SSH bytes flow over separate reverse HTTP + // CONNECT tunnels, not over this stream. + rpc ConnectSupervisor(stream SupervisorMessage) returns (stream GatewayMessage); + // Watch a sandbox and stream updates. // // This stream can include: @@ -704,6 +712,87 @@ message GetSandboxLogsResponse { uint32 buffer_total = 2; } +// --------------------------------------------------------------------------- +// Supervisor session messages +// --------------------------------------------------------------------------- + +// Envelope for supervisor-to-gateway messages on the ConnectSupervisor stream. +message SupervisorMessage { + oneof payload { + SupervisorHello hello = 1; + SupervisorHeartbeat heartbeat = 2; + RelayOpenResult relay_open_result = 3; + RelayClose relay_close = 4; + } +} + +// Envelope for gateway-to-supervisor messages on the ConnectSupervisor stream. +message GatewayMessage { + oneof payload { + SessionAccepted session_accepted = 1; + SessionRejected session_rejected = 2; + GatewayHeartbeat heartbeat = 3; + RelayOpen relay_open = 4; + RelayClose relay_close = 5; + } +} + +// Supervisor identifies itself and the sandbox it manages. +message SupervisorHello { + // Sandbox ID this supervisor manages. + string sandbox_id = 1; + // Supervisor instance ID (e.g. boot id or process epoch). + string instance_id = 2; +} + +// Gateway accepts the supervisor session. +message SessionAccepted { + // Gateway-assigned session ID for this connection. + string session_id = 1; + // Recommended heartbeat interval in seconds. + uint32 heartbeat_interval_secs = 2; +} + +// Gateway rejects the supervisor session. +message SessionRejected { + // Human-readable rejection reason. + string reason = 1; +} + +// Supervisor heartbeat. +message SupervisorHeartbeat {} + +// Gateway heartbeat. +message GatewayHeartbeat {} + +// Gateway requests the supervisor to open a relay channel. +// +// On receiving this, the supervisor should open a reverse HTTP CONNECT +// to the gateway's /relay/{channel_id} endpoint and bridge it to the +// local SSH daemon. +message RelayOpen { + // Gateway-allocated channel identifier (UUID). + string channel_id = 1; +} + +// Supervisor reports the result of a relay open request. +message RelayOpenResult { + // Channel identifier from the RelayOpen request. + string channel_id = 1; + // True if the relay was successfully established. + bool success = 2; + // Error message if success is false. + string error = 3; +} + +// Either side requests closure of a relay channel. +message RelayClose { + // Channel identifier to close. + string channel_id = 1; + // Optional reason for closure. + string reason = 2; +} + // --------------------------------------------------------------------------- // Service status // --------------------------------------------------------------------------- From 193bacc7ca5425a99873227c597ab7ab56407abe Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Wed, 15 Apr 2026 21:09:53 -0700 Subject: [PATCH 2/3] fix(server): wait for supervisor session before opening relay MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When a sandbox first reports Ready, the supervisor session may not have completed its gRPC handshake yet. Instead of failing immediately with 502 / "supervisor session not connected", the relay open now retries with exponential backoff (100ms → 2s) for up to 15 seconds. This fixes the race between K8s marking the pod Ready and the supervisor establishing its ConnectSupervisor session. --- crates/openshell-server/src/grpc/sandbox.rs | 6 ++-- crates/openshell-server/src/ssh_tunnel.rs | 10 +++++-- .../src/supervisor_session.rs | 28 +++++++++++++++++++ 3 files changed, 40 insertions(+), 4 deletions(-) diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index cdc7b51dd..58269b46a 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -436,10 +436,12 @@ pub(super) async fn handle_exec_sandbox( return Err(Status::failed_precondition("sandbox is not ready")); } - // Open a relay channel through the supervisor session. + // Open a relay channel through the supervisor session. The session may + // not be established yet right after the sandbox reports Ready, so wait + // briefly for it to appear. let (channel_id, relay_rx) = state .supervisor_sessions - .open_relay(&sandbox.id) + .open_relay_with_wait(&sandbox.id, std::time::Duration::from_secs(15)) .await .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs index de14976ac..8317aa7bb 100644 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ b/crates/openshell-server/src/ssh_tunnel.rs @@ -96,8 +96,14 @@ async fn ssh_connect( return StatusCode::PRECONDITION_FAILED.into_response(); } - // Open a relay channel through the supervisor session. - let (channel_id, relay_rx) = match state.supervisor_sessions.open_relay(&sandbox_id).await { + // Open a relay channel through the supervisor session. The session may + // not be established yet right after the sandbox reports Ready, so wait + // briefly for it to appear. + let (channel_id, relay_rx) = match state + .supervisor_sessions + .open_relay_with_wait(&sandbox_id, Duration::from_secs(15)) + .await + { Ok(pair) => pair, Err(status) => { warn!(sandbox_id = %sandbox_id, error = %status.message(), "SSH tunnel: supervisor session not available"); diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index ed33f8e15..5e3ccf852 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -94,6 +94,34 @@ impl SupervisorSessionRegistry { self.sessions.lock().unwrap().remove(sandbox_id); } + /// Open a relay channel, waiting for the supervisor session to appear. + /// + /// The supervisor session may not be established yet when the sandbox first + /// reports Ready (race between K8s readiness and gRPC session handshake). + /// This method retries the session lookup with short backoff before failing. + pub async fn open_relay_with_wait( + &self, + sandbox_id: &str, + timeout: Duration, + ) -> Result<(String, oneshot::Receiver), Status> { + let deadline = Instant::now() + timeout; + let mut backoff = Duration::from_millis(100); + + loop { + match self.open_relay(sandbox_id).await { + Ok(result) => return Ok(result), + Err(status) if status.code() == tonic::Code::Unavailable => { + if Instant::now() + backoff > deadline { + return Err(status); + } + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(Duration::from_secs(2)); + } + Err(status) => return Err(status), + } + } + } + /// Open a relay channel: sends RelayOpen to the supervisor and returns a /// stream that will be connected once the supervisor's reverse HTTP CONNECT /// arrives. From c698f5304530fc703d42095beb11156235c49974 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 16 Apr 2026 11:51:01 -0700 Subject: [PATCH 3/3] refactor(server): harden supervisor session relay lifecycle MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Three related changes: 1. Fold the session-wait into `open_relay` itself via a new `wait_for_session` helper with exponential backoff (100ms → 2s). Callers pass an explicit `session_wait_timeout`: - SSH connect uses 30s — it typically runs right after `sandbox create`, so the timeout has to cover a cold supervisor's TLS + gRPC handshake. - ExecSandbox uses 15s — during normal operation it only needs to cover a transient supervisor reconnect window. This covers both the startup race (pod Ready before the supervisor's ConnectSupervisor stream is up) and mid-lifetime reconnects after a network blip or gateway/supervisor restart — both look identical to the caller. 2. Fix a supersede cleanup race. `LiveSession` now tracks a `session_id`, and `remove_if_current(sandbox_id, session_id)` only evicts when the registered entry still matches. Previously an old session's cleanup could run after a reconnect had already registered the new session, unconditionally removing the live registration. 3. Wire up `spawn_relay_reaper` alongside the existing SSH session reaper so expired pending relay entries (supervisor acknowledged RelayOpen but never opened the reverse CONNECT) are swept every 30s instead of leaking until someone tries to claim them. Adds 12 unit tests covering: open_relay happy path, timeout, mid-wait session appearance, closed-receiver failure, supersede routing; claim_relay unknown/expired/receiver-dropped/round-trip; and the remove_if_current cleanup-race regression. --- crates/openshell-server/src/grpc/sandbox.rs | 9 +- crates/openshell-server/src/lib.rs | 1 + crates/openshell-server/src/ssh_tunnel.rs | 10 +- .../src/supervisor_session.rs | 425 ++++++++++++++++-- 4 files changed, 391 insertions(+), 54 deletions(-) diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 58269b46a..bdda63d6a 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -436,12 +436,13 @@ pub(super) async fn handle_exec_sandbox( return Err(Status::failed_precondition("sandbox is not ready")); } - // Open a relay channel through the supervisor session. The session may - // not be established yet right after the sandbox reports Ready, so wait - // briefly for it to appear. + // Open a relay channel through the supervisor session. Use a 15s + // session-wait timeout — enough to cover a transient supervisor + // reconnect, but shorter than `/connect/ssh` since `ExecSandbox` is + // typically called during normal operation (not right after create). let (channel_id, relay_rx) = state .supervisor_sessions - .open_relay_with_wait(&sandbox.id, std::time::Duration::from_secs(15)) + .open_relay(&sandbox.id, std::time::Duration::from_secs(15)) .await .map_err(|e| Status::unavailable(format!("supervisor relay failed: {e}")))?; diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 346aaa172..cbef28b0e 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -154,6 +154,7 @@ pub async fn run_server(config: Config, tracing_log_bus: TracingLogBus) -> Resul state.compute.spawn_watchers(); ssh_tunnel::spawn_session_reaper(store.clone(), std::time::Duration::from_secs(3600)); + supervisor_session::spawn_relay_reaper(state.clone(), std::time::Duration::from_secs(30)); // Create the multiplexed service let service = MultiplexService::new(state.clone()); diff --git a/crates/openshell-server/src/ssh_tunnel.rs b/crates/openshell-server/src/ssh_tunnel.rs index 8317aa7bb..8b7d6b48d 100644 --- a/crates/openshell-server/src/ssh_tunnel.rs +++ b/crates/openshell-server/src/ssh_tunnel.rs @@ -96,12 +96,14 @@ async fn ssh_connect( return StatusCode::PRECONDITION_FAILED.into_response(); } - // Open a relay channel through the supervisor session. The session may - // not be established yet right after the sandbox reports Ready, so wait - // briefly for it to appear. + // Open a relay channel through the supervisor session. Use a generous + // 30s session-wait timeout because `/connect/ssh` is typically called + // immediately after `sandbox create`, so we need to cover the supervisor's + // initial TLS + gRPC handshake on a cold-started pod. The old + // direct-connect path tolerated ~34s here for similar reasons. let (channel_id, relay_rx) = match state .supervisor_sessions - .open_relay_with_wait(&sandbox_id, Duration::from_secs(15)) + .open_relay(&sandbox_id, Duration::from_secs(30)) .await { Ok(pair) => pair, diff --git a/crates/openshell-server/src/supervisor_session.rs b/crates/openshell-server/src/supervisor_session.rs index 5e3ccf852..d79540d42 100644 --- a/crates/openshell-server/src/supervisor_session.rs +++ b/crates/openshell-server/src/supervisor_session.rs @@ -21,6 +21,10 @@ use crate::ServerState; const HEARTBEAT_INTERVAL_SECS: u32 = 15; const RELAY_PENDING_TIMEOUT: Duration = Duration::from_secs(10); +/// Initial backoff between session-availability polls in `wait_for_session`. +const SESSION_WAIT_INITIAL_BACKOFF: Duration = Duration::from_millis(100); +/// Maximum backoff between session-availability polls in `wait_for_session`. +const SESSION_WAIT_MAX_BACKOFF: Duration = Duration::from_secs(2); // --------------------------------------------------------------------------- // Session registry @@ -30,6 +34,9 @@ const RELAY_PENDING_TIMEOUT: Duration = Duration::from_secs(10); struct LiveSession { #[allow(dead_code)] sandbox_id: String, + /// Uniquely identifies this session instance. Used by cleanup to avoid + /// removing a session that has since been superseded by a reconnect. + session_id: String, tx: mpsc::Sender, #[allow(dead_code)] connected_at: Instant, @@ -74,6 +81,7 @@ impl SupervisorSessionRegistry { fn register( &self, sandbox_id: String, + session_id: String, tx: mpsc::Sender, ) -> Option> { let mut sessions = self.sessions.lock().unwrap(); @@ -82,6 +90,7 @@ impl SupervisorSessionRegistry { sandbox_id.clone(), LiveSession { sandbox_id, + session_id, tx, connected_at: Instant::now(), }, @@ -94,53 +103,87 @@ impl SupervisorSessionRegistry { self.sessions.lock().unwrap().remove(sandbox_id); } - /// Open a relay channel, waiting for the supervisor session to appear. + /// Remove the session only if its `session_id` matches the one we are + /// cleaning up. Returns `true` if the entry was removed. /// - /// The supervisor session may not be established yet when the sandbox first - /// reports Ready (race between K8s readiness and gRPC session handshake). - /// This method retries the session lookup with short backoff before failing. - pub async fn open_relay_with_wait( + /// This guards against the supersede race: an old session's task may + /// finish long after a new session has taken its place. The old task's + /// cleanup must not evict the new registration. + fn remove_if_current(&self, sandbox_id: &str, session_id: &str) -> bool { + let mut sessions = self.sessions.lock().unwrap(); + let is_current = sessions + .get(sandbox_id) + .is_some_and(|s| s.session_id == session_id); + if is_current { + sessions.remove(sandbox_id); + } + is_current + } + + /// Look up the sender for a supervisor session, waiting up to `timeout` + /// for it to appear if absent. + /// + /// Uses exponential backoff (100ms → 2s) while polling the sessions map. + async fn wait_for_session( &self, sandbox_id: &str, timeout: Duration, - ) -> Result<(String, oneshot::Receiver), Status> { + ) -> Result, Status> { let deadline = Instant::now() + timeout; - let mut backoff = Duration::from_millis(100); + let mut backoff = SESSION_WAIT_INITIAL_BACKOFF; loop { - match self.open_relay(sandbox_id).await { - Ok(result) => return Ok(result), - Err(status) if status.code() == tonic::Code::Unavailable => { - if Instant::now() + backoff > deadline { - return Err(status); - } - tokio::time::sleep(backoff).await; - backoff = (backoff * 2).min(Duration::from_secs(2)); - } - Err(status) => return Err(status), + if let Some(tx) = self.lookup_session(sandbox_id) { + return Ok(tx); + } + if Instant::now() + backoff > deadline { + return Err(Status::unavailable("supervisor session not connected")); } + tokio::time::sleep(backoff).await; + backoff = (backoff * 2).min(SESSION_WAIT_MAX_BACKOFF); } } - /// Open a relay channel: sends RelayOpen to the supervisor and returns a - /// stream that will be connected once the supervisor's reverse HTTP CONNECT - /// arrives. + fn lookup_session(&self, sandbox_id: &str) -> Option> { + self.sessions + .lock() + .unwrap() + .get(sandbox_id) + .map(|s| s.tx.clone()) + } + + /// Open a relay channel and return a receiver for the supervisor-side + /// stream. + /// + /// Sends `RelayOpen` over the supervisor's gRPC session and returns a + /// oneshot receiver that resolves once the supervisor opens its reverse + /// HTTP CONNECT to `/relay/{channel_id}`. + /// + /// If the session is not currently registered, this method waits up to + /// `session_wait_timeout` for it to appear. A session may be temporarily + /// absent for several reasons — all of which look identical from here: /// - /// Returns `(channel_id, receiver_for_relay_stream)`. + /// - startup race: the sandbox just reported Ready but the supervisor's + /// `ConnectSupervisor` gRPC handshake hasn't completed yet + /// - transient disconnect: the session was up but got dropped (network + /// blip, gateway restart, supervisor restart) and the supervisor is + /// in its reconnect backoff loop + /// + /// Callers pick the timeout based on how much patience the caller needs. + /// A first `sandbox connect` right after `sandbox create` may need to + /// wait for the supervisor's initial TLS + gRPC handshake (tens of + /// seconds on a slow cluster), while mid-lifetime calls typically just + /// need to cover a short reconnect window. pub async fn open_relay( &self, sandbox_id: &str, + session_wait_timeout: Duration, ) -> Result<(String, oneshot::Receiver), Status> { - let channel_id = Uuid::new_v4().to_string(); + let tx = self + .wait_for_session(sandbox_id, session_wait_timeout) + .await?; - // Look up the session and send RelayOpen. - let tx = { - let sessions = self.sessions.lock().unwrap(); - let session = sessions - .get(sandbox_id) - .ok_or_else(|| Status::unavailable("supervisor session not connected"))?; - session.tx.clone() - }; + let channel_id = Uuid::new_v4().to_string(); // Register the pending relay before sending RelayOpen to avoid a race. let (relay_tx, relay_rx) = oneshot::channel(); @@ -209,6 +252,23 @@ impl SupervisorSessionRegistry { } } +/// Spawn a background task that periodically reaps expired pending relay +/// entries. +/// +/// Pending entries are normally consumed either when the supervisor opens its +/// reverse CONNECT (via `claim_relay`) or by the gateway-side waiter timing +/// out. If neither happens — e.g., the supervisor crashed after acknowledging +/// `RelayOpen` but before dialing back — the entry would otherwise sit in the +/// map indefinitely. This sweeper bounds that leak. +pub fn spawn_relay_reaper(state: Arc, interval: Duration) { + tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + state.supervisor_sessions.reap_expired_relays(); + } + }); +} + // --------------------------------------------------------------------------- // ConnectSupervisor gRPC handler // --------------------------------------------------------------------------- @@ -248,11 +308,16 @@ pub async fn handle_connect_supervisor( // Step 2: Create the outbound channel and register the session. let (tx, rx) = mpsc::channel::(64); - if let Some(_previous_tx) = state - .supervisor_sessions - .register(sandbox_id.clone(), tx.clone()) + if let Some(_previous_tx) = + state + .supervisor_sessions + .register(sandbox_id.clone(), session_id.clone(), tx.clone()) { - info!(sandbox_id = %sandbox_id, "supervisor session: superseded previous session"); + info!( + sandbox_id = %sandbox_id, + session_id = %session_id, + "supervisor session: superseded previous session" + ); } // Step 3: Send SessionAccepted. @@ -263,7 +328,11 @@ pub async fn handle_connect_supervisor( })), }; if tx.send(accepted).await.is_err() { - state.supervisor_sessions.remove(&sandbox_id); + // Only evict ourselves — a faster reconnect may already have + // superseded this registration. + state + .supervisor_sessions + .remove_if_current(&sandbox_id, &session_id); return Err(Status::internal("failed to send session accepted")); } @@ -279,8 +348,14 @@ pub async fn handle_connect_supervisor( &mut inbound, ) .await; - state_clone.supervisor_sessions.remove(&sandbox_id_clone); - info!(sandbox_id = %sandbox_id_clone, session_id = %session_id, "supervisor session: ended"); + let still_ours = state_clone + .supervisor_sessions + .remove_if_current(&sandbox_id_clone, &session_id); + if still_ours { + info!(sandbox_id = %sandbox_id_clone, session_id = %session_id, "supervisor session: ended"); + } else { + info!(sandbox_id = %sandbox_id_clone, session_id = %session_id, "supervisor session: ended (already superseded)"); + } }); // Return the outbound stream. @@ -385,15 +460,21 @@ fn handle_supervisor_message(sandbox_id: &str, session_id: &str, msg: Supervisor #[cfg(test)] mod tests { use super::*; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + // ---- registry: register / remove ---- #[test] fn registry_register_and_lookup() { let registry = SupervisorSessionRegistry::new(); let (tx, _rx) = mpsc::channel(1); - assert!(registry.register("sandbox-1".to_string(), tx).is_none()); + assert!( + registry + .register("sandbox-1".to_string(), "s1".to_string(), tx) + .is_none() + ); - // Should find the session. let sessions = registry.sessions.lock().unwrap(); assert!(sessions.contains_key("sandbox-1")); } @@ -404,26 +485,182 @@ mod tests { let (tx1, _rx1) = mpsc::channel(1); let (tx2, _rx2) = mpsc::channel(1); - assert!(registry.register("sandbox-1".to_string(), tx1).is_none()); - assert!(registry.register("sandbox-1".to_string(), tx2).is_some()); + assert!( + registry + .register("sandbox-1".to_string(), "s1".to_string(), tx1) + .is_none() + ); + assert!( + registry + .register("sandbox-1".to_string(), "s2".to_string(), tx2) + .is_some() + ); } #[test] fn registry_remove() { let registry = SupervisorSessionRegistry::new(); let (tx, _rx) = mpsc::channel(1); - registry.register("sandbox-1".to_string(), tx); + registry.register("sandbox-1".to_string(), "s1".to_string(), tx); registry.remove("sandbox-1"); let sessions = registry.sessions.lock().unwrap(); assert!(!sessions.contains_key("sandbox-1")); } + #[test] + fn remove_if_current_removes_matching_session() { + let registry = SupervisorSessionRegistry::new(); + let (tx, _rx) = mpsc::channel(1); + registry.register("sbx".to_string(), "s1".to_string(), tx); + + assert!(registry.remove_if_current("sbx", "s1")); + assert!(!registry.sessions.lock().unwrap().contains_key("sbx")); + } + + #[test] + fn remove_if_current_ignores_stale_session_id() { + let registry = SupervisorSessionRegistry::new(); + let (tx_old, _rx_old) = mpsc::channel(1); + let (tx_new, _rx_new) = mpsc::channel(1); + + // Old session registers, then is superseded by a new session. + registry.register("sbx".to_string(), "s-old".to_string(), tx_old); + registry.register("sbx".to_string(), "s-new".to_string(), tx_new); + + // Cleanup from the old session task runs late. It must NOT evict the + // newly registered session. + assert!(!registry.remove_if_current("sbx", "s-old")); + let sessions = registry.sessions.lock().unwrap(); + assert!( + sessions.contains_key("sbx"), + "new session must still be registered" + ); + assert_eq!(sessions.get("sbx").unwrap().session_id, "s-new"); + } + + #[test] + fn remove_if_current_unknown_sandbox_is_noop() { + let registry = SupervisorSessionRegistry::new(); + assert!(!registry.remove_if_current("sbx-does-not-exist", "s1")); + } + + // ---- open_relay: happy path and wait semantics ---- + + #[tokio::test] + async fn open_relay_sends_relay_open_to_registered_session() { + let registry = SupervisorSessionRegistry::new(); + let (tx, mut rx) = mpsc::channel(4); + registry.register("sbx".to_string(), "s1".to_string(), tx); + + let (channel_id, _relay_rx) = registry + .open_relay("sbx", Duration::from_secs(1)) + .await + .expect("open_relay should succeed when session is live"); + + let msg = rx.recv().await.expect("relay open should be delivered"); + match msg.payload { + Some(gateway_message::Payload::RelayOpen(open)) => { + assert_eq!(open.channel_id, channel_id); + } + other => panic!("expected RelayOpen, got {other:?}"), + } + } + + #[tokio::test] + async fn open_relay_times_out_without_session() { + let registry = SupervisorSessionRegistry::new(); + let err = registry + .open_relay("missing", Duration::from_millis(50)) + .await + .expect_err("open_relay should time out"); + assert_eq!(err.code(), tonic::Code::Unavailable); + } + + #[tokio::test] + async fn open_relay_waits_for_session_to_appear() { + let registry = Arc::new(SupervisorSessionRegistry::new()); + let registry_for_register = Arc::clone(®istry); + + // Register the session after a small delay, shorter than the wait. + tokio::spawn(async move { + tokio::time::sleep(Duration::from_millis(200)).await; + let (tx, mut rx) = mpsc::channel::(4); + // Keep the receiver alive so the send in open_relay succeeds. + tokio::spawn(async move { while rx.recv().await.is_some() {} }); + registry_for_register.register("sbx".to_string(), "s1".to_string(), tx); + }); + + let result = registry.open_relay("sbx", Duration::from_secs(2)).await; + assert!( + result.is_ok(), + "open_relay should succeed when session arrives mid-wait: {result:?}" + ); + } + + #[tokio::test] + async fn open_relay_fails_when_session_receiver_dropped() { + let registry = SupervisorSessionRegistry::new(); + let (tx, rx) = mpsc::channel::(4); + registry.register("sbx".to_string(), "s1".to_string(), tx); + + // Simulate the supervisor's stream going away between lookup and send: + // the receiver held by `ReceiverStream` is dropped. + drop(rx); + + let err = registry + .open_relay("sbx", Duration::from_secs(1)) + .await + .expect_err("open_relay should fail when mpsc is closed"); + assert_eq!(err.code(), tonic::Code::Unavailable); + // The pending-relay entry must have been cleaned up on failure. + assert!(registry.pending_relays.lock().unwrap().is_empty()); + } + + #[tokio::test] + async fn open_relay_uses_newest_session_after_supersede() { + let registry = SupervisorSessionRegistry::new(); + let (tx_old, mut rx_old) = mpsc::channel::(4); + let (tx_new, mut rx_new) = mpsc::channel(4); + + // Hold a clone of the old sender so supersede doesn't close the old + // channel — that way try_recv distinguishes "no message sent" from + // "channel closed". + let _tx_old_alive = tx_old.clone(); + + registry.register("sbx".to_string(), "s-old".to_string(), tx_old); + registry.register("sbx".to_string(), "s-new".to_string(), tx_new); + + let (_channel_id, _relay_rx) = registry + .open_relay("sbx", Duration::from_secs(1)) + .await + .expect("open_relay should succeed"); + + let msg = rx_new + .recv() + .await + .expect("new session should receive RelayOpen"); + assert!(matches!( + msg.payload, + Some(gateway_message::Payload::RelayOpen(_)) + )); + + // The old session must have received no messages — the channel is + // still open but empty. + use tokio::sync::mpsc::error::TryRecvError; + match rx_old.try_recv() { + Err(TryRecvError::Empty) => {} + other => panic!("expected Empty on superseded session, got {other:?}"), + } + } + + // ---- claim_relay: expiry, drop, wiring ---- + #[test] fn claim_relay_unknown_channel() { let registry = SupervisorSessionRegistry::new(); - let result = registry.claim_relay("nonexistent"); - assert!(result.is_err()); + let err = registry.claim_relay("nonexistent").expect_err("should err"); + assert_eq!(err.code(), tonic::Code::NotFound); } #[test] @@ -440,12 +677,86 @@ mod tests { let result = registry.claim_relay("ch-1"); assert!(result.is_ok()); - // Should be consumed. assert!(!registry.pending_relays.lock().unwrap().contains_key("ch-1")); } #[test] - fn reap_expired_relays() { + fn claim_relay_expired_returns_deadline_exceeded() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-old".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now() - Duration::from_secs(60), + }, + ); + + let err = registry + .claim_relay("ch-old") + .expect_err("expired entry must fail"); + assert_eq!(err.code(), tonic::Code::DeadlineExceeded); + // Entry must have been consumed regardless. + assert!( + !registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-old") + ); + } + + #[test] + fn claim_relay_receiver_dropped_returns_internal() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, relay_rx) = oneshot::channel::(); + drop(relay_rx); // Gateway-side waiter has given up already. + registry.pending_relays.lock().unwrap().insert( + "ch-1".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + let err = registry + .claim_relay("ch-1") + .expect_err("should err when receiver is gone"); + assert_eq!(err.code(), tonic::Code::Internal); + } + + #[tokio::test] + async fn claim_relay_connects_both_ends() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, relay_rx) = oneshot::channel::(); + registry.pending_relays.lock().unwrap().insert( + "ch-io".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + let mut supervisor_side = registry.claim_relay("ch-io").expect("claim should succeed"); + let mut gateway_side = relay_rx.await.expect("gateway side should receive stream"); + + // Supervisor side writes → gateway side reads. + supervisor_side.write_all(b"hello").await.unwrap(); + let mut buf = [0u8; 5]; + gateway_side.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"hello"); + + // Gateway side writes → supervisor side reads. + gateway_side.write_all(b"world").await.unwrap(); + let mut buf = [0u8; 5]; + supervisor_side.read_exact(&mut buf).await.unwrap(); + assert_eq!(&buf, b"world"); + } + + // ---- reap_expired_relays ---- + + #[test] + fn reap_expired_relays_removes_old_entries() { let registry = SupervisorSessionRegistry::new(); let (relay_tx, _relay_rx) = oneshot::channel(); registry.pending_relays.lock().unwrap().insert( @@ -465,4 +776,26 @@ mod tests { .contains_key("ch-old") ); } + + #[test] + fn reap_expired_relays_keeps_fresh_entries() { + let registry = SupervisorSessionRegistry::new(); + let (relay_tx, _relay_rx) = oneshot::channel(); + registry.pending_relays.lock().unwrap().insert( + "ch-fresh".to_string(), + PendingRelay { + sender: relay_tx, + created_at: Instant::now(), + }, + ); + + registry.reap_expired_relays(); + assert!( + registry + .pending_relays + .lock() + .unwrap() + .contains_key("ch-fresh") + ); + } }