Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 54 additions & 3 deletions lib/guest/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ func recordGuestExecWait(span trace.Span, start time.Time, attempts, retryableAt
}

func startGuestExecSpan(ctx context.Context, opts ExecOptions) (context.Context, trace.Span) {
return otel.Tracer("hypeman/guest").Start(ctx, "guest.exec", trace.WithAttributes(
return guestTracer().Start(ctx, "guest.exec", trace.WithAttributes(
attribute.String("command_name", execCommandName(opts.Command)),
attribute.Bool("tty", opts.TTY),
attribute.Bool("wait_for_agent", opts.WaitForAgent > 0),
Expand All @@ -251,6 +251,30 @@ func startGuestExecSpan(ctx context.Context, opts ExecOptions) (context.Context,
))
}

func guestTracer() trace.Tracer {
return otel.Tracer("hypeman/guest")
}

func startGuestExecStep(ctx context.Context, opts ExecOptions, name string, attrs ...attribute.KeyValue) (trace.Span, func(error)) {
if opts.WaitForAgent == 0 {
return nil, func(error) {}
}
_, span := guestTracer().Start(ctx, name, trace.WithAttributes(attrs...))
return span, func(err error) {
finishGuestExecStepSpan(span, err)
}
}

func finishGuestExecStepSpan(span trace.Span, err error) {
if err != nil {
span.RecordError(err)
span.SetStatus(otelcodes.Error, err.Error())
} else {
span.SetStatus(otelcodes.Ok, "")
}
span.End()
}

func finishGuestExecSpan(span trace.Span, exit *ExitStatus, err error) {
if exit != nil {
span.SetAttributes(attribute.Int("exit_code", exit.Code))
Expand Down Expand Up @@ -307,22 +331,27 @@ func execIntoInstanceOnce(ctx context.Context, dialer hypervisor.VsockDialer, op
var bytesSent int64

// Get or create a reusable gRPC connection for this vsock dialer
_, finishGetConn := startGuestExecStep(ctx, opts, "guest.exec.get_conn")
grpcConn, err := GetOrCreateConn(ctx, dialer)
finishGetConn(err)
if err != nil {
return nil, fmt.Errorf("get grpc connection: %w", err)
}
// Note: Don't close the connection - it's pooled and reused

// Create guest client
client := NewGuestServiceClient(grpcConn)
_, finishOpenStream := startGuestExecStep(ctx, opts, "guest.exec.open_stream")
stream, err := client.Exec(ctx)
finishOpenStream(err)
if err != nil {
return nil, fmt.Errorf("start exec stream: %w", err)
}
// Ensure stream is properly closed when we're done
defer stream.CloseSend()

// Send start request with initial window size
_, finishSendStart := startGuestExecStep(ctx, opts, "guest.exec.send_start")
if err := stream.Send(&ExecRequest{
Request: &ExecRequest_Start{
Start: &ExecStart{
Expand All @@ -336,8 +365,10 @@ func execIntoInstanceOnce(ctx context.Context, dialer hypervisor.VsockDialer, op
},
},
}); err != nil {
finishSendStart(err)
return nil, fmt.Errorf("send start request: %w", err)
}
finishSendStart(nil)

// Mutex to protect concurrent stream.Send/CloseSend calls (gRPC streams are not thread-safe)
var streamMu sync.Mutex
Expand Down Expand Up @@ -383,13 +414,32 @@ func execIntoInstanceOnce(ctx context.Context, dialer hypervisor.VsockDialer, op

// Receive responses
var totalStdout, totalStderr int
recvSpan, finishRecv := startGuestExecStep(ctx, opts, "guest.exec.recv_until_exit")
finishReceive := func(err error, exitCode *int) {
if recvSpan != nil {
attrs := []attribute.KeyValue{
attribute.Int("stdout_bytes", totalStdout),
attribute.Int("stderr_bytes", totalStderr),
attribute.Int64("bytes_sent", atomic.LoadInt64(&bytesSent)),
}
if exitCode != nil {
attrs = append(attrs, attribute.Int("exit_code", *exitCode))
}
recvSpan.SetAttributes(attrs...)
}
finishRecv(err)
}
for {
resp, err := stream.Recv()
if err == io.EOF {
return nil, fmt.Errorf("stream closed without exit code (stdout=%d, stderr=%d)", totalStdout, totalStderr)
err := fmt.Errorf("stream closed without exit code (stdout=%d, stderr=%d)", totalStdout, totalStderr)
finishReceive(err, nil)
return nil, err
}
if err != nil {
return nil, fmt.Errorf("receive response (stdout=%d, stderr=%d): %w", totalStdout, totalStderr, err)
err := fmt.Errorf("receive response (stdout=%d, stderr=%d): %w", totalStdout, totalStderr, err)
finishReceive(err, nil)
return nil, err
}

switch r := resp.Response.(type) {
Expand All @@ -410,6 +460,7 @@ func execIntoInstanceOnce(ctx context.Context, dialer hypervisor.VsockDialer, op
bytesReceived := int64(totalStdout + totalStderr)
GuestMetrics.RecordExecSession(ctx, start, exitCode, atomic.LoadInt64(&bytesSent), bytesReceived)
}
finishReceive(nil, &exitCode)
return &ExitStatus{Code: exitCode}, nil
}
}
Expand Down
117 changes: 117 additions & 0 deletions lib/guest/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ import (
"errors"
"io"
"net"
"strings"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/kernel/hypeman/lib/hypervisor"
"go.opentelemetry.io/otel"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/sdk/trace/tracetest"
"google.golang.org/grpc"
)

Expand Down Expand Up @@ -88,6 +92,7 @@ func TestExecIntoInstanceRetriesWithFreshConnections(t *testing.T) {
key: "retry-fresh-connection-test",
readyAt: time.Now().Add(100 * time.Millisecond),
}
t.Cleanup(func() { CloseConn(dialer.Key()) })

start := time.Now()
exit, err := ExecIntoInstance(context.Background(), dialer, ExecOptions{
Expand All @@ -109,6 +114,88 @@ func TestExecIntoInstanceRetriesWithFreshConnections(t *testing.T) {
}
}

func TestExecIntoInstanceTracesDetailedWaitForAgentPath(t *testing.T) {
recorder, cleanup := newGuestTraceRecorder(t)
defer cleanup()

dialer := &delayedDialer{
key: "trace-wait-for-agent-test",
readyAt: time.Now(),
}
t.Cleanup(func() { CloseConn(dialer.Key()) })

exit, err := ExecIntoInstance(context.Background(), dialer, ExecOptions{
Command: []string{"true"},
WaitForAgent: time.Second,
})
if err != nil {
t.Fatalf("ExecIntoInstance failed: %v", err)
}
if exit.Code != 0 {
t.Fatalf("exit code = %d, want 0", exit.Code)
}

for _, name := range []string{
"guest.exec",
"guest.exec.get_conn",
"guest.exec.open_stream",
"guest.exec.send_start",
"guest.exec.recv_until_exit",
} {
if findEndedSpan(recorder, name) == nil {
t.Fatalf("missing span %q", name)
}
}

parent := findEndedSpan(recorder, "guest.exec")
parentAttrs := spanAttributes(parent)
if parentAttrs["attempts"] != "1" {
t.Fatalf("attempts = %s, want 1", parentAttrs["attempts"])
}
if parentAttrs["retryable_attempts"] != "0" {
t.Fatalf("retryable_attempts = %s, want 0", parentAttrs["retryable_attempts"])
}

recv := findEndedSpan(recorder, "guest.exec.recv_until_exit")
recvAttrs := spanAttributes(recv)
if recvAttrs["exit_code"] != "0" {
t.Fatalf("recv exit_code = %s, want 0", recvAttrs["exit_code"])
}
if recvAttrs["stdout_bytes"] != "0" {
t.Fatalf("stdout_bytes = %s, want 0", recvAttrs["stdout_bytes"])
}
if recvAttrs["stderr_bytes"] != "0" {
t.Fatalf("stderr_bytes = %s, want 0", recvAttrs["stderr_bytes"])
}
}

func TestExecIntoInstanceSkipsDetailedTraceWhenNotWaiting(t *testing.T) {
recorder, cleanup := newGuestTraceRecorder(t)
defer cleanup()

dialer := &delayedDialer{
key: "trace-no-wait-test",
readyAt: time.Now(),
}
t.Cleanup(func() { CloseConn(dialer.Key()) })

exit, err := ExecIntoInstance(context.Background(), dialer, ExecOptions{
Command: []string{"true"},
})
if err != nil {
t.Fatalf("ExecIntoInstance failed: %v", err)
}
if exit.Code != 0 {
t.Fatalf("exit code = %d, want 0", exit.Code)
}

for _, span := range recorder.Ended() {
if strings.HasPrefix(span.Name(), "guest.exec") {
t.Fatalf("unexpected detailed guest exec span %q", span.Name())
}
}
}

func TestCloseConnClosesPooledConnection(t *testing.T) {
dialer := &trackingDialer{
key: "close-conn-test",
Expand Down Expand Up @@ -143,6 +230,36 @@ func waitForTrackedConn(t *testing.T, conns <-chan *closeTrackingConn) *closeTra
}
}

func newGuestTraceRecorder(t *testing.T) (*tracetest.SpanRecorder, func()) {
t.Helper()

recorder := tracetest.NewSpanRecorder()
provider := sdktrace.NewTracerProvider(sdktrace.WithSpanProcessor(recorder))
previous := otel.GetTracerProvider()
otel.SetTracerProvider(provider)
return recorder, func() {
otel.SetTracerProvider(previous)
_ = provider.Shutdown(context.Background())
}
}

func findEndedSpan(recorder *tracetest.SpanRecorder, name string) sdktrace.ReadOnlySpan {
for _, span := range recorder.Ended() {
if span.Name() == name {
return span
}
}
return nil
}

func spanAttributes(span sdktrace.ReadOnlySpan) map[string]string {
attrs := make(map[string]string, len(span.Attributes()))
for _, attr := range span.Attributes() {
attrs[string(attr.Key)] = attr.Value.Emit()
}
return attrs
}

type delayedDialer struct {
key string
readyAt time.Time
Expand Down
Loading