diff --git a/cmd/ateapi/internal/credbundle/credbundle.go b/cmd/ateapi/internal/credbundle/credbundle.go index 58d86df76..281840adc 100644 --- a/cmd/ateapi/internal/credbundle/credbundle.go +++ b/cmd/ateapi/internal/credbundle/credbundle.go @@ -20,6 +20,7 @@ package credbundle import ( + "crypto" "crypto/tls" "crypto/x509" "encoding/pem" @@ -47,6 +48,7 @@ func Parse(bundlePath string) (*tls.Certificate, error) { } var leafKeyBytes []byte + var leafKeyBlockType string var chainBytes [][]byte for { @@ -59,8 +61,9 @@ func Parse(bundlePath string) (*tls.Certificate, error) { switch block.Type { case "CERTIFICATE": chainBytes = append(chainBytes, block.Bytes) - case "PRIVATE KEY": + case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY": leafKeyBytes = block.Bytes + leafKeyBlockType = block.Type default: return nil, fmt.Errorf("unknown PEM block type %q", block.Type) } @@ -74,7 +77,7 @@ func Parse(bundlePath string) (*tls.Certificate, error) { return nil, fmt.Errorf("no CERTIFICATE blocks found") } - leafKey, err := x509.ParsePKCS8PrivateKey(leafKeyBytes) + leafKey, err := parsePrivateKey(leafKeyBlockType, leafKeyBytes) if err != nil { return nil, fmt.Errorf("while parsing private key: %w", err) } @@ -90,3 +93,16 @@ func Parse(bundlePath string) (*tls.Certificate, error) { PrivateKey: leafKey, }, nil } + +func parsePrivateKey(blockType string, keyBytes []byte) (crypto.PrivateKey, error) { + switch blockType { + case "PRIVATE KEY": + return x509.ParsePKCS8PrivateKey(keyBytes) + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(keyBytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(keyBytes) + default: + return nil, fmt.Errorf("unsupported private key block type %q", blockType) + } +} diff --git a/cmd/ateapi/internal/credbundle/credbundle_test.go b/cmd/ateapi/internal/credbundle/credbundle_test.go new file mode 100644 index 000000000..14d6ff4b1 --- /dev/null +++ b/cmd/ateapi/internal/credbundle/credbundle_test.go @@ -0,0 +1,127 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package credbundle + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "testing" + "time" +) + +func TestParsePrivateKeyBlockTypes(t *testing.T) { + for _, tt := range []struct { + name string + blockType string + keyDER func(t *testing.T) []byte + }{ + { + name: "pkcs8", + blockType: "PRIVATE KEY", + keyDER: func(t *testing.T) []byte { + key := generateRSAKey(t) + der, err := x509.MarshalPKCS8PrivateKey(key) + if err != nil { + t.Fatalf("marshal PKCS8 key: %v", err) + } + return der + }, + }, + { + name: "rsa", + blockType: "RSA PRIVATE KEY", + keyDER: func(t *testing.T) []byte { + return x509.MarshalPKCS1PrivateKey(generateRSAKey(t)) + }, + }, + { + name: "ec", + blockType: "EC PRIVATE KEY", + keyDER: func(t *testing.T) []byte { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("generate EC key: %v", err) + } + der, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("marshal EC key: %v", err) + } + return der + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + certDER := generateCertificate(t) + bundle := append(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), pem.EncodeToMemory(&pem.Block{Type: tt.blockType, Bytes: tt.keyDER(t)})...) + + bundlePath := writeBundle(t, bundle) + cert, err := Parse(bundlePath) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if len(cert.Certificate) != 1 { + t.Fatalf("Parse() certificate chain length = %d, want 1", len(cert.Certificate)) + } + if cert.PrivateKey == nil { + t.Fatalf("Parse() private key is nil") + } + if cert.Leaf == nil { + t.Fatalf("Parse() leaf certificate is nil") + } + }) + } +} + +func generateRSAKey(t *testing.T) *rsa.PrivateKey { + t.Helper() + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("generate RSA key: %v", err) + } + return key +} + +func generateCertificate(t *testing.T) []byte { + t.Helper() + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "api.ate-system.svc"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + DNSNames: []string{"api.ate-system.svc"}, + } + key := generateRSAKey(t) + der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("create certificate: %v", err) + } + return der +} + +func writeBundle(t *testing.T, bundle []byte) string { + t.Helper() + path := t.TempDir() + "/bundle.pem" + if err := os.WriteFile(path, bundle, 0o600); err != nil { + t.Fatalf("write bundle: %v", err) + } + return path +} diff --git a/cmd/ateapi/internal/k8sjwt/k8sjwt.go b/cmd/ateapi/internal/k8sjwt/k8sjwt.go index e4403d0df..5b343ec4a 100644 --- a/cmd/ateapi/internal/k8sjwt/k8sjwt.go +++ b/cmd/ateapi/internal/k8sjwt/k8sjwt.go @@ -119,7 +119,9 @@ var permittedSkew = 5 * time.Minute // the object binding claims. If needed for your use case, you will need check the object bindings // by connecting to the cluster and seeing if the object(s) the bindings name still exist within the // cluster. -func Verify(ctx context.Context, jwt string, expectedIssuer, expectedAudience string, now time.Time) (*KubernetesClaims, error) { +// +// httpClient is used for OIDC discovery and JWKS fetches; nil uses http.DefaultClient. +func Verify(ctx context.Context, httpClient *http.Client, jwt string, expectedIssuer, expectedAudience string, now time.Time) (*KubernetesClaims, error) { segments := strings.Split(jwt, ".") if len(segments) != 3 { return nil, fmt.Errorf("malformed JWT") @@ -169,7 +171,7 @@ func Verify(ctx context.Context, jwt string, expectedIssuer, expectedAudience st } // TODO: Cache keys, and only fetch new keys if the JWT's key ID is not in the cache. - keys, err := discoverKeysForIssuer(ctx, rawClaims.Issuer) + keys, err := discoverKeysForIssuer(ctx, httpClient, rawClaims.Issuer) if err != nil { return nil, fmt.Errorf("while discovering keys from issuer: %w", err) } @@ -358,7 +360,7 @@ type jwkT struct { RSAE string `json:"e"` } -func discoverKeysForIssuer(ctx context.Context, issuer string) ([]*KeyAndID, error) { +func discoverKeysForIssuer(ctx context.Context, httpClient *http.Client, issuer string) ([]*KeyAndID, error) { var discoveryDocURL string if strings.HasSuffix(issuer, "/") { discoveryDocURL = issuer + ".well-known/openid-configuration" @@ -366,14 +368,14 @@ func discoverKeysForIssuer(ctx context.Context, issuer string) ([]*KeyAndID, err discoveryDocURL = issuer + "/.well-known/openid-configuration" } - oidcConfig, err := fetchJSON[oidcConfigT](discoveryDocURL) + oidcConfig, err := fetchJSON[oidcConfigT](httpClient, discoveryDocURL) if err != nil { return nil, fmt.Errorf("while fetching OIDC Discovery document: %w", err) } slog.InfoContext(ctx, "Fetched discovery doc", slog.Any("doc", oidcConfig)) - jwkSet, err := fetchJSON[jwkSetT](oidcConfig.JWKSURI) + jwkSet, err := fetchJSON[jwkSetT](httpClient, oidcConfig.JWKSURI) if err != nil { return nil, fmt.Errorf("while fetching JWKS: %w", err) } @@ -424,10 +426,12 @@ func discoverKeysForIssuer(ctx context.Context, issuer string) ([]*KeyAndID, err return ret, nil } -func fetchJSON[T any](url string) (T, error) { +func fetchJSON[T any](httpClient *http.Client, url string) (T, error) { var parsedBody T - - resp, err := http.Get(url) + if httpClient == nil { + httpClient = http.DefaultClient + } + resp, err := httpClient.Get(url) if err != nil { return parsedBody, fmt.Errorf("while making HTTP request: %w", err) } diff --git a/cmd/ateapi/internal/sessionidentity/sessionidentity.go b/cmd/ateapi/internal/sessionidentity/sessionidentity.go index b2bf43f29..a9a581271 100644 --- a/cmd/ateapi/internal/sessionidentity/sessionidentity.go +++ b/cmd/ateapi/internal/sessionidentity/sessionidentity.go @@ -21,6 +21,7 @@ import ( "crypto/x509/pkix" "fmt" "log/slog" + "net/http" "net/url" "os" "path" @@ -51,17 +52,19 @@ type Server struct { sessionIDCAPoolFile string workerCACerts string + httpClient *http.Client } var _ ateapipb.SessionIdentityServer = (*Server)(nil) -func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPoolFile, workerCACerts string) *Server { +func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPoolFile, workerCACerts string, httpClient *http.Client) *Server { return &Server{ clientJWTIssuer: clientJWTIssuer, clientJWTAudience: clientJWTAudience, sessionIDJWTPoolFile: sessionIDJWTPoolFile, sessionIDCAPoolFile: sessionIDCAPoolFile, workerCACerts: workerCACerts, + httpClient: httpClient, } } @@ -78,7 +81,7 @@ func (s *Server) MintJWT(ctx context.Context, req *ateapipb.MintJWTRequest) (*at clientJWT := strings.TrimPrefix(authorization[0], "Bearer ") - clientClaims, err := k8sjwt.Verify(ctx, clientJWT, s.clientJWTIssuer, s.clientJWTAudience, time.Now()) + clientClaims, err := k8sjwt.Verify(ctx, s.httpClient, clientJWT, s.clientJWTIssuer, s.clientJWTAudience, time.Now()) if err != nil { slog.ErrorContext(ctx, "Error while verifying client JWT", slog.Any("err", err)) return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated") diff --git a/cmd/ateapi/main.go b/cmd/ateapi/main.go index 34101bb12..223d0bb2b 100644 --- a/cmd/ateapi/main.go +++ b/cmd/ateapi/main.go @@ -21,13 +21,17 @@ import ( "fmt" "log/slog" "net" + "net/http" "os" + "strings" "time" "github.com/agent-substrate/substrate/cmd/ateapi/internal/controlapi" "github.com/agent-substrate/substrate/cmd/ateapi/internal/credbundle" + "github.com/agent-substrate/substrate/cmd/ateapi/internal/k8sjwt" "github.com/agent-substrate/substrate/cmd/ateapi/internal/sessionidentity" "github.com/agent-substrate/substrate/cmd/ateapi/internal/store/ateredis" + "github.com/agent-substrate/substrate/internal/ateapiauth" "github.com/agent-substrate/substrate/internal/ateinterceptors" "github.com/agent-substrate/substrate/internal/serverboot" "github.com/agent-substrate/substrate/internal/version" @@ -64,7 +68,9 @@ var ( sessionIDCAPoolFile = pflag.String("session-id-ca-pool", "", "The file that contains the CA pool for signing session JWTs") workerpoolCACerts = pflag.String("workerpool-ca-certs", "", "The file that contains the CA for verifying workerpool client certificates.") - showVersion = pflag.Bool("version", false, "Print version and exit.") + showVersion = pflag.Bool("version", false, "Print version and exit.") + authMode = pflag.String("auth-mode", "mtls", "Auth mode for incoming gRPC: mtls|jwt. 'mtls' (default) relies on transport-level mTLS for client identity. 'jwt' additionally requires a Kubernetes ServiceAccount Bearer token on every RPC.") + clientJWTCAFile = pflag.String("client-jwt-ca-cert", ateapiauth.DefaultServiceAccountCAFile, "CA cert file used to verify TLS when fetching the OIDC discovery document and JWKS for JWT authentication. Defaults to the in-cluster service account CA.") ) func main() { @@ -94,6 +100,11 @@ func main() { loadFlagsFromEnv() logFlagValues(ctx) + authModeParsed, err := ateapiauth.ParseMode(*authMode) + if err != nil { + serverboot.Fatal(ctx, "Invalid --auth-mode", err) + } + redisClient, err := connectRedis(ctx) if err != nil { serverboot.Fatal(ctx, "Failed to set up Redis/Valkey", err) @@ -133,7 +144,9 @@ func main() { dialer := controlapi.NewAteletDialer(workerPodInformer.GetIndexer(), ateletPodInformer.GetIndexer()) sm := controlapi.NewService(redisPersistence, actorTemplateLister, dialer, clientset) - sessionIdentitySrv := sessionidentity.New(*clientJWTIssuer, *clientJWTAudience, *sessionIDJWTPoolFile, *sessionIDCAPoolFile, *workerpoolCACerts) + jwtHTTPClient := buildJWTHTTPClient(ctx, *clientJWTCAFile) + + sessionIdentitySrv := sessionidentity.New(*clientJWTIssuer, *clientJWTAudience, *sessionIDJWTPoolFile, *sessionIDCAPoolFile, *workerpoolCACerts, jwtHTTPClient) lisCfg := &net.ListenConfig{} lis, err := lisCfg.Listen(ctx, "tcp", *listenAddr) @@ -141,10 +154,27 @@ func main() { serverboot.Fatal(ctx, "Failed to start listener", err) } + authCfg := ateapiauth.ServerConfig{ + Mode: authModeParsed, + VerifyBearerToken: func(ctx context.Context, bearer string) error { + _, err := k8sjwt.Verify(ctx, jwtHTTPClient, bearer, *clientJWTIssuer, *clientJWTAudience, time.Now()) + return err + }, + } + if err := ateapiauth.ValidateServerConfig(authCfg); err != nil { + serverboot.Fatal(ctx, "Invalid auth config", err) + } + mux := grpc.NewServer( grpc.Creds(serverCreds), grpc.StatsHandler(otelgrpc.NewServerHandler()), - grpc.UnaryInterceptor(ateinterceptors.ServerUnaryInterceptor), + grpc.ChainUnaryInterceptor( + ateapiauth.UnaryServerInterceptor(authCfg), + ateinterceptors.ServerUnaryInterceptor, + ), + grpc.ChainStreamInterceptor( + ateapiauth.StreamServerInterceptor(authCfg), + ), ) reflection.Register(mux) ateapipb.RegisterControlServer(mux, sm) @@ -196,6 +226,7 @@ func logFlagValues(ctx context.Context) { slog.String("session-id-jwt-pool", *sessionIDJWTPoolFile), slog.String("session-id-ca-pool", *sessionIDCAPoolFile), slog.String("workerpool-ca-certs", *workerpoolCACerts), + slog.String("auth-mode", *authMode), ) } @@ -325,3 +356,46 @@ func buildServerCreds(ctx context.Context) (credentials.TransportCredentials, er ClientCAs: clientCAs, }), nil } + +// buildJWTHTTPClient returns an *http.Client that trusts caFile for TLS +// verification and injects the pod's ServiceAccount Bearer token, used when +// fetching the OIDC discovery document and JWKS from the in-cluster Kubernetes +// API server. Returns nil (use http.DefaultClient) if caFile is empty or unreadable. +func buildJWTHTTPClient(ctx context.Context, caFile string) *http.Client { + if caFile == "" { + return nil + } + ca, err := os.ReadFile(caFile) + if err != nil { + slog.WarnContext(ctx, "Could not read JWT CA cert file; OIDC discovery will use system trust", slog.String("path", caFile), slog.Any("err", err)) + return nil + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(ca) { + slog.WarnContext(ctx, "Could not parse JWT CA cert file; OIDC discovery will use system trust", slog.String("path", caFile)) + return nil + } + return &http.Client{ + Transport: &saTokenTransport{ + base: &http.Transport{ + TLSClientConfig: &tls.Config{RootCAs: pool}, + }, + }, + } +} + +// saTokenTransport injects the pod's ServiceAccount Bearer token on every +// request. Reads the token file fresh on each request so token rotation is +// handled automatically. +type saTokenTransport struct { + base http.RoundTripper +} + +func (t *saTokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + token, err := os.ReadFile(ateapiauth.DefaultServiceAccountTokenFile) + if err == nil && len(token) > 0 { + req = req.Clone(req.Context()) + req.Header.Set("Authorization", "Bearer "+strings.TrimSpace(string(token))) + } + return t.base.RoundTrip(req) +} diff --git a/cmd/atecontroller/main.go b/cmd/atecontroller/main.go index ab9244cab..e96bdb242 100644 --- a/cmd/atecontroller/main.go +++ b/cmd/atecontroller/main.go @@ -14,15 +14,14 @@ package main import ( - "crypto/tls" "os" "github.com/agent-substrate/substrate/cmd/atecontroller/internal/controllers" + "github.com/agent-substrate/substrate/internal/ateapiauth" clientv1alpha1 "github.com/agent-substrate/substrate/pkg/api/v1alpha1" "github.com/agent-substrate/substrate/pkg/proto/ateapipb" "github.com/spf13/pflag" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" clientgoscheme "k8s.io/client-go/kubernetes/scheme" @@ -39,6 +38,11 @@ var ( setupLog = ctrl.Log.WithName("setup") ateAPIConnSpec = pflag.String("ateapi-conn-spec", "dns:///api.ate-system.svc:443", "") + + ateapiAuthMode = pflag.String("ateapi-auth", "mtls", "Client auth to ateapi: mtls|jwt. 'mtls' (default) dials with insecure TLS and relies on pod-projected mTLS credentials for identity. 'jwt' verifies the server cert and sends a Bearer SA token.") + ateapiCAFile = pflag.String("ateapi-ca-file", ateapiauth.DefaultServiceAccountCAFile, "PEM file with CAs trusted to verify the ateapi server cert. Required for jwt.") + ateapiServerName = pflag.String("ateapi-server-name", "", "SNI / hostname expected on the ateapi server cert. Optional.") + ateapiTokenFile = pflag.String("ateapi-token-file", ateapiauth.DefaultServiceAccountTokenFile, "Projected SA token file used as Bearer credential. Required for jwt.") ) func init() { @@ -47,15 +51,27 @@ func init() { } func main() { + pflag.Parse() ctrl.SetLogger(zap.New(zap.UseDevMode(true))) - // TODO: Verify server certificate, pass client certificate. - clientTLSConfig := &tls.Config{ - InsecureSkipVerify: true, // Temporarily bypass standard checks + mode, err := ateapiauth.ParseMode(*ateapiAuthMode) + if err != nil { + setupLog.Error(err, "invalid --ateapi-auth") + os.Exit(1) + } + + dialOpts, err := ateapiauth.DialOptions(ateapiauth.ClientConfig{ + Mode: mode, + CAFile: *ateapiCAFile, + ServerName: *ateapiServerName, + TokenFile: *ateapiTokenFile, + }) + if err != nil { + setupLog.Error(err, "building ateapi dial options") + os.Exit(1) } - clientCreds := credentials.NewTLS(clientTLSConfig) - ateapiConn, err := grpc.NewClient(*ateAPIConnSpec, grpc.WithTransportCredentials(clientCreds)) + ateapiConn, err := grpc.NewClient(*ateAPIConnSpec, dialOpts...) if err != nil { setupLog.Error(err, "Error creating grpc connection to ate api") os.Exit(1) diff --git a/cmd/atenet/internal/router.go b/cmd/atenet/internal/router.go index 5cff97233..aaf5c1877 100644 --- a/cmd/atenet/internal/router.go +++ b/cmd/atenet/internal/router.go @@ -21,6 +21,7 @@ import ( "github.com/spf13/cobra" "github.com/agent-substrate/substrate/cmd/atenet/internal/router" + "github.com/agent-substrate/substrate/internal/ateapiauth" ) func NewRouterCmd() *cobra.Command { @@ -56,6 +57,10 @@ func NewRouterCmd() *cobra.Command { cmd.Flags().DurationVar(&cfg.HealthInterval, "health-interval", 1*time.Second, "Interval for checking health of dependent services") cmd.Flags().IntVar(&cfg.HttpsPort, "port-https", 8443, "TCP port for HTTPS workload traffic entering through the Envoy Router") cmd.Flags().StringVar(&cfg.EnvoyCertPath, "envoy-cert-path", "", "Path to the Envoy certificate file (if empty, a self-signed cert will be generated for testing)") + cmd.Flags().StringVar(&cfg.AteapiAuthMode, "ateapi-auth", "mtls", "Client auth to ateapi: mtls|jwt. 'mtls' (default) dials with insecure TLS and relies on pod-projected mTLS credentials for identity. 'jwt' verifies the server cert and sends a Bearer SA token.") + cmd.Flags().StringVar(&cfg.AteapiCAFile, "ateapi-ca-file", ateapiauth.DefaultServiceAccountCAFile, "PEM file with CAs trusted to verify the ateapi server cert. Required for jwt.") + cmd.Flags().StringVar(&cfg.AteapiServerName, "ateapi-server-name", "", "SNI / hostname expected on the ateapi server cert. Optional.") + cmd.Flags().StringVar(&cfg.AteapiTokenFile, "ateapi-token-file", ateapiauth.DefaultServiceAccountTokenFile, "Projected SA token file used as Bearer credential. Required for jwt.") return cmd } diff --git a/cmd/atenet/internal/router/router.go b/cmd/atenet/internal/router/router.go index 797ce36c0..129e46dfe 100644 --- a/cmd/atenet/internal/router/router.go +++ b/cmd/atenet/internal/router/router.go @@ -18,7 +18,6 @@ import ( "context" "crypto/rand" "crypto/rsa" - "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/pem" @@ -37,7 +36,6 @@ import ( "github.com/spf13/cobra" "golang.org/x/sync/errgroup" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" "k8s.io/apimachinery/pkg/runtime" utilruntime "k8s.io/apimachinery/pkg/util/runtime" "k8s.io/client-go/kubernetes" @@ -46,6 +44,7 @@ import ( "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/config" + "github.com/agent-substrate/substrate/internal/ateapiauth" "github.com/agent-substrate/substrate/internal/serverboot" v1alpha1 "github.com/agent-substrate/substrate/pkg/api/v1alpha1" "github.com/agent-substrate/substrate/pkg/proto/ateapipb" @@ -78,6 +77,11 @@ type RouterConfig struct { EnvoyCertPath string LogLevel string MetricsAddr string + + AteapiAuthMode string + AteapiCAFile string + AteapiServerName string + AteapiTokenFile string } // RouterServer instantiates and coordinates runtime threads executing system modules. @@ -125,11 +129,24 @@ func NewRouterServer(cfg RouterConfig) (*RouterServer, error) { } } - conn, err := grpc.NewClient(cfg.AteapiAddr, grpc.WithTransportCredentials(credentials.NewTLS(&tls.Config{InsecureSkipVerify: true}))) + authMode, err := ateapiauth.ParseMode(cfg.AteapiAuthMode) + if err != nil { + return nil, fmt.Errorf("invalid --ateapi-auth: %w", err) + } + dialOpts, err := ateapiauth.DialOptions(ateapiauth.ClientConfig{ + Mode: authMode, + CAFile: cfg.AteapiCAFile, + ServerName: cfg.AteapiServerName, + TokenFile: cfg.AteapiTokenFile, + }) + if err != nil { + return nil, fmt.Errorf("building ateapi dial options: %w", err) + } + conn, err := grpc.NewClient(cfg.AteapiAddr, dialOpts...) if err != nil { return nil, fmt.Errorf("failed to establish grpc channel to ateapi client: %w", err) } - slog.Info("Connecting to ateapi", slog.String("address", cfg.AteapiAddr)) + slog.Info("Connecting to ateapi", slog.String("address", cfg.AteapiAddr), slog.String("auth", string(authMode))) apiClient := ateapipb.NewControlClient(conn) diff --git a/hack/install-ate.sh b/hack/install-ate.sh index 6c7f46056..4cb05757d 100755 --- a/hack/install-ate.sh +++ b/hack/install-ate.sh @@ -63,6 +63,7 @@ function usage() { echo " --deploy-ate-system Deploy core system (CRDs, atelet, apiserver)" echo " --delete-ate-system Delete core system" echo " --delete-all Delete core system and all registered demos" + echo " --auth-mode=mtls|jwt Select ateapi auth mode for --deploy-ate-system (default: mtls)" echo "" echo "Infrastructure components:" echo "" @@ -116,6 +117,40 @@ run_ko() { esac } +ate_auth_mode() { + case "${ATE_API_AUTH_MODE:-mtls}" in + mtls|jwt) + echo "${ATE_API_AUTH_MODE:-mtls}" + ;; + *) + echo "Error: ATE_API_AUTH_MODE must be mtls or jwt, got '${ATE_API_AUTH_MODE}'" >&2 + exit 1 + ;; + esac +} + +render_ate_system_manifests() { + local auth_mode="" + auth_mode="$(ate_auth_mode)" + + if [[ "${auth_mode}" == "jwt" ]]; then + local overlay="manifests/ate-install/jwt" + if [[ "${ATE_INSTALL_KIND:-false}" == "true" ]]; then + overlay="manifests/ate-install/kind-jwt" + fi + kubectl kustomize "${overlay}" --load-restrictor LoadRestrictionsNone | run_ko resolve -f - + return + fi + + if [[ "${ATE_INSTALL_KIND:-false}" == "true" ]]; then + # Build everything resolved with Kustomize for Kind + kubectl kustomize manifests/ate-install/kind --load-restrictor LoadRestrictionsNone | run_ko resolve -f - + else + # Build everything resolved with base manifests for GKE + run_ko resolve -f manifests/ate-install + fi +} + create_valkey_ca_certs_secret() { log_step "create_valkey_ca_certs_secret" local ca_certs="" @@ -238,13 +273,7 @@ deploy_ate_system() { done local manifests="" - if [[ "${ATE_INSTALL_KIND:-false}" == "true" ]]; then - # Build everything resolved with Kustomize for Kind - manifests=$(kubectl kustomize manifests/ate-install/kind --load-restrictor LoadRestrictionsNone | run_ko resolve -f -) - else - # Build everything resolved with base manifests for GKE - manifests=$(run_ko resolve -f manifests/ate-install) - fi + manifests="$(render_ate_system_manifests)" echo "${manifests}" | run_kubectl apply -f - log_step "Waiting for ATE system components to be ready..." @@ -360,6 +389,20 @@ for arg in "$@"; do esac done +args=("$@") +for ((i = 0; i < ${#args[@]}; i++)); do + case "${args[$i]}" in + --auth-mode=*) ATE_API_AUTH_MODE="${args[$i]#*=}" ;; + --auth-mode) + if (( i + 1 >= ${#args[@]} )); then + echo "Error: --auth-mode requires mtls or jwt" >&2 + exit 1 + fi + ATE_API_AUTH_MODE="${args[$((i + 1))]}" + ;; + esac +done + while [[ "$#" -gt 0 ]]; do # Run ${demo}_cmdline if it exists. If it returns 0, then we successfully # handled this argument and can continue. Otherwise, fallthrough to check @@ -374,6 +417,16 @@ while [[ "$#" -gt 0 ]]; do done case $1 in + --auth-mode=*) ATE_API_AUTH_MODE="${1#*=}" ;; + --auth-mode) + shift + if [[ "$#" -eq 0 ]]; then + echo "Error: --auth-mode requires mtls or jwt" >&2 + exit 1 + fi + ATE_API_AUTH_MODE="$1" + ;; + --deploy-ate-system) deploy_ate_system ;; --delete-ate-system) delete_ate_system ;; --delete-all) delete_all ;; diff --git a/internal/ateapiauth/client.go b/internal/ateapiauth/client.go new file mode 100644 index 000000000..345d23233 --- /dev/null +++ b/internal/ateapiauth/client.go @@ -0,0 +1,115 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ateapiauth + +import ( + "context" + "crypto/tls" + "crypto/x509" + "fmt" + "os" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/credentials" +) + +const ( + DefaultServiceAccountCAFile = "/var/run/secrets/kubernetes.io/serviceaccount/ca.crt" + DefaultServiceAccountTokenFile = "/var/run/secrets/kubernetes.io/serviceaccount/token" +) + +// ClientConfig configures how to dial the ateapi gRPC server. +// +// - Mode=ModeMTLS: insecure TLS dial (InsecureSkipVerify=true). Client +// identity is expected to come from mTLS credentials projected into +// the pod (servicedns.podcert.ate.dev). No app-level credentials. +// - Mode=ModeJWT: validates the server cert against CAFile, sends a Bearer +// token from TokenFile as per-RPC credentials. +type ClientConfig struct { + Mode Mode + + // CAFile is a PEM file containing CA certs that sign the server cert. + // Required for ModeJWT. Ignored for ModeMTLS. + CAFile string + + // ServerName overrides SNI / hostname verification. Optional. + ServerName string + + // TokenFile is a path to a Kubernetes projected ServiceAccount token used + // as a Bearer credential. Required for ModeJWT. + TokenFile string +} + +// DialOptions returns the grpc.DialOption set described by cfg, suitable to +// pass to grpc.NewClient. +func DialOptions(cfg ClientConfig) ([]grpc.DialOption, error) { + switch cfg.Mode { + case "", ModeMTLS: + tlsCfg := &tls.Config{InsecureSkipVerify: true} //nolint:gosec // explicit opt-in + return []grpc.DialOption{ + grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)), + }, nil + + case ModeJWT: + if cfg.CAFile == "" { + return nil, fmt.Errorf("ateapiauth: jwt mode requires CAFile") + } + if cfg.TokenFile == "" { + return nil, fmt.Errorf("ateapiauth: jwt mode requires TokenFile") + } + caPEM, err := os.ReadFile(cfg.CAFile) + if err != nil { + return nil, fmt.Errorf("ateapiauth: reading CA file: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPEM) { + return nil, fmt.Errorf("ateapiauth: no certificates found in CA file %q", cfg.CAFile) + } + tlsCfg := &tls.Config{ + MinVersion: tls.VersionTLS12, + RootCAs: pool, + ServerName: cfg.ServerName, + } + return []grpc.DialOption{ + grpc.WithTransportCredentials(credentials.NewTLS(tlsCfg)), + grpc.WithPerRPCCredentials(&fileTokenCreds{path: cfg.TokenFile}), + }, nil + + default: + return nil, fmt.Errorf("ateapiauth: unknown client mode %q", cfg.Mode) + } +} + +// fileTokenCreds reads a Kubernetes projected SA token from disk for every +// RPC. Kubernetes refreshes the file in place; reading it each time picks up +// rotations. +type fileTokenCreds struct { + path string +} + +func (c *fileTokenCreds) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + b, err := os.ReadFile(c.path) + if err != nil { + return nil, fmt.Errorf("ateapiauth: reading token file %q: %w", c.path, err) + } + tok := strings.TrimSpace(string(b)) + if tok == "" { + return nil, fmt.Errorf("ateapiauth: token file %q is empty", c.path) + } + return map[string]string{"authorization": "Bearer " + tok}, nil +} + +func (c *fileTokenCreds) RequireTransportSecurity() bool { return true } diff --git a/internal/ateapiauth/server.go b/internal/ateapiauth/server.go new file mode 100644 index 000000000..17604b2db --- /dev/null +++ b/internal/ateapiauth/server.go @@ -0,0 +1,182 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package ateapiauth adds optional Kubernetes ServiceAccount JWT +// authentication on top of the ateapi gRPC server, and a matching client +// dial helper. It does not replace the existing TLS / mTLS path — the +// server's transport credentials still apply unchanged. Set Mode=ModeJWT +// on the server to require an `authorization: Bearer ` header +// on every RPC; Mode=ModeMTLS (the default) leaves identity to the +// transport-layer mTLS credentials. +package ateapiauth + +import ( + "context" + "fmt" + "net/http" + "strings" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +// Mode selects whether the JWT interceptor enforces a Bearer token. +type Mode string + +const ( + ModeMTLS Mode = "mtls" + ModeJWT Mode = "jwt" +) + +// ParseMode parses a flag value into a Mode, defaulting to ModeMTLS on empty. +// ModeMTLS means identity is established by the transport-layer mTLS +// credentials; the interceptor performs no app-level checks. ModeJWT +// additionally requires a Kubernetes SA Bearer token on every RPC. +func ParseMode(s string) (Mode, error) { + switch Mode(s) { + case "", ModeMTLS: + return ModeMTLS, nil + case ModeJWT: + return ModeJWT, nil + default: + return "", fmt.Errorf("unknown auth mode %q (want mtls|jwt)", s) + } +} + +func ValidateServerConfig(cfg ServerConfig) error { + switch cfg.Mode { + case "", ModeMTLS: + return nil + case ModeJWT: + if cfg.VerifyBearerToken == nil { + return fmt.Errorf("jwt mode requires bearer token verifier") + } + return nil + default: + return fmt.Errorf("unknown auth mode %q", cfg.Mode) + } +} + +// ServerConfig configures the server-side auth interceptor. +type ServerConfig struct { + Mode Mode + + // VerifyBearerToken verifies a Bearer token presented by a client. Required + // for ModeJWT and ignored for ModeMTLS. + VerifyBearerToken func(context.Context, string) error + + // HTTPClient is kept for source compatibility with older callers. Token + // verification belongs in VerifyBearerToken. + HTTPClient *http.Client +} + +// UnaryServerInterceptor returns a gRPC unary interceptor enforcing cfg. +func UnaryServerInterceptor(cfg ServerConfig) grpc.UnaryServerInterceptor { + auth := serverAuthenticatorFor(cfg) + return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + newCtx, err := auth.authenticate(ctx) + if err != nil { + return nil, err + } + return handler(newCtx, req) + } +} + +// StreamServerInterceptor returns a gRPC stream interceptor enforcing cfg. +func StreamServerInterceptor(cfg ServerConfig) grpc.StreamServerInterceptor { + auth := serverAuthenticatorFor(cfg) + return func(srv any, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + newCtx, err := auth.authenticate(ss.Context()) + if err != nil { + return err + } + return handler(srv, &wrappedStream{ServerStream: ss, ctx: newCtx}) + } +} + +type wrappedStream struct { + grpc.ServerStream + ctx context.Context +} + +func (w *wrappedStream) Context() context.Context { return w.ctx } + +type serverAuthenticator interface { + authenticate(context.Context) (context.Context, error) +} + +func serverAuthenticatorFor(cfg ServerConfig) serverAuthenticator { + switch cfg.Mode { + case "", ModeMTLS: + return mtlsServerAuthenticator{} + case ModeJWT: + return jwtServerAuthenticator{ + verifyBearerToken: cfg.VerifyBearerToken, + } + } + + return invalidServerAuthenticator{mode: cfg.Mode} +} + +type mtlsServerAuthenticator struct{} + +func (mtlsServerAuthenticator) authenticate(ctx context.Context) (context.Context, error) { + return ctx, nil +} + +type jwtServerAuthenticator struct { + verifyBearerToken func(context.Context, string) error +} + +func (a jwtServerAuthenticator) authenticate(ctx context.Context) (context.Context, error) { + bearer, ok := bearerToken(ctx) + if !ok { + return nil, status.Error(codes.Unauthenticated, "missing bearer token") + } + if err := a.verifyBearerToken(ctx, bearer); err != nil { + return nil, status.Errorf(codes.Unauthenticated, "invalid bearer token: %v", err) + } + return ctx, nil +} + +type invalidServerAuthenticator struct { + mode Mode +} + +func (a invalidServerAuthenticator) authenticate(context.Context) (context.Context, error) { + return nil, status.Errorf(codes.Internal, "invalid auth mode %q", a.mode) +} + +func bearerToken(ctx context.Context) (string, bool) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", false + } + vals := md.Get("authorization") + if len(vals) == 0 { + return "", false + } + const prefix = "Bearer " + v := vals[0] + if !strings.HasPrefix(v, prefix) { + return "", false + } + tok := strings.TrimSpace(strings.TrimPrefix(v, prefix)) + if tok == "" { + return "", false + } + return tok, true +} diff --git a/internal/ateapiauth/server_test.go b/internal/ateapiauth/server_test.go new file mode 100644 index 000000000..5a0c0631e --- /dev/null +++ b/internal/ateapiauth/server_test.go @@ -0,0 +1,131 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ateapiauth + +import ( + "context" + "fmt" + "testing" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" +) + +func TestParseMode(t *testing.T) { + cases := []struct { + in string + want Mode + wantErr bool + }{ + {"", ModeMTLS, false}, + {"mtls", ModeMTLS, false}, + {"jwt", ModeJWT, false}, + {"none", "", true}, + {"bogus", "", true}, + } + for _, tc := range cases { + got, err := ParseMode(tc.in) + if (err != nil) != tc.wantErr { + t.Errorf("ParseMode(%q) err=%v wantErr=%v", tc.in, err, tc.wantErr) + } + if !tc.wantErr && got != tc.want { + t.Errorf("ParseMode(%q)=%v want %v", tc.in, got, tc.want) + } + } +} + +func TestValidateServerConfig(t *testing.T) { + tests := []struct { + name string + cfg ServerConfig + wantErr bool + }{ + {name: "mtls zero config", cfg: ServerConfig{Mode: ModeMTLS}}, + {name: "empty mode zero config", cfg: ServerConfig{}}, + {name: "jwt valid", cfg: ServerConfig{Mode: ModeJWT, VerifyBearerToken: func(context.Context, string) error { return nil }}}, + {name: "jwt missing verifier", cfg: ServerConfig{Mode: ModeJWT}, wantErr: true}, + {name: "unknown mode", cfg: ServerConfig{Mode: Mode("bogus")}, wantErr: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := ValidateServerConfig(tt.cfg) + if (err != nil) != tt.wantErr { + t.Fatalf("ValidateServerConfig(%+v) err=%v, wantErr=%v", tt.cfg, err, tt.wantErr) + } + }) + } +} + +func TestMTLSServerAuthenticatorAllowsAnonymous(t *testing.T) { + _, err := (mtlsServerAuthenticator{}).authenticate(context.Background()) + if err != nil { + t.Fatalf("ModeMTLS should not error: %v", err) + } +} + +func TestJWTServerAuthenticatorRequiresBearer(t *testing.T) { + auth := jwtServerAuthenticator{ + verifyBearerToken: func(context.Context, string) error { + return fmt.Errorf("bad token") + }, + } + + // Missing header -> Unauthenticated. + _, err := auth.authenticate(context.Background()) + if code := status.Code(err); code != codes.Unauthenticated { + t.Fatalf("missing bearer: want Unauthenticated, got %v (err=%v)", code, err) + } + + // Garbage bearer -> Unauthenticated (k8sjwt.Verify will fail). + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorization", "Bearer not-a-jwt")) + _, err = auth.authenticate(ctx) + if code := status.Code(err); code != codes.Unauthenticated { + t.Fatalf("bad bearer: want Unauthenticated, got %v (err=%v)", code, err) + } +} + +func TestBearerToken(t *testing.T) { + cases := []struct { + name string + hdr string + want string + found bool + }{ + {"missing", "", "", false}, + {"no prefix", "abc", "", false}, + {"prefix", "Bearer abc", "abc", true}, + {"prefix with spaces", "Bearer abc ", "abc", true}, + {"empty after prefix", "Bearer ", "", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + ctx := context.Background() + if tc.hdr != "" { + ctx = metadata.NewIncomingContext(ctx, metadata.Pairs("authorization", tc.hdr)) + } + got, ok := bearerToken(ctx) + if ok != tc.found || got != tc.want { + t.Errorf("bearerToken=(%q,%v) want (%q,%v)", got, ok, tc.want, tc.found) + } + }) + } +} + +// Build-time check. +var _ grpc.UnaryServerInterceptor = UnaryServerInterceptor(ServerConfig{}) +var _ grpc.StreamServerInterceptor = StreamServerInterceptor(ServerConfig{}) diff --git a/internal/ateclient/builder.go b/internal/ateclient/builder.go index 81ca867ec..4e6139364 100644 --- a/internal/ateclient/builder.go +++ b/internal/ateclient/builder.go @@ -21,6 +21,7 @@ import ( "io" "net/http" "os" + "strings" "sync" "github.com/agent-substrate/substrate/pkg/proto/ateapipb" @@ -33,6 +34,7 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/credentials" + authv1 "k8s.io/api/authentication/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/labels" "k8s.io/client-go/kubernetes" @@ -209,6 +211,12 @@ func dialPortForward(ctx context.Context, kubeconfigPath, k8sContext string, tra var opts []grpc.DialOption opts = append(opts, grpc.WithTransportCredentials(transportCreds)) opts = append(opts, grpc.WithStatsHandler(otelgrpc.NewClientHandler())) + jwtOpts, err := jwtDialOptions(ctx, clientset) + if err != nil { + close(stopCh) + return nil, err + } + opts = append(opts, jwtOpts...) if traceEnabled { opts = append(opts, grpc.WithUnaryInterceptor(newTraceInterceptor())) @@ -230,6 +238,74 @@ func dialPortForward(ctx context.Context, kubeconfigPath, k8sContext string, tra }, nil } +func jwtDialOptions(ctx context.Context, clientset *kubernetes.Clientset) ([]grpc.DialOption, error) { + jwtMode, err := isJWTMode(ctx, clientset) + if err != nil { + return nil, err + } + if !jwtMode { + return nil, nil + } + + expirationSeconds := int64(3600) + tokenRequest := &authv1.TokenRequest{ + Spec: authv1.TokenRequestSpec{ + Audiences: []string{"api.ate-system.svc"}, + ExpirationSeconds: &expirationSeconds, + }, + } + token, err := clientset.CoreV1().ServiceAccounts("ate-system").CreateToken(ctx, "ate-controller", tokenRequest, metav1.CreateOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to request ateapi bearer token: %w", err) + } + if token.Status.Token == "" { + return nil, fmt.Errorf("failed to request ateapi bearer token: token response was empty") + } + return []grpc.DialOption{grpc.WithPerRPCCredentials(bearerTokenCreds(token.Status.Token))}, nil +} + +func isJWTMode(ctx context.Context, clientset *kubernetes.Clientset) (bool, error) { + // TODO: Replace deployment introspection with an explicit client-readable + // config file once ateapi auth mode is part of install/runtime config. + deployment, err := clientset.AppsV1().Deployments("ate-system").Get(ctx, "ate-api-server-deployment", metav1.GetOptions{}) + if err != nil { + return false, fmt.Errorf("failed to get ate-api-server deployment: %w", err) + } + for _, container := range deployment.Spec.Template.Spec.Containers { + if container.Name != "ate-api-server" { + continue + } + return isJWTAuthModeArg(container.Args), nil + } + return false, fmt.Errorf("failed to find ate-api-server container in deployment") +} + +func isJWTAuthModeArg(args []string) bool { + for i, arg := range args { + if arg == "--auth-mode=jwt" { + return true + } + if strings.HasPrefix(arg, "--auth-mode=") { + return strings.TrimPrefix(arg, "--auth-mode=") == "jwt" + } + if arg == "--auth-mode" && i+1 < len(args) { + return args[i+1] == "jwt" + } + } + return false +} + +type bearerTokenCreds string + +func (c bearerTokenCreds) GetRequestMetadata(_ context.Context, _ ...string) (map[string]string, error) { + if c == "" { + return nil, fmt.Errorf("bearer token is empty") + } + return map[string]string{"authorization": "Bearer " + string(c)}, nil +} + +func (c bearerTokenCreds) RequireTransportSecurity() bool { return true } + func initTracing(ctx context.Context, enabled bool) (*sdktrace.TracerProvider, error) { res, err := resource.New(ctx, resource.WithAttributes( diff --git a/internal/ateclient/builder_test.go b/internal/ateclient/builder_test.go new file mode 100644 index 000000000..62c44feb9 --- /dev/null +++ b/internal/ateclient/builder_test.go @@ -0,0 +1,40 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package ateclient + +import "testing" + +func TestIsJWTAuthModeArg(t *testing.T) { + tests := []struct { + name string + args []string + want bool + }{ + {name: "equals jwt", args: []string{"--auth-mode=jwt"}, want: true}, + {name: "split jwt", args: []string{"--auth-mode", "jwt"}, want: true}, + {name: "equals mtls", args: []string{"--auth-mode=mtls"}, want: false}, + {name: "split mtls", args: []string{"--auth-mode", "mtls"}, want: false}, + {name: "missing value", args: []string{"--auth-mode"}, want: false}, + {name: "unrelated", args: []string{"--foo=bar"}, want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isJWTAuthModeArg(tt.args); got != tt.want { + t.Fatalf("isJWTAuthModeArg(%v) = %v, want %v", tt.args, got, tt.want) + } + }) + } +} diff --git a/internal/localca/localca.go b/internal/localca/localca.go index eb8370905..3078435ed 100644 --- a/internal/localca/localca.go +++ b/internal/localca/localca.go @@ -22,6 +22,7 @@ import ( "crypto/rand" "crypto/x509" "encoding/json" + "encoding/pem" "fmt" "time" ) @@ -43,7 +44,9 @@ type serializedPool struct { type serializedCA struct { ID string SigningKeyPKCS8 []byte + SigningKeyPEM string RootCertificateDER []byte + RootCertificatePEM string IntermediateCertificatesDER [][]byte } @@ -92,12 +95,12 @@ func Unmarshal(wireBytes []byte) (*Pool, error) { ID: wireCA.ID, } - ca.SigningKey, err = x509.ParsePKCS8PrivateKey(wireCA.SigningKeyPKCS8) + ca.SigningKey, err = parsePrivateKey(wireCA.SigningKeyPKCS8, wireCA.SigningKeyPEM) if err != nil { return nil, fmt.Errorf("while parsing signing key: %w", err) } - ca.RootCertificate, err = x509.ParseCertificate(wireCA.RootCertificateDER) + ca.RootCertificate, err = parseCertificate(wireCA.RootCertificateDER, wireCA.RootCertificatePEM) if err != nil { return nil, fmt.Errorf("while parsing root certificate: %w", err) } @@ -116,6 +119,43 @@ func Unmarshal(wireBytes []byte) (*Pool, error) { return pool, nil } +func parsePrivateKey(pkcs8 []byte, pemData string) (crypto.PrivateKey, error) { + if len(pkcs8) != 0 { + return x509.ParsePKCS8PrivateKey(pkcs8) + } + + block, _ := pem.Decode([]byte(pemData)) + if block == nil { + return nil, fmt.Errorf("missing PEM block") + } + + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return key, nil + } + return nil, fmt.Errorf("unsupported private key PEM type %q", block.Type) +} + +func parseCertificate(der []byte, pemData string) (*x509.Certificate, error) { + if len(der) != 0 { + return x509.ParseCertificate(der) + } + + block, _ := pem.Decode([]byte(pemData)) + if block == nil { + return nil, fmt.Errorf("missing PEM block") + } + if block.Type != "CERTIFICATE" { + return nil, fmt.Errorf("unsupported certificate PEM type %q", block.Type) + } + return x509.ParseCertificate(block.Bytes) +} + func GenerateED25519CA(id string) (*CA, error) { rootPubKey, rootPrivKey, err := ed25519.GenerateKey(rand.Reader) if err != nil { diff --git a/internal/localca/localca_test.go b/internal/localca/localca_test.go index 9b5d6247f..c470a4d48 100644 --- a/internal/localca/localca_test.go +++ b/internal/localca/localca_test.go @@ -18,8 +18,12 @@ import ( "bytes" "crypto/ed25519" "crypto/rand" + "crypto/rsa" "crypto/x509" + "crypto/x509/pkix" "encoding/json" + "encoding/pem" + "math/big" "strings" "testing" "time" @@ -198,6 +202,53 @@ func TestMarshalUnmarshalWithIntermediates(t *testing.T) { } } +func TestUnmarshalPEMPool(t *testing.T) { + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("GenerateKey(): %v", err) + } + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "session-id-ca"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(24 * time.Hour), + IsCA: true, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + } + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatalf("CreateCertificate(): %v", err) + } + keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})) + certPEM := string(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + + data, err := json.Marshal(&serializedPool{ + CAs: []*serializedCA{{ + ID: "1", + SigningKeyPEM: keyPEM, + RootCertificatePEM: certPEM, + }}, + }) + if err != nil { + t.Fatalf("Marshal(): %v", err) + } + + pool, err := Unmarshal(data) + if err != nil { + t.Fatalf("Unmarshal(): %v", err) + } + if len(pool.CAs) != 1 { + t.Fatalf("CAs length = %d, want 1", len(pool.CAs)) + } + if _, ok := pool.CAs[0].SigningKey.(*rsa.PrivateKey); !ok { + t.Fatalf("SigningKey type = %T, want *rsa.PrivateKey", pool.CAs[0].SigningKey) + } + if pool.CAs[0].RootCertificate.Subject.CommonName != "session-id-ca" { + t.Fatalf("RootCertificate CN = %q, want session-id-ca", pool.CAs[0].RootCertificate.Subject.CommonName) + } +} + func TestUnmarshalErrors(t *testing.T) { ca, err := GenerateED25519CA("err-test") if err != nil { diff --git a/internal/localjwtauthority/localjwtauthority.go b/internal/localjwtauthority/localjwtauthority.go index 62f647b55..97021fb2d 100644 --- a/internal/localjwtauthority/localjwtauthority.go +++ b/internal/localjwtauthority/localjwtauthority.go @@ -22,6 +22,7 @@ import ( "crypto/rand" "crypto/x509" "encoding/json" + "encoding/pem" "fmt" ) @@ -43,6 +44,7 @@ type serializedAuthority struct { ID string Algorithm string SigningKeyPKCS8 []byte + SigningKeyPEM string } // Marshal serializes a Pool to JSON. @@ -86,7 +88,7 @@ func Unmarshal(wireBytes []byte) (*Pool, error) { Algorithm: wireAuthority.Algorithm, } - signingKey, err := x509.ParsePKCS8PrivateKey(wireAuthority.SigningKeyPKCS8) + signingKey, err := parsePrivateKey(wireAuthority.SigningKeyPKCS8, wireAuthority.SigningKeyPEM) if err != nil { return nil, fmt.Errorf("while parsing signing key: %w", err) } @@ -98,6 +100,28 @@ func Unmarshal(wireBytes []byte) (*Pool, error) { return pool, nil } +func parsePrivateKey(pkcs8 []byte, pemData string) (crypto.PrivateKey, error) { + if len(pkcs8) != 0 { + return x509.ParsePKCS8PrivateKey(pkcs8) + } + + block, _ := pem.Decode([]byte(pemData)) + if block == nil { + return nil, fmt.Errorf("missing PEM block") + } + + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParseECPrivateKey(block.Bytes); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + return key, nil + } + return nil, fmt.Errorf("unsupported private key PEM type %q", block.Type) +} + // GenerateECDSAP256Authority generates an ECDSA P256 JWT signing key. func GenerateECDSAP256Authority(id string) (*Authority, error) { privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) diff --git a/internal/localjwtauthority/localjwtauthority_test.go b/internal/localjwtauthority/localjwtauthority_test.go new file mode 100644 index 000000000..7f25a72ea --- /dev/null +++ b/internal/localjwtauthority/localjwtauthority_test.go @@ -0,0 +1,62 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package localjwtauthority + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "encoding/json" + "encoding/pem" + "testing" +) + +func TestUnmarshalPEMSigningKey(t *testing.T) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatalf("GenerateKey(): %v", err) + } + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + t.Fatalf("MarshalECPrivateKey(): %v", err) + } + keyPEM := string(pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + + data, err := json.Marshal(&serializedPool{ + Authorities: []*serializedAuthority{{ + ID: "1", + Algorithm: "ES256", + SigningKeyPEM: keyPEM, + }}, + }) + if err != nil { + t.Fatalf("Marshal(): %v", err) + } + + pool, err := Unmarshal(data) + if err != nil { + t.Fatalf("Unmarshal(): %v", err) + } + if len(pool.Authorities) != 1 { + t.Fatalf("Authorities length = %d, want 1", len(pool.Authorities)) + } + if pool.Authorities[0].Algorithm != "ES256" { + t.Fatalf("Algorithm = %q, want ES256", pool.Authorities[0].Algorithm) + } + if _, ok := pool.Authorities[0].SigningKey.(*ecdsa.PrivateKey); !ok { + t.Fatalf("SigningKey type = %T, want *ecdsa.PrivateKey", pool.Authorities[0].SigningKey) + } +} diff --git a/manifests/ate-install/jwt/kustomization.yaml b/manifests/ate-install/jwt/kustomization.yaml new file mode 100644 index 000000000..6fa433e4c --- /dev/null +++ b/manifests/ate-install/jwt/kustomization.yaml @@ -0,0 +1,54 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +resources: + - ../ate-api-server.yaml + - ../ate-controller.yaml + - ../atelet.yaml + - ../atenet-dns.yaml + - ../atenet-router.yaml + - ../valkey.yaml + - ../pod-certificate-controller.yaml + +patches: + - path: patches.yaml + - target: + group: apps + version: v1 + kind: Deployment + name: ate-api-server-deployment + namespace: ate-system + patch: |- + - op: add + path: /spec/template/spec/containers/0/args/- + value: --auth-mode=jwt + - target: + group: apps + version: v1 + kind: Deployment + name: atenet-router + namespace: ate-system + patch: |- + - op: add + path: /spec/template/spec/containers/0/args/- + value: --ateapi-auth=jwt + - op: add + path: /spec/template/spec/containers/0/args/- + value: --ateapi-ca-file=/run/servicedns-ca/trust-bundle.pem + - op: add + path: /spec/template/spec/containers/0/args/- + value: --ateapi-token-file=/run/ateapi-token/token diff --git a/manifests/ate-install/jwt/patches.yaml b/manifests/ate-install/jwt/patches.yaml new file mode 100644 index 000000000..0be856899 --- /dev/null +++ b/manifests/ate-install/jwt/patches.yaml @@ -0,0 +1,87 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +apiVersion: apps/v1 +kind: Deployment +metadata: + name: ate-controller + namespace: ate-system +spec: + template: + spec: + containers: + - name: ate-controller + args: + - --ateapi-auth=jwt + - --ateapi-ca-file=/run/servicedns-ca/trust-bundle.pem + - --ateapi-token-file=/run/ateapi-token/token + volumeMounts: + - name: ateapi-token + mountPath: /run/ateapi-token + readOnly: true + - name: servicedns-ca + mountPath: /run/servicedns-ca + readOnly: true + volumes: + - name: ateapi-token + projected: + sources: + - serviceAccountToken: + audience: api.ate-system.svc + expirationSeconds: 3600 + path: token + - name: servicedns-ca + projected: + sources: + - clusterTrustBundle: + signerName: servicedns.podcert.ate.dev/identity + labelSelector: + matchLabels: + podcert.ate.dev/canarying: live + path: trust-bundle.pem +--- +apiVersion: apps/v1 +kind: Deployment +metadata: + name: atenet-router + namespace: ate-system +spec: + template: + spec: + containers: + - name: atenet-router + volumeMounts: + - name: ateapi-token + mountPath: /run/ateapi-token + readOnly: true + - name: servicedns-ca + mountPath: /run/servicedns-ca + readOnly: true + volumes: + - name: ateapi-token + projected: + sources: + - serviceAccountToken: + audience: api.ate-system.svc + expirationSeconds: 3600 + path: token + - name: servicedns-ca + projected: + sources: + - clusterTrustBundle: + signerName: servicedns.podcert.ate.dev/identity + labelSelector: + matchLabels: + podcert.ate.dev/canarying: live + path: trust-bundle.pem diff --git a/manifests/ate-install/kind-jwt/kustomization.yaml b/manifests/ate-install/kind-jwt/kustomization.yaml new file mode 100644 index 000000000..dc9c98d7c --- /dev/null +++ b/manifests/ate-install/kind-jwt/kustomization.yaml @@ -0,0 +1,48 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +apiVersion: kustomize.config.k8s.io/v1beta1 +kind: Kustomization + +resources: + - ../kind + +patches: + - path: ../jwt/patches.yaml + - target: + group: apps + version: v1 + kind: Deployment + name: ate-api-server-deployment + namespace: ate-system + patch: |- + - op: add + path: /spec/template/spec/containers/0/args/- + value: --auth-mode=jwt + - target: + group: apps + version: v1 + kind: Deployment + name: atenet-router + namespace: ate-system + patch: |- + - op: add + path: /spec/template/spec/containers/0/args/- + value: --ateapi-auth=jwt + - op: add + path: /spec/template/spec/containers/0/args/- + value: --ateapi-ca-file=/run/servicedns-ca/trust-bundle.pem + - op: add + path: /spec/template/spec/containers/0/args/- + value: --ateapi-token-file=/run/ateapi-token/token