Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from urllib.parse import quote, urljoin, urlparse

import httpx
from azure.identity import CertificateCredential, ClientSecretCredential
from azure.identity.aio import CertificateCredential, ClientSecretCredential
from httpx import AsyncHTTPTransport, Response, Timeout
from kiota_abstractions.api_error import APIError
from kiota_abstractions.method import Method
Expand All @@ -50,18 +50,16 @@

from airflow.exceptions import AirflowBadRequest, AirflowConfigException, AirflowProviderDeprecationWarning
from airflow.providers.common.compat.connection import get_async_connection
from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException, BaseHook
from airflow.providers.common.compat.sdk import AirflowException, AirflowNotFoundException, BaseHook, redact

if TYPE_CHECKING:
from azure.identity._internal.client_credential_base import ClientCredentialBase
from azure.core.credentials_async import AsyncTokenCredential
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_abstractions.response_handler import NativeResponseType
from kiota_abstractions.serialization import ParsableFactory

from airflow.providers.common.compat.sdk import Connection

from airflow.providers.common.compat.sdk import redact

PaginationCallable = Callable[..., tuple[str, dict[str, Any] | None]]


Expand Down Expand Up @@ -366,15 +364,14 @@ def _build_request_adapter(self, connection) -> tuple[str, RequestAdapter]:
http_client=http_client,
base_url=base_url,
)
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
return api_version, request_adapter

def get_conn(self) -> RequestAdapter:
"""
Initiate a new RequestAdapter connection.

.. warning::
This method is deprecated.
This method is deprecated. Use :meth:`get_async_conn` instead.
"""
if not self.conn_id:
raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")
Expand All @@ -390,19 +387,35 @@ def get_conn(self) -> RequestAdapter:
if not request_adapter:
connection = self.get_connection(conn_id=self.conn_id)
api_version, request_adapter = self._build_request_adapter(connection)
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)
self.api_version = api_version
return request_adapter

@staticmethod
def _is_http_client_closed(request_adapter: RequestAdapter) -> bool:
"""Return True when the underlying httpx AsyncClient has been closed."""
return cast("HttpxRequestAdapter", request_adapter)._http_client.is_closed

async def get_async_conn(self) -> RequestAdapter:
"""Initiate a new RequestAdapter connection asynchronously."""
if not self.conn_id:
raise AirflowException("Failed to create the KiotaRequestAdapterHook. No conn_id provided!")

api_version, request_adapter = self.cached_request_adapters.get(self.conn_id, (None, None))

if request_adapter and self._is_http_client_closed(request_adapter):
self.log.warning(
"Cached request adapter for conn_id '%s' has a closed HTTP client. Rebuilding.",
self.conn_id,
)
self.cached_request_adapters.pop(self.conn_id, None)
request_adapter = None

if not request_adapter:
connection = await get_async_connection(conn_id=self.conn_id)
api_version, request_adapter = self._build_request_adapter(connection)
self.cached_request_adapters[self.conn_id] = (api_version, request_adapter)

self.api_version = api_version
return request_adapter

Expand Down Expand Up @@ -433,7 +446,7 @@ def get_credentials(
authority: str | None,
verify: bool,
proxies: dict | None,
) -> ClientCredentialBase:
) -> AsyncTokenCredential:
tenant_id = config.get("tenant_id") or config.get("tenantId")
certificate_path = config.get("certificate_path")
certificate_data = config.get("certificate_data")
Expand Down Expand Up @@ -582,16 +595,25 @@ async def run(
async def send_request(self, request_info: RequestInformation, response_type: str | None = None):
conn = await self.get_async_conn()

if response_type:
return await conn.send_primitive_async(
try:
if response_type:
return await conn.send_primitive_async(
request_info=request_info,
response_type=response_type,
error_map=self.error_mapping(),
)
return await conn.send_no_response_content_async(
request_info=request_info,
response_type=response_type,
error_map=self.error_mapping(),
)
return await conn.send_no_response_content_async(
request_info=request_info,
error_map=self.error_mapping(),
)
except Exception as e:
self.log.warning(
"Request failed for conn_id '%s': %s. Invalidating cached request adapter.",
self.conn_id,
e,
)
self.cached_request_adapters.pop(self.conn_id, None)
raise

def request_information(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

import asyncio
import inspect
from contextlib import AbstractAsyncContextManager
from json import JSONDecodeError
from os.path import dirname
from typing import TYPE_CHECKING, cast
from unittest.mock import Mock, patch
from typing import cast
from unittest.mock import AsyncMock, Mock, patch

import pytest
from httpx import Response
Expand Down Expand Up @@ -52,31 +53,8 @@
patch_hook_and_request_adapter,
)

if TYPE_CHECKING:
from azure.identity._internal.msal_credentials import MsalCredential
from kiota_abstractions.authentication import BaseBearerTokenAuthenticationProvider
from kiota_abstractions.request_adapter import RequestAdapter
from kiota_authentication_azure.azure_identity_access_token_provider import (
AzureIdentityAccessTokenProvider,
)


class TestKiotaRequestAdapterHook:
@staticmethod
def assert_tenant_id(request_adapter: RequestAdapter, expected_tenant_id: str):
adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", request_adapter)
auth_provider: BaseBearerTokenAuthenticationProvider = cast(
"BaseBearerTokenAuthenticationProvider",
adapter._authentication_provider,
)
access_token_provider: AzureIdentityAccessTokenProvider = cast(
"AzureIdentityAccessTokenProvider",
auth_provider.access_token_provider,
)
credentials: MsalCredential = cast("MsalCredential", access_token_provider._credentials)
tenant_id = credentials._tenant_id
assert tenant_id == expected_tenant_id

def test_get_conn(self):
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
Expand Down Expand Up @@ -276,10 +254,15 @@ def test_execute_callable_when_required_parameter_is_missing(self):
@pytest.mark.asyncio
async def test_tenant_id(self):
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
actual = await hook.get_async_conn()
with patch(
"airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential",
autospec=True,
) as mock_credential_cls:
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
await hook.get_async_conn()

self.assert_tenant_id(actual, "tenant-id")
mock_credential_cls.assert_called_once()
assert mock_credential_cls.call_args.kwargs.get("tenant_id") == "tenant-id"

@pytest.mark.asyncio
async def test_azure_tenant_id(self):
Expand All @@ -289,10 +272,15 @@ async def test_azure_tenant_id(self):
azure_tenant_id="azure-tenant-id",
)
):
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
actual = await hook.get_async_conn()
with patch(
"airflow.providers.microsoft.azure.hooks.msgraph.ClientSecretCredential",
autospec=True,
) as mock_credential_cls:
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
await hook.get_async_conn()

self.assert_tenant_id(actual, "azure-tenant-id")
mock_credential_cls.assert_called_once()
assert mock_credential_cls.call_args.kwargs.get("tenant_id") == "azure-tenant-id"

@pytest.mark.asyncio
async def test_proxies(self):
Expand Down Expand Up @@ -472,6 +460,116 @@ def test_msal_returns_proxies_when_no_authority_with_proxy_key(self):

assert result == proxies

def test_get_credentials_returns_async_client_secret_credential(self):
"""get_credentials must return an async context manager (azure.identity.aio credential)."""
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
config = {"tenant_id": "tenant-id"}

credentials = hook.get_credentials(
login="client_id",
password="client_secret",
config=config,
authority=None,
verify=True,
proxies=None,
)

assert isinstance(credentials, AbstractAsyncContextManager)

def test_get_credentials_returns_async_certificate_credential(self):
"""get_credentials must return an async context manager when certificate_data is set."""
import datetime

from cryptography import x509
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.x509.oid import NameOID

private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, "test")])
cert = (
x509.CertificateBuilder()
.subject_name(name)
.issuer_name(name)
.public_key(private_key.public_key())
.serial_number(x509.random_serial_number())
.not_valid_before(datetime.datetime.now(datetime.timezone.utc))
.not_valid_after(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1))
.sign(private_key, hashes.SHA256())
)
pem = private_key.private_bytes(
serialization.Encoding.PEM,
serialization.PrivateFormat.TraditionalOpenSSL,
serialization.NoEncryption(),
) + cert.public_bytes(serialization.Encoding.PEM)

hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
config = {
"tenant_id": "tenant-id",
"certificate_data": pem.decode(),
}

credentials = hook.get_credentials(
login="client_id",
password=None,
config=config,
authority=None,
verify=True,
proxies=None,
)

assert isinstance(credentials, AbstractAsyncContextManager)

@pytest.mark.asyncio
async def test_get_async_conn_uses_async_credentials(self):
"""get_async_conn must build a request adapter backed by async credentials."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
request_adapter = await hook.get_async_conn()

adapter: HttpxRequestAdapter = cast("HttpxRequestAdapter", request_adapter)
# Reach into the auth provider chain to retrieve the underlying credential object.
access_token_provider = adapter._authentication_provider.access_token_provider
credentials = access_token_provider._credentials

assert isinstance(credentials, AbstractAsyncContextManager)

@pytest.mark.asyncio
async def test_get_async_conn_rebuilds_adapter_when_http_client_is_closed(self):
"""get_async_conn evicts and rebuilds the adapter when the cached HTTP client is already closed."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")

stale_adapter = Mock(spec=HttpxRequestAdapter)
stale_adapter._http_client = Mock(is_closed=True)
hook.cached_request_adapters[hook.conn_id] = (hook.api_version, stale_adapter)

fresh_adapter = Mock(spec=HttpxRequestAdapter)
fresh_adapter._http_client = Mock(is_closed=False)

with patch.object(hook, "_build_request_adapter", return_value=("v1.0", fresh_adapter)):
result = await hook.get_async_conn()

assert result is fresh_adapter
assert hook.cached_request_adapters[hook.conn_id] == ("v1.0", fresh_adapter)

@pytest.mark.asyncio
async def test_send_request_invalidates_cache_and_raises_on_any_error(self):
"""send_request evicts the cached adapter and re-raises on any request error."""
with patch_hook():
hook = KiotaRequestAdapterHook(conn_id="msgraph_api")

adapter = Mock(spec=HttpxRequestAdapter)
adapter._http_client = Mock(is_closed=False)
adapter.send_no_response_content_async = AsyncMock(side_effect=RuntimeError("some error"))
hook.cached_request_adapters[hook.conn_id] = (hook.api_version, adapter)

with pytest.raises(RuntimeError, match="some error"):
await hook.run(url="users")

adapter.send_no_response_content_async.assert_called_once()
assert hook.conn_id not in hook.cached_request_adapters


class TestKiotaRequestAdapterHookProtocol:
"""Test protocol handling in KiotaRequestAdapterHook."""
Expand Down
Loading