diff --git a/tests/unittests/models/test_models.py b/tests/unittests/models/test_models.py index 16c9567aa8..b793f8a84a 100644 --- a/tests/unittests/models/test_models.py +++ b/tests/unittests/models/test_models.py @@ -14,6 +14,7 @@ from google.adk import models from google.adk.models.anthropic_llm import Claude +from google.adk.models.gemma_llm import Gemma from google.adk.models.google_llm import Gemini from google.adk.models.lite_llm import LiteLlm import pytest @@ -56,6 +57,19 @@ def test_match_claude_family(model_name): assert models.LLMRegistry.resolve(model_name) is Claude +@pytest.mark.parametrize( + 'model_name', + [ + 'gemma-3-27b-it', + 'gemma-4-27b-it', + 'gemma-4-31b-it', + ], +) +def test_match_gemma_family(model_name): + """Test that Gemma models are resolved correctly.""" + assert models.LLMRegistry.resolve(model_name) is Gemma + + @pytest.mark.parametrize( 'model_name', [