From 2058c9f4deaea8a27b6ddc4f9296808b82afe511 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Wed, 27 May 2026 07:41:21 -0400 Subject: [PATCH 01/29] Generated code --- .../azure/apiview-properties.json | 6 + .../aio/operations/_container_operations.py | 80 ++++++++++++ .../blob/_generated/models/__init__.py | 8 ++ .../models/_azure_blob_storage_enums.py | 8 ++ .../blob/_generated/models/_models_py3.py | 111 ++++++++++++++++ .../operations/_container_operations.py | 121 ++++++++++++++++++ 6 files changed, 334 insertions(+) diff --git a/sdk/storage/azure-storage-blob/azure/apiview-properties.json b/sdk/storage/azure-storage-blob/azure/apiview-properties.json index 64ad80365c09..217e34631107 100644 --- a/sdk/storage/azure-storage-blob/azure/apiview-properties.json +++ b/sdk/storage/azure-storage-blob/azure/apiview-properties.json @@ -26,6 +26,8 @@ "azure.storage.blob.models.CorsRule": null, "azure.storage.blob.models.CpkInfo": null, "azure.storage.blob.models.CpkScopeInfo": null, + "azure.storage.blob.models.CreateSessionConfiguration": null, + "azure.storage.blob.models.CreateSessionResponse": null, "azure.storage.blob.models.DelimitedTextConfiguration": null, "azure.storage.blob.models.FilterBlobItem": null, "azure.storage.blob.models.FilterBlobSegment": null, @@ -46,6 +48,7 @@ "azure.storage.blob.models.QuerySerialization": null, "azure.storage.blob.models.RetentionPolicy": null, "azure.storage.blob.models.SequenceNumberAccessConditions": null, + "azure.storage.blob.models.SessionCredentials": null, "azure.storage.blob.models.SignedIdentifier": null, "azure.storage.blob.models.SourceCpkInfo": null, "azure.storage.blob.models.SourceModifiedAccessConditions": null, @@ -68,6 +71,7 @@ "azure.storage.blob.models.RehydratePriority": null, "azure.storage.blob.models.BlobImmutabilityPolicyMode": null, "azure.storage.blob.models.GeoReplicationStatusType": null, + "azure.storage.blob.models.AuthenticationType": null, "azure.storage.blob.models.PremiumPageBlobAccessTier": null, "azure.storage.blob.models.AccessTierOptional": null, "azure.storage.blob.models.FileShareTokenIntent": null, @@ -134,6 +138,8 @@ "azure.storage.blob.aio.operations.ContainerOperations.list_blob_hierarchy_segment": null, "azure.storage.blob.operations.ContainerOperations.get_account_info": null, "azure.storage.blob.aio.operations.ContainerOperations.get_account_info": null, + "azure.storage.blob.operations.ContainerOperations.create_session": null, + "azure.storage.blob.aio.operations.ContainerOperations.create_session": null, "azure.storage.blob.operations.BlobOperations.download": null, "azure.storage.blob.aio.operations.BlobOperations.download": null, "azure.storage.blob.operations.BlobOperations.get_properties": null, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py index 09bb123a20af..2528775823d8 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/aio/operations/_container_operations.py @@ -32,6 +32,7 @@ build_break_lease_request, build_change_lease_request, build_create_request, + build_create_session_request, build_delete_request, build_filter_blobs_request, build_get_access_policy_request, @@ -1861,3 +1862,82 @@ async def get_account_info( if cls: return cls(pipeline_response, None, response_headers) # type: ignore + + @distributed_trace_async + async def create_session( + self, + create_session_configuration: _models.CreateSessionConfiguration, + timeout: Optional[int] = None, + request_id_parameter: Optional[str] = None, + **kwargs: Any + ) -> _models.CreateSessionResponse: + """The Create Session operation enables users to create a session scoped to a container. + + :param create_session_configuration: Required. + :type create_session_configuration: ~azure.storage.blob.models.CreateSessionConfiguration + :param timeout: The timeout parameter is expressed in seconds. For more information, see + :code:`Setting + Timeouts for Blob Service Operations.`. Default value is None. + :type timeout: int + :param request_id_parameter: Provides a client-generated, opaque value with a 1 KB character + limit that is recorded in the analytics logs when storage analytics logging is enabled. Default + value is None. + :type request_id_parameter: str + :return: CreateSessionResponse or the result of cls(response) + :rtype: ~azure.storage.blob.models.CreateSessionResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + restype: Literal["container"] = kwargs.pop("restype", _params.pop("restype", "container")) + comp: Literal["session"] = kwargs.pop("comp", _params.pop("comp", "session")) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/xml")) + cls: ClsType[_models.CreateSessionResponse] = kwargs.pop("cls", None) + + _content = self._serialize.body(create_session_configuration, "CreateSessionConfiguration", is_xml=True) + + _request = build_create_session_request( + url=self._config.url, + version=self._config.version, + timeout=timeout, + request_id_parameter=request_id_parameter, + restype=restype, + comp=comp, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [201]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = self._deserialize.failsafe_deserialize( + _models.StorageError, + pipeline_response, + ) + raise HttpResponseError(response=response, model=error) + + deserialized = self._deserialize("CreateSessionResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py index 95e38c268f1b..c3a60746ee63 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/__init__.py @@ -39,6 +39,8 @@ CorsRule, CpkInfo, CpkScopeInfo, + CreateSessionConfiguration, + CreateSessionResponse, DelimitedTextConfiguration, FilterBlobItem, FilterBlobSegment, @@ -59,6 +61,7 @@ QuerySerialization, RetentionPolicy, SequenceNumberAccessConditions, + SessionCredentials, SignedIdentifier, SourceCpkInfo, SourceModifiedAccessConditions, @@ -75,6 +78,7 @@ AccessTierRequired, AccountKind, ArchiveStatus, + AuthenticationType, BlobCopySourceTags, BlobExpiryOptions, BlobImmutabilityPolicyMode, @@ -129,6 +133,8 @@ "CorsRule", "CpkInfo", "CpkScopeInfo", + "CreateSessionConfiguration", + "CreateSessionResponse", "DelimitedTextConfiguration", "FilterBlobItem", "FilterBlobSegment", @@ -149,6 +155,7 @@ "QuerySerialization", "RetentionPolicy", "SequenceNumberAccessConditions", + "SessionCredentials", "SignedIdentifier", "SourceCpkInfo", "SourceModifiedAccessConditions", @@ -162,6 +169,7 @@ "AccessTierRequired", "AccountKind", "ArchiveStatus", + "AuthenticationType", "BlobCopySourceTags", "BlobExpiryOptions", "BlobImmutabilityPolicyMode", diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py index 8938f0fd9e97..c9f286990572 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_azure_blob_storage_enums.py @@ -93,6 +93,14 @@ class ArchiveStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): REHYDRATE_PENDING_TO_SMART = "rehydrate-pending-to-smart" +class AuthenticationType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """The type of authentication required to create the session. The only type currently supported is + HMAC. + """ + + HMAC = "HMAC" + + class BlobCopySourceTags(str, Enum, metaclass=CaseInsensitiveEnumMeta): """BlobCopySourceTags.""" diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py index 3534891fba4e..1f4793b4829c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/models/_models_py3.py @@ -1456,6 +1456,84 @@ def __init__(self, *, encryption_scope: Optional[str] = None, **kwargs: Any) -> self.encryption_scope = encryption_scope +class CreateSessionConfiguration(_serialization.Model): + """CreateSessionConfiguration. + + All required parameters must be populated in order to send to server. + + :ivar authentication_type: The type of authentication required to create the session. The only + type currently supported is HMAC. Required. "HMAC" + :vartype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + """ + + _validation = { + "authentication_type": {"required": True}, + } + + _attribute_map = { + "authentication_type": {"key": "AuthenticationType", "type": "str"}, + } + _xml_map = {"name": "CreateSessionRequest"} + + def __init__(self, *, authentication_type: Union[str, "_models.AuthenticationType"], **kwargs: Any) -> None: + """ + :keyword authentication_type: The type of authentication required to create the session. The + only type currently supported is HMAC. Required. "HMAC" + :paramtype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + """ + super().__init__(**kwargs) + self.authentication_type = authentication_type + + +class CreateSessionResponse(_serialization.Model): + """CreateSessionResponse. + + :ivar id: A unique identifier for the created session. + :vartype id: str + :ivar expiration: The time when the session will expire. The format follows RFC 1123. + :vartype expiration: ~datetime.datetime + :ivar authentication_type: The type of authentication required to create the session. The only + type currently supported is HMAC. "HMAC" + :vartype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + :ivar credentials: + :vartype credentials: ~azure.storage.blob.models.SessionCredentials + """ + + _attribute_map = { + "id": {"key": "Id", "type": "str"}, + "expiration": {"key": "Expiration", "type": "rfc-1123"}, + "authentication_type": {"key": "AuthenticationType", "type": "str"}, + "credentials": {"key": "Credentials", "type": "SessionCredentials"}, + } + _xml_map = {"name": "CreateSessionResult"} + + def __init__( + self, + *, + id: Optional[str] = None, # pylint: disable=redefined-builtin + expiration: Optional[datetime.datetime] = None, + authentication_type: Optional[Union[str, "_models.AuthenticationType"]] = None, + credentials: Optional["_models.SessionCredentials"] = None, + **kwargs: Any + ) -> None: + """ + :keyword id: A unique identifier for the created session. + :paramtype id: str + :keyword expiration: The time when the session will expire. The format follows RFC 1123. + :paramtype expiration: ~datetime.datetime + :keyword authentication_type: The type of authentication required to create the session. The + only type currently supported is HMAC. "HMAC" + :paramtype authentication_type: str or ~azure.storage.blob.models.AuthenticationType + :keyword credentials: + :paramtype credentials: ~azure.storage.blob.models.SessionCredentials + """ + super().__init__(**kwargs) + self.id = id + self.expiration = expiration + self.authentication_type = authentication_type + self.credentials = credentials + + class DelimitedTextConfiguration(_serialization.Model): """Groups the settings used for interpreting the blob data if the blob is delimited text formatted. @@ -2483,6 +2561,39 @@ def __init__( self.if_sequence_number_equal_to = if_sequence_number_equal_to +class SessionCredentials(_serialization.Model): + """SessionCredentials. + + :ivar session_token: An opaque token used to authorize subsequent requests in the session. Must + be treated as a security credential. + :vartype session_token: str + :ivar session_key: Only returned when AuthenticationType is HMAC. A symmetric encryption key + used to sign requests in the session using the Shared Key protocol. + :vartype session_key: str + """ + + _attribute_map = { + "session_token": {"key": "SessionToken", "type": "str"}, + "session_key": {"key": "SessionKey", "type": "str"}, + } + _xml_map = {"name": "Credentials"} + + def __init__( + self, *, session_token: Optional[str] = None, session_key: Optional[str] = None, **kwargs: Any + ) -> None: + """ + :keyword session_token: An opaque token used to authorize subsequent requests in the session. + Must be treated as a security credential. + :paramtype session_token: str + :keyword session_key: Only returned when AuthenticationType is HMAC. A symmetric encryption key + used to sign requests in the session using the Shared Key protocol. + :paramtype session_key: str + """ + super().__init__(**kwargs) + self.session_token = session_token + self.session_key = session_key + + class SignedIdentifier(_serialization.Model): """signed identifier. diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py index ec2deb0de1c0..b4212745c052 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_generated/operations/_container_operations.py @@ -888,6 +888,48 @@ def build_get_account_info_request( return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) +def build_create_session_request( + url: str, + *, + content: Any, + version: str, + timeout: Optional[int] = None, + request_id_parameter: Optional[str] = None, + **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + restype: Literal["container"] = kwargs.pop("restype", _params.pop("restype", "container")) + comp: Literal["session"] = kwargs.pop("comp", _params.pop("comp", "session")) + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/xml") + + # Construct URL + _url = kwargs.pop("template_url", "{url}") + path_format_arguments = { + "url": _SERIALIZER.url("url", url, "str", skip_quote=True), + } + + _url: str = _url.format(**path_format_arguments) # type: ignore + + # Construct parameters + _params["restype"] = _SERIALIZER.query("restype", restype, "str") + _params["comp"] = _SERIALIZER.query("comp", comp, "str") + if timeout is not None: + _params["timeout"] = _SERIALIZER.query("timeout", timeout, "int", minimum=0) + + # Construct headers + _headers["x-ms-version"] = _SERIALIZER.header("version", version, "str") + if request_id_parameter is not None: + _headers["x-ms-client-request-id"] = _SERIALIZER.header("request_id_parameter", request_id_parameter, "str") + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, content=content, **kwargs) + + class ContainerOperations: """ .. warning:: @@ -2696,3 +2738,82 @@ def get_account_info( # pylint: disable=inconsistent-return-statements if cls: return cls(pipeline_response, None, response_headers) # type: ignore + + @distributed_trace + def create_session( + self, + create_session_configuration: _models.CreateSessionConfiguration, + timeout: Optional[int] = None, + request_id_parameter: Optional[str] = None, + **kwargs: Any + ) -> _models.CreateSessionResponse: + """The Create Session operation enables users to create a session scoped to a container. + + :param create_session_configuration: Required. + :type create_session_configuration: ~azure.storage.blob.models.CreateSessionConfiguration + :param timeout: The timeout parameter is expressed in seconds. For more information, see + :code:`Setting + Timeouts for Blob Service Operations.`. Default value is None. + :type timeout: int + :param request_id_parameter: Provides a client-generated, opaque value with a 1 KB character + limit that is recorded in the analytics logs when storage analytics logging is enabled. Default + value is None. + :type request_id_parameter: str + :return: CreateSessionResponse or the result of cls(response) + :rtype: ~azure.storage.blob.models.CreateSessionResponse + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + restype: Literal["container"] = kwargs.pop("restype", _params.pop("restype", "container")) + comp: Literal["session"] = kwargs.pop("comp", _params.pop("comp", "session")) + content_type: str = kwargs.pop("content_type", _headers.pop("Content-Type", "application/xml")) + cls: ClsType[_models.CreateSessionResponse] = kwargs.pop("cls", None) + + _content = self._serialize.body(create_session_configuration, "CreateSessionConfiguration", is_xml=True) + + _request = build_create_session_request( + url=self._config.url, + version=self._config.version, + timeout=timeout, + request_id_parameter=request_id_parameter, + restype=restype, + comp=comp, + content_type=content_type, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [201]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + error = self._deserialize.failsafe_deserialize( + _models.StorageError, + pipeline_response, + ) + raise HttpResponseError(response=response, model=error) + + deserialized = self._deserialize("CreateSessionResponse", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore From be05b4d698c06fda38ddc6b61285c18db5efba1c Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 28 May 2026 10:09:51 -0400 Subject: [PATCH 02/29] Redact sensitive session info for recorded tests / changelogs --- sdk/storage/azure-storage-blob/CHANGELOG.md | 5 +++++ sdk/storage/azure-storage-blob/tests/conftest.py | 7 +++++++ 2 files changed, 12 insertions(+) diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index 52e79b9fe146..28e28f659a2d 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -3,6 +3,11 @@ ## 12.31.0b1 (Unreleased) ### Features Added +- Added opt-in session-based authentication for `ContainerClient` via the new +`use_session` keyword argument. When enabled, it must be used with a +`TokenCredential`. GET blob download operations issued through the client +are authenticated using a short-lived session credential obtained from the +service instead of the bearer token. ## 12.29.0 (2026-05-14) diff --git a/sdk/storage/azure-storage-blob/tests/conftest.py b/sdk/storage/azure-storage-blob/tests/conftest.py index 3142da18cd5d..a023f261285b 100644 --- a/sdk/storage/azure-storage-blob/tests/conftest.py +++ b/sdk/storage/azure-storage-blob/tests/conftest.py @@ -29,6 +29,13 @@ def add_sanitizers(test_proxy): add_header_regex_sanitizer(key="x-ms-copy-source-authorization", value="Sanitized") add_header_regex_sanitizer(key="x-ms-encryption-key", value="Sanitized") + add_header_regex_sanitizer(key="x-ms-session-token", value="Sanitized") + add_general_regex_sanitizer( + regex=r"[^<]*", value="Sanitized" + ) + add_general_regex_sanitizer( + regex=r"[^<]*", value="Sanitized" + ) add_general_regex_sanitizer(regex=r'"EncryptionLibrary": "Python .*?"', value='"EncryptionLibrary": "Python x.x.x"') add_uri_regex_sanitizer(regex=r"\.preprod\.", value=".") From 5eeb3928119304d0fc2b249b1c032271129bf9cb Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sun, 31 May 2026 18:02:02 -0400 Subject: [PATCH 03/29] Rewrite --- .../azure/storage/blob/_shared/base_client.py | 34 +- .../azure/storage/blob/_shared/policies.py | 318 +++++++++++++++++- .../storage/blob/_shared/policies_async.py | 14 + 3 files changed, 363 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 8fd641acd2c2..d081f037b541 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -55,10 +55,12 @@ StorageLoggingPolicy, StorageRequestHook, StorageResponseHook, + StorageSessionPolicy, ) from .request_handlers import serialize_batch_body, _get_batch_request_delimiter from .response_handlers import PartialBatchErrorException, process_storage_error from .shared_access_signature import QueryStringConstants +from .._generated import AzureBlobStorage from .._version import VERSION from .._shared_access_signature import _is_credential_sastoken @@ -327,11 +329,41 @@ def _create_pipeline( config.headers_policy, StorageRequestHook(**kwargs), self._credential_policy, + ] + use_session = bool(kwargs.pop("use_session", False)) + if use_session: + if not hasattr(credential, "get_token"): + raise ValueError( + "use_session=True requires a TokenCredential; received " + f"{type(credential).__name__ if credential is not None else 'None'}." + ) + + api_version = kwargs.get("version") or VERSION + + def _session_client_factory(container_url: str) -> AzureBlobStorage: + sub_kwargs = dict(kwargs) + sub_kwargs["use_session"] = False + sub_kwargs["transport"] = transport # reuse the same transport + _, session_pipeline = self._create_pipeline( + credential, sdk_moniker=self._sdk_moniker, **sub_kwargs + ) + return AzureBlobStorage( + container_url, api_version, base_url=container_url, pipeline=session_pipeline + ) + + policies.append( + StorageSessionPolicy( + account_name=self.account_name, + session_client_factory=_session_client_factory, + use_session=True, + ) + ) + policies.extend([ config.logging_policy, StorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), - ] + ]) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 3a5f0b9d662f..21016c887be0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -9,9 +9,11 @@ import random import re import uuid +from datetime import datetime, timedelta, UTC from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time -from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from threading import Lock +from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, urlencode, @@ -30,7 +32,7 @@ SansIOHTTPPolicy, ) -from .authentication import AzureSigningError, StorageHttpChallenge +from .authentication import AzureSigningError, SharedKeyCredentialPolicy, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode from .streams import ( @@ -45,6 +47,7 @@ is_crc64_validation, is_md5_validation, ) +from .._generated.models import CreateSessionConfiguration if TYPE_CHECKING: from azure.core.credentials import TokenCredential @@ -61,6 +64,9 @@ SM_HEADER = "x-ms-structured-body" SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" SM_LENGTH_HEADER = "x-ms-structured-content-length" +SESSION_ELIGIBLE_CONTEXT_KEY = "_session_eligible" +SESSION_RETRIED_CONTEXT_KEY = "_session_retried" +SESSION_TOKEN_HEADER = "x-ms-session-token" def encode_base64(data: Union[bytes, str]) -> str: @@ -843,3 +849,311 @@ def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") self.authorize_request(request, scope, tenant_id=challenge.tenant_id) return True + + +class Session: + """A session entry.""" + + __slots__ = ("session_token", "session_key", "expires_at", "is_fallback") + + REFRESH_BUFFER: timedelta = timedelta(seconds=30) + """Buffer before proactive refresh is initiated.""" + + def __init__( + self, + session_token: Optional[str], + session_key: Optional[str], + expires_at: datetime, + is_fallback: bool = False, + ) -> None: + self.session_token = session_token + self.session_key = session_key + self.expires_at = expires_at + self.is_fallback = is_fallback + + def expired(self, now: Optional[datetime] = None) -> bool: + now = now if now is not None else datetime.now(UTC) + diff = timedelta(seconds=0) if self.is_fallback else Session.REFRESH_BUFFER + return now >= self.expires_at - diff + + +class SessionCache: + """Thread-safe, container-level session cache for the sync stack. + + Concurrency model + ----------------- + * Reads (`get`) are lock-free. They perform a single ``dict.get`` and never + mutate the cache, so concurrent readers never need to coordinate. + * Writes (`put` / `put_fallback`) and the CreateSession single-flight are + serialized per-container via the lock returned by :meth:`lock_container`. + * A single ``_locks_guard`` serializes only the *creation* of per-container + locks, so two threads racing on a brand-new container can't build two + different lock objects. + """ + + FALLBACK_COOLDOWN: timedelta = timedelta(minutes=5) + """Cooldown applied to the fallback-to-bearer sentinel after an eligible create session failure.""" + + def __init__(self) -> None: + self._locks: Dict[str, Lock] = {} + self._locks_guard: Lock = Lock() + self._entry: Dict[str, Session] = {} + + def lock_container(self, container_name: str) -> Lock: + """Return the per-container lock, creating it exactly once. + + :param str container_name: The container to get the lock for. + :return: The single lock instance associated with the container. + :rtype: ~threading.Lock + """ + # Easy path: lock already exists, and on free threads it falls to slow path + existing_lock = self._locks.get(container_name) + if existing_lock is not None: + return existing_lock + # Slow path: create exactly one lock per container + with self._locks_guard: + return self._locks.setdefault(container_name, Lock()) + + def get(self, container_name: str) -> Optional[Session]: + """Return a live session for the container, or ``None``. + + Lock-free and non-mutating. Expired entries are NOT deleted. + Instead, they are simply treated as a cache miss and overwritten on the next refresh. + + :param str container_name: The container to look up. + :return: A live (non-expired) session, or None on miss/expiry. + :rtype: ~azure.storage.blob._shared.policies.Session or None + """ + cached = self._entry.get(container_name, None) + if cached is None or cached.expired(): + return None + return cached + + def put(self, container_name: str, session_token: str, session_key: str, expires_at: datetime) -> None: + """Install a real session entry. Caller must hold ``lock_container``. + + :param str container_name: The container the session belongs to. + :param str session_token: The session token to send as a header. + :param str session_key: The HMAC signing key for the session. + :param ~datetime.datetime expires_at: When the session expires. + """ + self._entry[container_name] = Session(session_token, session_key, expires_at, is_fallback=False) + + def put_fallback(self, container_name: str) -> None: + """Install a fallback-to-bearer sentinel for the cooldown window. + + Caller must hold SessionCache.lock_container(). + + :param str container_name: The container to mark for bearer fallback. + """ + self._entry[container_name] = Session( + None, None, datetime.now(UTC) + SessionCache.FALLBACK_COOLDOWN, is_fallback=True + ) + + +class StorageSessionPolicy(HTTPPolicy): + """ + A pipeline policy that selects between session token and bearer token authentication. + + When enabled, eligible requests are authenticated with a session token. + The session token is cached to the container. + + When disabled, all requests are delegated to the bearer token policy. + """ + + SESSIONS_UNAVAILABLE: str = "SessionOperationsTemporarilyUnavailable" + """Service-reported code: session operations are temporarily unavailable.""" + FEATURE_NOT_ENABLED: str = "FeatureNotEnabled" + """Service-reported code: the session feature is not enabled on the scale unit.""" + + def __init__( + self, + *, + account_name: str, + session_client_factory: Callable[[str], Any], + use_session: bool = False, + ) -> None: + """Constructs a StorageSessionPolicy. + + :keyword str account_name: Storage account name; used as the signer + identity when signing session-authenticated requests. + :keyword session_client_factory: A callable that, given a container URL, + returns a session-disabled generated client (AzureBlobStorage) + whose pipeline uses OAuth/bearer auth. Invoked to issue CreateSession. + :paramtype session_client_factory: Callable[[str], Any] + :keyword bool use_session: Whether session authentication is enabled. + When set to False, the policy is a pass-through no-op. + :raises ValueError: if `account_name` or `session_client_factory` is `None`. + """ + if account_name is None or session_client_factory is None: + raise ValueError("account_name and session_client_factory are required.") + super().__init__() + self._account_name = account_name + self._session_client_factory = session_client_factory + self._use_session = use_session + self._cache = SessionCache() + + @staticmethod + def _parse_container(url: str) -> Optional[str]: + """Extract the container name (first path segment) from a request URL. + + :param str url: The request URL. + :return: The container name, or `None` for service-level URLs. + :rtype: str or None + """ + path = urlparse(url).path + segments = [seg for seg in path.split("/") if seg] + return segments[0] if segments else None + + @staticmethod + def _container_url(request_url: str) -> str: + """Build the container-scoped URL (scheme://host/container) for CreateSession. + + :param str request_url: The originating request URL. + :return: A URL pointing at the container root. + :rtype: str + """ + parsed = urlparse(request_url) + segments = [seg for seg in parsed.path.split("/") if seg] + container = segments[0] if segments else "" + return f"{parsed.scheme}://{parsed.netloc}/{container}" + + @staticmethod + def _extract_session(response: Any) -> Tuple[str, str, datetime]: + creds = getattr(response, "credentials", None) + if not creds or not getattr(creds, "session_token", None) or not getattr(creds, "session_key", None): + raise ValueError("CreateSession response missing SessionToken/SessionKey") + session_token: str = creds.session_token + session_key: str = creds.session_key + expires_at = getattr(response, "expiration", None) + if expires_at is None: + expires_at = datetime.now(UTC) + timedelta(minutes=5) + elif expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + return session_token, session_key, expires_at + + def _apply_session_auth(self, request: "PipelineRequest", session_token: str, session_key: str) -> None: + # Stamp the session token BEFORE signing so it participates in the + # canonicalized headers, then sign with SharedKey using the session key. + request.http_request.headers[SESSION_TOKEN_HEADER] = session_token + SharedKeyCredentialPolicy(self._account_name, session_key).on_request(request) + + def _is_eligible(self, request: "PipelineRequest") -> bool: + if not self._use_session: + return False + if request.http_request.method != "GET": + return False + return bool(request.context.options.get(SESSION_ELIGIBLE_CONTEXT_KEY)) + + def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: + # The factory returns a session-DISABLED generated client bound to the + # container URL; its pipeline uses OAuth/bearer, so this call authenticates + # without re-entering this policy. + config = CreateSessionConfiguration(authentication_type="HMAC") + client = self._session_client_factory(container_url) + response = client.container.create_session(create_session_configuration=config) + return self._extract_session(response) + + def _refresh_session_token(self, container_name: str, container_url: str) -> Optional[Session]: + """Acquire (or re-use) a session for the container under per-container single-flight. + + :param str container_name: The container key for the cache and lock. + :param str container_url: The container-scoped URL for the CreateSession call. + :return: A live session, a fallback sentinel, or `None` if unusable. + :rtype: ~azure.storage.blob._shared.policies.Session or None + """ + with self._cache.lock_container(container_name): + existing = self._cache.get(container_name) + if existing is not None and not existing.expired(): + return existing + try: + token, key, expires_at = self._create_session(container_url) + self._cache.put(container_name, token, key, expires_at) + except Exception: # pylint: disable=broad-except + _LOGGER.warning( + "CreateSession failed for container '%s'; falling back to bearer for %d seconds.", + container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + exc_info=True, + ) + self._cache.put_fallback(container_name) + return self._cache.get(container_name) + + def send(self, request: "PipelineRequest") -> "PipelineResponse": + """Orchestrate session auth. + + :param ~azure.core.pipeline.PipelineRequest request: The outgoing request. + :return: The pipeline response. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + container_name = self.on_request(request) + response = self.next.send(request) + return self.on_response(request, response, container_name) + + def on_request(self, request: "PipelineRequest") -> Optional[str]: + """Stamp session auth if eligible, otherwise leave the bearer header intact. + + :param ~azure.core.pipeline.PipelineRequest request: The request to (maybe) sign. + :return: The container name if a session was applied, else ``None``. + :rtype: str or None + """ + if not self._is_eligible(request): + return None + container_name = self._parse_container(request.http_request.url) + if not container_name: + return None + + session = self._cache.get(container_name) + if session is None: + container_url = self._container_url(request.http_request.url) + session = self._refresh_session_token(container_name, container_url) + + if session is None or session.is_fallback or not session.session_token or not session.session_key: + return None + + self._apply_session_auth(request, session.session_token, session.session_key) + return container_name + + def on_response( + self, + request: "PipelineRequest", + response: "PipelineResponse", + container_name: Optional[str], + ) -> "PipelineResponse": + """React to session-related failures: cooldown sentinel or one-shot re-acquire. + + :param ~azure.core.pipeline.PipelineRequest request: The original request. + :param ~azure.core.pipeline.PipelineResponse response: The response to inspect. + :param container_name: Container that was session-signed, or `None` if bearer was used. + :type container_name: str or None + :return: The final response (possibly from a one-shot retry). + :rtype: ~azure.core.pipeline.PipelineResponse + """ + if container_name is None: + return response # bearer was used; nothing session-related to react to + + status = response.http_response.status_code + error_code = response.http_response.headers.get("x-ms-error-code", "") + + # Unavailable / feature-off / 5xx → negative-cache cooldown. + if error_code in (self.SESSIONS_UNAVAILABLE, self.FEATURE_NOT_ENABLED) or status >= 500: + _LOGGER.warning( + "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", + error_code or "5xx", status, container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + ) + with self._cache.lock_container(container_name): + self._cache.put_fallback(container_name) + return response + + # 401 → invalidate + re-acquire ONCE, then resend. + if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): + _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) + with self._cache.lock_container(container_name): + self._cache.put_fallback(container_name) # drop the stale entry + request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True + retried_container = self.on_request(request) # re-stamp (or fall to bearer) + retried_response = self.next.send(request) + return self.on_response(request, retried_response, retried_container) + + return response diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index e1d13b1a83fa..ee89c7c3dc9b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -338,3 +338,17 @@ async def on_challenge(self, request: "PipelineRequest", response: "PipelineResp await self.authorize_request(request, scope, tenant_id=challenge.tenant_id) return True + +class AsyncStorageSessionPolicy(AsyncHTTPPolicy): + + def __init__(self): + pass + + def on_request(self): + pass + + def send(self, request): + pass + + def on_response(self): + pass From a4b0379da99173a756ee729bbc04b956182505db Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Mon, 1 Jun 2026 12:23:35 -0400 Subject: [PATCH 04/29] WIP --- .../azure/storage/blob/_shared/base_client.py | 12 ++-- .../azure/storage/blob/_shared/policies.py | 51 +++++++++++++--- .../tests/test_container.py | 59 +++++++++++++++++++ 3 files changed, 111 insertions(+), 11 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index d081f037b541..8be4762c3841 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -342,14 +342,18 @@ def _create_pipeline( def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs = dict(kwargs) + sub_kwargs.pop("_pipeline", None) + sub_kwargs.pop("_configuration", None) + sub_kwargs.pop("pipeline", None) sub_kwargs["use_session"] = False - sub_kwargs["transport"] = transport # reuse the same transport + sub_kwargs["transport"] = transport + _, session_pipeline = self._create_pipeline( credential, sdk_moniker=self._sdk_moniker, **sub_kwargs ) - return AzureBlobStorage( - container_url, api_version, base_url=container_url, pipeline=session_pipeline - ) + generated = AzureBlobStorage(container_url, api_version, base_url=container_url) + generated._client._pipeline = session_pipeline # pylint: disable=protected-access + return generated policies.append( StorageSessionPolicy( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 21016c887be0..60c00c1d1cc7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -32,7 +32,8 @@ SansIOHTTPPolicy, ) -from .authentication import AzureSigningError, SharedKeyCredentialPolicy, StorageHttpChallenge +from . import sign_string +from .authentication import AzureSigningError, _storage_header_sort, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE, DATA_BLOCK_SIZE from .models import LocationMode, StorageErrorCode from .streams import ( @@ -965,6 +966,11 @@ class StorageSessionPolicy(HTTPPolicy): """Service-reported code: session operations are temporarily unavailable.""" FEATURE_NOT_ENABLED: str = "FeatureNotEnabled" """Service-reported code: the session feature is not enabled on the scale unit.""" + _SIGNED_HEADERS = ( + "content-encoding", "content-language", "content-length", "content-md5", + "content-type", "date", "if-modified-since", "if-match", "if-none-match", + "if-unmodified-since", "byte_range", + ) def __init__( self, @@ -1033,17 +1039,48 @@ def _extract_session(response: Any) -> Tuple[str, str, datetime]: return session_token, session_key, expires_at def _apply_session_auth(self, request: "PipelineRequest", session_token: str, session_key: str) -> None: - # Stamp the session token BEFORE signing so it participates in the - # canonicalized headers, then sign with SharedKey using the session key. - request.http_request.headers[SESSION_TOKEN_HEADER] = session_token - SharedKeyCredentialPolicy(self._account_name, session_key).on_request(request) + http_request = request.http_request + http_request.headers["x-ms-date"] = format_date_time(time()) + + # Lowercased non-empty headers; Storage omits content-length when it is "0". + headers = { + name.lower(): value + for name, value in http_request.headers.items() + if value and not (name.lower() == "content-length" and value == "0") + } + x_ms = _storage_header_sort( + [(n.lower(), v) for n, v in http_request.headers.items() if n.lower().startswith("x-ms-")] + ) + string_to_sign = "\n".join( + ( + http_request.method, + *(headers.get(h, "") for h in self._SIGNED_HEADERS), + *(f"{n}:{v}" for n, v in x_ms if v is not None), + "/" + self._account_name + urlparse(http_request.url).path + + "".join(f"\n{n.lower()}:{v}" for n, v in sorted(http_request.query.items()) if v is not None), + ) + ) + + try: + signature = sign_string(session_key, string_to_sign) + except Exception as ex: # pylint: disable=broad-except + raise AzureSigningError(str(ex)) from ex + http_request.headers["Authorization"] = f"Session {session_token}:{signature}" def _is_eligible(self, request: "PipelineRequest") -> bool: if not self._use_session: return False - if request.http_request.method != "GET": + http_request = request.http_request + if http_request.method != "GET": + return False + parsed = urlparse(http_request.url) + segments = [s for s in parsed.path.split("/") if s] + if len(segments) < 2: + return False + query = http_request.query + if "comp" in query or query.get("restype") == "container": return False - return bool(request.context.options.get(SESSION_ELIGIBLE_CONTEXT_KEY)) + return True def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: # The factory returns a session-DISABLED generated client bound to the diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 2d0d77e73ecb..caea9954f31a 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2804,3 +2804,62 @@ def recursive_walk(prefix): # Assert assert blobs is not None assert blobs == ["a/b/blob2", "a/b/blob3", "a/b/blob4", "a/blob1"] + + @BlobPreparer() + @recorded_by_proxy + def test_create_session_same_policy(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient) + captured = {} + def make_capture(label): + def _hook(response): + auth = response.http_request.headers.get("Authorization", "") + captured[label] = auth + return _hook + + def session_token_from(auth): + # "Session {token}:{signature}" -> token + assert auth.startswith("Session ") + return auth[len("Session "):].split(":", 1)[0] + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=True, + ) + container1 = service.get_container_client("container1") + try: + container1.create_container() + except ResourceExistsError: + pass + + c1b1_name, c1b1_data = "c1b1", b"abc123" + container1.upload_blob(c1b1_name, c1b1_data, overwrite=True, raw_response_hook=make_capture("c1_upload")) + assert captured["c1_upload"].startswith("Bearer ") + + # Download is an eligible GET → session scheme. + c1b1_actual = container1.download_blob(c1b1_name, raw_response_hook=make_capture("c1_download")).readall() + assert c1b1_data == c1b1_actual + assert captured["c1_download"].startswith("Session ") + c1_token = session_token_from(captured["c1_download"]) + + container2 = service.get_container_client("container2") + try: + container2.create_container() + except ResourceExistsError: + pass + + c2b2_name, c2b2_data = "c2b2", b"def456" + container2.upload_blob(c2b2_name, c2b2_data, overwrite=True, raw_response_hook=make_capture("c2_upload")) + assert captured["c2_upload"].startswith("Bearer ") + + c2b2_actual = container2.download_blob( + c2b2_name, raw_response_hook=make_capture("c2_download"), + ).readall() + assert c2b2_data == c2b2_actual + assert captured["c2_download"].startswith("Session ") + c2_token = session_token_from(captured["c2_download"]) + + # Different containers must not share the same session token. + assert c1_token != c2_token From 49ad7dc2c27dafbb7972620293ee84d72ef9e4fa Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Mon, 1 Jun 2026 12:36:51 -0400 Subject: [PATCH 05/29] GG --- .../azure/storage/blob/_shared/base_client.py | 5 +-- .../azure/storage/blob/_shared/policies.py | 37 ++++++++++--------- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 8be4762c3841..df5f9a44dd0b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -338,20 +338,19 @@ def _create_pipeline( f"{type(credential).__name__ if credential is not None else 'None'}." ) - api_version = kwargs.get("version") or VERSION - def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs = dict(kwargs) sub_kwargs.pop("_pipeline", None) sub_kwargs.pop("_configuration", None) sub_kwargs.pop("pipeline", None) + sub_kwargs.pop("sdk_moniker", None) sub_kwargs["use_session"] = False sub_kwargs["transport"] = transport _, session_pipeline = self._create_pipeline( credential, sdk_moniker=self._sdk_moniker, **sub_kwargs ) - generated = AzureBlobStorage(container_url, api_version, base_url=container_url) + generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) generated._client._pipeline = session_pipeline # pylint: disable=protected-access return generated diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 60c00c1d1cc7..51e3bd091deb 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING, Union from urllib.parse import ( parse_qsl, + unquote, urlencode, urlparse, urlunparse, @@ -1042,23 +1043,26 @@ def _apply_session_auth(self, request: "PipelineRequest", session_token: str, se http_request = request.http_request http_request.headers["x-ms-date"] = format_date_time(time()) - # Lowercased non-empty headers; Storage omits content-length when it is "0". - headers = { - name.lower(): value - for name, value in http_request.headers.items() - if value and not (name.lower() == "content-length" and value == "0") - } - x_ms = _storage_header_sort( + # 1) Standard headers. Storage omits content-length when it is "0". + headers = {name.lower(): value for name, value in http_request.headers.items() if value} + if headers.get("content-length") == "0": + del headers["content-length"] + signed_headers = "\n".join(headers.get(h, "") for h in self._SIGNED_HEADERS) + "\n" + + # 2) Canonicalized x-ms-* headers, sorted by the service-emulating comparator. + x_ms_headers = _storage_header_sort( [(n.lower(), v) for n, v in http_request.headers.items() if n.lower().startswith("x-ms-")] ) - string_to_sign = "\n".join( - ( - http_request.method, - *(headers.get(h, "") for h in self._SIGNED_HEADERS), - *(f"{n}:{v}" for n, v in x_ms if v is not None), - "/" + self._account_name + urlparse(http_request.url).path - + "".join(f"\n{n.lower()}:{v}" for n, v in sorted(http_request.query.items()) if v is not None), - ) + canonicalized_headers = "".join(f"{n}:{v}\n" for n, v in x_ms_headers if v is not None) + + # 3) Canonicalized resource + query (query values must be url-decoded). + canonicalized_resource = "/" + self._account_name + urlparse(http_request.url).path + canonicalized_resource += "".join( + f"\n{n.lower()}:{unquote(v)}" for n, v in sorted(http_request.query.items()) if v is not None + ) + + string_to_sign = ( + http_request.method + "\n" + signed_headers + canonicalized_headers + canonicalized_resource ) try: @@ -1083,9 +1087,6 @@ def _is_eligible(self, request: "PipelineRequest") -> bool: return True def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: - # The factory returns a session-DISABLED generated client bound to the - # container URL; its pipeline uses OAuth/bearer, so this call authenticates - # without re-entering this policy. config = CreateSessionConfiguration(authentication_type="HMAC") client = self._session_client_factory(container_url) response = client.container.create_session(create_session_configuration=config) From 19b3f7593d33bd0d5ee2f021ac84227e5f95b516 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Mon, 1 Jun 2026 12:37:20 -0400 Subject: [PATCH 06/29] Removed comments --- sdk/storage/azure-storage-blob/tests/test_container.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index caea9954f31a..673ceb9c1fbb 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2838,7 +2838,6 @@ def session_token_from(auth): container1.upload_blob(c1b1_name, c1b1_data, overwrite=True, raw_response_hook=make_capture("c1_upload")) assert captured["c1_upload"].startswith("Bearer ") - # Download is an eligible GET → session scheme. c1b1_actual = container1.download_blob(c1b1_name, raw_response_hook=make_capture("c1_download")).readall() assert c1b1_data == c1b1_actual assert captured["c1_download"].startswith("Session ") @@ -2861,5 +2860,4 @@ def session_token_from(auth): assert captured["c2_download"].startswith("Session ") c2_token = session_token_from(captured["c2_download"]) - # Different containers must not share the same session token. assert c1_token != c2_token From 4a8ef92bf25e5e2af4aa8b80d112c6587c95b396 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Mon, 1 Jun 2026 12:53:11 -0400 Subject: [PATCH 07/29] More stuff --- .../azure-storage-blob/tests/conftest.py | 2 +- .../tests/test_container.py | 20 +++++++++++++------ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/conftest.py b/sdk/storage/azure-storage-blob/tests/conftest.py index a023f261285b..125d748cc84c 100644 --- a/sdk/storage/azure-storage-blob/tests/conftest.py +++ b/sdk/storage/azure-storage-blob/tests/conftest.py @@ -34,7 +34,7 @@ def add_sanitizers(test_proxy): regex=r"[^<]*", value="Sanitized" ) add_general_regex_sanitizer( - regex=r"[^<]*", value="Sanitized" + regex=r"[^<]*", value="U2FuaXRpemVk" ) add_general_regex_sanitizer(regex=r'"EncryptionLibrary": "Python .*?"', value='"EncryptionLibrary": "Python x.x.x"') diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 673ceb9c1fbb..ca2e94e4aa34 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2841,7 +2841,7 @@ def session_token_from(auth): c1b1_actual = container1.download_blob(c1b1_name, raw_response_hook=make_capture("c1_download")).readall() assert c1b1_data == c1b1_actual assert captured["c1_download"].startswith("Session ") - c1_token = session_token_from(captured["c1_download"]) + session1 = session_token_from(captured["c1_download"]) container2 = service.get_container_client("container2") try: @@ -2853,11 +2853,19 @@ def session_token_from(auth): container2.upload_blob(c2b2_name, c2b2_data, overwrite=True, raw_response_hook=make_capture("c2_upload")) assert captured["c2_upload"].startswith("Bearer ") - c2b2_actual = container2.download_blob( - c2b2_name, raw_response_hook=make_capture("c2_download"), - ).readall() + c2b2_actual = container2.download_blob(c2b2_name, raw_response_hook=make_capture("c2_download")).readall() assert c2b2_data == c2b2_actual assert captured["c2_download"].startswith("Session ") - c2_token = session_token_from(captured["c2_download"]) + session2 = session_token_from(captured["c2_download"]) - assert c1_token != c2_token + assert session1 != session2 + + c1b1_actual = container1.download_blob(c1b1_name, raw_response_hook=make_capture("c1_download")).readall() + assert c1b1_data == c1b1_actual + assert captured["c1_download"].startswith("Session ") + assert session1 == session_token_from(captured["c1_download"]) + + c2b2_actual = container2.download_blob(c2b2_name, raw_response_hook=make_capture("c2_download")).readall() + assert c2b2_data == c2b2_actual + assert captured["c2_download"].startswith("Session ") + assert session2 == session_token_from(captured["c2_download"]) From 75744d1f39f383f0c29edf657e2b0f2da51b1a19 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 12:31:38 -0400 Subject: [PATCH 08/29] More tests --- .../azure/storage/blob/_shared/policies.py | 4 +- .../tests/test_container.py | 73 ++++++++++++++----- 2 files changed, 56 insertions(+), 21 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 51e3bd091deb..0501d88c3b23 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -1061,9 +1061,7 @@ def _apply_session_auth(self, request: "PipelineRequest", session_token: str, se f"\n{n.lower()}:{unquote(v)}" for n, v in sorted(http_request.query.items()) if v is not None ) - string_to_sign = ( - http_request.method + "\n" + signed_headers + canonicalized_headers + canonicalized_resource - ) + string_to_sign = http_request.method + "\n" + signed_headers + canonicalized_headers + canonicalized_resource try: signature = sign_string(session_key, string_to_sign) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index ca2e94e4aa34..1a1f8105842e 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2812,6 +2812,7 @@ def test_create_session_same_policy(self, **kwargs): credential = self.get_credential(BlobServiceClient) captured = {} + def make_capture(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") @@ -2828,44 +2829,80 @@ def session_token_from(auth): credential=credential, use_session=True, ) - container1 = service.get_container_client("container1") + container1 = service.get_container_client(self.get_resource_name("utcontainer1")) try: container1.create_container() except ResourceExistsError: pass - c1b1_name, c1b1_data = "c1b1", b"abc123" - container1.upload_blob(c1b1_name, c1b1_data, overwrite=True, raw_response_hook=make_capture("c1_upload")) + blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" + container1.upload_blob( + blob1_name, blob1_data, overwrite=True, raw_response_hook=make_capture("c1_upload") + ) assert captured["c1_upload"].startswith("Bearer ") - c1b1_actual = container1.download_blob(c1b1_name, raw_response_hook=make_capture("c1_download")).readall() - assert c1b1_data == c1b1_actual + blob1_actual = container1.download_blob( + blob1_name, raw_response_hook=make_capture("c1_download") + ).readall() + assert blob1_data == blob1_actual assert captured["c1_download"].startswith("Session ") session1 = session_token_from(captured["c1_download"]) - container2 = service.get_container_client("container2") + container2 = service.get_container_client(self.get_resource_name("utcontainer2")) try: container2.create_container() except ResourceExistsError: pass - c2b2_name, c2b2_data = "c2b2", b"def456" - container2.upload_blob(c2b2_name, c2b2_data, overwrite=True, raw_response_hook=make_capture("c2_upload")) + blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" + container2.upload_blob(blob2_name, blob2_data, overwrite=True, raw_response_hook=make_capture("c2_upload")) assert captured["c2_upload"].startswith("Bearer ") - c2b2_actual = container2.download_blob(c2b2_name, raw_response_hook=make_capture("c2_download")).readall() - assert c2b2_data == c2b2_actual + blob2_actual = container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download")).readall() + assert blob2_data == blob2_actual assert captured["c2_download"].startswith("Session ") session2 = session_token_from(captured["c2_download"]) assert session1 != session2 - c1b1_actual = container1.download_blob(c1b1_name, raw_response_hook=make_capture("c1_download")).readall() - assert c1b1_data == c1b1_actual - assert captured["c1_download"].startswith("Session ") - assert session1 == session_token_from(captured["c1_download"]) + blob1_actual = container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download2")).readall() + assert blob1_data == blob1_actual + assert captured["c1_download2"].startswith("Session ") + assert session1 == session_token_from(captured["c1_download2"]) - c2b2_actual = container2.download_blob(c2b2_name, raw_response_hook=make_capture("c2_download")).readall() - assert c2b2_data == c2b2_actual - assert captured["c2_download"].startswith("Session ") - assert session2 == session_token_from(captured["c2_download"]) + blob2_actual = container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download2")).readall() + assert blob2_data == blob2_actual + assert captured["c2_download2"].startswith("Session ") + assert session2 == session_token_from(captured["c2_download2"]) + + @BlobPreparer() + def test_sessions_disabled(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient) + captured = {} + + def make_capture(label): + def _hook(response): + auth = response.http_request.headers.get("Authorization", "") + captured[label] = auth + return _hook + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=False, + ) + container = service.get_container_client(self.get_resource_name("utcontainer")) + try: + container.create_container() + except ResourceExistsError: + pass + + blob_name, blob_data = self.get_resource_name("blob"), b"abc123" + container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=make_capture("upload")) + assert captured["upload"].startswith("Bearer ") + + blob_actual = container.download_blob(blob_name, raw_response_hook=make_capture("download")).readall() + assert blob_data == blob_actual + assert captured["download"].startswith("Bearer ") From 75a2bd7799f65c0d9fdb8e8991f445a129dd50b1 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 14:51:33 -0400 Subject: [PATCH 09/29] Format --- sdk/storage/azure-storage-blob/tests/test_container.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 1a1f8105842e..71a81202b016 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2841,9 +2841,7 @@ def session_token_from(auth): ) assert captured["c1_upload"].startswith("Bearer ") - blob1_actual = container1.download_blob( - blob1_name, raw_response_hook=make_capture("c1_download") - ).readall() + blob1_actual = container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download")).readall() assert blob1_data == blob1_actual assert captured["c1_download"].startswith("Session ") session1 = session_token_from(captured["c1_download"]) From 0eb242ae4a0b0c43f4a5cc90f9b28d39f278ea3d Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 17:46:34 -0400 Subject: [PATCH 10/29] Added one more test --- .../tests/test_container.py | 22 +++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 71a81202b016..d7e35d4e0b92 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2824,12 +2824,20 @@ def session_token_from(auth): assert auth.startswith("Session ") return auth[len("Session "):].split(":", 1)[0] + def find_session_policy(pipeline): + # Match by class name to avoid importing internals into the test module. + for p in getattr(pipeline, "_impl_policies", []): + if type(p).__name__ == "StorageSessionPolicy": + return p + raise AssertionError("StorageSessionPolicy not found on the pipeline") + service = BlobServiceClient( self.account_url(storage_account_name, "blob"), credential=credential, use_session=True, ) - container1 = service.get_container_client(self.get_resource_name("utcontainer1")) + container1_name = self.get_resource_name("utcontainer1") + container1 = service.get_container_client(container1_name) try: container1.create_container() except ResourceExistsError: @@ -2846,7 +2854,8 @@ def session_token_from(auth): assert captured["c1_download"].startswith("Session ") session1 = session_token_from(captured["c1_download"]) - container2 = service.get_container_client(self.get_resource_name("utcontainer2")) + container2_name = self.get_resource_name("utcontainer2") + container2 = service.get_container_client(container2_name) try: container2.create_container() except ResourceExistsError: @@ -2873,6 +2882,15 @@ def session_token_from(auth): assert captured["c2_download2"].startswith("Session ") assert session2 == session_token_from(captured["c2_download2"]) + policy = find_session_policy(service._pipeline) + cached = policy._cache._entry[container1_name] + cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) + + blob1_actual = container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download3")).readall() + assert blob1_data == blob1_actual + assert captured["c1_download3"].startswith("Session ") + assert session1 != session_token_from(captured["c1_download3"]) + @BlobPreparer() def test_sessions_disabled(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") From b7c4ef9ae6fd80f7b609f334be3c248c1115f05c Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 17:47:23 -0400 Subject: [PATCH 11/29] One more assert --- sdk/storage/azure-storage-blob/tests/test_container.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index d7e35d4e0b92..c861fa0390e3 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2890,6 +2890,7 @@ def find_session_policy(pipeline): assert blob1_data == blob1_actual assert captured["c1_download3"].startswith("Session ") assert session1 != session_token_from(captured["c1_download3"]) + assert session2 != session_token_from(captured["c1_download3"]) @BlobPreparer() def test_sessions_disabled(self, **kwargs): From 48faaec28a29f1e7bbe769d80f423db5573a3b7e Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 17:54:32 -0400 Subject: [PATCH 12/29] API changes --- .../azure/storage/blob/_blob_service_client.py | 5 +++++ .../azure/storage/blob/_blob_service_client.pyi | 1 + .../azure/storage/blob/_container_client.py | 5 +++++ .../azure/storage/blob/_container_client.pyi | 1 + .../azure/storage/blob/aio/_blob_client_async.py | 5 +++++ .../azure/storage/blob/aio/_blob_client_async.pyi | 1 + .../azure/storage/blob/aio/_container_client_async.py | 5 +++++ .../azure/storage/blob/aio/_container_client_async.pyi | 1 + 8 files changed, 24 insertions(+) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py index 2333d9558d11..460ded703b9d 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.py @@ -100,6 +100,11 @@ class BlobServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided TokenCredential. Only supported with a TokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi index 526c2bfae18a..715ce9ee0c82 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_blob_service_client.pyi @@ -56,6 +56,7 @@ class BlobServiceClient(StorageAccountHostsMixin, StorageEncryptionMixin): max_single_get_size: int = 32 * 1024 * 1024, max_chunk_get_size: int = 4 * 1024 * 1024, audience: Optional[str] = None, + use_session: bool = False, **kwargs: Any ) -> None: ... def __enter__(self) -> Self: ... diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py index dfba8f54eb63..cb5b4aabb1a0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.py @@ -116,6 +116,11 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): # py :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided TokenCredential. Only supported with a TokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi index 2ad0514290d3..56c2d5372082 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_container_client.pyi @@ -71,6 +71,7 @@ class ContainerClient(StorageAccountHostsMixin, StorageEncryptionMixin): max_single_get_size: int = 32 * 1024 * 1024, min_large_block_upload_threshold: int = 4 * 1024 * 1024 + 1, use_byte_buffer: Optional[bool] = None, + use_session: bool = False, **kwargs: Any, ) -> None: ... def __enter__(self) -> Self: ... diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py index a050d8206df2..90bd75a2662c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.py @@ -152,6 +152,11 @@ class BlobClient( # type: ignore [misc] # pylint: disable=too-many-public-metho :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided AsyncTokenCredential. Only supported with an AsyncTokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi index ba9a7460425c..4c1d59cd7f53 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_blob_client_async.pyi @@ -81,6 +81,7 @@ class BlobClient( # type: ignore[misc] max_single_get_size: int = 32 * 1024 * 1024, min_large_block_upload_threshold: int = 4 * 1024 * 1024 + 1, use_byte_buffer: Optional[bool] = None, + use_session: bool = False, **kwargs: Any ) -> None: ... async def __aenter__(self) -> Self: ... diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py index 413ba2cd0f0c..43f253f0f78a 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.py @@ -112,6 +112,11 @@ class ContainerClient( # type: ignore [misc] # pylint: disable=too-many-public :keyword str audience: The audience to use when requesting tokens for Azure Active Directory authentication. Only has an effect when credential is of type TokenCredential. The value could be https://storage.azure.com/ (default) or https://.blob.core.windows.net. + :keyword bool use_session: If True, enable session-based authentication for this container. + When enabled, eligible GET requests issued by this client will be authenticated using + a short-lived session credential obtained from the service instead + of the provided AsyncTokenCredential. Only supported with an AsyncTokenCredential; + ValueError is raised otherwise. Defaults to False. .. admonition:: Example: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi index 49362aac8058..2cff3516be37 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_container_client_async.pyi @@ -79,6 +79,7 @@ class ContainerClient( # type: ignore[misc] max_single_get_size: int = 32 * 1024 * 1024, min_large_block_upload_threshold: int = 4 * 1024 * 1024 + 1, use_byte_buffer: Optional[bool] = None, + use_session: bool = False, **kwargs: Any ) -> None: ... async def __aenter__(self) -> Self: ... From a63e45ef9a80cf492b892acf5aa298441cbda28f Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 19:06:45 -0400 Subject: [PATCH 13/29] BOOM BOOM --- sdk/storage/azure-storage-blob/tests/test_container.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index c861fa0390e3..6c84d1046529 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2807,7 +2807,7 @@ def recursive_walk(prefix): @BlobPreparer() @recorded_by_proxy - def test_create_session_same_policy(self, **kwargs): + def test_create_session(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") credential = self.get_credential(BlobServiceClient) @@ -2893,6 +2893,7 @@ def find_session_policy(pipeline): assert session2 != session_token_from(captured["c1_download3"]) @BlobPreparer() + @recorded_by_proxy def test_sessions_disabled(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") From 24ffd974f3bc8956adcf5c207cde8c021a944262 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 19:51:33 -0400 Subject: [PATCH 14/29] Async --- sdk/storage/azure-storage-blob/CHANGELOG.md | 2 +- .../storage/blob/_shared/base_client_async.py | 37 +++- .../azure/storage/blob/_shared/policies.py | 21 +- .../storage/blob/_shared/policies_async.py | 200 +++++++++++++++++- 4 files changed, 240 insertions(+), 20 deletions(-) diff --git a/sdk/storage/azure-storage-blob/CHANGELOG.md b/sdk/storage/azure-storage-blob/CHANGELOG.md index 28e28f659a2d..2d7e7ab037f6 100644 --- a/sdk/storage/azure-storage-blob/CHANGELOG.md +++ b/sdk/storage/azure-storage-blob/CHANGELOG.md @@ -5,7 +5,7 @@ ### Features Added - Added opt-in session-based authentication for `ContainerClient` via the new `use_session` keyword argument. When enabled, it must be used with a -`TokenCredential`. GET blob download operations issued through the client +`TokenCredential`. Eligible GET requests issued through the client are authenticated using a short-lived session credential obtained from the service instead of the bearer token. diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 2e023b1cc8d9..18dc54f47e0e 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -44,8 +44,10 @@ AsyncStorageBearerTokenCredentialPolicy, AsyncContentValidationPolicy, AsyncStorageResponseHook, + AsyncStorageSessionPolicy, ) from .response_handlers import PartialBatchErrorException, process_storage_error +from .._generated.aio import AzureBlobStorage from .._shared_access_signature import _is_credential_sastoken if TYPE_CHECKING: @@ -141,11 +143,44 @@ def _create_pipeline( config.headers_policy, StorageRequestHook(**kwargs), self._credential_policy, + ] + use_session = bool(kwargs.pop("use_session", False)) + if use_session: + if not hasattr(credential, "get_token"): + raise ValueError( + "use_session=True requires a TokenCredential; received " + f"{type(credential).__name__ if credential is not None else 'None'}." + ) + + def _session_client_factory(container_url: str) -> AzureBlobStorage: + sub_kwargs = dict(kwargs) + sub_kwargs.pop("_pipeline", None) + sub_kwargs.pop("_configuration", None) + sub_kwargs.pop("pipeline", None) + sub_kwargs.pop("sdk_moniker", None) + sub_kwargs["use_session"] = False + sub_kwargs["transport"] = transport + + _, session_pipeline = self._create_pipeline( + credential, sdk_moniker=self._sdk_moniker, **sub_kwargs + ) + generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) + generated._client._pipeline = session_pipeline # pylint: disable=protected-access + return generated + + policies.append( + AsyncStorageSessionPolicy( + account_name=self.account_name, + session_client_factory=_session_client_factory, + use_session=True, + ) + ) + policies.extend([ config.logging_policy, AsyncStorageResponseHook(**kwargs), DistributedTracingPolicy(**kwargs), HttpLoggingPolicy(**kwargs), - ] + ]) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 0501d88c3b23..80769ba63f13 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -9,7 +9,7 @@ import random import re import uuid -from datetime import datetime, timedelta, UTC +from datetime import datetime, timedelta, timezone from io import BytesIO, SEEK_SET, UnsupportedOperation from time import time from threading import Lock @@ -69,6 +69,7 @@ SESSION_ELIGIBLE_CONTEXT_KEY = "_session_eligible" SESSION_RETRIED_CONTEXT_KEY = "_session_retried" SESSION_TOKEN_HEADER = "x-ms-session-token" +UTC = timezone.utc def encode_base64(data: Union[bytes, str]) -> str: @@ -884,11 +885,11 @@ class SessionCache: Concurrency model ----------------- - * Reads (`get`) are lock-free. They perform a single ``dict.get`` and never + * Reads (`get`) are lock-free. They perform a single dict.get and never mutate the cache, so concurrent readers never need to coordinate. * Writes (`put` / `put_fallback`) and the CreateSession single-flight are serialized per-container via the lock returned by :meth:`lock_container`. - * A single ``_locks_guard`` serializes only the *creation* of per-container + * A single _locks_guard serializes only the *creation* of per-container locks, so two threads racing on a brand-new container can't build two different lock objects. """ @@ -917,7 +918,7 @@ def lock_container(self, container_name: str) -> Lock: return self._locks.setdefault(container_name, Lock()) def get(self, container_name: str) -> Optional[Session]: - """Return a live session for the container, or ``None``. + """Return a live session for the container, or None. Lock-free and non-mutating. Expired entries are NOT deleted. Instead, they are simply treated as a cache miss and overwritten on the next refresh. @@ -932,7 +933,9 @@ def get(self, container_name: str) -> Optional[Session]: return cached def put(self, container_name: str, session_token: str, session_key: str, expires_at: datetime) -> None: - """Install a real session entry. Caller must hold ``lock_container``. + """Install a real session entry. + + Caller must hold the lock at the container-level. :param str container_name: The container the session belongs to. :param str session_token: The session token to send as a header. @@ -944,7 +947,7 @@ def put(self, container_name: str, session_token: str, session_key: str, expires def put_fallback(self, container_name: str) -> None: """Install a fallback-to-bearer sentinel for the cooldown window. - Caller must hold SessionCache.lock_container(). + Caller must hold the lock at the container-level. :param str container_name: The container to mark for bearer fallback. """ @@ -1130,7 +1133,7 @@ def on_request(self, request: "PipelineRequest") -> Optional[str]: """Stamp session auth if eligible, otherwise leave the bearer header intact. :param ~azure.core.pipeline.PipelineRequest request: The request to (maybe) sign. - :return: The container name if a session was applied, else ``None``. + :return: The container name if a session was applied, else None. :rtype: str or None """ if not self._is_eligible(request): @@ -1186,9 +1189,9 @@ def on_response( if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) with self._cache.lock_container(container_name): - self._cache.put_fallback(container_name) # drop the stale entry + self._cache.put_fallback(container_name) request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True - retried_container = self.on_request(request) # re-stamp (or fall to bearer) + retried_container = self.on_request(request) retried_response = self.next.send(request) return self.on_response(request, retried_response, retried_container) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index ee89c7c3dc9b..3833ed042585 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -8,7 +8,8 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging import random -from typing import Any, Dict, TYPE_CHECKING +from datetime import datetime +from typing import Any, Callable, Dict, Optional, Tuple, TYPE_CHECKING from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError from azure.core.pipeline.policies import ( @@ -21,9 +22,14 @@ from .policies import ( _prepare_content_validation, _validate_content_response, + CreateSessionConfiguration, encode_base64, is_retry, + Session, + SessionCache, StorageRetryPolicy, + StorageSessionPolicy, + SESSION_RETRIED_CONTEXT_KEY, ) from .streams_async import AsyncStructuredMessageDecoder from .validation import ( @@ -339,16 +345,192 @@ async def on_challenge(self, request: "PipelineRequest", response: "PipelineResp return True + +class AsyncSessionCache(SessionCache): + """Async variant of :class:`SessionCache`. + + Reuses the lock-free, non-mutating read path and the immutable + Session snapshots from the sync cache, but serializes the + per-container CreateSession single-flight with asynchronous locks. + """ + + def __init__(self) -> None: + super().__init__() + self._async_locks: Dict[str, asyncio.Lock] = {} + + def lock_container_async(self, container_name: str) -> asyncio.Lock: + """Return the per-container asyncio lock, creating it exactly once. + + Lock creation is not awaited, so it is safe to do without holding a + lock: the event loop guarantees this runs to completion without + interleaving other coroutines. + + :param str container_name: The container to get the lock for. + :return: The single asyncio lock associated with the container. + :rtype: ~asyncio.Lock + """ + existing = self._async_locks.get(container_name) + if existing is not None: + return existing + return self._async_locks.setdefault(container_name, asyncio.Lock()) + + class AsyncStorageSessionPolicy(AsyncHTTPPolicy): + """Constructs an AsyncStorageSessionPolicy. + + Eligible blob download GETs are authenticated with a per-container session + token; everything else is delegated unchanged to the bearer credential + policy that sits earlier in the pipeline. + """ + + SESSIONS_UNAVAILABLE: str = StorageSessionPolicy.SESSIONS_UNAVAILABLE + """Service-reported code: session operations are temporarily unavailable.""" + FEATURE_NOT_ENABLED: str = StorageSessionPolicy.FEATURE_NOT_ENABLED + """Service-reported code: the session feature is not enabled on the scale unit.""" + + def __init__( + self, + *, + account_name: str, + session_client_factory: Callable[[str], Any], + use_session: bool = False, + ) -> None: + """Constructs an AsyncStorageSessionPolicy. + + :keyword str account_name: Storage account name; used as the signer + identity when signing session-authenticated requests. + :keyword session_client_factory: A callable that, given a container URL, + returns a session-disabled generated async client whose pipeline + uses OAuth/bearer auth. Invoked (and awaited) to issue CreateSession. + :paramtype session_client_factory: Callable[[str], Any] + :keyword bool use_session: Whether session authentication is enabled. + :raises ValueError: if account_name or session_client_factory is None. + """ + if account_name is None or session_client_factory is None: + raise ValueError("account_name and session_client_factory are required.") + super().__init__() + self._account_name = account_name + self._session_client_factory = session_client_factory + self._use_session = use_session + self._cache = AsyncSessionCache() + self._signer = StorageSessionPolicy( + account_name=account_name, + session_client_factory=session_client_factory, + use_session=use_session, + ) + + async def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: + config = CreateSessionConfiguration(authentication_type="HMAC") + client = self._session_client_factory(container_url) + response = await client.container.create_session(create_session_configuration=config) + return StorageSessionPolicy._extract_session(response) # pylint: disable=protected-access + + async def _refresh_session_token(self, container_name: str, container_url: str) -> Optional[Session]: + """Acquire (or re-use) a session under per-container async single-flight. + + :param str container_name: The container key for the cache and lock. + :param str container_url: The container-scoped URL for the CreateSession call. + :return: A live session, a fallback sentinel, or None if unusable. + :rtype: ~azure.storage.blob._shared.policies.Session or None + """ + async with self._cache.lock_container_async(container_name): + existing = self._cache.get(container_name) + if existing is not None and not existing.expired(): + return existing + try: + token, key, expires_at = await self._create_session(container_url) + self._cache.put(container_name, token, key, expires_at) + except Exception: # pylint: disable=broad-except + _LOGGER.warning( + "CreateSession failed for container '%s'; falling back to bearer for %d seconds.", + container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + exc_info=True, + ) + self._cache.put_fallback(container_name) + return self._cache.get(container_name) + + async def send(self, request: "PipelineRequest") -> "PipelineResponse": + """Orchestrate session auth. + + :param ~azure.core.pipeline.PipelineRequest request: The outgoing request. + :return: The pipeline response. + :rtype: ~azure.core.pipeline.PipelineResponse + """ + container_name = await self.on_request(request) + response = await self.next.send(request) + return await self.on_response(request, response, container_name) + + async def on_request(self, request: "PipelineRequest") -> Optional[str]: + """Stamp session auth if eligible, otherwise leave the bearer header intact. + + :param ~azure.core.pipeline.PipelineRequest request: The request to (maybe) sign. + :return: The container name if a session was applied, else None. + :rtype: str or None + """ + if not self._signer._is_eligible(request): # pylint: disable=protected-access + return None + container_name = StorageSessionPolicy._parse_container( # pylint: disable=protected-access + request.http_request.url + ) + if not container_name: + return None - def __init__(self): - pass + session = self._cache.get(container_name) + if session is None: + container_url = StorageSessionPolicy._container_url( # pylint: disable=protected-access + request.http_request.url + ) + session = await self._refresh_session_token(container_name, container_url) - def on_request(self): - pass + if session is None or session.is_fallback or not session.session_token or not session.session_key: + return None - def send(self, request): - pass + self._signer._apply_session_auth( # pylint: disable=protected-access + request, session.session_token, session.session_key + ) + return container_name - def on_response(self): - pass + async def on_response( + self, + request: "PipelineRequest", + response: "PipelineResponse", + container_name: Optional[str], + ) -> "PipelineResponse": + """React to session-related failures: cooldown sentinel or one-shot re-acquire. + + :param ~azure.core.pipeline.PipelineRequest request: The original request. + :param ~azure.core.pipeline.PipelineResponse response: The response to inspect. + :param container_name: Container that was session-signed, or None if bearer was used. + :type container_name: str or None + :return: The final response (possibly from a one-shot retry). + :rtype: ~azure.core.pipeline.PipelineResponse + """ + if container_name is None: + return response # bearer was used; nothing session-related to react to + + status = response.http_response.status_code + error_code = response.http_response.headers.get("x-ms-error-code", "") + + # Unavailable / feature-off / 5xx → negative-cache cooldown. + if error_code in (self.SESSIONS_UNAVAILABLE, self.FEATURE_NOT_ENABLED) or status >= 500: + _LOGGER.warning( + "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", + error_code or "5xx", status, container_name, + int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), + ) + async with self._cache.lock_container_async(container_name): + self._cache.put_fallback(container_name) + return response + + # 401 → invalidate + re-acquire ONCE, then resend. + if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): + _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) + async with self._cache.lock_container_async(container_name): + self._cache.put_fallback(container_name) + request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True + retried_container = await self.on_request(request) + retried_response = await self.next.send(request) + return await self.on_response(request, retried_response, retried_container) + + return response From 1e9171be0658e6e9d2a8ac777a3c043db11bc14e Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Tue, 2 Jun 2026 19:56:36 -0400 Subject: [PATCH 15/29] Async tests --- .../tests/test_container_async.py | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) diff --git a/sdk/storage/azure-storage-blob/tests/test_container_async.py b/sdk/storage/azure-storage-blob/tests/test_container_async.py index 2b97a45f4936..12e627cc8f29 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_container_async.py @@ -2687,3 +2687,130 @@ async def recursive_walk(prefix): # Assert assert blobs is not None assert blobs == ["a/b/blob2", "a/b/blob3", "a/b/blob4", "a/blob1"] + + @BlobPreparer() + @recorded_by_proxy_async + async def test_create_session(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient, is_async=True) + captured = {} + + def make_capture(label): + def _hook(response): + auth = response.http_request.headers.get("Authorization", "") + captured[label] = auth + return _hook + + def session_token_from(auth): + # "Session {token}:{signature}" -> token + assert auth.startswith("Session ") + return auth[len("Session "):].split(":", 1)[0] + + def find_session_policy(pipeline): + # Match by class name to avoid importing internals into the test module. + for p in getattr(pipeline, "_impl_policies", []): + if type(p).__name__ == "AsyncStorageSessionPolicy": + return p + raise AssertionError("AsyncStorageSessionPolicy not found on the pipeline") + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=True, + ) + container1_name = self.get_resource_name("utcontainer1") + container1 = service.get_container_client(container1_name) + try: + await container1.create_container() + except ResourceExistsError: + pass + + blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" + await container1.upload_blob( + blob1_name, blob1_data, overwrite=True, raw_response_hook=make_capture("c1_upload") + ) + assert captured["c1_upload"].startswith("Bearer ") + + blob1_actual = await (await container1.download_blob( + blob1_name, raw_response_hook=make_capture("c1_download"))).readall() + assert blob1_data == blob1_actual + assert captured["c1_download"].startswith("Session ") + session1 = session_token_from(captured["c1_download"]) + + container2_name = self.get_resource_name("utcontainer2") + container2 = service.get_container_client(container2_name) + try: + await container2.create_container() + except ResourceExistsError: + pass + + blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" + await container2.upload_blob( + blob2_name, blob2_data, overwrite=True, raw_response_hook=make_capture("c2_upload")) + assert captured["c2_upload"].startswith("Bearer ") + + blob2_actual = await (await container2.download_blob( + blob2_name, raw_response_hook=make_capture("c2_download"))).readall() + assert blob2_data == blob2_actual + assert captured["c2_download"].startswith("Session ") + session2 = session_token_from(captured["c2_download"]) + + assert session1 != session2 + + blob1_actual = await (await container1.download_blob( + blob1_name, raw_response_hook=make_capture("c1_download2"))).readall() + assert blob1_data == blob1_actual + assert captured["c1_download2"].startswith("Session ") + assert session1 == session_token_from(captured["c1_download2"]) + + blob2_actual = await (await container2.download_blob( + blob2_name, raw_response_hook=make_capture("c2_download2"))).readall() + assert blob2_data == blob2_actual + assert captured["c2_download2"].startswith("Session ") + assert session2 == session_token_from(captured["c2_download2"]) + + policy = find_session_policy(service._pipeline) + cached = policy._cache._entry[container1_name] + cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) + + blob1_actual = await (await container1.download_blob( + blob1_name, raw_response_hook=make_capture("c1_download3"))).readall() + assert blob1_data == blob1_actual + assert captured["c1_download3"].startswith("Session ") + assert session1 != session_token_from(captured["c1_download3"]) + assert session2 != session_token_from(captured["c1_download3"]) + + @BlobPreparer() + @recorded_by_proxy_async + async def test_sessions_disabled(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + + credential = self.get_credential(BlobServiceClient, is_async=True) + captured = {} + + def make_capture(label): + def _hook(response): + auth = response.http_request.headers.get("Authorization", "") + captured[label] = auth + return _hook + + service = BlobServiceClient( + self.account_url(storage_account_name, "blob"), + credential=credential, + use_session=False, + ) + container = service.get_container_client(self.get_resource_name("utcontainer")) + try: + await container.create_container() + except ResourceExistsError: + pass + + blob_name, blob_data = self.get_resource_name("blob"), b"abc123" + await container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=make_capture("upload")) + assert captured["upload"].startswith("Bearer ") + + blob_actual = await (await container.download_blob( + blob_name, raw_response_hook=make_capture("download"))).readall() + assert blob_data == blob_actual + assert captured["download"].startswith("Bearer ") From c7c77af4986b5fbc911635875e1809df397b0c30 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Wed, 3 Jun 2026 13:27:45 -0400 Subject: [PATCH 16/29] Feature not enabled --- .../azure/storage/blob/_shared/policies.py | 9 +++++++-- .../azure/storage/blob/_shared/policies_async.py | 7 ++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 80769ba63f13..708a867b3231 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -1174,13 +1174,18 @@ def on_response( status = response.http_response.status_code error_code = response.http_response.headers.get("x-ms-error-code", "") + if error_code == self.FEATURE_NOT_ENABLED: + _LOGGER.info("Session feature not enabled on this account; disabling session auth.") + self._use_session = False + return response + # Unavailable / feature-off / 5xx → negative-cache cooldown. - if error_code in (self.SESSIONS_UNAVAILABLE, self.FEATURE_NOT_ENABLED) or status >= 500: + if error_code in self.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", error_code or "5xx", status, container_name, int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), - ) + ) with self._cache.lock_container(container_name): self._cache.put_fallback(container_name) return response diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 3833ed042585..57c154ba186d 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -512,8 +512,13 @@ async def on_response( status = response.http_response.status_code error_code = response.http_response.headers.get("x-ms-error-code", "") + if error_code == self.FEATURE_NOT_ENABLED: + _LOGGER.info("Session feature not enabled on this account; disabling session auth.") + self._use_session = False + return response + # Unavailable / feature-off / 5xx → negative-cache cooldown. - if error_code in (self.SESSIONS_UNAVAILABLE, self.FEATURE_NOT_ENABLED) or status >= 500: + if error_code in self.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", error_code or "5xx", status, container_name, From 6a1492df79b73d15ba7cebebe428edae68de7d76 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Wed, 3 Jun 2026 16:02:24 -0400 Subject: [PATCH 17/29] Feedback --- .../azure/storage/blob/_shared/policies.py | 79 +++++++------------ .../storage/blob/_shared/policies_async.py | 18 ++--- 2 files changed, 37 insertions(+), 60 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 708a867b3231..f334d2298313 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -66,9 +66,7 @@ SM_HEADER = "x-ms-structured-body" SM_HEADER_V1_CRC64 = "XSM/1.0; properties=crc64" SM_LENGTH_HEADER = "x-ms-structured-content-length" -SESSION_ELIGIBLE_CONTEXT_KEY = "_session_eligible" SESSION_RETRIED_CONTEXT_KEY = "_session_retried" -SESSION_TOKEN_HEADER = "x-ms-session-token" UTC = timezone.utc @@ -1004,29 +1002,20 @@ def __init__( self._cache = SessionCache() @staticmethod - def _parse_container(url: str) -> Optional[str]: - """Extract the container name (first path segment) from a request URL. - - :param str url: The request URL. - :return: The container name, or `None` for service-level URLs. - :rtype: str or None - """ - path = urlparse(url).path - segments = [seg for seg in path.split("/") if seg] - return segments[0] if segments else None - - @staticmethod - def _container_url(request_url: str) -> str: - """Build the container-scoped URL (scheme://host/container) for CreateSession. - - :param str request_url: The originating request URL. - :return: A URL pointing at the container root. - :rtype: str - """ - parsed = urlparse(request_url) + def _analyze_request(request: "PipelineRequest") -> Optional[Tuple[str, str]]: + http_request = request.http_request + if http_request.method != "GET": + return None + parsed = urlparse(http_request.url) segments = [seg for seg in parsed.path.split("/") if seg] - container = segments[0] if segments else "" - return f"{parsed.scheme}://{parsed.netloc}/{container}" + if len(segments) < 2: + return None + query = http_request.query + if "comp" in query or query.get("restype") == "container": + return None + container_name = segments[0] + container_url = f"{parsed.scheme}://{parsed.netloc}/{container_name}" + return container_name, container_url @staticmethod def _extract_session(response: Any) -> Tuple[str, str, datetime]: @@ -1072,20 +1061,13 @@ def _apply_session_auth(self, request: "PipelineRequest", session_token: str, se raise AzureSigningError(str(ex)) from ex http_request.headers["Authorization"] = f"Session {session_token}:{signature}" - def _is_eligible(self, request: "PipelineRequest") -> bool: - if not self._use_session: - return False - http_request = request.http_request - if http_request.method != "GET": - return False - parsed = urlparse(http_request.url) - segments = [s for s in parsed.path.split("/") if s] - if len(segments) < 2: - return False - query = http_request.query - if "comp" in query or query.get("restype") == "container": - return False - return True + def _resolve_session(self, container_name: str, container_url: str) -> Optional[Session]: + session = self._cache.get(container_name) + if session is None: + session = self._refresh_session_token(container_name, container_url) + if session is None or session.is_fallback: + return None + return session def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: config = CreateSessionConfiguration(authentication_type="HMAC") @@ -1136,18 +1118,15 @@ def on_request(self, request: "PipelineRequest") -> Optional[str]: :return: The container name if a session was applied, else None. :rtype: str or None """ - if not self._is_eligible(request): + if not self._use_session: return None - container_name = self._parse_container(request.http_request.url) - if not container_name: + analysis = self._analyze_request(request) + if analysis is None: return None + container_name, container_url = analysis - session = self._cache.get(container_name) - if session is None: - container_url = self._container_url(request.http_request.url) - session = self._refresh_session_token(container_name, container_url) - - if session is None or session.is_fallback or not session.session_token or not session.session_key: + session = self._resolve_session(container_name, container_url) + if session is None or not session.session_token or not session.session_key: return None self._apply_session_auth(request, session.session_token, session.session_key) @@ -1179,13 +1158,13 @@ def on_response( self._use_session = False return response - # Unavailable / feature-off / 5xx → negative-cache cooldown. - if error_code in self.SESSIONS_UNAVAILABLE or status >= 500: + # Unavailable / 5xx → negative-cache cooldown. + if error_code == self.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", error_code or "5xx", status, container_name, int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), - ) + ) with self._cache.lock_container(container_name): self._cache.put_fallback(container_name) return response diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 57c154ba186d..6e1b652ebb26 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -468,19 +468,17 @@ async def on_request(self, request: "PipelineRequest") -> Optional[str]: :return: The container name if a session was applied, else None. :rtype: str or None """ - if not self._signer._is_eligible(request): # pylint: disable=protected-access + if not self._use_session: return None - container_name = StorageSessionPolicy._parse_container( # pylint: disable=protected-access - request.http_request.url - ) - if not container_name: + analysis = StorageSessionPolicy._analyze_request(request) # pylint: disable=protected-access + if analysis is None: return None + container_name, container_url = analysis session = self._cache.get(container_name) if session is None: - container_url = StorageSessionPolicy._container_url( # pylint: disable=protected-access - request.http_request.url - ) + # True miss/expiry (a live fallback sentinel is returned by get(), + # so we never reach refresh while the cooldown is active). session = await self._refresh_session_token(container_name, container_url) if session is None or session.is_fallback or not session.session_token or not session.session_key: @@ -517,8 +515,8 @@ async def on_response( self._use_session = False return response - # Unavailable / feature-off / 5xx → negative-cache cooldown. - if error_code in self.SESSIONS_UNAVAILABLE or status >= 500: + # Unavailable / 5xx → negative-cache cooldown. + if error_code == self.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", error_code or "5xx", status, container_name, From d86b2d96d9d2d524142f2824972bbf336340e314 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Wed, 3 Jun 2026 17:33:10 -0400 Subject: [PATCH 18/29] Got rid of broad except --- .../azure-storage-blob/azure/storage/blob/_shared/policies.py | 2 +- .../azure/storage/blob/_shared/policies_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index f334d2298313..204a7904db05 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -1090,7 +1090,7 @@ def _refresh_session_token(self, container_name: str, container_url: str) -> Opt try: token, key, expires_at = self._create_session(container_url) self._cache.put(container_name, token, key, expires_at) - except Exception: # pylint: disable=broad-except + except (AzureError, ValueError): _LOGGER.warning( "CreateSession failed for container '%s'; falling back to bearer for %d seconds.", container_name, diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 6e1b652ebb26..bf97044813f7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -440,7 +440,7 @@ async def _refresh_session_token(self, container_name: str, container_url: str) try: token, key, expires_at = await self._create_session(container_url) self._cache.put(container_name, token, key, expires_at) - except Exception: # pylint: disable=broad-except + except (AzureError, ValueError): _LOGGER.warning( "CreateSession failed for container '%s'; falling back to bearer for %d seconds.", container_name, From 27003ebf015032f92a8c1d490e9b96bfd97b9270 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Wed, 3 Jun 2026 17:34:53 -0400 Subject: [PATCH 19/29] black --- .../azure/storage/blob/_shared/base_client.py | 18 ++++----- .../storage/blob/_shared/base_client_async.py | 18 ++++----- .../azure/storage/blob/_shared/policies.py | 18 +++++++-- .../storage/blob/_shared/policies_async.py | 8 ++-- .../azure-storage-blob/tests/conftest.py | 4 +- .../tests/test_container.py | 8 ++-- .../tests/test_container_async.py | 37 ++++++++++++------- 7 files changed, 65 insertions(+), 46 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 2a6682415932..98ffdea90fb0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -347,9 +347,7 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs["use_session"] = False sub_kwargs["transport"] = transport - _, session_pipeline = self._create_pipeline( - credential, sdk_moniker=self._sdk_moniker, **sub_kwargs - ) + _, session_pipeline = self._create_pipeline(credential, sdk_moniker=self._sdk_moniker, **sub_kwargs) generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) generated._client._pipeline = session_pipeline # pylint: disable=protected-access return generated @@ -361,12 +359,14 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: use_session=True, ) ) - policies.extend([ - config.logging_policy, - StorageResponseHook(**kwargs), - DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs), - ]) + policies.extend( + [ + config.logging_policy, + StorageResponseHook(**kwargs), + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs), + ] + ) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 1ae5e1861808..26e2afb9209f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -161,9 +161,7 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs["use_session"] = False sub_kwargs["transport"] = transport - _, session_pipeline = self._create_pipeline( - credential, sdk_moniker=self._sdk_moniker, **sub_kwargs - ) + _, session_pipeline = self._create_pipeline(credential, sdk_moniker=self._sdk_moniker, **sub_kwargs) generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) generated._client._pipeline = session_pipeline # pylint: disable=protected-access return generated @@ -175,12 +173,14 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: use_session=True, ) ) - policies.extend([ - config.logging_policy, - AsyncStorageResponseHook(**kwargs), - DistributedTracingPolicy(**kwargs), - HttpLoggingPolicy(**kwargs), - ]) + policies.extend( + [ + config.logging_policy, + AsyncStorageResponseHook(**kwargs), + DistributedTracingPolicy(**kwargs), + HttpLoggingPolicy(**kwargs), + ] + ) if kwargs.get("_additional_pipeline_policies"): policies = policies + kwargs.get("_additional_pipeline_policies") # type: ignore config.transport = transport # type: ignore diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 204a7904db05..9195666a355d 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -969,9 +969,17 @@ class StorageSessionPolicy(HTTPPolicy): FEATURE_NOT_ENABLED: str = "FeatureNotEnabled" """Service-reported code: the session feature is not enabled on the scale unit.""" _SIGNED_HEADERS = ( - "content-encoding", "content-language", "content-length", "content-md5", - "content-type", "date", "if-modified-since", "if-match", "if-none-match", - "if-unmodified-since", "byte_range", + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", ) def __init__( @@ -1162,7 +1170,9 @@ def on_response( if error_code == self.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", - error_code or "5xx", status, container_name, + error_code or "5xx", + status, + container_name, int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), ) with self._cache.lock_container(container_name): diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index bf97044813f7..83498988a7b0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -228,7 +228,7 @@ def __init__( retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs + **kwargs, ) -> None: """ Constructs an Exponential retry object. The initial_backoff is used for @@ -288,7 +288,7 @@ def __init__( retry_total: int = 3, retry_to_secondary: bool = False, random_jitter_range: int = 3, - **kwargs: Any + **kwargs: Any, ) -> None: """ Constructs a Linear retry object. @@ -519,7 +519,9 @@ async def on_response( if error_code == self.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", - error_code or "5xx", status, container_name, + error_code or "5xx", + status, + container_name, int(SessionCache.FALLBACK_COOLDOWN.total_seconds()), ) async with self._cache.lock_container_async(container_name): diff --git a/sdk/storage/azure-storage-blob/tests/conftest.py b/sdk/storage/azure-storage-blob/tests/conftest.py index 205e8ae2a6fa..be6bb2be485a 100644 --- a/sdk/storage/azure-storage-blob/tests/conftest.py +++ b/sdk/storage/azure-storage-blob/tests/conftest.py @@ -33,9 +33,7 @@ def add_sanitizers(test_proxy): add_general_regex_sanitizer( regex=r"[^<]*", value="Sanitized" ) - add_general_regex_sanitizer( - regex=r"[^<]*", value="U2FuaXRpemVk" - ) + add_general_regex_sanitizer(regex=r"[^<]*", value="U2FuaXRpemVk") add_general_regex_sanitizer(regex=r'"EncryptionLibrary": "Python .*?"', value='"EncryptionLibrary": "Python x.x.x"') add_uri_regex_sanitizer(regex=r"\.preprod\.", value=".") diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 8a430fb2090c..854d485ecd01 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2757,12 +2757,13 @@ def make_capture(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth + return _hook def session_token_from(auth): # "Session {token}:{signature}" -> token assert auth.startswith("Session ") - return auth[len("Session "):].split(":", 1)[0] + return auth[len("Session ") :].split(":", 1)[0] def find_session_policy(pipeline): # Match by class name to avoid importing internals into the test module. @@ -2784,9 +2785,7 @@ def find_session_policy(pipeline): pass blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" - container1.upload_blob( - blob1_name, blob1_data, overwrite=True, raw_response_hook=make_capture("c1_upload") - ) + container1.upload_blob(blob1_name, blob1_data, overwrite=True, raw_response_hook=make_capture("c1_upload")) assert captured["c1_upload"].startswith("Bearer ") blob1_actual = container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download")).readall() @@ -2844,6 +2843,7 @@ def make_capture(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth + return _hook service = BlobServiceClient( diff --git a/sdk/storage/azure-storage-blob/tests/test_container_async.py b/sdk/storage/azure-storage-blob/tests/test_container_async.py index 83e6b93fed7e..d4aba74c92f2 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_container_async.py @@ -2666,12 +2666,13 @@ def make_capture(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth + return _hook def session_token_from(auth): # "Session {token}:{signature}" -> token assert auth.startswith("Session ") - return auth[len("Session "):].split(":", 1)[0] + return auth[len("Session ") :].split(":", 1)[0] def find_session_policy(pipeline): # Match by class name to avoid importing internals into the test module. @@ -2698,8 +2699,9 @@ def find_session_policy(pipeline): ) assert captured["c1_upload"].startswith("Bearer ") - blob1_actual = await (await container1.download_blob( - blob1_name, raw_response_hook=make_capture("c1_download"))).readall() + blob1_actual = await ( + await container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download")) + ).readall() assert blob1_data == blob1_actual assert captured["c1_download"].startswith("Session ") session1 = session_token_from(captured["c1_download"]) @@ -2713,25 +2715,29 @@ def find_session_policy(pipeline): blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" await container2.upload_blob( - blob2_name, blob2_data, overwrite=True, raw_response_hook=make_capture("c2_upload")) + blob2_name, blob2_data, overwrite=True, raw_response_hook=make_capture("c2_upload") + ) assert captured["c2_upload"].startswith("Bearer ") - blob2_actual = await (await container2.download_blob( - blob2_name, raw_response_hook=make_capture("c2_download"))).readall() + blob2_actual = await ( + await container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download")) + ).readall() assert blob2_data == blob2_actual assert captured["c2_download"].startswith("Session ") session2 = session_token_from(captured["c2_download"]) assert session1 != session2 - blob1_actual = await (await container1.download_blob( - blob1_name, raw_response_hook=make_capture("c1_download2"))).readall() + blob1_actual = await ( + await container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download2")) + ).readall() assert blob1_data == blob1_actual assert captured["c1_download2"].startswith("Session ") assert session1 == session_token_from(captured["c1_download2"]) - blob2_actual = await (await container2.download_blob( - blob2_name, raw_response_hook=make_capture("c2_download2"))).readall() + blob2_actual = await ( + await container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download2")) + ).readall() assert blob2_data == blob2_actual assert captured["c2_download2"].startswith("Session ") assert session2 == session_token_from(captured["c2_download2"]) @@ -2740,8 +2746,9 @@ def find_session_policy(pipeline): cached = policy._cache._entry[container1_name] cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) - blob1_actual = await (await container1.download_blob( - blob1_name, raw_response_hook=make_capture("c1_download3"))).readall() + blob1_actual = await ( + await container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download3")) + ).readall() assert blob1_data == blob1_actual assert captured["c1_download3"].startswith("Session ") assert session1 != session_token_from(captured["c1_download3"]) @@ -2759,6 +2766,7 @@ def make_capture(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth + return _hook service = BlobServiceClient( @@ -2776,7 +2784,8 @@ def _hook(response): await container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=make_capture("upload")) assert captured["upload"].startswith("Bearer ") - blob_actual = await (await container.download_blob( - blob_name, raw_response_hook=make_capture("download"))).readall() + blob_actual = await ( + await container.download_blob(blob_name, raw_response_hook=make_capture("download")) + ).readall() assert blob_data == blob_actual assert captured["download"].startswith("Bearer ") From e13d8d5e9b1092c900125b40a1896f8d9870ebaa Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 02:29:02 -0400 Subject: [PATCH 20/29] Rename --- .../tests/test_container.py | 36 +++++++++---------- .../tests/test_container_async.py | 36 +++++++++---------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 854d485ecd01..68f900dff2c1 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2753,14 +2753,14 @@ def test_create_session(self, **kwargs): credential = self.get_credential(BlobServiceClient) captured = {} - def make_capture(label): + def capture_auth_header(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth return _hook - def session_token_from(auth): + def parse_session_token(auth): # "Session {token}:{signature}" -> token assert auth.startswith("Session ") return auth[len("Session ") :].split(":", 1)[0] @@ -2785,13 +2785,13 @@ def find_session_policy(pipeline): pass blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" - container1.upload_blob(blob1_name, blob1_data, overwrite=True, raw_response_hook=make_capture("c1_upload")) + container1.upload_blob(blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header("c1_upload")) assert captured["c1_upload"].startswith("Bearer ") - blob1_actual = container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download")).readall() + blob1_actual = container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download")).readall() assert blob1_data == blob1_actual assert captured["c1_download"].startswith("Session ") - session1 = session_token_from(captured["c1_download"]) + session1 = parse_session_token(captured["c1_download"]) container2_name = self.get_resource_name("utcontainer2") container2 = service.get_container_client(container2_name) @@ -2801,35 +2801,35 @@ def find_session_policy(pipeline): pass blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" - container2.upload_blob(blob2_name, blob2_data, overwrite=True, raw_response_hook=make_capture("c2_upload")) + container2.upload_blob(blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header("c2_upload")) assert captured["c2_upload"].startswith("Bearer ") - blob2_actual = container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download")).readall() + blob2_actual = container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download")).readall() assert blob2_data == blob2_actual assert captured["c2_download"].startswith("Session ") - session2 = session_token_from(captured["c2_download"]) + session2 = parse_session_token(captured["c2_download"]) assert session1 != session2 - blob1_actual = container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download2")).readall() + blob1_actual = container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download2")).readall() assert blob1_data == blob1_actual assert captured["c1_download2"].startswith("Session ") - assert session1 == session_token_from(captured["c1_download2"]) + assert session1 == parse_session_token(captured["c1_download2"]) - blob2_actual = container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download2")).readall() + blob2_actual = container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download2")).readall() assert blob2_data == blob2_actual assert captured["c2_download2"].startswith("Session ") - assert session2 == session_token_from(captured["c2_download2"]) + assert session2 == parse_session_token(captured["c2_download2"]) policy = find_session_policy(service._pipeline) cached = policy._cache._entry[container1_name] cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) - blob1_actual = container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download3")).readall() + blob1_actual = container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download3")).readall() assert blob1_data == blob1_actual assert captured["c1_download3"].startswith("Session ") - assert session1 != session_token_from(captured["c1_download3"]) - assert session2 != session_token_from(captured["c1_download3"]) + assert session1 != parse_session_token(captured["c1_download3"]) + assert session2 != parse_session_token(captured["c1_download3"]) @BlobPreparer() @recorded_by_proxy @@ -2839,7 +2839,7 @@ def test_sessions_disabled(self, **kwargs): credential = self.get_credential(BlobServiceClient) captured = {} - def make_capture(label): + def capture_auth_header(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth @@ -2858,9 +2858,9 @@ def _hook(response): pass blob_name, blob_data = self.get_resource_name("blob"), b"abc123" - container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=make_capture("upload")) + container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header("upload")) assert captured["upload"].startswith("Bearer ") - blob_actual = container.download_blob(blob_name, raw_response_hook=make_capture("download")).readall() + blob_actual = container.download_blob(blob_name, raw_response_hook=capture_auth_header("download")).readall() assert blob_data == blob_actual assert captured["download"].startswith("Bearer ") diff --git a/sdk/storage/azure-storage-blob/tests/test_container_async.py b/sdk/storage/azure-storage-blob/tests/test_container_async.py index d4aba74c92f2..402be270ffb6 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_container_async.py @@ -2662,14 +2662,14 @@ async def test_create_session(self, **kwargs): credential = self.get_credential(BlobServiceClient, is_async=True) captured = {} - def make_capture(label): + def capture_auth_header(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth return _hook - def session_token_from(auth): + def parse_session_token(auth): # "Session {token}:{signature}" -> token assert auth.startswith("Session ") return auth[len("Session ") :].split(":", 1)[0] @@ -2695,16 +2695,16 @@ def find_session_policy(pipeline): blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" await container1.upload_blob( - blob1_name, blob1_data, overwrite=True, raw_response_hook=make_capture("c1_upload") + blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header("c1_upload") ) assert captured["c1_upload"].startswith("Bearer ") blob1_actual = await ( - await container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download")) + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download")) ).readall() assert blob1_data == blob1_actual assert captured["c1_download"].startswith("Session ") - session1 = session_token_from(captured["c1_download"]) + session1 = parse_session_token(captured["c1_download"]) container2_name = self.get_resource_name("utcontainer2") container2 = service.get_container_client(container2_name) @@ -2715,44 +2715,44 @@ def find_session_policy(pipeline): blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" await container2.upload_blob( - blob2_name, blob2_data, overwrite=True, raw_response_hook=make_capture("c2_upload") + blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header("c2_upload") ) assert captured["c2_upload"].startswith("Bearer ") blob2_actual = await ( - await container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download")) + await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download")) ).readall() assert blob2_data == blob2_actual assert captured["c2_download"].startswith("Session ") - session2 = session_token_from(captured["c2_download"]) + session2 = parse_session_token(captured["c2_download"]) assert session1 != session2 blob1_actual = await ( - await container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download2")) + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download2")) ).readall() assert blob1_data == blob1_actual assert captured["c1_download2"].startswith("Session ") - assert session1 == session_token_from(captured["c1_download2"]) + assert session1 == parse_session_token(captured["c1_download2"]) blob2_actual = await ( - await container2.download_blob(blob2_name, raw_response_hook=make_capture("c2_download2")) + await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download2")) ).readall() assert blob2_data == blob2_actual assert captured["c2_download2"].startswith("Session ") - assert session2 == session_token_from(captured["c2_download2"]) + assert session2 == parse_session_token(captured["c2_download2"]) policy = find_session_policy(service._pipeline) cached = policy._cache._entry[container1_name] cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) blob1_actual = await ( - await container1.download_blob(blob1_name, raw_response_hook=make_capture("c1_download3")) + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download3")) ).readall() assert blob1_data == blob1_actual assert captured["c1_download3"].startswith("Session ") - assert session1 != session_token_from(captured["c1_download3"]) - assert session2 != session_token_from(captured["c1_download3"]) + assert session1 != parse_session_token(captured["c1_download3"]) + assert session2 != parse_session_token(captured["c1_download3"]) @BlobPreparer() @recorded_by_proxy_async @@ -2762,7 +2762,7 @@ async def test_sessions_disabled(self, **kwargs): credential = self.get_credential(BlobServiceClient, is_async=True) captured = {} - def make_capture(label): + def capture_auth_header(label): def _hook(response): auth = response.http_request.headers.get("Authorization", "") captured[label] = auth @@ -2781,11 +2781,11 @@ def _hook(response): pass blob_name, blob_data = self.get_resource_name("blob"), b"abc123" - await container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=make_capture("upload")) + await container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header("upload")) assert captured["upload"].startswith("Bearer ") blob_actual = await ( - await container.download_blob(blob_name, raw_response_hook=make_capture("download")) + await container.download_blob(blob_name, raw_response_hook=capture_auth_header("download")) ).readall() assert blob_data == blob_actual assert captured["download"].startswith("Bearer ") From 47dfd8cdb7220985c4fa0379e58b06d0cea18f23 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 10:44:25 -0400 Subject: [PATCH 21/29] More feedback --- .../azure/storage/blob/_shared/base_client.py | 5 +---- .../azure/storage/blob/_shared/base_client_async.py | 5 +---- .../azure/storage/blob/_shared/models.py | 2 ++ .../azure/storage/blob/_shared/policies.py | 8 ++------ .../azure/storage/blob/_shared/policies_async.py | 10 +++------- 5 files changed, 9 insertions(+), 21 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 98ffdea90fb0..64e6ddff465f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -340,14 +340,11 @@ def _create_pipeline( def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs = dict(kwargs) - sub_kwargs.pop("_pipeline", None) sub_kwargs.pop("_configuration", None) sub_kwargs.pop("pipeline", None) - sub_kwargs.pop("sdk_moniker", None) - sub_kwargs["use_session"] = False sub_kwargs["transport"] = transport - _, session_pipeline = self._create_pipeline(credential, sdk_moniker=self._sdk_moniker, **sub_kwargs) + _, session_pipeline = self._create_pipeline(credential, **sub_kwargs) generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) generated._client._pipeline = session_pipeline # pylint: disable=protected-access return generated diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 26e2afb9209f..6287bd416510 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -154,14 +154,11 @@ def _create_pipeline( def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs = dict(kwargs) - sub_kwargs.pop("_pipeline", None) sub_kwargs.pop("_configuration", None) sub_kwargs.pop("pipeline", None) - sub_kwargs.pop("sdk_moniker", None) - sub_kwargs["use_session"] = False sub_kwargs["transport"] = transport - _, session_pipeline = self._create_pipeline(credential, sdk_moniker=self._sdk_moniker, **sub_kwargs) + _, session_pipeline = self._create_pipeline(credential, **sub_kwargs) generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) generated._client._pipeline = session_pipeline # pylint: disable=protected-access return generated diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py index 23786baef24b..4d455df080e2 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/models.py @@ -34,6 +34,7 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): CONDITION_HEADERS_NOT_SUPPORTED = "ConditionHeadersNotSupported" CONDITION_NOT_MET = "ConditionNotMet" EMPTY_METADATA_KEY = "EmptyMetadataKey" + FEATURE_NOT_ENABLED = "FeatureNotEnabled" INSUFFICIENT_ACCOUNT_PERMISSIONS = "InsufficientAccountPermissions" INTERNAL_ERROR = "InternalError" INVALID_AUTHENTICATION_INFO = "InvalidAuthenticationInfo" @@ -64,6 +65,7 @@ class StorageErrorCode(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESOURCE_ALREADY_EXISTS = "ResourceAlreadyExists" RESOURCE_NOT_FOUND = "ResourceNotFound" SERVER_BUSY = "ServerBusy" + SESSIONS_UNAVAILABLE = "SessionOperationsTemporarilyUnavailable" UNSUPPORTED_HEADER = "UnsupportedHeader" UNSUPPORTED_XML_NODE = "UnsupportedXmlNode" UNSUPPORTED_QUERY_PARAMETER = "UnsupportedQueryParameter" diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 9195666a355d..b3fd7195bf1c 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -964,10 +964,6 @@ class StorageSessionPolicy(HTTPPolicy): When disabled, all requests are delegated to the bearer token policy. """ - SESSIONS_UNAVAILABLE: str = "SessionOperationsTemporarilyUnavailable" - """Service-reported code: session operations are temporarily unavailable.""" - FEATURE_NOT_ENABLED: str = "FeatureNotEnabled" - """Service-reported code: the session feature is not enabled on the scale unit.""" _SIGNED_HEADERS = ( "content-encoding", "content-language", @@ -1161,13 +1157,13 @@ def on_response( status = response.http_response.status_code error_code = response.http_response.headers.get("x-ms-error-code", "") - if error_code == self.FEATURE_NOT_ENABLED: + if error_code == StorageErrorCode.FEATURE_NOT_ENABLED: _LOGGER.info("Session feature not enabled on this account; disabling session auth.") self._use_session = False return response # Unavailable / 5xx → negative-cache cooldown. - if error_code == self.SESSIONS_UNAVAILABLE or status >= 500: + if error_code == StorageErrorCode.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", error_code or "5xx", diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 83498988a7b0..e37118ff5288 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -19,6 +19,7 @@ from .authentication import AzureSigningError, StorageHttpChallenge from .constants import DEFAULT_OAUTH_SCOPE +from .models import StorageErrorCode from .policies import ( _prepare_content_validation, _validate_content_response, @@ -383,11 +384,6 @@ class AsyncStorageSessionPolicy(AsyncHTTPPolicy): policy that sits earlier in the pipeline. """ - SESSIONS_UNAVAILABLE: str = StorageSessionPolicy.SESSIONS_UNAVAILABLE - """Service-reported code: session operations are temporarily unavailable.""" - FEATURE_NOT_ENABLED: str = StorageSessionPolicy.FEATURE_NOT_ENABLED - """Service-reported code: the session feature is not enabled on the scale unit.""" - def __init__( self, *, @@ -510,13 +506,13 @@ async def on_response( status = response.http_response.status_code error_code = response.http_response.headers.get("x-ms-error-code", "") - if error_code == self.FEATURE_NOT_ENABLED: + if error_code == StorageErrorCode.FEATURE_NOT_ENABLED: _LOGGER.info("Session feature not enabled on this account; disabling session auth.") self._use_session = False return response # Unavailable / 5xx → negative-cache cooldown. - if error_code == self.SESSIONS_UNAVAILABLE or status >= 500: + if error_code == StorageErrorCode.SESSIONS_UNAVAILABLE or status >= 500: _LOGGER.warning( "Session authentication: '%s' (HTTP %d) on container '%s'; bearer fallback for %d seconds.", error_code or "5xx", From 9f3796caafcc4d67cf9fc350c1503b0975277ec0 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 13:08:35 -0400 Subject: [PATCH 22/29] Passing pipeline directly --- .../azure/storage/blob/_shared/base_client.py | 5 ++-- .../storage/blob/_shared/base_client_async.py | 5 ++-- .../tests/test_container.py | 28 ++++++++++++++----- .../tests/test_container_async.py | 4 ++- 4 files changed, 30 insertions(+), 12 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 64e6ddff465f..99cfd7c1ca80 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -345,8 +345,9 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs["transport"] = transport _, session_pipeline = self._create_pipeline(credential, **sub_kwargs) - generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) - generated._client._pipeline = session_pipeline # pylint: disable=protected-access + generated = AzureBlobStorage( + container_url, self.api_version, base_url=container_url, pipeline=session_pipeline + ) return generated policies.append( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 6287bd416510..23c0f655d391 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -159,8 +159,9 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: sub_kwargs["transport"] = transport _, session_pipeline = self._create_pipeline(credential, **sub_kwargs) - generated = AzureBlobStorage(container_url, self.api_version, base_url=container_url) - generated._client._pipeline = session_pipeline # pylint: disable=protected-access + generated = AzureBlobStorage( + container_url, self.api_version, base_url=container_url, pipeline=session_pipeline + ) return generated policies.append( diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index 68f900dff2c1..bb413ca151c3 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -2785,10 +2785,14 @@ def find_session_policy(pipeline): pass blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" - container1.upload_blob(blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header("c1_upload")) + container1.upload_blob( + blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header("c1_upload") + ) assert captured["c1_upload"].startswith("Bearer ") - blob1_actual = container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download")).readall() + blob1_actual = container1.download_blob( + blob1_name, raw_response_hook=capture_auth_header("c1_download") + ).readall() assert blob1_data == blob1_actual assert captured["c1_download"].startswith("Session ") session1 = parse_session_token(captured["c1_download"]) @@ -2801,22 +2805,30 @@ def find_session_policy(pipeline): pass blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" - container2.upload_blob(blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header("c2_upload")) + container2.upload_blob( + blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header("c2_upload") + ) assert captured["c2_upload"].startswith("Bearer ") - blob2_actual = container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download")).readall() + blob2_actual = container2.download_blob( + blob2_name, raw_response_hook=capture_auth_header("c2_download") + ).readall() assert blob2_data == blob2_actual assert captured["c2_download"].startswith("Session ") session2 = parse_session_token(captured["c2_download"]) assert session1 != session2 - blob1_actual = container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download2")).readall() + blob1_actual = container1.download_blob( + blob1_name, raw_response_hook=capture_auth_header("c1_download2") + ).readall() assert blob1_data == blob1_actual assert captured["c1_download2"].startswith("Session ") assert session1 == parse_session_token(captured["c1_download2"]) - blob2_actual = container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download2")).readall() + blob2_actual = container2.download_blob( + blob2_name, raw_response_hook=capture_auth_header("c2_download2") + ).readall() assert blob2_data == blob2_actual assert captured["c2_download2"].startswith("Session ") assert session2 == parse_session_token(captured["c2_download2"]) @@ -2825,7 +2837,9 @@ def find_session_policy(pipeline): cached = policy._cache._entry[container1_name] cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) - blob1_actual = container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download3")).readall() + blob1_actual = container1.download_blob( + blob1_name, raw_response_hook=capture_auth_header("c1_download3") + ).readall() assert blob1_data == blob1_actual assert captured["c1_download3"].startswith("Session ") assert session1 != parse_session_token(captured["c1_download3"]) diff --git a/sdk/storage/azure-storage-blob/tests/test_container_async.py b/sdk/storage/azure-storage-blob/tests/test_container_async.py index 402be270ffb6..88f2a242a2d3 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_container_async.py @@ -2781,7 +2781,9 @@ def _hook(response): pass blob_name, blob_data = self.get_resource_name("blob"), b"abc123" - await container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header("upload")) + await container.upload_blob( + blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header("upload") + ) assert captured["upload"].startswith("Bearer ") blob_actual = await ( From edb67d3bb563194e9d4cafa1c1e4d76c3064bef9 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 13:36:53 -0400 Subject: [PATCH 23/29] Removed the use_session kwarg in policies --- .../azure/storage/blob/_shared/base_client.py | 1 - .../azure/storage/blob/_shared/base_client_async.py | 1 - .../azure/storage/blob/_shared/policies.py | 9 +++------ .../azure/storage/blob/_shared/policies_async.py | 9 +++------ 4 files changed, 6 insertions(+), 14 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py index 99cfd7c1ca80..20398952eea7 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client.py @@ -354,7 +354,6 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: StorageSessionPolicy( account_name=self.account_name, session_client_factory=_session_client_factory, - use_session=True, ) ) policies.extend( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py index 23c0f655d391..98e961a3fad0 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/base_client_async.py @@ -168,7 +168,6 @@ def _session_client_factory(container_url: str) -> AzureBlobStorage: AsyncStorageSessionPolicy( account_name=self.account_name, session_client_factory=_session_client_factory, - use_session=True, ) ) policies.extend( diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index b3fd7195bf1c..bbb5fa459485 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -983,7 +983,6 @@ def __init__( *, account_name: str, session_client_factory: Callable[[str], Any], - use_session: bool = False, ) -> None: """Constructs a StorageSessionPolicy. @@ -993,8 +992,6 @@ def __init__( returns a session-disabled generated client (AzureBlobStorage) whose pipeline uses OAuth/bearer auth. Invoked to issue CreateSession. :paramtype session_client_factory: Callable[[str], Any] - :keyword bool use_session: Whether session authentication is enabled. - When set to False, the policy is a pass-through no-op. :raises ValueError: if `account_name` or `session_client_factory` is `None`. """ if account_name is None or session_client_factory is None: @@ -1002,7 +999,7 @@ def __init__( super().__init__() self._account_name = account_name self._session_client_factory = session_client_factory - self._use_session = use_session + self._enabled = True self._cache = SessionCache() @staticmethod @@ -1122,7 +1119,7 @@ def on_request(self, request: "PipelineRequest") -> Optional[str]: :return: The container name if a session was applied, else None. :rtype: str or None """ - if not self._use_session: + if not self._enabled: return None analysis = self._analyze_request(request) if analysis is None: @@ -1159,7 +1156,7 @@ def on_response( if error_code == StorageErrorCode.FEATURE_NOT_ENABLED: _LOGGER.info("Session feature not enabled on this account; disabling session auth.") - self._use_session = False + self._enabled = False return response # Unavailable / 5xx → negative-cache cooldown. diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index e37118ff5288..dbfa617536a6 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -389,7 +389,6 @@ def __init__( *, account_name: str, session_client_factory: Callable[[str], Any], - use_session: bool = False, ) -> None: """Constructs an AsyncStorageSessionPolicy. @@ -399,7 +398,6 @@ def __init__( returns a session-disabled generated async client whose pipeline uses OAuth/bearer auth. Invoked (and awaited) to issue CreateSession. :paramtype session_client_factory: Callable[[str], Any] - :keyword bool use_session: Whether session authentication is enabled. :raises ValueError: if account_name or session_client_factory is None. """ if account_name is None or session_client_factory is None: @@ -407,12 +405,11 @@ def __init__( super().__init__() self._account_name = account_name self._session_client_factory = session_client_factory - self._use_session = use_session + self._enabled = True self._cache = AsyncSessionCache() self._signer = StorageSessionPolicy( account_name=account_name, session_client_factory=session_client_factory, - use_session=use_session, ) async def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: @@ -464,7 +461,7 @@ async def on_request(self, request: "PipelineRequest") -> Optional[str]: :return: The container name if a session was applied, else None. :rtype: str or None """ - if not self._use_session: + if not self._enabled: return None analysis = StorageSessionPolicy._analyze_request(request) # pylint: disable=protected-access if analysis is None: @@ -508,7 +505,7 @@ async def on_response( if error_code == StorageErrorCode.FEATURE_NOT_ENABLED: _LOGGER.info("Session feature not enabled on this account; disabling session auth.") - self._use_session = False + self._enabled = False return response # Unavailable / 5xx → negative-cache cooldown. From 8cce52442aacb0d1818d02bcbc9c71143a4f285c Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 14:07:51 -0400 Subject: [PATCH 24/29] Actually retry on 401 --- .../azure/storage/blob/_shared/policies.py | 19 ++++++++++++++++++- .../storage/blob/_shared/policies_async.py | 3 ++- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index bbb5fa459485..593e02a481f1 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -953,6 +953,14 @@ def put_fallback(self, container_name: str) -> None: None, None, datetime.now(UTC) + SessionCache.FALLBACK_COOLDOWN, is_fallback=True ) + def invalidate(self, container_name: str, session_token: Optional[str] = None) -> None: + if session_token is None: + self._entry.pop(container_name, None) + return + cached = self._entry.get(container_name) + if cached is not None and cached.session_token == session_token: + self._entry.pop(container_name, None) + class StorageSessionPolicy(HTTPPolicy): """ @@ -1018,6 +1026,14 @@ def _analyze_request(request: "PipelineRequest") -> Optional[Tuple[str, str]]: container_url = f"{parsed.scheme}://{parsed.netloc}/{container_name}" return container_name, container_url + @staticmethod + def _used_session_token(request: "PipelineRequest") -> Optional[str]: + """Use for distinguishing between a successful concurrent refresh and invalidated token.""" + auth = request.http_request.headers.get("Authorization", "") + if not auth.startswith("Session "): + return None + return auth[len("Session ") :].split(":", 1)[0] or None + @staticmethod def _extract_session(response: Any) -> Tuple[str, str, datetime]: creds = getattr(response, "credentials", None) @@ -1175,8 +1191,9 @@ def on_response( # 401 → invalidate + re-acquire ONCE, then resend. if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) + used_token = self._used_session_token(request) with self._cache.lock_container(container_name): - self._cache.put_fallback(container_name) + self._cache.invalidate(container_name, used_token) request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True retried_container = self.on_request(request) retried_response = self.next.send(request) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index dbfa617536a6..1017f9eb6a66 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -524,8 +524,9 @@ async def on_response( # 401 → invalidate + re-acquire ONCE, then resend. if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) + used_token = StorageSessionPolicy._used_session_token(request) # pylint: disable=protected-access async with self._cache.lock_container_async(container_name): - self._cache.put_fallback(container_name) + self._cache.invalidate(container_name, used_token) request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True retried_container = await self.on_request(request) retried_response = await self.next.send(request) From e5c7ee7bbb8a55773707ec478faebfc9e0838201 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 14:52:22 -0400 Subject: [PATCH 25/29] Factored out helper methods in test files --- .../tests/test_container.py | 85 +++++++------------ .../tests/test_container_async.py | 83 ++++++------------ .../azure-storage-blob/tests/test_helpers.py | 57 +++++++++++++ 3 files changed, 117 insertions(+), 108 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_container.py b/sdk/storage/azure-storage-blob/tests/test_container.py index bb413ca151c3..a547f93f9d24 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container.py +++ b/sdk/storage/azure-storage-blob/tests/test_container.py @@ -14,6 +14,7 @@ from devtools_testutils import recorded_by_proxy, set_custom_default_matcher from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import BlobPreparer +from test_helpers import CaptureAuthHeader, _find_session_policy, _parse_session_token from azure.core import MatchConditions from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceModifiedError, ResourceNotFoundError @@ -2751,26 +2752,7 @@ def test_create_session(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") credential = self.get_credential(BlobServiceClient) - captured = {} - - def capture_auth_header(label): - def _hook(response): - auth = response.http_request.headers.get("Authorization", "") - captured[label] = auth - - return _hook - - def parse_session_token(auth): - # "Session {token}:{signature}" -> token - assert auth.startswith("Session ") - return auth[len("Session ") :].split(":", 1)[0] - - def find_session_policy(pipeline): - # Match by class name to avoid importing internals into the test module. - for p in getattr(pipeline, "_impl_policies", []): - if type(p).__name__ == "StorageSessionPolicy": - return p - raise AssertionError("StorageSessionPolicy not found on the pipeline") + capture_auth_header = CaptureAuthHeader() service = BlobServiceClient( self.account_url(storage_account_name, "blob"), @@ -2786,16 +2768,16 @@ def find_session_policy(pipeline): blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" container1.upload_blob( - blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header("c1_upload") + blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c1_upload") ) - assert captured["c1_upload"].startswith("Bearer ") + assert capture_auth_header["c1_upload"].startswith("Bearer ") blob1_actual = container1.download_blob( - blob1_name, raw_response_hook=capture_auth_header("c1_download") + blob1_name, raw_response_hook=capture_auth_header.hook("c1_download") ).readall() assert blob1_data == blob1_actual - assert captured["c1_download"].startswith("Session ") - session1 = parse_session_token(captured["c1_download"]) + assert capture_auth_header["c1_download"].startswith("Session ") + session1 = _parse_session_token(capture_auth_header["c1_download"]) container2_name = self.get_resource_name("utcontainer2") container2 = service.get_container_client(container2_name) @@ -2806,44 +2788,44 @@ def find_session_policy(pipeline): blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" container2.upload_blob( - blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header("c2_upload") + blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c2_upload") ) - assert captured["c2_upload"].startswith("Bearer ") + assert capture_auth_header["c2_upload"].startswith("Bearer ") blob2_actual = container2.download_blob( - blob2_name, raw_response_hook=capture_auth_header("c2_download") + blob2_name, raw_response_hook=capture_auth_header.hook("c2_download") ).readall() assert blob2_data == blob2_actual - assert captured["c2_download"].startswith("Session ") - session2 = parse_session_token(captured["c2_download"]) + assert capture_auth_header["c2_download"].startswith("Session ") + session2 = _parse_session_token(capture_auth_header["c2_download"]) assert session1 != session2 blob1_actual = container1.download_blob( - blob1_name, raw_response_hook=capture_auth_header("c1_download2") + blob1_name, raw_response_hook=capture_auth_header.hook("c1_download2") ).readall() assert blob1_data == blob1_actual - assert captured["c1_download2"].startswith("Session ") - assert session1 == parse_session_token(captured["c1_download2"]) + assert capture_auth_header["c1_download2"].startswith("Session ") + assert session1 == _parse_session_token(capture_auth_header["c1_download2"]) blob2_actual = container2.download_blob( - blob2_name, raw_response_hook=capture_auth_header("c2_download2") + blob2_name, raw_response_hook=capture_auth_header.hook("c2_download2") ).readall() assert blob2_data == blob2_actual - assert captured["c2_download2"].startswith("Session ") - assert session2 == parse_session_token(captured["c2_download2"]) + assert capture_auth_header["c2_download2"].startswith("Session ") + assert session2 == _parse_session_token(capture_auth_header["c2_download2"]) - policy = find_session_policy(service._pipeline) + policy = _find_session_policy(service._pipeline) cached = policy._cache._entry[container1_name] cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) blob1_actual = container1.download_blob( - blob1_name, raw_response_hook=capture_auth_header("c1_download3") + blob1_name, raw_response_hook=capture_auth_header.hook("c1_download3") ).readall() assert blob1_data == blob1_actual - assert captured["c1_download3"].startswith("Session ") - assert session1 != parse_session_token(captured["c1_download3"]) - assert session2 != parse_session_token(captured["c1_download3"]) + assert capture_auth_header["c1_download3"].startswith("Session ") + assert session1 != _parse_session_token(capture_auth_header["c1_download3"]) + assert session2 != _parse_session_token(capture_auth_header["c1_download3"]) @BlobPreparer() @recorded_by_proxy @@ -2851,14 +2833,7 @@ def test_sessions_disabled(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") credential = self.get_credential(BlobServiceClient) - captured = {} - - def capture_auth_header(label): - def _hook(response): - auth = response.http_request.headers.get("Authorization", "") - captured[label] = auth - - return _hook + capture_auth_header = CaptureAuthHeader() service = BlobServiceClient( self.account_url(storage_account_name, "blob"), @@ -2872,9 +2847,13 @@ def _hook(response): pass blob_name, blob_data = self.get_resource_name("blob"), b"abc123" - container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header("upload")) - assert captured["upload"].startswith("Bearer ") + container.upload_blob( + blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header.hook("upload") + ) + assert capture_auth_header["upload"].startswith("Bearer ") - blob_actual = container.download_blob(blob_name, raw_response_hook=capture_auth_header("download")).readall() + blob_actual = container.download_blob( + blob_name, raw_response_hook=capture_auth_header.hook("download") + ).readall() assert blob_data == blob_actual - assert captured["download"].startswith("Bearer ") + assert capture_auth_header["download"].startswith("Bearer ") diff --git a/sdk/storage/azure-storage-blob/tests/test_container_async.py b/sdk/storage/azure-storage-blob/tests/test_container_async.py index 88f2a242a2d3..aecc3876298f 100644 --- a/sdk/storage/azure-storage-blob/tests/test_container_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_container_async.py @@ -16,6 +16,7 @@ from devtools_testutils.storage import LogCaptured from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import BlobPreparer +from test_helpers import CaptureAuthHeader, _find_session_policy, _parse_session_token from azure.core import MatchConditions from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceModifiedError, ResourceNotFoundError @@ -2660,26 +2661,7 @@ async def test_create_session(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") credential = self.get_credential(BlobServiceClient, is_async=True) - captured = {} - - def capture_auth_header(label): - def _hook(response): - auth = response.http_request.headers.get("Authorization", "") - captured[label] = auth - - return _hook - - def parse_session_token(auth): - # "Session {token}:{signature}" -> token - assert auth.startswith("Session ") - return auth[len("Session ") :].split(":", 1)[0] - - def find_session_policy(pipeline): - # Match by class name to avoid importing internals into the test module. - for p in getattr(pipeline, "_impl_policies", []): - if type(p).__name__ == "AsyncStorageSessionPolicy": - return p - raise AssertionError("AsyncStorageSessionPolicy not found on the pipeline") + capture_auth_header = CaptureAuthHeader() service = BlobServiceClient( self.account_url(storage_account_name, "blob"), @@ -2695,16 +2677,16 @@ def find_session_policy(pipeline): blob1_name, blob1_data = self.get_resource_name("blob1"), b"abc123" await container1.upload_blob( - blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header("c1_upload") + blob1_name, blob1_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c1_upload") ) - assert captured["c1_upload"].startswith("Bearer ") + assert capture_auth_header["c1_upload"].startswith("Bearer ") blob1_actual = await ( - await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download")) + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header.hook("c1_download")) ).readall() assert blob1_data == blob1_actual - assert captured["c1_download"].startswith("Session ") - session1 = parse_session_token(captured["c1_download"]) + assert capture_auth_header["c1_download"].startswith("Session ") + session1 = _parse_session_token(capture_auth_header["c1_download"]) container2_name = self.get_resource_name("utcontainer2") container2 = service.get_container_client(container2_name) @@ -2715,44 +2697,44 @@ def find_session_policy(pipeline): blob2_name, blob2_data = self.get_resource_name("blob2"), b"def456" await container2.upload_blob( - blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header("c2_upload") + blob2_name, blob2_data, overwrite=True, raw_response_hook=capture_auth_header.hook("c2_upload") ) - assert captured["c2_upload"].startswith("Bearer ") + assert capture_auth_header["c2_upload"].startswith("Bearer ") blob2_actual = await ( - await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download")) + await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header.hook("c2_download")) ).readall() assert blob2_data == blob2_actual - assert captured["c2_download"].startswith("Session ") - session2 = parse_session_token(captured["c2_download"]) + assert capture_auth_header["c2_download"].startswith("Session ") + session2 = _parse_session_token(capture_auth_header["c2_download"]) assert session1 != session2 blob1_actual = await ( - await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download2")) + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header.hook("c1_download2")) ).readall() assert blob1_data == blob1_actual - assert captured["c1_download2"].startswith("Session ") - assert session1 == parse_session_token(captured["c1_download2"]) + assert capture_auth_header["c1_download2"].startswith("Session ") + assert session1 == _parse_session_token(capture_auth_header["c1_download2"]) blob2_actual = await ( - await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header("c2_download2")) + await container2.download_blob(blob2_name, raw_response_hook=capture_auth_header.hook("c2_download2")) ).readall() assert blob2_data == blob2_actual - assert captured["c2_download2"].startswith("Session ") - assert session2 == parse_session_token(captured["c2_download2"]) + assert capture_auth_header["c2_download2"].startswith("Session ") + assert session2 == _parse_session_token(capture_auth_header["c2_download2"]) - policy = find_session_policy(service._pipeline) + policy = _find_session_policy(service._pipeline, "AsyncStorageSessionPolicy") cached = policy._cache._entry[container1_name] cached.expires_at = datetime.fromtimestamp(0, tz=cached.expires_at.tzinfo) blob1_actual = await ( - await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header("c1_download3")) + await container1.download_blob(blob1_name, raw_response_hook=capture_auth_header.hook("c1_download3")) ).readall() assert blob1_data == blob1_actual - assert captured["c1_download3"].startswith("Session ") - assert session1 != parse_session_token(captured["c1_download3"]) - assert session2 != parse_session_token(captured["c1_download3"]) + assert capture_auth_header["c1_download3"].startswith("Session ") + assert session1 != _parse_session_token(capture_auth_header["c1_download3"]) + assert session2 != _parse_session_token(capture_auth_header["c1_download3"]) @BlobPreparer() @recorded_by_proxy_async @@ -2760,14 +2742,7 @@ async def test_sessions_disabled(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") credential = self.get_credential(BlobServiceClient, is_async=True) - captured = {} - - def capture_auth_header(label): - def _hook(response): - auth = response.http_request.headers.get("Authorization", "") - captured[label] = auth - - return _hook + capture = CaptureAuthHeader() service = BlobServiceClient( self.account_url(storage_account_name, "blob"), @@ -2781,13 +2756,11 @@ def _hook(response): pass blob_name, blob_data = self.get_resource_name("blob"), b"abc123" - await container.upload_blob( - blob_name, blob_data, overwrite=True, raw_response_hook=capture_auth_header("upload") - ) - assert captured["upload"].startswith("Bearer ") + await container.upload_blob(blob_name, blob_data, overwrite=True, raw_response_hook=capture.hook("upload")) + assert capture["upload"].startswith("Bearer ") blob_actual = await ( - await container.download_blob(blob_name, raw_response_hook=capture_auth_header("download")) + await container.download_blob(blob_name, raw_response_hook=capture.hook("download")) ).readall() assert blob_data == blob_actual - assert captured["download"].startswith("Bearer ") + assert capture["download"].startswith("Bearer ") diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers.py b/sdk/storage/azure-storage-blob/tests/test_helpers.py index c1ec6f3cb4af..d97a7602ac7c 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers.py @@ -58,6 +58,63 @@ def _create_file_share_oauth( return file_name, base_url +def _parse_session_token(auth: str) -> str: + """Extract the token from a "Session {token}:{signature}" Authorization header. + + :param str auth: The raw Authorization header value. + :return: The session token portion (before the ':'). + :rtype: str + """ + assert auth.startswith("Session ") + return auth[len("Session ") :].split(":", 1)[0] + + +def _find_session_policy(pipeline: Any, policy_name: str = "StorageSessionPolicy") -> Any: + """Return the session policy instance on a client pipeline, matched by class name. + + Matching by name avoids importing SDK internals into the test modules. + + :param pipeline: The client pipeline to search (e.g. ``client._pipeline``). + :type pipeline: Any + :param str policy_name: The policy class name to find. Use "StorageSessionPolicy" + for the sync stack and "AsyncStorageSessionPolicy" for the async stack. + :return: The matching policy instance. + :rtype: Any + """ + for policy in getattr(pipeline, "_impl_policies", []): + if type(policy).__name__ == policy_name: + return policy + raise AssertionError(f"{policy_name} not found on the pipeline") + + +class CaptureAuthHeader: + """Captures per-label Authorization headers via ``raw_response_hook`` callbacks. + + Encapsulates the captured-headers dict so the hook factory doesn't need a + closure over a test-local variable. Works for both sync and async clients, + since the response hook is invoked as a plain callable in both stacks. + """ + + def __init__(self) -> None: + self.captured: Dict[str, str] = {} + + def hook(self, label: str): + """Return a ``raw_response_hook`` that records the request's Authorization header. + + :param str label: The key under which to store the captured header. + :return: A callable suitable for ``raw_response_hook``. + :rtype: callable + """ + + def _hook(response): + self.captured[label] = response.http_request.headers.get("Authorization", "") + + return _hook + + def __getitem__(self, label: str) -> str: + return self.captured[label] + + class ProgressTracker: def __init__(self, total: int, step: int): self.total = total From ffe44dbdf1d806a761b8394f092ec911ba6fdd62 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 16:30:09 -0400 Subject: [PATCH 26/29] Refactoring --- .../azure/storage/blob/_shared/policies.py | 186 ++++++++++-------- .../storage/blob/_shared/policies_async.py | 19 +- 2 files changed, 109 insertions(+), 96 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 593e02a481f1..9a7c81b8f56a 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -77,6 +77,103 @@ def encode_base64(data: Union[bytes, str]) -> str: return encoded.decode("utf-8") +def _extract_session(response: Any) -> Tuple[str, str, datetime]: + creds = getattr(response, "credentials", None) + if not creds or not getattr(creds, "session_token", None) or not getattr(creds, "session_key", None): + raise ValueError("CreateSession response missing SessionToken/SessionKey") + session_token: str = creds.session_token + session_key: str = creds.session_key + expires_at = getattr(response, "expiration", None) + if expires_at is None: + expires_at = datetime.now(UTC) + timedelta(minutes=5) + elif expires_at.tzinfo is None: + expires_at = expires_at.replace(tzinfo=UTC) + return session_token, session_key, expires_at + + +def _analyze_request(request: "PipelineRequest") -> Optional[Tuple[str, str]]: + http_request = request.http_request + if http_request.method != "GET": + return None + parsed = urlparse(http_request.url) + segments = [seg for seg in parsed.path.split("/") if seg] + if len(segments) < 2: + return None + query = http_request.query + if "comp" in query or query.get("restype") == "container": + return None + container_name = segments[0] + container_url = f"{parsed.scheme}://{parsed.netloc}/{container_name}" + return container_name, container_url + + +def _used_session_token(request: "PipelineRequest") -> Optional[str]: + """Use for distinguishing between a successful concurrent refresh and invalidated token.""" + auth = request.http_request.headers.get("Authorization", "") + if not auth.startswith("Session "): + return None + return auth[len("Session ") :].split(":", 1)[0] or None + + +_SIGNED_HEADERS = ( + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", +) + + +def _apply_session_auth( + request: "PipelineRequest", session_token: str, session_key: str, account_name: str +) -> None: + """Sign an eligible request with the SharedKey protocol under the Session scheme. + + Shared by the sync and async session policies; ``account_name`` is passed in + rather than read from ``self`` so neither policy needs to instantiate the other. + + :param ~azure.core.pipeline.PipelineRequest request: The request to sign in place. + :param str session_token: The session token to embed in the Authorization header. + :param str session_key: The HMAC signing key for the session. + :param str account_name: Storage account name; the signer identity. + :raises ~azure.storage.blob._shared.authentication.AzureSigningError: if signing fails. + """ + http_request = request.http_request + http_request.headers["x-ms-date"] = format_date_time(time()) + + # 1) Standard headers. Storage omits content-length when it is "0". + headers = {name.lower(): value for name, value in http_request.headers.items() if value} + if headers.get("content-length") == "0": + del headers["content-length"] + signed_headers = "\n".join(headers.get(h, "") for h in _SIGNED_HEADERS) + "\n" + + # 2) Canonicalized x-ms-* headers, sorted by the service-emulating comparator. + x_ms_headers = _storage_header_sort( + [(n.lower(), v) for n, v in http_request.headers.items() if n.lower().startswith("x-ms-")] + ) + canonicalized_headers = "".join(f"{n}:{v}\n" for n, v in x_ms_headers if v is not None) + + # 3) Canonicalized resource + query (query values must be url-decoded). + canonicalized_resource = "/" + account_name + urlparse(http_request.url).path + canonicalized_resource += "".join( + f"\n{n.lower()}:{unquote(v)}" for n, v in sorted(http_request.query.items()) if v is not None + ) + + string_to_sign = http_request.method + "\n" + signed_headers + canonicalized_headers + canonicalized_resource + + try: + signature = sign_string(session_key, string_to_sign) + except Exception as ex: # pylint: disable=broad-except + raise AzureSigningError(str(ex)) from ex + http_request.headers["Authorization"] = f"Session {session_token}:{signature}" + + # Are we out of retries? def is_exhausted(settings): retry_counts = ( @@ -972,19 +1069,6 @@ class StorageSessionPolicy(HTTPPolicy): When disabled, all requests are delegated to the bearer token policy. """ - _SIGNED_HEADERS = ( - "content-encoding", - "content-language", - "content-length", - "content-md5", - "content-type", - "date", - "if-modified-since", - "if-match", - "if-none-match", - "if-unmodified-since", - "byte_range", - ) def __init__( self, @@ -1010,74 +1094,6 @@ def __init__( self._enabled = True self._cache = SessionCache() - @staticmethod - def _analyze_request(request: "PipelineRequest") -> Optional[Tuple[str, str]]: - http_request = request.http_request - if http_request.method != "GET": - return None - parsed = urlparse(http_request.url) - segments = [seg for seg in parsed.path.split("/") if seg] - if len(segments) < 2: - return None - query = http_request.query - if "comp" in query or query.get("restype") == "container": - return None - container_name = segments[0] - container_url = f"{parsed.scheme}://{parsed.netloc}/{container_name}" - return container_name, container_url - - @staticmethod - def _used_session_token(request: "PipelineRequest") -> Optional[str]: - """Use for distinguishing between a successful concurrent refresh and invalidated token.""" - auth = request.http_request.headers.get("Authorization", "") - if not auth.startswith("Session "): - return None - return auth[len("Session ") :].split(":", 1)[0] or None - - @staticmethod - def _extract_session(response: Any) -> Tuple[str, str, datetime]: - creds = getattr(response, "credentials", None) - if not creds or not getattr(creds, "session_token", None) or not getattr(creds, "session_key", None): - raise ValueError("CreateSession response missing SessionToken/SessionKey") - session_token: str = creds.session_token - session_key: str = creds.session_key - expires_at = getattr(response, "expiration", None) - if expires_at is None: - expires_at = datetime.now(UTC) + timedelta(minutes=5) - elif expires_at.tzinfo is None: - expires_at = expires_at.replace(tzinfo=UTC) - return session_token, session_key, expires_at - - def _apply_session_auth(self, request: "PipelineRequest", session_token: str, session_key: str) -> None: - http_request = request.http_request - http_request.headers["x-ms-date"] = format_date_time(time()) - - # 1) Standard headers. Storage omits content-length when it is "0". - headers = {name.lower(): value for name, value in http_request.headers.items() if value} - if headers.get("content-length") == "0": - del headers["content-length"] - signed_headers = "\n".join(headers.get(h, "") for h in self._SIGNED_HEADERS) + "\n" - - # 2) Canonicalized x-ms-* headers, sorted by the service-emulating comparator. - x_ms_headers = _storage_header_sort( - [(n.lower(), v) for n, v in http_request.headers.items() if n.lower().startswith("x-ms-")] - ) - canonicalized_headers = "".join(f"{n}:{v}\n" for n, v in x_ms_headers if v is not None) - - # 3) Canonicalized resource + query (query values must be url-decoded). - canonicalized_resource = "/" + self._account_name + urlparse(http_request.url).path - canonicalized_resource += "".join( - f"\n{n.lower()}:{unquote(v)}" for n, v in sorted(http_request.query.items()) if v is not None - ) - - string_to_sign = http_request.method + "\n" + signed_headers + canonicalized_headers + canonicalized_resource - - try: - signature = sign_string(session_key, string_to_sign) - except Exception as ex: # pylint: disable=broad-except - raise AzureSigningError(str(ex)) from ex - http_request.headers["Authorization"] = f"Session {session_token}:{signature}" - def _resolve_session(self, container_name: str, container_url: str) -> Optional[Session]: session = self._cache.get(container_name) if session is None: @@ -1090,7 +1106,7 @@ def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: config = CreateSessionConfiguration(authentication_type="HMAC") client = self._session_client_factory(container_url) response = client.container.create_session(create_session_configuration=config) - return self._extract_session(response) + return _extract_session(response) def _refresh_session_token(self, container_name: str, container_url: str) -> Optional[Session]: """Acquire (or re-use) a session for the container under per-container single-flight. @@ -1137,7 +1153,7 @@ def on_request(self, request: "PipelineRequest") -> Optional[str]: """ if not self._enabled: return None - analysis = self._analyze_request(request) + analysis = _analyze_request(request) if analysis is None: return None container_name, container_url = analysis @@ -1146,7 +1162,7 @@ def on_request(self, request: "PipelineRequest") -> Optional[str]: if session is None or not session.session_token or not session.session_key: return None - self._apply_session_auth(request, session.session_token, session.session_key) + _apply_session_auth(request, session.session_token, session.session_key, self._account_name) return container_name def on_response( @@ -1191,7 +1207,7 @@ def on_response( # 401 → invalidate + re-acquire ONCE, then resend. if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) - used_token = self._used_session_token(request) + used_token = _used_session_token(request) with self._cache.lock_container(container_name): self._cache.invalidate(container_name, used_token) request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 1017f9eb6a66..7d825ccb7928 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -21,7 +21,11 @@ from .constants import DEFAULT_OAUTH_SCOPE from .models import StorageErrorCode from .policies import ( + _analyze_request, + _apply_session_auth, + _extract_session, _prepare_content_validation, + _used_session_token, _validate_content_response, CreateSessionConfiguration, encode_base64, @@ -29,7 +33,6 @@ Session, SessionCache, StorageRetryPolicy, - StorageSessionPolicy, SESSION_RETRIED_CONTEXT_KEY, ) from .streams_async import AsyncStructuredMessageDecoder @@ -407,16 +410,12 @@ def __init__( self._session_client_factory = session_client_factory self._enabled = True self._cache = AsyncSessionCache() - self._signer = StorageSessionPolicy( - account_name=account_name, - session_client_factory=session_client_factory, - ) async def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: config = CreateSessionConfiguration(authentication_type="HMAC") client = self._session_client_factory(container_url) response = await client.container.create_session(create_session_configuration=config) - return StorageSessionPolicy._extract_session(response) # pylint: disable=protected-access + return _extract_session(response) async def _refresh_session_token(self, container_name: str, container_url: str) -> Optional[Session]: """Acquire (or re-use) a session under per-container async single-flight. @@ -463,7 +462,7 @@ async def on_request(self, request: "PipelineRequest") -> Optional[str]: """ if not self._enabled: return None - analysis = StorageSessionPolicy._analyze_request(request) # pylint: disable=protected-access + analysis = _analyze_request(request) if analysis is None: return None container_name, container_url = analysis @@ -477,9 +476,7 @@ async def on_request(self, request: "PipelineRequest") -> Optional[str]: if session is None or session.is_fallback or not session.session_token or not session.session_key: return None - self._signer._apply_session_auth( # pylint: disable=protected-access - request, session.session_token, session.session_key - ) + _apply_session_auth(request, session.session_token, session.session_key, self._account_name) return container_name async def on_response( @@ -524,7 +521,7 @@ async def on_response( # 401 → invalidate + re-acquire ONCE, then resend. if status == 401 and not request.context.options.get(SESSION_RETRIED_CONTEXT_KEY): _LOGGER.info("Session authentication: HTTP 401 on '%s'; re-acquiring once.", container_name) - used_token = StorageSessionPolicy._used_session_token(request) # pylint: disable=protected-access + used_token = _used_session_token(request) async with self._cache.lock_container_async(container_name): self._cache.invalidate(container_name, used_token) request.context.options[SESSION_RETRIED_CONTEXT_KEY] = True From 7f9d2c0456c974f1173fa3c5133fd8d84acdc6eb Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 16:32:04 -0400 Subject: [PATCH 27/29] More refactoring --- .../azure/storage/blob/_shared/policies.py | 29 +++++++++---------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 9a7c81b8f56a..b9ff5d43816f 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -60,6 +60,20 @@ _LOGGER = logging.getLogger(__name__) +_SIGNED_HEADERS = ( + "content-encoding", + "content-language", + "content-length", + "content-md5", + "content-type", + "date", + "if-modified-since", + "if-match", + "if-none-match", + "if-unmodified-since", + "byte_range", +) + CONTENT_LENGTH_HEADER = "Content-Length" MD5_HEADER = "Content-MD5" CRC64_HEADER = "x-ms-content-crc64" @@ -115,21 +129,6 @@ def _used_session_token(request: "PipelineRequest") -> Optional[str]: return auth[len("Session ") :].split(":", 1)[0] or None -_SIGNED_HEADERS = ( - "content-encoding", - "content-language", - "content-length", - "content-md5", - "content-type", - "date", - "if-modified-since", - "if-match", - "if-none-match", - "if-unmodified-since", - "byte_range", -) - - def _apply_session_auth( request: "PipelineRequest", session_token: str, session_key: str, account_name: str ) -> None: From e247fe75c1144fdf2da8c4d43cdfd971ffaac0a1 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 17:18:36 -0400 Subject: [PATCH 28/29] invalidate update --- .../azure/storage/blob/_shared/policies.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index b9ff5d43816f..0a3e96b33ba5 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -1050,10 +1050,7 @@ def put_fallback(self, container_name: str) -> None: ) def invalidate(self, container_name: str, session_token: Optional[str] = None) -> None: - if session_token is None: - self._entry.pop(container_name, None) - return - cached = self._entry.get(container_name) + cached = self._entry.get(container_name, None) if cached is not None and cached.session_token == session_token: self._entry.pop(container_name, None) From 5eacfb890b939ca5f9a992bfca4bce6aa6905a0f Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Thu, 4 Jun 2026 17:29:34 -0400 Subject: [PATCH 29/29] Inline resolve session --- .../azure/storage/blob/_shared/policies.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py index 0a3e96b33ba5..f7963b428f8d 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies.py @@ -1090,13 +1090,6 @@ def __init__( self._enabled = True self._cache = SessionCache() - def _resolve_session(self, container_name: str, container_url: str) -> Optional[Session]: - session = self._cache.get(container_name) - if session is None: - session = self._refresh_session_token(container_name, container_url) - if session is None or session.is_fallback: - return None - return session def _create_session(self, container_url: str) -> Tuple[str, str, datetime]: config = CreateSessionConfiguration(authentication_type="HMAC") @@ -1154,8 +1147,13 @@ def on_request(self, request: "PipelineRequest") -> Optional[str]: return None container_name, container_url = analysis - session = self._resolve_session(container_name, container_url) - if session is None or not session.session_token or not session.session_key: + session = self._cache.get(container_name) + if session is None: + # True miss/expiry (a live fallback sentinel is returned by get(), + # so we never reach refresh while the cooldown is active). + session = self._refresh_session_token(container_name, container_url) + + if session is None or session.is_fallback or not session.session_token or not session.session_key: return None _apply_session_auth(request, session.session_token, session.session_key, self._account_name)