diff --git a/src/google/adk/models/apigee_llm.py b/src/google/adk/models/apigee_llm.py index 65c4156744..2667ae8e00 100644 --- a/src/google/adk/models/apigee_llm.py +++ b/src/google/adk/models/apigee_llm.py @@ -40,6 +40,7 @@ from .llm_response import LlmResponse if TYPE_CHECKING: + from google.auth.credentials import Credentials from google.genai import Client from .llm_request import LlmRequest @@ -92,6 +93,7 @@ def __init__( custom_headers: dict[str, str] | None = None, retry_options: Optional[types.HttpRetryOptions] = None, api_type: ApiType | str = ApiType.UNKNOWN, + credentials: Optional[Credentials] = None, ): """Initializes the Apigee LLM backend. @@ -123,6 +125,11 @@ def __init__( authorization headers in Vertex AI and Gemini API calls. retry_options: Allow google-genai to retry failed responses. api_type: The type of API to use. One of `ApiType` or string. + credentials: Optional google-auth credentials passed through to the + underlying `genai.Client`. Use this when the Apigee proxy requires + additional OAuth scopes (e.g., `userinfo.email` for tokeninfo-based + caller identification). When omitted, the default `genai.Client` + authentication flow is used. """ # fmt: skip super().__init__(model=model, retry_options=retry_options) @@ -165,6 +172,7 @@ def __init__( ) self._custom_headers = custom_headers or {} self._user_agent = f'google-adk/{adk_version.__version__}' + self._credentials = credentials @classmethod @override @@ -239,6 +247,8 @@ def api_client(self) -> Client: if self._isvertexai: kwargs_for_client['project'] = self._project kwargs_for_client['location'] = self._location + if self._credentials is not None: + kwargs_for_client['credentials'] = self._credentials return Client( http_options=http_options, diff --git a/tests/unittests/models/test_apigee_llm.py b/tests/unittests/models/test_apigee_llm.py index ecbb61d18f..1e371e8aa1 100644 --- a/tests/unittests/models/test_apigee_llm.py +++ b/tests/unittests/models/test_apigee_llm.py @@ -651,6 +651,71 @@ def test_parse_response_usage_metadata(): assert llm_response.usage_metadata.thoughts_token_count == 4 +@pytest.mark.asyncio +@mock.patch('google.genai.Client') +async def test_api_client_passes_credentials_when_provided( + mock_client_constructor, llm_request +): + """Tests that credentials passed to __init__ are forwarded to genai.Client.""" + mock_credentials = mock.Mock() + + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + apigee_llm = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + credentials=mock_credentials, + ) + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + _, kwargs = mock_client_constructor.call_args + assert kwargs['credentials'] is mock_credentials + + +@pytest.mark.asyncio +@mock.patch('google.genai.Client') +async def test_api_client_omits_credentials_when_not_provided( + mock_client_constructor, llm_request +): + """Tests that credentials kwarg is not forwarded when not supplied.""" + mock_client_instance = mock.Mock() + mock_client_instance.aio.models.generate_content = AsyncMock( + return_value=types.GenerateContentResponse( + candidates=[ + types.Candidate( + content=Content( + parts=[Part.from_text(text='Test response')], + role='model', + ) + ) + ] + ) + ) + mock_client_constructor.return_value = mock_client_instance + + apigee_llm = ApigeeLlm( + model=APIGEE_GEMINI_MODEL_ID, + proxy_url=PROXY_URL, + ) + _ = [resp async for resp in apigee_llm.generate_content_async(llm_request)] + + _, kwargs = mock_client_constructor.call_args + assert 'credentials' not in kwargs + + def test_parse_response_with_refusal(): """Tests that CompletionsHTTPClient parses refusal correctly.""" client = CompletionsHTTPClient(base_url='http://test')