Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions cmd/ateapi/internal/credbundle/credbundle.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
package credbundle

import (
"crypto"
"crypto/tls"
"crypto/x509"
"encoding/pem"
Expand Down Expand Up @@ -47,6 +48,7 @@ func Parse(bundlePath string) (*tls.Certificate, error) {
}

var leafKeyBytes []byte
var leafKeyBlockType string
var chainBytes [][]byte

for {
Expand All @@ -59,8 +61,9 @@ func Parse(bundlePath string) (*tls.Certificate, error) {
switch block.Type {
case "CERTIFICATE":
chainBytes = append(chainBytes, block.Bytes)
case "PRIVATE KEY":
case "PRIVATE KEY", "RSA PRIVATE KEY", "EC PRIVATE KEY":
leafKeyBytes = block.Bytes
leafKeyBlockType = block.Type
default:
return nil, fmt.Errorf("unknown PEM block type %q", block.Type)
}
Expand All @@ -74,7 +77,7 @@ func Parse(bundlePath string) (*tls.Certificate, error) {
return nil, fmt.Errorf("no CERTIFICATE blocks found")
}

leafKey, err := x509.ParsePKCS8PrivateKey(leafKeyBytes)
leafKey, err := parsePrivateKey(leafKeyBlockType, leafKeyBytes)
if err != nil {
return nil, fmt.Errorf("while parsing private key: %w", err)
}
Expand All @@ -90,3 +93,16 @@ func Parse(bundlePath string) (*tls.Certificate, error) {
PrivateKey: leafKey,
}, nil
}

func parsePrivateKey(blockType string, keyBytes []byte) (crypto.PrivateKey, error) {
switch blockType {
case "PRIVATE KEY":
return x509.ParsePKCS8PrivateKey(keyBytes)
case "RSA PRIVATE KEY":
return x509.ParsePKCS1PrivateKey(keyBytes)
case "EC PRIVATE KEY":
return x509.ParseECPrivateKey(keyBytes)
default:
return nil, fmt.Errorf("unsupported private key block type %q", blockType)
}
}
127 changes: 127 additions & 0 deletions cmd/ateapi/internal/credbundle/credbundle_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
// Copyright 2026 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package credbundle

import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"math/big"
"os"
"testing"
"time"
)

func TestParsePrivateKeyBlockTypes(t *testing.T) {
for _, tt := range []struct {
name string
blockType string
keyDER func(t *testing.T) []byte
}{
{
name: "pkcs8",
blockType: "PRIVATE KEY",
keyDER: func(t *testing.T) []byte {
key := generateRSAKey(t)
der, err := x509.MarshalPKCS8PrivateKey(key)
if err != nil {
t.Fatalf("marshal PKCS8 key: %v", err)
}
return der
},
},
{
name: "rsa",
blockType: "RSA PRIVATE KEY",
keyDER: func(t *testing.T) []byte {
return x509.MarshalPKCS1PrivateKey(generateRSAKey(t))
},
},
{
name: "ec",
blockType: "EC PRIVATE KEY",
keyDER: func(t *testing.T) []byte {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
t.Fatalf("generate EC key: %v", err)
}
der, err := x509.MarshalECPrivateKey(key)
if err != nil {
t.Fatalf("marshal EC key: %v", err)
}
return der
},
},
} {
t.Run(tt.name, func(t *testing.T) {
certDER := generateCertificate(t)
bundle := append(pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}), pem.EncodeToMemory(&pem.Block{Type: tt.blockType, Bytes: tt.keyDER(t)})...)

bundlePath := writeBundle(t, bundle)
cert, err := Parse(bundlePath)
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
if len(cert.Certificate) != 1 {
t.Fatalf("Parse() certificate chain length = %d, want 1", len(cert.Certificate))
}
if cert.PrivateKey == nil {
t.Fatalf("Parse() private key is nil")
}
if cert.Leaf == nil {
t.Fatalf("Parse() leaf certificate is nil")
}
})
}
}

func generateRSAKey(t *testing.T) *rsa.PrivateKey {
t.Helper()
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("generate RSA key: %v", err)
}
return key
}

func generateCertificate(t *testing.T) []byte {
t.Helper()
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{CommonName: "api.ate-system.svc"},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
DNSNames: []string{"api.ate-system.svc"},
}
key := generateRSAKey(t)
der, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
t.Fatalf("create certificate: %v", err)
}
return der
}

func writeBundle(t *testing.T, bundle []byte) string {
t.Helper()
path := t.TempDir() + "/bundle.pem"
if err := os.WriteFile(path, bundle, 0o600); err != nil {
t.Fatalf("write bundle: %v", err)
}
return path
}
20 changes: 12 additions & 8 deletions cmd/ateapi/internal/k8sjwt/k8sjwt.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ var permittedSkew = 5 * time.Minute
// the object binding claims. If needed for your use case, you will need check the object bindings
// by connecting to the cluster and seeing if the object(s) the bindings name still exist within the
// cluster.
func Verify(ctx context.Context, jwt string, expectedIssuer, expectedAudience string, now time.Time) (*KubernetesClaims, error) {
//
// httpClient is used for OIDC discovery and JWKS fetches; nil uses http.DefaultClient.
func Verify(ctx context.Context, httpClient *http.Client, jwt string, expectedIssuer, expectedAudience string, now time.Time) (*KubernetesClaims, error) {
segments := strings.Split(jwt, ".")
if len(segments) != 3 {
return nil, fmt.Errorf("malformed JWT")
Expand Down Expand Up @@ -169,7 +171,7 @@ func Verify(ctx context.Context, jwt string, expectedIssuer, expectedAudience st
}

// TODO: Cache keys, and only fetch new keys if the JWT's key ID is not in the cache.
keys, err := discoverKeysForIssuer(ctx, rawClaims.Issuer)
keys, err := discoverKeysForIssuer(ctx, httpClient, rawClaims.Issuer)
if err != nil {
return nil, fmt.Errorf("while discovering keys from issuer: %w", err)
}
Expand Down Expand Up @@ -358,22 +360,22 @@ type jwkT struct {
RSAE string `json:"e"`
}

func discoverKeysForIssuer(ctx context.Context, issuer string) ([]*KeyAndID, error) {
func discoverKeysForIssuer(ctx context.Context, httpClient *http.Client, issuer string) ([]*KeyAndID, error) {
var discoveryDocURL string
if strings.HasSuffix(issuer, "/") {
discoveryDocURL = issuer + ".well-known/openid-configuration"
} else {
discoveryDocURL = issuer + "/.well-known/openid-configuration"
}

oidcConfig, err := fetchJSON[oidcConfigT](discoveryDocURL)
oidcConfig, err := fetchJSON[oidcConfigT](httpClient, discoveryDocURL)
if err != nil {
return nil, fmt.Errorf("while fetching OIDC Discovery document: %w", err)
}

slog.InfoContext(ctx, "Fetched discovery doc", slog.Any("doc", oidcConfig))

jwkSet, err := fetchJSON[jwkSetT](oidcConfig.JWKSURI)
jwkSet, err := fetchJSON[jwkSetT](httpClient, oidcConfig.JWKSURI)
if err != nil {
return nil, fmt.Errorf("while fetching JWKS: %w", err)
}
Expand Down Expand Up @@ -424,10 +426,12 @@ func discoverKeysForIssuer(ctx context.Context, issuer string) ([]*KeyAndID, err
return ret, nil
}

func fetchJSON[T any](url string) (T, error) {
func fetchJSON[T any](httpClient *http.Client, url string) (T, error) {
var parsedBody T

resp, err := http.Get(url)
if httpClient == nil {
httpClient = http.DefaultClient
}
resp, err := httpClient.Get(url)
if err != nil {
return parsedBody, fmt.Errorf("while making HTTP request: %w", err)
}
Expand Down
7 changes: 5 additions & 2 deletions cmd/ateapi/internal/sessionidentity/sessionidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"crypto/x509/pkix"
"fmt"
"log/slog"
"net/http"
"net/url"
"os"
"path"
Expand Down Expand Up @@ -51,17 +52,19 @@ type Server struct {
sessionIDCAPoolFile string

workerCACerts string
httpClient *http.Client
}

var _ ateapipb.SessionIdentityServer = (*Server)(nil)

func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPoolFile, workerCACerts string) *Server {
func New(clientJWTIssuer, clientJWTAudience, sessionIDJWTPoolFile, sessionIDCAPoolFile, workerCACerts string, httpClient *http.Client) *Server {
return &Server{
clientJWTIssuer: clientJWTIssuer,
clientJWTAudience: clientJWTAudience,
sessionIDJWTPoolFile: sessionIDJWTPoolFile,
sessionIDCAPoolFile: sessionIDCAPoolFile,
workerCACerts: workerCACerts,
httpClient: httpClient,
}
}

Expand All @@ -78,7 +81,7 @@ func (s *Server) MintJWT(ctx context.Context, req *ateapipb.MintJWTRequest) (*at

clientJWT := strings.TrimPrefix(authorization[0], "Bearer ")

clientClaims, err := k8sjwt.Verify(ctx, clientJWT, s.clientJWTIssuer, s.clientJWTAudience, time.Now())
clientClaims, err := k8sjwt.Verify(ctx, s.httpClient, clientJWT, s.clientJWTIssuer, s.clientJWTAudience, time.Now())
if err != nil {
slog.ErrorContext(ctx, "Error while verifying client JWT", slog.Any("err", err))
return nil, status.Errorf(codes.Unauthenticated, "Unauthenticated")
Expand Down
Loading
Loading