Skip to content

Commit b82fcad

Browse files
committed
fix(go/genkit): improve Google AI API key validation to prevent silent failures
1 parent 35d8898 commit b82fcad

File tree

2 files changed

+277
-1
lines changed

2 files changed

+277
-1
lines changed

go/plugins/googlegenai/googlegenai.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@ import (
88
"errors"
99
"fmt"
1010
"os"
11+
"strings"
1112
"sync"
13+
"time"
1214

1315
"github.com/firebase/genkit/go/ai"
1416
"github.com/firebase/genkit/go/genkit"
@@ -81,6 +83,11 @@ func (ga *GoogleAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
8183
}
8284
}
8385

86+
// Validate API key format - do basic validation before even creating the client
87+
if len(apiKey) < 30 || !strings.HasPrefix(apiKey, "AI") {
88+
return fmt.Errorf("invalid Google AI API key format: keys should start with 'AI' and be at least 30 characters long")
89+
}
90+
8491
gc := genai.ClientConfig{
8592
Backend: genai.BackendGeminiAPI,
8693
APIKey: apiKey,
@@ -91,8 +98,14 @@ func (ga *GoogleAI) Init(ctx context.Context, g *genkit.Genkit) (err error) {
9198

9299
client, err := genai.NewClient(ctx, &gc)
93100
if err != nil {
94-
return err
101+
return fmt.Errorf("failed to create Google AI client: %w", err)
95102
}
103+
104+
// Validate API key by making a simple call
105+
if err := validateAPIKey(ctx, client); err != nil {
106+
return fmt.Errorf("API key validation failed: %w", err)
107+
}
108+
96109
ga.gclient = client
97110
ga.initted = true
98111

@@ -310,3 +323,43 @@ func GoogleAIEmbedder(g *genkit.Genkit, name string) ai.Embedder {
310323
func VertexAIEmbedder(g *genkit.Genkit, name string) ai.Embedder {
311324
return genkit.LookupEmbedder(g, vertexAIProvider, name)
312325
}
326+
327+
// validateAPIKey performs an API call to verify the API key is valid
328+
func validateAPIKey(ctx context.Context, client *genai.Client) error {
329+
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
330+
defer cancel()
331+
332+
content := &genai.Content{
333+
Parts: []*genai.Part{
334+
{Text: "Say hello"},
335+
},
336+
}
337+
338+
_, err := client.Models.GenerateContent(ctx, "gemini-2.0-flash", []*genai.Content{content}, nil)
339+
if err != nil {
340+
return extractAndFormatAPIError(err)
341+
}
342+
343+
return nil
344+
}
345+
346+
// extractAndFormatAPIError extracts useful information from API errors
347+
func extractAndFormatAPIError(err error) error {
348+
// Extract HTTP status code if possible
349+
if err != nil {
350+
errMsg := err.Error()
351+
if strings.Contains(errMsg, "401") {
352+
return fmt.Errorf("unauthorized (HTTP 401): invalid API key or the key doesn't have permission to access the API")
353+
} else if strings.Contains(errMsg, "403") {
354+
return fmt.Errorf("forbidden (HTTP 403): the API key is valid but doesn't have sufficient permissions")
355+
} else if strings.Contains(errMsg, "429") {
356+
return fmt.Errorf("rate limit exceeded (HTTP 429): the API key has reached its quota limit")
357+
} else if strings.Contains(errMsg, "400") {
358+
return fmt.Errorf("bad request (HTTP 400): the request was malformed or invalid parameters were provided")
359+
} else if strings.Contains(errMsg, "500") {
360+
return fmt.Errorf("server error (HTTP 500): an error occurred on the Google AI service")
361+
}
362+
}
363+
364+
return fmt.Errorf("error validating API key: %w", err)
365+
}
Lines changed: 223 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,223 @@
1+
// googlegenai_test.go
2+
package googlegenai
3+
4+
import (
5+
"context"
6+
"fmt"
7+
"os"
8+
"strings"
9+
"testing"
10+
"time"
11+
12+
"google.golang.org/genai"
13+
)
14+
15+
func TestAPIErrors(t *testing.T) {
16+
ctx := context.Background()
17+
apiKey := os.Getenv("GEMINI_API_KEY")
18+
if apiKey == "" {
19+
t.Skip("GEMINI_API_KEY environment variable not set")
20+
}
21+
22+
// 1. Test valid API key
23+
t.Run("ValidAPIKey", func(t *testing.T) {
24+
invalidAPIKeys := []string{
25+
"invalid-key-123",
26+
"",
27+
"AI" + strings.Repeat("x", 30),
28+
}
29+
30+
for _, invalidAPIKey := range invalidAPIKeys {
31+
t.Logf("Testing invalid key: %s", invalidAPIKey)
32+
33+
gc := genai.ClientConfig{
34+
Backend: genai.BackendGeminiAPI,
35+
APIKey: invalidAPIKey,
36+
}
37+
38+
client, err := genai.NewClient(ctx, &gc)
39+
t.Logf("Client creation result: %v", err)
40+
41+
if err != nil {
42+
continue // Skip to next key if we can't even create a client
43+
}
44+
45+
// Try to make an API call
46+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
47+
defer cancel()
48+
49+
_, err = client.Chats.Create(ctx, "gemini-2.0-flash", nil, nil)
50+
t.Logf("API call result: %v", err)
51+
52+
if err == nil {
53+
t.Logf("Unexpected: No error returned with invalid key!")
54+
continue
55+
56+
}
57+
58+
t.Logf("Error type: %T", err)
59+
t.Logf("Error contains '401': %v", strings.Contains(err.Error(), "401"))
60+
t.Logf("Error contains 'unauthorized': %v", strings.Contains(strings.ToLower(err.Error()), "unauthorized"))
61+
t.Logf("Error contains 'invalid': %v", strings.Contains(strings.ToLower(err.Error()), "invalid"))
62+
}
63+
})
64+
65+
// 2. Test 401 (Unauthorized) - Invalid API Key
66+
t.Run("Unauthorized401", func(t *testing.T) {
67+
gc := genai.ClientConfig{
68+
Backend: genai.BackendGeminiAPI,
69+
APIKey: "invalid-key-deliberately-wrong",
70+
}
71+
72+
badClient, err := genai.NewClient(ctx, &gc)
73+
if err != nil {
74+
t.Logf("Failed at client creation: %v", err)
75+
if strings.Contains(err.Error(), "401") {
76+
return // Test passed - caught at client creation
77+
}
78+
}
79+
80+
// If client creation succeeded, try to use it
81+
_, err = badClient.Chats.Create(ctx, "gemini-2.0-flash", nil, nil)
82+
if err == nil {
83+
t.Fatal("Expected 401 error but got none")
84+
}
85+
86+
if !strings.Contains(err.Error(), "401") {
87+
t.Fatalf("Expected 401 error but got: %v", err)
88+
}
89+
})
90+
91+
// 3. Test 403 (Forbidden) - Access to restricted model
92+
t.Run("Forbidden403", func(t *testing.T) {
93+
gc := genai.ClientConfig{
94+
Backend: genai.BackendGeminiAPI,
95+
APIKey: apiKey,
96+
}
97+
98+
client, err := genai.NewClient(ctx, &gc)
99+
if err != nil {
100+
t.Fatalf("Failed to create client: %v", err)
101+
}
102+
103+
// Try to access a potentially restricted model
104+
_, err = client.Chats.Create(ctx, "restricted-model-name", nil, nil)
105+
if err == nil {
106+
t.Skip("No 403 error - either model exists or permissions are sufficient")
107+
}
108+
109+
if strings.Contains(err.Error(), "403") {
110+
t.Logf("Successfully triggered 403: %v", err)
111+
} else {
112+
t.Logf("Got error but not 403: %v", err)
113+
}
114+
})
115+
116+
// 4. For the 400 (Bad Request) test
117+
t.Run("BadRequest400", func(t *testing.T) {
118+
gc := genai.ClientConfig{
119+
Backend: genai.BackendGeminiAPI,
120+
APIKey: apiKey,
121+
}
122+
123+
client, err := genai.NewClient(ctx, &gc)
124+
if err != nil {
125+
t.Fatalf("Failed to create client: %v", err)
126+
}
127+
128+
// Try more obviously invalid parameters
129+
// 1. Try an empty model name
130+
t.Logf("Testing with empty model name...")
131+
_, err = client.Chats.Create(ctx, "", nil, nil)
132+
if err != nil {
133+
t.Logf("Error with empty model: %v", err)
134+
}
135+
136+
// 2. Try with extremely long invalid model name
137+
t.Logf("Testing with extremely long model name...")
138+
longModelName := strings.Repeat("invalid-model", 50)
139+
_, err = client.Chats.Create(ctx, longModelName, nil, nil)
140+
if err != nil {
141+
t.Logf("Error with long model name: %v", err)
142+
}
143+
144+
// 3. Try with special characters
145+
t.Logf("Testing with special characters in model name...")
146+
_, err = client.Chats.Create(ctx, "$$$$^^^%%%", nil, nil)
147+
if err != nil {
148+
t.Logf("Error with special chars: %v", err)
149+
}
150+
151+
if err == nil {
152+
t.Skip("Could not trigger a 400 error with any of the invalid inputs")
153+
}
154+
})
155+
156+
// 5. Test 429 (Rate Limit) - Too many requests
157+
t.Run("RateLimit429", func(t *testing.T) {
158+
gc := genai.ClientConfig{
159+
Backend: genai.BackendGeminiAPI,
160+
APIKey: apiKey,
161+
}
162+
163+
client, err := genai.NewClient(ctx, &gc)
164+
if err != nil {
165+
t.Fatalf("Failed to create client: %v", err)
166+
}
167+
168+
// Try to trigger rate limit with multiple rapid requests
169+
for i := 0; i < 10; i++ {
170+
_, err = client.Chats.Create(ctx, "gemini-2.0-flash", nil, nil)
171+
if err != nil && strings.Contains(err.Error(), "429") {
172+
t.Logf("Successfully triggered 429: %v", err)
173+
return
174+
}
175+
}
176+
177+
t.Skip("Could not trigger 429 rate limit error with 10 requests")
178+
})
179+
}
180+
181+
// Add this to test your error extraction function
182+
func TestExtractAndFormatAPIError(t *testing.T) {
183+
testCases := []struct {
184+
name string
185+
err error
186+
expected string
187+
}{
188+
{
189+
name: "401 Error",
190+
err: fmt.Errorf("error: status code 401: Unauthorized"),
191+
expected: "unauthorized (HTTP 401)",
192+
},
193+
{
194+
name: "403 Error",
195+
err: fmt.Errorf("error: status code 403: Forbidden"),
196+
expected: "forbidden (HTTP 403)",
197+
},
198+
{
199+
name: "429 Error",
200+
err: fmt.Errorf("error: status code 429: Too Many Requests"),
201+
expected: "rate limit exceeded (HTTP 429)",
202+
},
203+
{
204+
name: "400 Error",
205+
err: fmt.Errorf("error: status code 400: Bad Request"),
206+
expected: "bad request (HTTP 400)",
207+
},
208+
{
209+
name: "500 Error",
210+
err: fmt.Errorf("error: status code 500: Internal Server Error"),
211+
expected: "server error (HTTP 500)",
212+
},
213+
}
214+
215+
for _, tc := range testCases {
216+
t.Run(tc.name, func(t *testing.T) {
217+
formatted := extractAndFormatAPIError(tc.err)
218+
if !strings.Contains(formatted.Error(), tc.expected) {
219+
t.Errorf("Expected error to contain '%s', got: %v", tc.expected, formatted)
220+
}
221+
})
222+
}
223+
}

0 commit comments

Comments
 (0)