diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index 52e79b9fe146..2d7e7ab037f6 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -3,6 +3,11 @@ ## 12.31.0b1 (Unreleased) ### Features Added +- Added opt-in session-based authentication for `ContainerClient` via the new +`use_session` keyword argument. When enabled, it must be used with a +`TokenCredential`. Eligible GET requests issued through the client +are authenticated using a short-lived session credential obtained from the +service instead of the bearer token. ## 12.29.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-blob/azure/apiview-properties.json b/sdk/storage/azure-storage-blob/azure/apiview-properties.json index 64ad80365c09..217e34631107 100644 --- a/sdk/storage/azure-storage-blob/azure/apiview-properties.json +++ b/sdk/storage/azure-storage-blob/azure/apiview-properties.json @@ -26,6 +26,8 @@ "azure.storage.blob.models.CorsRule": null, "azure.storage.blob.models.CpkInfo": null, "azure.storage.blob.models.CpkScopeInfo": null, + "azure.storage.blob.models.CreateSessionConfiguration": null, + "azure.storage.blob.models.CreateSessionResponse": null, "azure.storage.blob.models.DelimitedTextConfiguration": null, "azure.storage.blob.models.FilterBlobItem": null, "azure.storage.blob.models.FilterBlobSegment": null, @@ -46,6 +48,7 @@ "azure.storage.blob.models.QuerySerialization": null, "azure.storage.blob.models.RetentionPolicy": null, "azure.storage.blob.models.SequenceNumberAccessConditions": null, + "azure.storage.blob.models.SessionCredentials": null, "azure.storage.blob.models.SignedIdentifier": null, "azure.storage.blob.models.SourceCpkInfo": null, "azure.storage.blob.models.SourceModifiedAccessConditions": null, @@ -68,6 +71,7 @@ "azure.storage.blob.models.RehydratePriority": null, "azure.storage.blob.models.BlobImmutabilityPolicyMode": null, "azure.storage.blob.models.GeoReplicationStatusType": null, + "azure.storage.blob.models.AuthenticationType": null, "azure.storage.blob.models.PremiumPageBlobAccessTier": null, "azure.storage.blob.models.AccessTierOptional": null, "azure.storage.blob.models.FileShareTokenIntent": null, @@ -134,6 +138,8 @@ "azure.storage.blob.aio.operations.ContainerOperations.list_blob_hierarchy_segment": null, "azure.storage.blob.operations.ContainerOperations.get_account_info": null, "azure.storage.blob.aio.operations.ContainerOperations.get_account_info": null, + "azure.storage.blob.operations.ContainerOperations.create_session": null, + "azure.storage.blob.aio.operations.ContainerOperations.create_session": null, "azure.storage.blob.operations.BlobOperations.download": null, "azure.storage.blob.aio.operations.BlobOperations.download": null, "azure.storage.blob.operations.BlobOperations.get_properties": null, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py index 63cd584ad193..b2d3329a2307 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py @@ -90,6 +90,11 @@ class BlobServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided TokenCredential. Only supported with a TokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi index 526c2bfae18a..715ce9ee0c82 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi @@ -56,6 +56,7 @@ class BlobServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): max_single_get_size: int = 32 * 1024 * 1024, max_chunk_get_size: int = 4 * 1024 * 1024, audience: Optional[str] = None, + use_session: bool = False, **kwargs: Any ) -> None: ... def __enter__(self) -> Self: ... diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index 90c151a9ddc7..2d4fc84c2a60 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -99,6 +99,11 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): # pyli :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided TokenCredential. Only supported with a TokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi index fe3e0142beef..5c4273ecf970 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi @@ -71,6 +71,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): max_single_get_size: int = 32 * 1024 * 1024, min_large_block_upload_threshold: int = 4 * 1024 * 1024 + 1, use_byte_buffer: Optional[bool] = None, + use_session: bool = False, **kwargs: Any, ) -> None: ... def __enter__(self) -> Self: ... diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py index 09bb123a20af..2528775823d8 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py @@ -32,6 +32,7 @@ build_break_lease_request, build_change_lease_request, build_create_request, + build_create_session_request, build_delete_request, build_filter_blobs_request, build_get_access_policy_request, @@ -1861,3 +1862,82 @@ async def get_account_info( if cls: return cls(pipeline_response, None, response_headers) # type: ignore + + @distributed_trace_async + async def create_session( + self, + create_session_configuration: _models.CreateSessionConfiguration, + timeout: Optional[int] = None, + request_id_parameter: Optional[str] = None, + **kwargs: Any + ) -> _models.CreateSessionResponse: + """The Create Session operation enables users to create a session scoped to a container. + + :param create_session_configuration: Required. + :type create_session_configuration: ~azure.storage.blob.models.CreateSessionConfiguration + :param timeout: The timeout parameter is expressed in seconds. For more information, see + :code:`Setting + Timeouts for Blob Service Operations.`. Default value is None. + :type timeout: int + :param request_id_parameter: Provides a client-generated, opaque value with a 1 KB character + limit that is recorded in the analytics logs when storage analytics logging is enabled. Default + value is None. + :type request_id_parameter: str + :return: CreateSessionResponse or the result of cls(response) + :rtype: ~azure.storage.blob.models.CreateSessionResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + restype: Literal["container"] = kwargs.pop("restype", _params.pop("restype", "container")) + comp: Literal["session"] = kwargs.pop("comp", _params.pop("comp", "session")) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/xml")) + cls: ClsType[_models.CreateSessionResponse] = kwargs.pop("cls", None) + + _content = self._serialize.body(create_session_configuration, "CreateSessionConfiguration", is_xml=True) + + _request = build_create_session_request( + url=self._config.url, + version=self._config.version, + timeout=timeout, + request_id_parameter=request_id_parameter, + restype=restype, + comp=comp, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [201]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = self._deserialize.failsafe_deserialize( + _models.StorageError, + pipeline_response, + ) + raise HttpResponseError(response=response, model=error) + + deserialized = self._deserialize("CreateSessionResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py index 95e38c268f1b..c3a60746ee63 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py @@ -39,6 +39,8 @@ CorsRule, CpkInfo, CpkScopeInfo, + CreateSessionConfiguration, + CreateSessionResponse, DelimitedTextConfiguration, FilterBlobItem, FilterBlobSegment, @@ -59,6 +61,7 @@ QuerySerialization, RetentionPolicy, SequenceNumberAccessConditions, + SessionCredentials, SignedIdentifier, SourceCpkInfo, SourceModifiedAccessConditions, @@ -75,6 +78,7 @@ AccessTierRequired, AccountKind, ArchiveStatus, + AuthenticationType, BlobCopySourceTags, BlobExpiryOptions, BlobImmutabilityPolicyMode, @@ -129,6 +133,8 @@ "CorsRule", "CpkInfo", "CpkScopeInfo", + "CreateSessionConfiguration", + "CreateSessionResponse", "DelimitedTextConfiguration", "FilterBlobItem", "FilterBlobSegment", @@ -149,6 +155,7 @@ "QuerySerialization", "RetentionPolicy", "SequenceNumberAccessConditions", + "SessionCredentials", "SignedIdentifier", "SourceCpkInfo", "SourceModifiedAccessConditions", @@ -162,6 +169,7 @@ "AccessTierRequired", "AccountKind", "ArchiveStatus", + "AuthenticationType", "BlobCopySourceTags", "BlobExpiryOptions", "BlobImmutabilityPolicyMode", diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py index 8938f0fd9e97..c9f286990572 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py @@ -93,6 +93,14 @@ class ArchiveStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): REHYDRATE_PENDING_TO_SMART = "rehydrate-pending-to-smart" +class AuthenticationType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """The type of authentication required to create the session. The only type currently supported is + HMAC. + """ + + HMAC = "HMAC" + + class BlobCopySourceTags(str, Enum, metaclass=CaseInsensitiveEnumMeta): """BlobCopySourceTags.""" diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py index 3534891fba4e..1f4793b4829c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py @@ -1456,6 +1456,84 @@ def __init__(self, *, encryption_scope: Optional[str] = None, **kwargs: Any) -> self.encryption_scope = encryption_scope +class CreateSessionConfiguration(_serialization.Model): + """CreateSessionConfiguration. + + All required parameters must be populated in order to send to server. + + :ivar authentication_type: The type of authentication required to create the session. The only + type currently supported is HMAC. Required. "HMAC" + :vartype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + """ + + _validation = { + "authentication_type": {"required": True}, + } + + _attribute_map = { + "authentication_type": {"key": "AuthenticationType", "type": "str"}, + } + _xml_map = {"name": "CreateSessionRequest"} + + def __init__(self, *, authentication_type: Union[str, "_models.AuthenticationType"], **kwargs: Any) -> None: + """ + :keyword authentication_type: The type of authentication required to create the session. The + only type currently supported is HMAC. Required. "HMAC" + :paramtype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + """ + super().__init__(**kwargs) + self.authentication_type = authentication_type + + +class CreateSessionResponse(_serialization.Model): + """CreateSessionResponse. + + :ivar id: A unique identifier for the created session. + :vartype id: str + :ivar expiration: The time when the session will expire. The format follows RFC 1123. + :vartype expiration: ~datetime.datetime + :ivar authentication_type: The type of authentication required to create the session. The only + type currently supported is HMAC. "HMAC" + :vartype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + :ivar credentials: + :vartype credentials: ~azure.storage.blob.models.SessionCredentials + """ + + _attribute_map = { + "id": {"key": "Id", "type": "str"}, + "expiration": {"key": "Expiration", "type": "rfc-1123"}, + "authentication_type": {"key": "AuthenticationType", "type": "str"}, + "credentials": {"key": "Credentials", "type": "SessionCredentials"}, + } + _xml_map = {"name": "CreateSessionResult"} + + def __init__( + self, + *, + id: Optional[str] = None, # pylint: disable=redefined-builtin + expiration: Optional[datetime.datetime] = None, + authentication_type: Optional[Union[str, "_models.AuthenticationType"]] = None, + credentials: Optional["_models.SessionCredentials"] = None, + **kwargs: Any + ) -> None: + """ + :keyword id: A unique identifier for the created session. + :paramtype id: str + :keyword expiration: The time when the session will expire. The format follows RFC 1123. + :paramtype expiration: ~datetime.datetime + :keyword authentication_type: The type of authentication required to create the session. The + only type currently supported is HMAC. "HMAC" + :paramtype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + :keyword credentials: + :paramtype credentials: ~azure.storage.blob.models.SessionCredentials + """ + super().__init__(**kwargs) + self.id = id + self.expiration = expiration + self.authentication_type = authentication_type + self.credentials = credentials + + class DelimitedTextConfiguration(_serialization.Model): """Groups the settings used for interpreting the blob data if the blob is delimited text formatted. @@ -2483,6 +2561,39 @@ def __init__( self.if_sequence_number_equal_to = if_sequence_number_equal_to +class SessionCredentials(_serialization.Model): + """SessionCredentials. + + :ivar session_token: An opaque token used to authorize subsequent requests in the session. Must + be treated as a security credential. + :vartype session_token: str + :ivar session_key: Only returned when AuthenticationType is HMAC. A symmetric encryption key + used to sign requests in the session using the Shared Key protocol. + :vartype session_key: str + """ + + _attribute_map = { + "session_token": {"key": "SessionToken", "type": "str"}, + "session_key": {"key": "SessionKey", "type": "str"}, + } + _xml_map = {"name": "Credentials"} + + def __init__( + self, *, session_token: Optional[str] = None, session_key: Optional[str] = None, **kwargs: Any + ) -> None: + """ + :keyword session_token: An opaque token used to authorize subsequent requests in the session. + Must be treated as a security credential. + :paramtype session_token: str + :keyword session_key: Only returned when AuthenticationType is HMAC. A symmetric encryption key + used to sign requests in the session using the Shared Key protocol. + :paramtype session_key: str + """ + super().__init__(**kwargs) + self.session_token = session_token + self.session_key = session_key + + class SignedIdentifier(_serialization.Model): """signed identifier. diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py index ec2deb0de1c0..b4212745c052 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py @@ -888,6 +888,48 @@ def build_get_account_info_request( return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) +def build_create_session_request( + url: str, + *, + content: Any, + version: str, + timeout: Optional[int] = None, + request_id_parameter: Optional[str] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + restype: Literal["container"] = kwargs.pop("restype", _params.pop("restype", "container")) + comp: Literal["session"] = kwargs.pop("comp", _params.pop("comp", "session")) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/xml") + + # Construct URL + _url = kwargs.pop("template_url", "{url}") + path_format_arguments = { + "url": _SERIALIZER.url("url", url, "str", skip_quote=True), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["restype"] = _SERIALIZER.query("restype", restype, "str") + _params["comp"] = _SERIALIZER.query("comp", comp, "str") + if timeout is not None: + _params["timeout"] = _SERIALIZER.query("timeout", timeout, "int", minimum=0) + + # Construct headers + _headers["x-ms-version"] = _SERIALIZER.header("version", version, "str") + if request_id_parameter is not None: + _headers["x-ms-client-request-id"] = _SERIALIZER.header("request_id_parameter", request_id_parameter, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, content=content, **kwargs) + + class ContainerOperations: """ .. warning:: @@ -2696,3 +2738,82 @@ def get_account_info( # pylint: disable=inconsistent-return-statements if cls: return cls(pipeline_response, None, response_headers) # type: ignore + + @distributed_trace + def create_session( + self, + create_session_configuration: _models.CreateSessionConfiguration, + timeout: Optional[int] = None, + request_id_parameter: Optional[str] = None, + **kwargs: Any + ) -> _models.CreateSessionResponse: + """The Create Session operation enables users to create a session scoped to a container. + + :param create_session_configuration: Required. + :type create_session_configuration: ~azure.storage.blob.models.CreateSessionConfiguration + :param timeout: The timeout parameter is expressed in seconds. For more information, see + :code:`Setting + Timeouts for Blob Service Operations.`. Default value is None. + :type timeout: int + :param request_id_parameter: Provides a client-generated, opaque value with a 1 KB character + limit that is recorded in the analytics logs when storage analytics logging is enabled. Default + value is None. + :type request_id_parameter: str + :return: CreateSessionResponse or the result of cls(response) + :rtype: ~azure.storage.blob.models.CreateSessionResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + restype: Literal["container"] = kwargs.pop("restype", _params.pop("restype", "container")) + comp: Literal["session"] = kwargs.pop("comp", _params.pop("comp", "session")) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/xml")) + cls: ClsType[_models.CreateSessionResponse] = kwargs.pop("cls", None) + + _content = self._serialize.body(create_session_configuration, "CreateSessionConfiguration", is_xml=True) + + _request = build_create_session_request( + url=self._config.url, + version=self._config.version, + timeout=timeout, + request_id_parameter=request_id_parameter, + restype=restype, + comp=comp, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [201]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = self._deserialize.failsafe_deserialize( + _models.StorageError, + pipeline_response, + ) + raise HttpResponseError(response=response, model=error) + + deserialized = self._deserialize("CreateSessionResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index fb62552c15b4..20398952eea7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -55,10 +55,12 @@ StorageLoggingPolicy, StorageRequestHook, StorageResponseHook, + StorageSessionPolicy, ) from .request_handlers import serialize_batch_body, _get_batch_request_delimiter from .response_handlers import PartialBatchErrorException, process_storage_error from .shared_access_signature import QueryStringConstants +from .._generated import AzureBlobStorage from .._version import VERSION from .._shared_access_signature import _is_credential_sastoken @@ -327,11 +329,41 @@ def _create_pipeline( config.headers_policy, StorageRequestHook(**kwargs), self._credential_policy, - config.logging_policy, - StorageResponseHook(**kwargs), - DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs), ] + use_session = bool(kwargs.pop("use_session", False)) + if use_session: + if not hasattr(credential, "get_token"): + raise ValueError( + "use_session=True requires a TokenCredential; received " + f"{type(credential).__name__ if credential is not None else 'None'}." + ) + + def _session_client_factory(container_url: str) -> AzureBlobStorage: + sub_kwargs = dict(kwargs) + sub_kwargs.pop("_configuration", None) + sub_kwargs.pop("pipeline", None) + sub_kwargs["transport"] = transport + + _, session_pipeline = self._create_pipeline(credential, **sub_kwargs) + generated = AzureBlobStorage( + container_url, self.api_version, base_url=container_url, pipeline=session_pipeline + ) + return generated + + policies.append( + StorageSessionPolicy( + account_name=self.account_name, + session_client_factory=_session_client_factory, + ) + ) + policies.extend( + [ + config.logging_policy, + StorageResponseHook(**kwargs), + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs), + ] + ) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 7169ac25464c..98e961a3fad0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -44,8 +44,10 @@ AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook, + AsyncStorageSessionPolicy, ) from .response_handlers import PartialBatchErrorException, process_storage_error +from .._generated.aio import AzureBlobStorage from .._shared_access_signature import _is_credential_sastoken if TYPE_CHECKING: @@ -141,11 +143,41 @@ def _create_pipeline( config.headers_policy, StorageRequestHook(**kwargs), self._credential_policy, - config.logging_policy, - AsyncStorageResponseHook(**kwargs), - DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs), ] + use_session = bool(kwargs.pop("use_session", False)) + if use_session: + if not hasattr(credential, "get_token"): + raise ValueError( + "use_session=True requires a TokenCredential; received " + f"{type(credential).__name__ if credential is not None else 'None'}." + ) + + def _session_client_factory(container_url: str) -> AzureBlobStorage: + sub_kwargs = dict(kwargs) + sub_kwargs.pop("_configuration", None) + sub_kwargs.pop("pipeline", None) + sub_kwargs["transport"] = transport + + _, session_pipeline = self._create_pipeline(credential, **sub_kwargs) + generated = AzureBlobStorage( + container_url, self.api_version, base_url=container_url, pipeline=session_pipeline + ) + return generated + + policies.append( + AsyncStorageSessionPolicy( + account_name=self.account_name, + session_client_factory=_session_client_factory, + ) + ) + policies.extend( + [ + config.logging_policy, + AsyncStorageResponseHook(**kwargs), + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs), + ] + ) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py index 23786baef24b..4d455df080e2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py @@ -34,6 +34,7 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): CONDITION_HEADERS_NOT_SUPPORTED = "ConditionHeadersNotSupported" CONDITION_NOT_MET = "ConditionNotMet" EMPTY_METADATA_KEY = "EmptyMetadataKey" + FEATURE_NOT_ENABLED = "FeatureNotEnabled" INSUFFICIENT_ACCOUNT_PERMISSIONS = "InsufficientAccountPermissions" INTERNAL_ERROR = "InternalError" INVALID_AUTHENTICATION_INFO = "InvalidAuthenticationInfo" @@ -64,6 +65,7 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESOURCE_ALREADY_EXISTS = "ResourceAlreadyExists" RESOURCE_NOT_FOUND = "ResourceNotFound" SERVER_BUSY = "ServerBusy" + SESSIONS_UNAVAILABLE = "SessionOperationsTemporarilyUnavailable" UNSUPPORTED_HEADER = "UnsupportedHeader" UNSUPPORTED_XML_NODE = "UnsupportedXmlNode" UNSUPPORTED_QUERY_PARAMETER = "UnsupportedQueryParameter" diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 3a5f0b9d662f..f7963b428f8d 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -9,11 +9,14 @@ import random import re import uuid +from datetime import datetime, timedelta, timezone from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from threading import Lock +from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, + unquote, urlencode, urlparse, urlunparse, @@ -30,7 +33,8 @@ SansIOHTTPPolicy, ) -from .authentication import AzureSigningError, StorageHttpChallenge +from . import sign_string +from .authentication import AzureSigningError, _storage_header_sort, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode from .streams import ( @@ -45,6 +49,7 @@ is_crc64_validation, is_md5_validation, ) +from .._generated.models import CreateSessionConfiguration if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -55,12 +60,28 @@ _LOGGER = logging.getLogger(__name__) +_SIGNED_HEADERS = ( + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", +) + CONTENT_LENGTH_HEADER = "Content-Length" MD5_HEADER = "Content-MD5" CRC64_HEADER = "x-ms-content-crc64" SM_HEADER = "x-ms-structured-body" SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" SM_LENGTH_HEADER = "x-ms-structured-content-length" +SESSION_RETRIED_CONTEXT_KEY = "_session_retried" +UTC = timezone.utc def encode_base64(data: Union[bytes, str]) -> str: @@ -70,6 +91,88 @@ def encode_base64(data: Union[bytes, str]) -> str: return encoded.decode("utf-8") +def _extract_session(response: Any) -> Tuple[str, str, datetime]: + creds = getattr(response, "credentials", None) + if not creds or not getattr(creds, "session_token", None) or not getattr(creds, "session_key", None): + raise ValueError("CreateSession response missing SessionToken/SessionKey") + session_token: str = creds.session_token + session_key: str = creds.session_key + expires_at = getattr(response, "expiration", None) + if expires_at is None: + expires_at = datetime.now(UTC) + timedelta(minutes=5) + elif expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + return session_token, session_key, expires_at + + +def _analyze_request(request: "PipelineRequest") -> Optional[Tuple[str, str]]: + http_request = request.http_request + if http_request.method != "GET": + return None + parsed = urlparse(http_request.url) + segments = [seg for seg in parsed.path.split("/") if seg] + if len(segments) < 2: + return None + query = http_request.query + if "comp" in query or query.get("restype") == "container": + return None + container_name = segments[0] + container_url = f"{parsed.scheme}://{parsed.netloc}/{container_name}" + return container_name, container_url + + +def _used_session_token(request: "PipelineRequest") -> Optional[str]: + """Use for distinguishing between a successful concurrent refresh and invalidated token.""" + auth = request.http_request.headers.get("Authorization", "") + if not auth.startswith("Session "): + return None + return auth[len("Session ") :].split(":", 1)[0] or None + + +def _apply_session_auth( + request: "PipelineRequest", session_token: str, session_key: str, account_name: str +) -> None: + """Sign an eligible request with the SharedKey protocol under the Session scheme. + + Shared by the sync and async session policies; ``account_name`` is passed in + rather than read from ``self`` so neither policy needs to instantiate the other. + + :param ~azure.core.pipeline.PipelineRequest request: The request to sign in place. + :param str session_token: The session token to embed in the Authorization header. + :param str session_key: The HMAC signing key for the session. + :param str account_name: Storage account name; the signer identity. + :raises ~azure.storage.blob._shared.authentication.AzureSigningError: if signing fails. + """ + http_request = request.http_request + http_request.headers["x-ms-date"] = format_date_time(time()) + + # 1) Standard headers. Storage omits content-length when it is "0". + headers = {name.lower(): value for name, value in http_request.headers.items() if value} + if headers.get("content-length") == "0": + del headers["content-length"] + signed_headers = "\n".join(headers.get(h, "") for h in _SIGNED_HEADERS) + "\n" + + # 2) Canonicalized x-ms-* headers, sorted by the service-emulating comparator. + x_ms_headers = _storage_header_sort( + [(n.lower(), v) for n, v in http_request.headers.items() if n.lower().startswith("x-ms-")] + ) + canonicalized_headers = "".join(f"{n}:{v}\n" for n, v in x_ms_headers if v is not None) + + # 3) Canonicalized resource + query (query values must be url-decoded). + canonicalized_resource = "/" + account_name + urlparse(http_request.url).path + canonicalized_resource += "".join( + f"\n{n.lower()}:{unquote(v)}" for n, v in sorted(http_request.query.items()) if v is not None + ) + + string_to_sign = http_request.method + "\n" + signed_headers + canonicalized_headers + canonicalized_resource + + try: + signature = sign_string(session_key, string_to_sign) + except Exception as ex: # pylint: disable=broad-except + raise AzureSigningError(str(ex)) from ex + http_request.headers["Authorization"] = f"Session {session_token}:{signature}" + + # Are we out of retries? def is_exhausted(settings): retry_counts = ( @@ -843,3 +946,267 @@ def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") self.authorize_request(request, scope, tenant_id=challenge.tenant_id) return True + + +class Session: + """A session entry.""" + + __slots__ = ("session_token", "session_key", "expires_at", "is_fallback") + + REFRESH_BUFFER: timedelta = timedelta(seconds=30) + """Buffer before proactive refresh is initiated.""" + + def __init__( + self, + session_token: Optional[str], + session_key: Optional[str], + expires_at: datetime, + is_fallback: bool = False, + ) -> None: + self.session_token = session_token + self.session_key = session_key + self.expires_at = expires_at + self.is_fallback = is_fallback + + def expired(self, now: Optional[datetime] = None) -> bool: + now = now if now is not None else datetime.now(UTC) + diff = timedelta(seconds=0) if self.is_fallback else Session.REFRESH_BUFFER + return now >= self.expires_at - diff + + +class SessionCache: + """Thread-safe, container-level session cache for the sync stack. + + Concurrency model + ----------------- + * Reads (`get`) are lock-free. They perform a single dict.get and never + mutate the cache, so concurrent readers never need to coordinate. + * Writes (`put` / `put_fallback`) and the CreateSession single-flight are + serialized per-container via the lock returned by :meth:`lock_container`. + * A single _locks_guard serializes only the *creation* of per-container + locks, so two threads racing on a brand-new container can't build two + different lock objects. + """ + + FALLBACK_COOLDOWN: timedelta = timedelta(minutes=5) + """Cooldown applied to the fallback-to-bearer sentinel after an eligible create session failure.""" + + def __init__(self) -> None: + self._locks: Dict[str, Lock] = {} + self._locks_guard: Lock = Lock() + self._entry: Dict[str, Session] = {} + + def lock_container(self, container_name: str) -> Lock: + """Return the per-container lock, creating it exactly once. + + :param str container_name: The container to get the lock for. + :return: The single lock instance associated with the container. + :rtype: ~threading.Lock + """ + # Easy path: lock already exists, and on free threads it falls to slow path + existing_lock = self._locks.get(container_name) + if existing_lock is not None: + return existing_lock + # Slow path: create exactly one lock per container + with self._locks_guard: + return self._locks.setdefault(container_name, Lock()) + + def get(self, container_name: str) -> Optional[Session]: + """Return a live session for the container, or None. + + Lock-free and non-mutating. Expired entries are NOT deleted. + Instead, they are simply treated as a cache miss and overwritten on the next refresh. + + :param str container_name: The container to look up. + :return: A live (non-expired) session, or None on miss/expiry. + :rtype: ~azure.storage.blob._shared.policies.Session or None + """ + cached = self._entry.get(container_name, None) + if cached is None or cached.expired(): + return None + return cached + + def put(self, container_name: str, session_token: str, session_key: str, expires_at: datetime) -> None: + """Install a real session entry. + + Caller must hold the lock at the container-level. + + :param str container_name: The container the session belongs to. + :param str session_token: The session token to send as a header. + :param str session_key: The HMAC signing key for the session. + :param ~datetime.datetime expires_at: When the session expires. + """ + self._entry[container_name] = Session(session_token, session_key, expires_at, is_fallback=False) + + def put_fallback(self, container_name: str) -> None: + """Install a fallback-to-bearer sentinel for the cooldown window. + + Caller must hold the lock at the container-level. + + :param str container_name: The container to mark for bearer fallback. + """ + self._entry[container_name] = Session( + None, None, datetime.now(UTC) + SessionCache.FALLBACK_COOLDOWN, is_fallback=True + ) + + def invalidate(self, container_name: str, session_token: Optional[str] = None) -> None: + cached = self._entry.get(container_name, None) + if cached is not None and cached.session_token == session_token: + self._entry.pop(container_name, None) + + +class StorageSessionPolicy(HTTPPolicy): + """ + A pipeline policy that selects between session token and bearer token authentication. + + When enabled, eligible requests are authenticated with a session token. + The session token is cached to the container. + + When disabled, all requests are delegated to the bearer token policy. + """ + + + def __init__( + self, + *, + account_name: str, + session_client_factory: Callable[[str], Any], + ) -> None: + """Constructs a StorageSessionPolicy. + + :keyword str account_name: Storage account name; used as the signer + identity when signing session-authenticated requests. + :keyword session_client_factory: A callable that, given a container URL, + returns a session-disabled generated client (AzureBlobStorage) + whose pipeline uses OAuth/bearer auth. Invoked to issue CreateSession. + :paramtype session_client_factory: Callable[[str], Any] + :raises ValueError: if `account_name` or `session_client_factory` is `None`. + """ + if account_name is None or session_client_factory is None: + raise ValueError("account_name and session_client_factory are required.") + super().__init__() + self._account_name = account_name + self._session_client_factory = session_client_factory + self._enabled = True + self._cache = SessionCache() + + + def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: + config = CreateSessionConfiguration(authentication_type="HMAC") + client = self._session_client_factory(container_url) + response = client.container.create_session(create_session_configuration=config) + return _extract_session(response) + + def _refresh_session_token(self, container_name: str, container_url: str) -> Optional[Session]: + """Acquire (or re-use) a session for the container under per-container single-flight. + + :param str container_name: The container key for the cache and lock. + :param str container_url: The container-scoped URL for the CreateSession call. + :return: A live session, a fallback sentinel, or `None` if unusable. + :rtype: ~azure.storage.blob._shared.policies.Session or None + """ + with self._cache.lock_container(container_name): + existing = self._cache.get(container_name) + if existing is not None and not existing.expired(): + return existing + try: + token, key, expires_at = self._create_session(container_url) + self._cache.put(container_name, token, key, expires_at) + except (AzureError, ValueError): + _LOGGER.warning( + "CreateSession failed for container '%s'; falling back to bearer for %d seconds.", + container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + exc_info=True, + ) + self._cache.put_fallback(container_name) + return self._cache.get(container_name) + + def send(self, request: "PipelineRequest") -> "PipelineResponse": + """Orchestrate session auth. + + :param ~azure.core.pipeline.PipelineRequest request: The outgoing request. + :return: The pipeline response. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + container_name = self.on_request(request) + response = self.next.send(request) + return self.on_response(request, response, container_name) + + def on_request(self, request: "PipelineRequest") -> Optional[str]: + """Stamp session auth if eligible, otherwise leave the bearer header intact. + + :param ~azure.core.pipeline.PipelineRequest request: The request to (maybe) sign. + :return: The container name if a session was applied, else None. + :rtype: str or None + """ + if not self._enabled: + return None + analysis = _analyze_request(request) + if analysis is None: + return None + container_name, container_url = analysis + + session = self._cache.get(container_name) + if session is None: + # True miss/expiry (a live fallback sentinel is returned by get(), + # so we never reach refresh while the cooldown is active). + session = self._refresh_session_token(container_name, container_url) + + if session is None or session.is_fallback or not session.session_token or not session.session_key: + return None + + _apply_session_auth(request, session.session_token, session.session_key, self._account_name) + return container_name + + def on_response( + self, + request: "PipelineRequest", + response: "PipelineResponse", + container_name: Optional[str], + ) -> "PipelineResponse": + """React to session-related failures: cooldown sentinel or one-shot re-acquire. + + :param ~azure.core.pipeline.PipelineRequest request: The original request. + :param ~azure.core.pipeline.PipelineResponse response: The response to inspect. + :param container_name: Container that was session-signed, or `None` if bearer was used. + :type container_name: str or None + :return: The final response (possibly from a one-shot retry). + :rtype: ~azure.core.pipeline.PipelineResponse + """ + if container_name is None: + return response # bearer was used; nothing session-related to react to + + status = response.http_response.status_code + error_code = response.http_response.headers.get("x-ms-error-code", "") + + if error_code == StorageErrorCode.FEATURE_NOT_ENABLED: + _LOGGER.info("Session feature not enabled on this account; disabling session auth.") + self._enabled = False + return response + + # Unavailable / 5xx → negative-cache cooldown. + if error_code == StorageErrorCode.SESSIONS_UNAVAILABLE or status >= 500: + _LOGGER.warning( + "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", + error_code or "5xx", + status, + container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + ) + with self._cache.lock_container(container_name): + self._cache.put_fallback(container_name) + return response + + # 401 → invalidate + re-acquire ONCE, then resend. + if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): + _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) + used_token = _used_session_token(request) + with self._cache.lock_container(container_name): + self._cache.invalidate(container_name, used_token) + request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True + retried_container = self.on_request(request) + retried_response = self.next.send(request) + return self.on_response(request, retried_response, retried_container) + + return response diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index e1d13b1a83fa..7d825ccb7928 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -8,7 +8,8 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging import random -from typing import Any, Dict, TYPE_CHECKING +from datetime import datetime +from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError from azure.core.pipeline.policies import ( @@ -18,12 +19,21 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE +from .models import StorageErrorCode from .policies import ( + _analyze_request, + _apply_session_auth, + _extract_session, _prepare_content_validation, + _used_session_token, _validate_content_response, + CreateSessionConfiguration, encode_base64, is_retry, + Session, + SessionCache, StorageRetryPolicy, + SESSION_RETRIED_CONTEXT_KEY, ) from .streams_async import AsyncStructuredMessageDecoder from .validation import ( @@ -222,7 +232,7 @@ def __init__( retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs + **kwargs, ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -282,7 +292,7 @@ def __init__( retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs a Linear retry object. @@ -338,3 +348,185 @@ async def on_challenge(self, request: "PipelineRequest", response: "PipelineResp await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) return True + + +class AsyncSessionCache(SessionCache): + """Async variant of :class:`SessionCache`. + + Reuses the lock-free, non-mutating read path and the immutable + Session snapshots from the sync cache, but serializes the + per-container CreateSession single-flight with asynchronous locks. + """ + + def __init__(self) -> None: + super().__init__() + self._async_locks: Dict[str, asyncio.Lock] = {} + + def lock_container_async(self, container_name: str) -> asyncio.Lock: + """Return the per-container asyncio lock, creating it exactly once. + + Lock creation is not awaited, so it is safe to do without holding a + lock: the event loop guarantees this runs to completion without + interleaving other coroutines. + + :param str container_name: The container to get the lock for. + :return: The single asyncio lock associated with the container. + :rtype: ~asyncio.Lock + """ + existing = self._async_locks.get(container_name) + if existing is not None: + return existing + return self._async_locks.setdefault(container_name, asyncio.Lock()) + + +class AsyncStorageSessionPolicy(AsyncHTTPPolicy): + """Constructs an AsyncStorageSessionPolicy. + + Eligible blob download GETs are authenticated with a per-container session + token; everything else is delegated unchanged to the bearer credential + policy that sits earlier in the pipeline. + """ + + def __init__( + self, + *, + account_name: str, + session_client_factory: Callable[[str], Any], + ) -> None: + """Constructs an AsyncStorageSessionPolicy. + + :keyword str account_name: Storage account name; used as the signer + identity when signing session-authenticated requests. + :keyword session_client_factory: A callable that, given a container URL, + returns a session-disabled generated async client whose pipeline + uses OAuth/bearer auth. Invoked (and awaited) to issue CreateSession. + :paramtype session_client_factory: Callable[[str], Any] + :raises ValueError: if account_name or session_client_factory is None. + """ + if account_name is None or session_client_factory is None: + raise ValueError("account_name and session_client_factory are required.") + super().__init__() + self._account_name = account_name + self._session_client_factory = session_client_factory + self._enabled = True + self._cache = AsyncSessionCache() + + async def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: + config = CreateSessionConfiguration(authentication_type="HMAC") + client = self._session_client_factory(container_url) + response = await client.container.create_session(create_session_configuration=config) + return _extract_session(response) + + async def _refresh_session_token(self, container_name: str, container_url: str) -> Optional[Session]: + """Acquire (or re-use) a session under per-container async single-flight. + + :param str container_name: The container key for the cache and lock. + :param str container_url: The container-scoped URL for the CreateSession call. + :return: A live session, a fallback sentinel, or None if unusable. + :rtype: ~azure.storage.blob._shared.policies.Session or None + """ + async with self._cache.lock_container_async(container_name): + existing = self._cache.get(container_name) + if existing is not None and not existing.expired(): + return existing + try: + token, key, expires_at = await self._create_session(container_url) + self._cache.put(container_name, token, key, expires_at) + except (AzureError, ValueError): + _LOGGER.warning( + "CreateSession failed for container '%s'; falling back to bearer for %d seconds.", + container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + exc_info=True, + ) + self._cache.put_fallback(container_name) + return self._cache.get(container_name) + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + """Orchestrate session auth. + + :param ~azure.core.pipeline.PipelineRequest request: The outgoing request. + :return: The pipeline response. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + container_name = await self.on_request(request) + response = await self.next.send(request) + return await self.on_response(request, response, container_name) + + async def on_request(self, request: "PipelineRequest") -> Optional[str]: + """Stamp session auth if eligible, otherwise leave the bearer header intact. + + :param ~azure.core.pipeline.PipelineRequest request: The request to (maybe) sign. + :return: The container name if a session was applied, else None. + :rtype: str or None + """ + if not self._enabled: + return None + analysis = _analyze_request(request) + if analysis is None: + return None + container_name, container_url = analysis + + session = self._cache.get(container_name) + if session is None: + # True miss/expiry (a live fallback sentinel is returned by get(), + # so we never reach refresh while the cooldown is active). + session = await self._refresh_session_token(container_name, container_url) + + if session is None or session.is_fallback or not session.session_token or not session.session_key: + return None + + _apply_session_auth(request, session.session_token, session.session_key, self._account_name) + return container_name + + async def on_response( + self, + request: "PipelineRequest", + response: "PipelineResponse", + container_name: Optional[str], + ) -> "PipelineResponse": + """React to session-related failures: cooldown sentinel or one-shot re-acquire. + + :param ~azure.core.pipeline.PipelineRequest request: The original request. + :param ~azure.core.pipeline.PipelineResponse response: The response to inspect. + :param container_name: Container that was session-signed, or None if bearer was used. + :type container_name: str or None + :return: The final response (possibly from a one-shot retry). + :rtype: ~azure.core.pipeline.PipelineResponse + """ + if container_name is None: + return response # bearer was used; nothing session-related to react to + + status = response.http_response.status_code + error_code = response.http_response.headers.get("x-ms-error-code", "") + + if error_code == StorageErrorCode.FEATURE_NOT_ENABLED: + _LOGGER.info("Session feature not enabled on this account; disabling session auth.") + self._enabled = False + return response + + # Unavailable / 5xx → negative-cache cooldown. + if error_code == StorageErrorCode.SESSIONS_UNAVAILABLE or status >= 500: + _LOGGER.warning( + "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", + error_code or "5xx", + status, + container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + ) + async with self._cache.lock_container_async(container_name): + self._cache.put_fallback(container_name) + return response + + # 401 → invalidate + re-acquire ONCE, then resend. + if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): + _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) + used_token = _used_session_token(request) + async with self._cache.lock_container_async(container_name): + self._cache.invalidate(container_name, used_token) + request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True + retried_container = await self.on_request(request) + retried_response = await self.next.send(request) + return await self.on_response(request, retried_response, retried_container) + + return response diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index 44c63abee2e0..75eee2562419 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -157,6 +157,11 @@ class BlobClient( # type: ignore [misc] # pylint: disable=too-many-public-metho :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided AsyncTokenCredential. Only supported with an AsyncTokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi index 71bdd6c7410c..c4a9179a4726 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi @@ -81,6 +81,7 @@ class BlobClient( # type: ignore[misc] max_single_get_size: int = 32 * 1024 * 1024, min_large_block_upload_threshold: int = 4 * 1024 * 1024 + 1, use_byte_buffer: Optional[bool] = None, + use_session: bool = False, **kwargs: Any ) -> None: ... async def __aenter__(self) -> Self: ... diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 97927c5367cf..78322ed6f791 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -112,6 +112,11 @@ class ContainerClient( # type: ignore [misc] # pylint: disable=too-many-public :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided AsyncTokenCredential. Only supported with an AsyncTokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi index 07cca4ad5f39..507d885b1721 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi @@ -79,6 +79,7 @@ class ContainerClient( # type: ignore[misc] max_single_get_size: int = 32 * 1024 * 1024, min_large_block_upload_threshold: int = 4 * 1024 * 1024 + 1, use_byte_buffer: Optional[bool] = None, + use_session: bool = False, **kwargs: Any ) -> None: ... async def __aenter__(self) -> Self: ... diff --git a/sdk/storage/azure-storage-blob/tests/conftest.py b/sdk/storage/azure-storage-blob/tests/conftest.py index 01daca17a5c9..be6bb2be485a 100644 --- a/sdk/storage/azure-storage-blob/tests/conftest.py +++ b/sdk/storage/azure-storage-blob/tests/conftest.py @@ -29,6 +29,11 @@ def add_sanitizers(test_proxy): add_header_regex_sanitizer(key="x-ms-copy-source-authorization", value="Sanitized") add_header_regex_sanitizer(key="x-ms-encryption-key", value="Sanitized") + add_header_regex_sanitizer(key="x-ms-session-token", value="Sanitized") + add_general_regex_sanitizer( + regex=r"[^<]*", value="Sanitized" + ) + add_general_regex_sanitizer(regex=r"[^<]*", value="U2FuaXRpemVk") add_general_regex_sanitizer(regex=r'"EncryptionLibrary": "Python .*?"', value='"EncryptionLibrary": "Python x.x.x"') add_uri_regex_sanitizer(regex=r"\.preprod\.", value=".") diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 519bd4de862b..a547f93f9d24 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -14,6 +14,7 @@ from devtools_testutils import recorded_by_proxy, set_custom_default_matcher from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import BlobPreparer +from test_helpers import CaptureAuthHeader, _find_session_policy, _parse_session_token from azure.core import MatchConditions from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceModifiedError, ResourceNotFoundError @@ -2744,3 +2745,115 @@ def recursive_walk(prefix): # Assert assert blobs is not None assert blobs == ["a/b/blob2", "a/b/blob3", "a/b/blob4", "a/blob1"] + + @BlobPreparer() + @recorded_by_proxy + def test_create_session(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient) + capture_auth_header = CaptureAuthHeader() + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=True, + ) + container1_name = self.get_resource_name("utcontainer1") + container1 = service.get_container_client(container1_name) + try: + container1.create_container() + except ResourceExistsError: + pass + + blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" + container1.upload_blob( + blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c1_upload") + ) + assert capture_auth_header["c1_upload"].startswith("Bearer ") + + blob1_actual = container1.download_blob( + blob1_name, raw_response_hook=capture_auth_header.hook("c1_download") + ).readall() + assert blob1_data == blob1_actual + assert capture_auth_header["c1_download"].startswith("Session ") + session1 = _parse_session_token(capture_auth_header["c1_download"]) + + container2_name = self.get_resource_name("utcontainer2") + container2 = service.get_container_client(container2_name) + try: + container2.create_container() + except ResourceExistsError: + pass + + blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" + container2.upload_blob( + blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c2_upload") + ) + assert capture_auth_header["c2_upload"].startswith("Bearer ") + + blob2_actual = container2.download_blob( + blob2_name, raw_response_hook=capture_auth_header.hook("c2_download") + ).readall() + assert blob2_data == blob2_actual + assert capture_auth_header["c2_download"].startswith("Session ") + session2 = _parse_session_token(capture_auth_header["c2_download"]) + + assert session1 != session2 + + blob1_actual = container1.download_blob( + blob1_name, raw_response_hook=capture_auth_header.hook("c1_download2") + ).readall() + assert blob1_data == blob1_actual + assert capture_auth_header["c1_download2"].startswith("Session ") + assert session1 == _parse_session_token(capture_auth_header["c1_download2"]) + + blob2_actual = container2.download_blob( + blob2_name, raw_response_hook=capture_auth_header.hook("c2_download2") + ).readall() + assert blob2_data == blob2_actual + assert capture_auth_header["c2_download2"].startswith("Session ") + assert session2 == _parse_session_token(capture_auth_header["c2_download2"]) + + policy = _find_session_policy(service._pipeline) + cached = policy._cache._entry[container1_name] + cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) + + blob1_actual = container1.download_blob( + blob1_name, raw_response_hook=capture_auth_header.hook("c1_download3") + ).readall() + assert blob1_data == blob1_actual + assert capture_auth_header["c1_download3"].startswith("Session ") + assert session1 != _parse_session_token(capture_auth_header["c1_download3"]) + assert session2 != _parse_session_token(capture_auth_header["c1_download3"]) + + @BlobPreparer() + @recorded_by_proxy + def test_sessions_disabled(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient) + capture_auth_header = CaptureAuthHeader() + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=False, + ) + container = service.get_container_client(self.get_resource_name("utcontainer")) + try: + container.create_container() + except ResourceExistsError: + pass + + blob_name, blob_data = self.get_resource_name("blob"), b"abc123" + container.upload_blob( + blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header.hook("upload") + ) + assert capture_auth_header["upload"].startswith("Bearer ") + + blob_actual = container.download_blob( + blob_name, raw_response_hook=capture_auth_header.hook("download") + ).readall() + assert blob_data == blob_actual + assert capture_auth_header["download"].startswith("Bearer ") diff --git a/sdk/storage/azure-storage-blob/tests/test_container_async.py b/sdk/storage/azure-storage-blob/tests/test_container_async.py index 0a83f825b08e..aecc3876298f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_container_async.py @@ -16,6 +16,7 @@ from devtools_testutils.storage import LogCaptured from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import BlobPreparer +from test_helpers import CaptureAuthHeader, _find_session_policy, _parse_session_token from azure.core import MatchConditions from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceModifiedError, ResourceNotFoundError @@ -2653,3 +2654,113 @@ async def recursive_walk(prefix): # Assert assert blobs is not None assert blobs == ["a/b/blob2", "a/b/blob3", "a/b/blob4", "a/blob1"] + + @BlobPreparer() + @recorded_by_proxy_async + async def test_create_session(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient, is_async=True) + capture_auth_header = CaptureAuthHeader() + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=True, + ) + container1_name = self.get_resource_name("utcontainer1") + container1 = service.get_container_client(container1_name) + try: + await container1.create_container() + except ResourceExistsError: + pass + + blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" + await container1.upload_blob( + blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c1_upload") + ) + assert capture_auth_header["c1_upload"].startswith("Bearer ") + + blob1_actual = await ( + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header.hook("c1_download")) + ).readall() + assert blob1_data == blob1_actual + assert capture_auth_header["c1_download"].startswith("Session ") + session1 = _parse_session_token(capture_auth_header["c1_download"]) + + container2_name = self.get_resource_name("utcontainer2") + container2 = service.get_container_client(container2_name) + try: + await container2.create_container() + except ResourceExistsError: + pass + + blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" + await container2.upload_blob( + blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c2_upload") + ) + assert capture_auth_header["c2_upload"].startswith("Bearer ") + + blob2_actual = await ( + await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header.hook("c2_download")) + ).readall() + assert blob2_data == blob2_actual + assert capture_auth_header["c2_download"].startswith("Session ") + session2 = _parse_session_token(capture_auth_header["c2_download"]) + + assert session1 != session2 + + blob1_actual = await ( + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header.hook("c1_download2")) + ).readall() + assert blob1_data == blob1_actual + assert capture_auth_header["c1_download2"].startswith("Session ") + assert session1 == _parse_session_token(capture_auth_header["c1_download2"]) + + blob2_actual = await ( + await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header.hook("c2_download2")) + ).readall() + assert blob2_data == blob2_actual + assert capture_auth_header["c2_download2"].startswith("Session ") + assert session2 == _parse_session_token(capture_auth_header["c2_download2"]) + + policy = _find_session_policy(service._pipeline, "AsyncStorageSessionPolicy") + cached = policy._cache._entry[container1_name] + cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) + + blob1_actual = await ( + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header.hook("c1_download3")) + ).readall() + assert blob1_data == blob1_actual + assert capture_auth_header["c1_download3"].startswith("Session ") + assert session1 != _parse_session_token(capture_auth_header["c1_download3"]) + assert session2 != _parse_session_token(capture_auth_header["c1_download3"]) + + @BlobPreparer() + @recorded_by_proxy_async + async def test_sessions_disabled(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient, is_async=True) + capture = CaptureAuthHeader() + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=False, + ) + container = service.get_container_client(self.get_resource_name("utcontainer")) + try: + await container.create_container() + except ResourceExistsError: + pass + + blob_name, blob_data = self.get_resource_name("blob"), b"abc123" + await container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=capture.hook("upload")) + assert capture["upload"].startswith("Bearer ") + + blob_actual = await ( + await container.download_blob(blob_name, raw_response_hook=capture.hook("download")) + ).readall() + assert blob_data == blob_actual + assert capture["download"].startswith("Bearer ") diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers.py b/sdk/storage/azure-storage-blob/tests/test_helpers.py index c1ec6f3cb4af..d97a7602ac7c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers.py @@ -58,6 +58,63 @@ def _create_file_share_oauth( return file_name, base_url +def _parse_session_token(auth: str) -> str: + """Extract the token from a "Session {token}:{signature}" Authorization header. + + :param str auth: The raw Authorization header value. + :return: The session token portion (before the ':'). + :rtype: str + """ + assert auth.startswith("Session ") + return auth[len("Session ") :].split(":", 1)[0] + + +def _find_session_policy(pipeline: Any, policy_name: str = "StorageSessionPolicy") -> Any: + """Return the session policy instance on a client pipeline, matched by class name. + + Matching by name avoids importing SDK internals into the test modules. + + :param pipeline: The client pipeline to search (e.g. ``client._pipeline``). + :type pipeline: Any + :param str policy_name: The policy class name to find. Use "StorageSessionPolicy" + for the sync stack and "AsyncStorageSessionPolicy" for the async stack. + :return: The matching policy instance. + :rtype: Any + """ + for policy in getattr(pipeline, "_impl_policies", []): + if type(policy).__name__ == policy_name: + return policy + raise AssertionError(f"{policy_name} not found on the pipeline") + + +class CaptureAuthHeader: + """Captures per-label Authorization headers via ``raw_response_hook`` callbacks. + + Encapsulates the captured-headers dict so the hook factory doesn't need a + closure over a test-local variable. Works for both sync and async clients, + since the response hook is invoked as a plain callable in both stacks. + """ + + def __init__(self) -> None: + self.captured: Dict[str, str] = {} + + def hook(self, label: str): + """Return a ``raw_response_hook`` that records the request's Authorization header. + + :param str label: The key under which to store the captured header. + :return: A callable suitable for ``raw_response_hook``. + :rtype: callable + """ + + def _hook(response): + self.captured[label] = response.http_request.headers.get("Authorization", "") + + return _hook + + def __getitem__(self, label: str) -> str: + return self.captured[label] + + class ProgressTracker: def __init__(self, total: int, step: int): self.total = total