diff --git a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py index 32e21f61ebed3..2f3bf3a40308c 100644 --- a/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/microsoft/azure/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -31,7 +31,7 @@ from urllib.parse import quote, urljoin, urlparse import httpx -from azure.identity import CertificateCredential, ClientSecretCredential +from azure.identity.aio import CertificateCredential, ClientSecretCredential from httpx import AsyncHTTPTransport, Response, Timeout from kiota_abstractions.api_error import APIError from kiota_abstractions.method import Method @@ -50,18 +50,16 @@ from airflow.exceptions import AirflowBadRequest, AirflowConfigException, AirflowProviderDeprecationWarning from airflow.providers.common.compat.connection import get_async_connection -from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException, BaseHook +from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException, BaseHook, redact if TYPE_CHECKING: - from azure.identity._internal.client_credential_base import ClientCredentialBase + from azure.core.credentials_async import AsyncTokenCredential from kiota_abstractions.request_adapter import RequestAdapter from kiota_abstractions.response_handler import NativeResponseType from kiota_abstractions.serialization import ParsableFactory from airflow.providers.common.compat.sdk import Connection -from airflow.providers.common.compat.sdk import redact - PaginationCallable = Callable[..., tuple[str, dict[str, Any] | None]] @@ -366,7 +364,6 @@ def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]: http_client=http_client, base_url=base_url, ) - self.cached_request_adapters[self.conn_id] = (api_version, request_adapter) return api_version, request_adapter def get_conn(self) -> RequestAdapter: @@ -374,7 +371,7 @@ def get_conn(self) -> RequestAdapter: Initiate a new RequestAdapter connection. .. warning:: - This method is deprecated. + This method is deprecated. Use :meth:`get_async_conn` instead. """ if not self.conn_id: raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!") @@ -390,9 +387,15 @@ def get_conn(self) -> RequestAdapter: if not request_adapter: connection = self.get_connection(conn_id=self.conn_id) api_version, request_adapter = self._build_request_adapter(connection) + self.cached_request_adapters[self.conn_id] = (api_version, request_adapter) self.api_version = api_version return request_adapter + @staticmethod + def _is_http_client_closed(request_adapter: RequestAdapter) -> bool: + """Return True when the underlying httpx AsyncClient has been closed.""" + return cast("HttpxRequestAdapter", request_adapter)._http_client.is_closed + async def get_async_conn(self) -> RequestAdapter: """Initiate a new RequestAdapter connection asynchronously.""" if not self.conn_id: @@ -400,9 +403,19 @@ async def get_async_conn(self) -> RequestAdapter: api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None)) + if request_adapter and self._is_http_client_closed(request_adapter): + self.log.warning( + "Cached request adapter for conn_id '%s' has a closed HTTP client. Rebuilding.", + self.conn_id, + ) + self.cached_request_adapters.pop(self.conn_id, None) + request_adapter = None + if not request_adapter: connection = await get_async_connection(conn_id=self.conn_id) api_version, request_adapter = self._build_request_adapter(connection) + self.cached_request_adapters[self.conn_id] = (api_version, request_adapter) + self.api_version = api_version return request_adapter @@ -433,7 +446,7 @@ def get_credentials( authority: str | None, verify: bool, proxies: dict | None, - ) -> ClientCredentialBase: + ) -> AsyncTokenCredential: tenant_id = config.get("tenant_id") or config.get("tenantId") certificate_path = config.get("certificate_path") certificate_data = config.get("certificate_data") @@ -582,16 +595,25 @@ async def run( async def send_request(self, request_info: RequestInformation, response_type: str | None = None): conn = await self.get_async_conn() - if response_type: - return await conn.send_primitive_async( + try: + if response_type: + return await conn.send_primitive_async( + request_info=request_info, + response_type=response_type, + error_map=self.error_mapping(), + ) + return await conn.send_no_response_content_async( request_info=request_info, - response_type=response_type, error_map=self.error_mapping(), ) - return await conn.send_no_response_content_async( - request_info=request_info, - error_map=self.error_mapping(), - ) + except Exception as e: + self.log.warning( + "Request failed for conn_id '%s': %s. Invalidating cached request adapter.", + self.conn_id, + e, + ) + self.cached_request_adapters.pop(self.conn_id, None) + raise def request_information( self, diff --git a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py index 96443b5a67de0..dab96656d43b4 100644 --- a/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py +++ b/providers/microsoft/azure/tests/unit/microsoft/azure/hooks/test_msgraph.py @@ -18,10 +18,11 @@ import asyncio import inspect +from contextlib import AbstractAsyncContextManager from json import JSONDecodeError from os.path import dirname -from typing import TYPE_CHECKING, cast -from unittest.mock import Mock, patch +from typing import cast +from unittest.mock import AsyncMock, Mock, patch import pytest from httpx import Response @@ -52,31 +53,8 @@ patch_hook_and_request_adapter, ) -if TYPE_CHECKING: - from azure.identity._internal.msal_credentials import MsalCredential - from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider - from kiota_abstractions.request_adapter import RequestAdapter - from kiota_authentication_azure.azure_identity_access_token_provider import ( - AzureIdentityAccessTokenProvider, - ) - class TestKiotaRequestAdapterHook: - @staticmethod - def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str): - adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", request_adapter) - auth_provider: BaseBearerTokenAuthenticationProvider = cast( - "BaseBearerTokenAuthenticationProvider", - adapter._authentication_provider, - ) - access_token_provider: AzureIdentityAccessTokenProvider = cast( - "AzureIdentityAccessTokenProvider", - auth_provider.access_token_provider, - ) - credentials: MsalCredential = cast("MsalCredential", access_token_provider._credentials) - tenant_id = credentials._tenant_id - assert tenant_id == expected_tenant_id - def test_get_conn(self): with patch_hook(): hook = KiotaRequestAdapterHook(conn_id="msgraph_api") @@ -276,10 +254,15 @@ def test_execute_callable_when_required_parameter_is_missing(self): @pytest.mark.asyncio async def test_tenant_id(self): with patch_hook(): - hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = await hook.get_async_conn() + with patch( + "airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential", + autospec=True, + ) as mock_credential_cls: + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + await hook.get_async_conn() - self.assert_tenant_id(actual, "tenant-id") + mock_credential_cls.assert_called_once() + assert mock_credential_cls.call_args.kwargs.get("tenant_id") == "tenant-id" @pytest.mark.asyncio async def test_azure_tenant_id(self): @@ -289,10 +272,15 @@ async def test_azure_tenant_id(self): azure_tenant_id="azure-tenant-id", ) ): - hook = KiotaRequestAdapterHook(conn_id="msgraph_api") - actual = await hook.get_async_conn() + with patch( + "airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential", + autospec=True, + ) as mock_credential_cls: + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + await hook.get_async_conn() - self.assert_tenant_id(actual, "azure-tenant-id") + mock_credential_cls.assert_called_once() + assert mock_credential_cls.call_args.kwargs.get("tenant_id") == "azure-tenant-id" @pytest.mark.asyncio async def test_proxies(self): @@ -472,6 +460,116 @@ def test_msal_returns_proxies_when_no_authority_with_proxy_key(self): assert result == proxies + def test_get_credentials_returns_async_client_secret_credential(self): + """get_credentials must return an async context manager (azure.identity.aio credential).""" + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + config = {"tenant_id": "tenant-id"} + + credentials = hook.get_credentials( + login="client_id", + password="client_secret", + config=config, + authority=None, + verify=True, + proxies=None, + ) + + assert isinstance(credentials, AbstractAsyncContextManager) + + def test_get_credentials_returns_async_certificate_credential(self): + """get_credentials must return an async context manager when certificate_data is set.""" + import datetime + + from cryptography import x509 + from cryptography.hazmat.primitives import hashes, serialization + from cryptography.hazmat.primitives.asymmetric import rsa + from cryptography.x509.oid import NameOID + + private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) + name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")]) + cert = ( + x509.CertificateBuilder() + .subject_name(name) + .issuer_name(name) + .public_key(private_key.public_key()) + .serial_number(x509.random_serial_number()) + .not_valid_before(datetime.datetime.now(datetime.timezone.utc)) + .not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)) + .sign(private_key, hashes.SHA256()) + ) + pem = private_key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.TraditionalOpenSSL, + serialization.NoEncryption(), + ) + cert.public_bytes(serialization.Encoding.PEM) + + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + config = { + "tenant_id": "tenant-id", + "certificate_data": pem.decode(), + } + + credentials = hook.get_credentials( + login="client_id", + password=None, + config=config, + authority=None, + verify=True, + proxies=None, + ) + + assert isinstance(credentials, AbstractAsyncContextManager) + + @pytest.mark.asyncio + async def test_get_async_conn_uses_async_credentials(self): + """get_async_conn must build a request adapter backed by async credentials.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + request_adapter = await hook.get_async_conn() + + adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", request_adapter) + # Reach into the auth provider chain to retrieve the underlying credential object. + access_token_provider = adapter._authentication_provider.access_token_provider + credentials = access_token_provider._credentials + + assert isinstance(credentials, AbstractAsyncContextManager) + + @pytest.mark.asyncio + async def test_get_async_conn_rebuilds_adapter_when_http_client_is_closed(self): + """get_async_conn evicts and rebuilds the adapter when the cached HTTP client is already closed.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + + stale_adapter = Mock(spec=HttpxRequestAdapter) + stale_adapter._http_client = Mock(is_closed=True) + hook.cached_request_adapters[hook.conn_id] = (hook.api_version, stale_adapter) + + fresh_adapter = Mock(spec=HttpxRequestAdapter) + fresh_adapter._http_client = Mock(is_closed=False) + + with patch.object(hook, "_build_request_adapter", return_value=("v1.0", fresh_adapter)): + result = await hook.get_async_conn() + + assert result is fresh_adapter + assert hook.cached_request_adapters[hook.conn_id] == ("v1.0", fresh_adapter) + + @pytest.mark.asyncio + async def test_send_request_invalidates_cache_and_raises_on_any_error(self): + """send_request evicts the cached adapter and re-raises on any request error.""" + with patch_hook(): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + + adapter = Mock(spec=HttpxRequestAdapter) + adapter._http_client = Mock(is_closed=False) + adapter.send_no_response_content_async = AsyncMock(side_effect=RuntimeError("some error")) + hook.cached_request_adapters[hook.conn_id] = (hook.api_version, adapter) + + with pytest.raises(RuntimeError, match="some error"): + await hook.run(url="users") + + adapter.send_no_response_content_async.assert_called_once() + assert hook.conn_id not in hook.cached_request_adapters + class TestKiotaRequestAdapterHookProtocol: """Test protocol handling in KiotaRequestAdapterHook."""