diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 1efd6d4841df..dedb6a07202e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -39,21 +39,14 @@ def _clean_location_list(locations: Optional[Sequence[Optional[str]]]) -> list[str]: - """Drop ``None``/empty/whitespace-only entries from a location list. + """Return the list with None, empty, and whitespace-only entries removed. - We filter these out at every intake point because the region normalizer - turns ``None`` (and empty/whitespace input) into an empty string. If such - a value were allowed into an exclusion list, every endpoint that wasn't - found in our by-endpoint lookup map — which also returns an empty string - as its default — would compare equal to it and get silently excluded. - Stripping the bad inputs once, at the boundary, prevents that collision. + Used at every location-list entry point so blank inputs cannot accidentally + match real endpoints during later comparisons. - :param locations: The raw location list to clean, or ``None``. May contain - ``None`` entries and empty/whitespace-only strings. + :param locations: The raw list of region names, or None. :type locations: Optional[Sequence[Optional[str]]] - :return: A new list containing only the non-empty, non-whitespace string - entries from ``locations``. Returns an empty list when ``locations`` - is ``None`` or empty. + :return: The cleaned list. Empty when input is None or empty. :rtype: list[str] """ if not locations: @@ -62,41 +55,15 @@ def _clean_location_list(locations: Optional[Sequence[Optional[str]]]) -> list[s def _normalize_region_name(region_name: Optional[str]) -> str: - """Canonicalize a region name for equality comparison. - - This is the single function every region-name comparison in this module - routes through (exclusion matching, preferred-location resolution, the - by-endpoint normalized lookup, the config-mismatch warning emitter, and - the most-preferred check in ``should_refresh_endpoints``). Two names - normalize to the same string iff the SDK treats them as the same region. - - The rule, in plain terms: - - - **Stripped:** outer and inner whitespace (via ``.strip()`` followed by - ``.split()`` / ``"".join(...)``), case (``.lower()``), and the two - separator characters customers commonly write — ``-`` and ``_``. - - **Preserved:** ASCII letters and **digits**. The digit invariant is the - load-bearing one: it is what keeps ``"East US"`` and ``"East US 2"`` - distinct after normalization (``"eastus"`` vs ``"eastus2"``). Stripping - digits, even as part of a well-meaning cleanup, would silently collide - these regions and route each one's traffic to the other. - - **``None`` / empty / whitespace-only input maps to ``""``.** That empty - string is a sentinel that also appears as the default value of the - by-endpoint name lookup, so callers must filter such inputs out at the - configuration boundary with :func:`_clean_location_list` before they - reach this function. See that function's docstring for the collision - scenario. - - **Idempotent.** ``_normalize_region_name(_normalize_region_name(x)) == - _normalize_region_name(x)`` for every ``x``. Defensive double-calls are - safe. - - :param region_name: A region name to canonicalize, or ``None``. Customer - input (preferred/excluded locations) and account-side names from the - gateway both flow through here. + """Return a canonical form of a region name for equality checks. + + Lowercases the text and strips spaces, hyphens, and underscores so values + like "East US 2", "east-us-2", and "east_us_2" compare equal. Digits are + kept so "East US" and "East US 2" stay distinct. None becomes "". + + :param region_name: A region name, or None. :type region_name: Optional[str] - :return: The canonicalized form, suitable for direct ``==`` / ``in`` - comparison against other canonicalized names. Returns ``""`` for - ``None`` or whitespace-only input. + :return: The canonical form, or "" for None or whitespace-only input. :rtype: str """ if region_name is None: @@ -648,9 +615,8 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self._read_locations_by_normalized, ) - # Config-time visibility for misconfigured region names. Dedupe ensures periodic - # refreshes do not re-emit identical warnings; new mismatches still surface because - # the dedupe key includes the available account regions snapshot. + # Warn once when configured region names do not match any account region. + # Repeated identical warnings are suppressed; a different set of regions emits a new one. if self.connection_policy.PreferredLocations: self._emit_config_mismatch_warning_once( self.connection_policy.PreferredLocations, diff --git a/sdk/cosmos/azure-cosmos/cspell.json b/sdk/cosmos/azure-cosmos/cspell.json index 62b0df0f579f..6997b194f868 100644 --- a/sdk/cosmos/azure-cosmos/cspell.json +++ b/sdk/cosmos/azure-cosmos/cspell.json @@ -13,6 +13,7 @@ "ppcb", "reindexing", "reranker", - "toctou" + "toctou", + "ufffd" ] } diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py index 2f24bf893262..1a45ada2ebf3 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider.py @@ -921,6 +921,131 @@ def refresher_fn(): self.assertEqual(none_seen['count'], 0, "Cache entry should never be None during a refresh — it should be atomically replaced") + # The tests below run through SmartRoutingMapProvider to confirm that a + # bad cache snapshot surfaces as a CosmosHttpResponseError the caller can + # handle, not as a raw ValueError or AssertionError. + + class _SequencedSnapshotClient(object): + """Mock client that returns the next payload from response_sequence on + each fresh read, and an empty page when the If-None-Match matches the + last etag (acts like a 304 reply).""" + + def __init__(self, response_sequence): + self.response_sequence = response_sequence + self.url_connection = "https://mock-sequenced-test.documents.azure.com:443/" + self.call_count = 0 + self._last_etag = None + + def _ReadPartitionKeyRanges(self, _collection_link, _feed_options=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get('If-None-Match') + if inm is not None and inm == self._last_etag: + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = 304 + captured_headers = kwargs.get('_internal_response_headers_capture') + if captured_headers is not None: + captured_headers.clear() + captured_headers.update({'ETag': self._last_etag}) + return [] + idx = min(self.call_count, len(self.response_sequence) - 1) + payload = self.response_sequence[idx] + self.call_count += 1 + etag = f'"etag-{self.call_count}"' + self._last_etag = etag + captured_headers = kwargs.get('_internal_response_headers_capture') + if captured_headers is not None: + captured_headers.clear() + captured_headers.update({'ETag': etag}) + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = 200 + return payload + + _OVERLAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10', 'minInclusive': '80', 'maxExclusive': 'A0'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90'}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GOOD_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90', 'parents': ['10']}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0', 'parents': ['10']}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + + def _reset_shared_cache_state(self, provider): + """Release the given provider and clear shared cache dicts so the next + sub-test or run starts with a clean slate.""" + provider.release() + with _shared_cache_lock: + _shared_routing_map_cache.clear() + _shared_collection_locks.clear() + _shared_locks_locks.clear() + _shared_cache_refcounts.clear() + + def test_smart_provider_does_not_leak_overlap_value_error_on_persistent_inconsistency(self): + """A persistent overlap or gap snapshot must raise 503 with sub_status + 21015 from SmartRoutingMapProvider.get_overlapping_ranges, not a bare + ValueError or AssertionError.""" + full_range = routing_range.Range("", "FF", True, False) + + for label, payload in (("overlap", self._OVERLAP_PAYLOAD), ("gap", self._GAP_PAYLOAD)): + with self.subTest(snapshot=label): + client = TestRoutingMapProvider._SequencedSnapshotClient([payload]) + provider = SmartRoutingMapProvider(client) + try: + with patch( + 'azure.cosmos._routing.routing_map_provider.time.sleep', + return_value=None, + ): + with self.assertRaises(CosmosHttpResponseError) as ctx: + provider.get_overlapping_ranges( + "dbs/db/colls/container", [full_range] + ) + exc = ctx.exception + self.assertEqual( + exc.status_code, + http_constants.StatusCodes.SERVICE_UNAVAILABLE, + f"Persistent {label} snapshot must surface as 503.", + ) + self.assertEqual( + exc.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + f"503 from a persistent {label} must set sub_status to 21015.", + ) + self.assertNotIsInstance(exc, AssertionError) + self.assertFalse(isinstance(exc, ValueError)) + finally: + self._reset_shared_cache_state(provider) + + def test_smart_provider_recovers_through_full_stack_after_transient_overlap(self): + """A bad overlap response followed by a good one must return the + expected ranges from get_overlapping_ranges.""" + full_range = routing_range.Range("", "FF", True, False) + client = TestRoutingMapProvider._SequencedSnapshotClient( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + try: + with patch( + 'azure.cosmos._routing.routing_map_provider.time.sleep', + return_value=None, + ): + overlapping = provider.get_overlapping_ranges( + "dbs/db/colls/container", [full_range] + ) + ids = [r['id'] for r in overlapping] + self.assertEqual(ids, ['L', '10/0', '10/1', 'R']) + finally: + self._reset_shared_cache_state(provider) + if __name__ == "__main__": # import sys;sys.argv = ['', 'Test.testName'] unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py index 8650e326f800..4d80e6dd076b 100644 --- a/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py +++ b/sdk/cosmos/azure-cosmos/tests/routing/test_routing_map_provider_async.py @@ -1,30 +1,30 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import gc import unittest +from typing import Any, Mapping, Optional +from unittest.mock import MagicMock, patch import pytest +from azure.cosmos import _base, http_constants from azure.cosmos._routing import routing_range as routing_range -from azure.cosmos._routing.aio.routing_map_provider import CollectionRoutingMap -from azure.cosmos._routing.aio.routing_map_provider import SmartRoutingMapProvider -from azure.cosmos._routing.aio.routing_map_provider import PartitionKeyRangeCache from azure.cosmos._routing._routing_map_provider_common import ( _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS, ) -from azure.cosmos import http_constants - -from typing import Optional, Mapping, Any -from unittest.mock import MagicMock, patch -import gc -from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos._routing.aio.routing_map_provider import ( - _shared_routing_map_cache, + CollectionRoutingMap, + PartitionKeyRangeCache, + SmartRoutingMapProvider, + _shared_cache_lock, + _shared_cache_refcounts, _shared_collection_locks, _shared_locks_locks, - _shared_cache_refcounts, - _shared_cache_lock, + _shared_routing_map_cache, ) +from azure.cosmos.exceptions import CosmosHttpResponseError @pytest.mark.cosmosEmulator @@ -108,9 +108,6 @@ def _instantiate_smart_routing_map_provider(self, partition_key_ranges): client = TestRoutingMapProviderAsync.MockedCosmosClientConnection(partition_key_ranges) return SmartRoutingMapProvider(client) - # --------------------------------------------------------------- - # SmartRoutingMapProvider.get_overlapping_ranges tests - # --------------------------------------------------------------- async def test_full_range_async(self): pkRange = routing_range.Range("", "FF", True, False) @@ -203,9 +200,6 @@ async def test_complex_async(self): expected = [self.partition_key_ranges[1], self.partition_key_ranges[4]] self.assertEqual(overlapping, expected) - # --------------------------------------------------------------- - # PartitionKeyRangeCache(async) caching tests - # --------------------------------------------------------------- async def test_get_routing_map_caches_on_first_call_async(self): """Initial call to get_routing_map fetches from service and caches the result.""" @@ -218,7 +212,6 @@ async def test_get_routing_map_caches_on_first_call_async(self): self.assertIsNotNone(result) self.assertEqual(len(list(result._orderedPartitionKeyRanges)), 5) - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) self.assertIn(collection_id, provider._collection_routing_map_by_item) @@ -340,7 +333,6 @@ async def test_is_cache_stale_etag_logic_async(self): TestRoutingMapProviderAsync.MockedCosmosClientConnection(self.partition_key_ranges) ) collection_link = "dbs/db/colls/container" - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) cached_map = await provider.get_routing_map(collection_link, feed_options={}) @@ -385,7 +377,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(IncompleteClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -436,7 +427,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(DeltaClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -494,7 +484,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(HeaderCapturingClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -559,7 +548,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -613,7 +601,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -688,7 +675,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(RapidSplitClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -751,7 +737,6 @@ async def _gen(): return _gen() provider = PartitionKeyRangeCache(MergeClient()) - from azure.cosmos import _base collection_link = "dbs/db/colls/container" collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) @@ -819,7 +804,6 @@ async def test_concurrent_refresh_serialized_by_lock_async(self): Verifies that coroutines don't corrupt the cache and all get a valid result. """ - import asyncio call_count = {'count': 0} original_ranges = self.partition_key_ranges fetch_event = asyncio.Event() @@ -872,7 +856,6 @@ async def test_cache_never_none_during_refresh_async(self): The cache entry is atomically replaced, never deleted. """ - import asyncio original_ranges = self.partition_key_ranges call_count = {'count': 0} @@ -893,7 +876,6 @@ async def _gen(): provider = PartitionKeyRangeCache(SlowClient()) collection_link = "dbs/db/colls/container" - from azure.cosmos import _base collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) # Populate cache @@ -924,6 +906,144 @@ async def refresher_fn(): self.assertEqual(none_seen['count'], 0, "Cache entry should never be None during a refresh — it should be atomically replaced") + # The tests below run through SmartRoutingMapProvider to confirm that a + # bad cache snapshot surfaces as a CosmosHttpResponseError the caller can + # handle, not as a raw ValueError or AssertionError. + + class _SequencedSnapshotAsyncClient(object): + """Async mock client that returns the next payload from + response_sequence on each fresh read, and an empty async generator + when the If-None-Match matches the last etag (acts like a 304 reply).""" + + def __init__(self, response_sequence): + self.response_sequence = response_sequence + self.url_connection = "https://mock-async-sequenced-test.documents.azure.com:443/" + self.call_count = 0 + self._last_etag = None + + def _ReadPartitionKeyRanges(self, _collection_link, _feed_options=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get('If-None-Match') + if inm is not None and inm == self._last_etag: + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = 304 + captured_headers = kwargs.get('_internal_response_headers_capture') + if captured_headers is not None: + captured_headers.clear() + captured_headers.update({'ETag': self._last_etag}) + + async def _empty(): + if False: + yield # pragma: no cover + return _empty() + + idx = min(self.call_count, len(self.response_sequence) - 1) + payload = self.response_sequence[idx] + self.call_count += 1 + etag = f'"etag-{self.call_count}"' + self._last_etag = etag + captured_headers = kwargs.get('_internal_response_headers_capture') + if captured_headers is not None: + captured_headers.clear() + captured_headers.update({'ETag': etag}) + status_capture = kwargs.get('_internal_response_status_capture') + if status_capture is not None: + status_capture[0] = 200 + + async def _gen(): + for r in payload: + yield r + return _gen() + + _OVERLAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10', 'minInclusive': '80', 'maxExclusive': 'A0'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90'}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GOOD_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90', 'parents': ['10']}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0', 'parents': ['10']}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + + @staticmethod + async def _no_sleep(_seconds): + return None + + def _reset_shared_cache_state(self, provider): + """Release the given provider and clear shared cache dicts so the next + sub-test or run starts with a clean slate.""" + provider.release() + with _shared_cache_lock: + _shared_routing_map_cache.clear() + _shared_collection_locks.clear() + _shared_locks_locks.clear() + _shared_cache_refcounts.clear() + + async def test_smart_provider_does_not_leak_overlap_value_error_on_persistent_inconsistency_async(self): + """A persistent overlap or gap snapshot must raise 503 with sub_status + 21015 from SmartRoutingMapProvider.get_overlapping_ranges, not a bare + ValueError or AssertionError.""" + full_range = routing_range.Range("", "FF", True, False) + + for label, payload in (("overlap", self._OVERLAP_PAYLOAD), ("gap", self._GAP_PAYLOAD)): + with self.subTest(snapshot=label): + client = TestRoutingMapProviderAsync._SequencedSnapshotAsyncClient([payload]) + provider = SmartRoutingMapProvider(client) + try: + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + with self.assertRaises(CosmosHttpResponseError) as ctx: + await provider.get_overlapping_ranges( + "dbs/db/colls/container", [full_range] + ) + exc = ctx.exception + self.assertEqual( + exc.status_code, + http_constants.StatusCodes.SERVICE_UNAVAILABLE, + f"Persistent {label} snapshot must surface as 503.", + ) + self.assertEqual( + exc.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + f"503 from a persistent {label} must set sub_status to 21015.", + ) + self.assertNotIsInstance(exc, AssertionError) + self.assertFalse(isinstance(exc, ValueError)) + finally: + self._reset_shared_cache_state(provider) + + async def test_smart_provider_recovers_through_full_stack_after_transient_overlap_async(self): + """A bad overlap response followed by a good one must return the + expected ranges from get_overlapping_ranges.""" + full_range = routing_range.Range("", "FF", True, False) + client = TestRoutingMapProviderAsync._SequencedSnapshotAsyncClient( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + try: + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + overlapping = await provider.get_overlapping_ranges( + "dbs/db/colls/container", [full_range] + ) + ids = [r['id'] for r in overlapping] + self.assertEqual(ids, ['L', '10/0', '10/1', 'R']) + finally: + self._reset_shared_cache_state(provider) + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_aio_extras_packaging_unit.py b/sdk/cosmos/azure-cosmos/tests/test_aio_extras_packaging_unit.py new file mode 100644 index 000000000000..beee8eb7dd9a --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_aio_extras_packaging_unit.py @@ -0,0 +1,66 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests that verify the ``aio`` extras are declared on the package.""" + +import re +import unittest +from importlib import metadata as importlib_metadata + +import pytest +from packaging.requirements import Requirement +from packaging.version import Version + +from azure.cosmos.aio import CosmosClient # noqa: F401 + + +@pytest.mark.cosmosEmulator +class TestAioExtrasPackaging(unittest.TestCase): + + def test_aio_extras_declared_in_distribution_metadata(self): + # The installed package must advertise the aio extra and pin it to + # azure-core with the aio extra at version 1.30.0 or newer, so + # installing with the aio extra pulls in the async transport. + try: + dist = importlib_metadata.distribution("azure-cosmos") + except importlib_metadata.PackageNotFoundError: + self.skipTest("azure-cosmos is not installed in this interpreter.") + + provides_extra = dist.metadata.get_all("Provides-Extra") or [] + self.assertIn("aio", provides_extra) + + requires_dist = dist.metadata.get_all("Requires-Dist") or [] + aio_reqs = [ + req for req in requires_dist + if re.search(r"extra\s*==\s*['\"]aio['\"]", req) + ] + self.assertTrue(aio_reqs, "no requirement is tagged for the 'aio' extra") + + joined = " ".join(aio_reqs).lower() + self.assertIn("azure-core", joined) + self.assertIn("[aio]", joined) + + # Check the azure-core[aio] requirement allows version 1.30.0 or newer. + # Asking the specifier whether an older version is allowed keeps the + # check valid across future version bumps and catches any regression. + core_req_str = next( + req for req in aio_reqs if "azure-core" in req.lower() + ) + # Drop the environment marker before parsing as a Requirement. + core_req = Requirement(core_req_str.split(";", 1)[0].strip()) + self.assertNotIn( + Version("1.29.99"), core_req.specifier, + f"azure-core[aio] requirement allows versions older than 1.30.0: " + f"{core_req.specifier!r}", + ) + + def test_azure_cosmos_aio_module_imports(self): + # If the async module cannot be imported the file would already + # have failed to load at the top, so this is a small explicit + # confirmation that the symbol is available. + self.assertTrue(callable(CosmosClient)) + + +if __name__ == "__main__": + unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_availability_strategy.py b/sdk/cosmos/azure-cosmos/tests/test_availability_strategy.py index 6ac32f7a3938..53ef61bb88a4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_availability_strategy.py +++ b/sdk/cosmos/azure-cosmos/tests/test_availability_strategy.py @@ -895,5 +895,51 @@ def test_default_availability_strategy_with_ppaf_enabled(self, operation): retry_write=True) self._clean_up_container(setup['db'].id, setup['col'].id) + # When the client is built with a hedging strategy and a per-call + # surface explicitly passes ``availability_strategy=None``, the + # request must still hedge per the client's strategy. The None + # means "use what the client was configured with." + def test_per_request_none_falls_back_to_client_strategy(self): + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, self.REGION_1) + failed_over_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, self.REGION_2) + + # Inject a 1-second delay on the first region so the hedging + # threshold (150 ms) fires and the second region gets the read. + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + FaultInjectionTransport.predicate_is_operation_type(r, OperationType.Read) and + FaultInjectionTransport.predicate_targets_region(r, uri_down)) + error_lambda = lambda r: FaultInjectionTransport.error_after_delay( + 1000, + CosmosHttpResponseError(status_code=400, message="Injected Error"), + ) + custom_transport = self._get_custom_transport_with_fault_injection(predicate, error_lambda) + + client_strategy = {'threshold_ms': 150, 'threshold_steps_ms': 50} + setup = self._setup_method_with_custom_transport( + custom_transport, + multiple_write_locations=True, + availability_strategy=client_strategy, + ) + + # Seed the document via a separate fault-free client. + setup_without_fault = self._setup_method_with_custom_transport(None) + doc = _create_doc() + setup_without_fault['col'].create_item(body=doc) + + # Exercise the explicit per-request None path directly; helper + # utilities omit the kwarg when the value is None. + setup['col'].read_item( + item=doc['id'], + partition_key=doc['pk'], + availability_strategy=None, + ) + _validate_response_uris( + [uri_down, failed_over_uri], + [], + operation_type=OperationType.Read, + resource_type=ResourceType.Document, + ) + self._clean_up_container(setup['db'].id, setup['col'].id) + if __name__ == '__main__': unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_availability_strategy_async.py b/sdk/cosmos/azure-cosmos/tests/test_availability_strategy_async.py index cc0352f5b184..2017c6d7a9dd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_availability_strategy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_availability_strategy_async.py @@ -1050,5 +1050,60 @@ async def test_default_availability_strategy_with_ppaf_enabled_async( await self._clean_up_container(setup['client_without_fault'], setup_with_transport['db'].id, setup_with_transport['col'].id) + # When the async client is built with a hedging strategy and a + # per-call surface explicitly passes ``availability_strategy=None``, + # the request must still hedge per the client's strategy. + @pytest.mark.asyncio + async def test_per_request_none_falls_back_to_client_strategy_async(self, setup): + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, setup['region_1']) + failed_over_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, setup['region_2']) + + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_is_operation_type(r, OperationType.Read) and + FaultInjectionTransportAsync.predicate_targets_region(r, uri_down)) + error_lambda = lambda r: FaultInjectionTransportAsync.error_after_delay( + 1000, + CosmosHttpResponseError(status_code=400, message="Injected Error"), + ) + custom_transport = self._get_custom_transport_with_fault_injection(predicate, error_lambda) + + client_strategy = {'threshold_ms': 150, 'threshold_steps_ms': 50} + setup_with_transport = await self._setup_method_with_custom_transport( + setup['write_locations'], + setup['read_locations'], + custom_transport, + multiple_write_locations=True, + availability_strategy=client_strategy, + ) + setup_without_fault = await self._setup_method_with_custom_transport( + setup['write_locations'], + setup['read_locations'], + None, + ) + + doc = _create_doc() + await setup_without_fault['col'].create_item(doc) + + # Exercise the explicit per-request None path directly; helper + # utilities omit the kwarg when the value is None. + await setup_with_transport['col'].read_item( + item=doc['id'], + partition_key=doc['pk'], + availability_strategy=None, + ) + _validate_response_uris( + [uri_down, failed_over_uri], + [], + operation_type=OperationType.Read, + resource_type=ResourceType.Document, + ) + await setup_with_transport['client'].close() + await setup_without_fault['client'].close() + await self._clean_up_container( + setup['client_without_fault'], + setup_with_transport['db'].id, + setup_with_transport['col'].id, + ) + if __name__ == '__main__': unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py index 569eec222246..80a2352d3472 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py +++ b/sdk/cosmos/azure-cosmos/tests/test_content_length_encoding.py @@ -1,25 +1,8 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -"""Regression tests for the Content-Length header computation. - -The SDK previously computed ``Content-Length`` from ``len(request.data)`` — -the number of Unicode code points in the JSON string — instead of the -UTF-8 byte length that actually goes on the wire. For any non-ASCII -payload that under-counted the body by the number of multi-byte -characters, which can cause downstream HTTP receivers to truncate the -body, reject the request, or mis-frame the next keep-alive request. - -Every assertion in this file exercises the actual production code path -in ``_synchronized_request.SynchronizedRequest`` or -``_asynchronous_request.AsynchronousRequest`` by patching the retry -layer and inspecting the request object the SDK would have put on the -wire. A previous iteration of this file also contained "mirror" tests -that re-implemented the production formula locally — those have been -removed because they could not catch a production regression (they -only verified that ``len(s.encode("utf-8"))`` works, which is a Python -built-in). -""" +"""Tests that the Content-Length header on outgoing requests is the +UTF-8 byte count of the body, not the number of characters.""" import unittest from unittest import mock @@ -29,22 +12,10 @@ from azure.cosmos.http_constants import HttpHeaders -# Payload matrix covering the four interesting char-vs-byte cases. Each -# str payload is named by the most divergent character it contains. -# The 4-byte emoji case maximizes the difference between ``len(s)`` -# (the old, buggy formula) and ``len(s.encode("utf-8"))`` (the new, -# correct formula), so a regression that reverts the fix will fail -# loudest on that case. -# -# Subtlety on why the payloads are pre-serialized JSON strings rather -# than dicts: the SDK's ``_request_body_from_data`` uses -# ``json.dumps(data, separators=(",", ":"))`` with default -# ``ensure_ascii=True``. That means dicts containing multi-byte chars -# get escaped to pure ASCII (e.g. ``"é"`` -> ``"\\u00e9"``) *before* -# Content-Length is computed — the byte-length code path is never -# actually exercised. The path that matters is when a customer passes -# a pre-serialized string, which ``_request_body_from_data`` returns -# unchanged. That is the path these tests exercise. +# Payloads chosen so the byte count differs from the character count +# by an increasing amount: 1, 2, 3, then 4 bytes per non-ASCII char. +# The string forms are passed as-is (not as dicts) because that's the +# input shape that exercises the byte-count code path. _STR_PAYLOADS = [ ("ascii_baseline", '{"name":"hello"}'), # 1 byte per char ("two_byte_latin", '{"name":"café"}'), # 2-byte 'é' @@ -75,9 +46,8 @@ def __init__(self): class TestContentLengthWiringSync(unittest.TestCase): - """Sync path: ``SynchronizedRequest`` → ``Execute`` should produce a - request whose ``Content-Length`` header equals the UTF-8 byte count - of the serialized body (not the code-point count).""" + """Checks the sync request path sets Content-Length to the byte + count of the body.""" def _capture_outgoing_request(self, request_data): params = _DummyRequestParams() @@ -107,10 +77,9 @@ def _fake_execute(*args, **kwargs): return captured def test_str_bodies_set_utf8_byte_content_length(self): - """For each payload in the byte-divergence matrix, the - ``Content-Length`` header the SDK puts on the wire must equal - the UTF-8 byte count of the JSON-serialized body. For the emoji - case in particular this exceeds the code-point count by 3×.""" + """For each payload, Content-Length equals the UTF-8 byte count + of the body. For the multi-byte cases it must also differ from + the character count.""" for label, payload in _STR_PAYLOADS: with self.subTest(payload=label): captured = self._capture_outgoing_request(payload) @@ -118,22 +87,34 @@ def test_str_bodies_set_utf8_byte_content_length(self): self.assertIsInstance(body, str) expected_bytes = len(body.encode("utf-8")) self.assertEqual(captured["content_length"], expected_bytes) - # Explicitly assert the value differs from the buggy - # formula for the multi-byte cases. ASCII is excluded - # because for ASCII both formulas agree. + # ASCII is the same in both counts, so only check the + # difference for the multi-byte cases. if label != "ascii_baseline": self.assertNotEqual(captured["content_length"], len(body)) def test_none_body_sets_content_length_zero(self): - """Covers the ``elif body is None`` branch in production: a - request with no body should still get ``Content-Length: 0``.""" + """A request with no body should still get Content-Length 0.""" captured = self._capture_outgoing_request(None) self.assertEqual(captured["content_length"], 0) + def test_bytes_body_is_coerced_to_none_and_content_length_zero(self): + """A bytes payload is currently converted to None inside the + SDK, so Content-Length ends up as 0. This test pins the current + behavior so any change to bytes handling is forced to be deliberate.""" + captured = self._capture_outgoing_request(b'{"x":1}') + self.assertIsNone(captured["body"]) + self.assertEqual(captured["content_length"], 0) + + def test_bytearray_body_is_coerced_to_none_and_content_length_zero(self): + """Same contract as for bytes. Bytearray inputs are also + converted to None and get Content-Length 0.""" + captured = self._capture_outgoing_request(bytearray(b'{"x":1}')) + self.assertIsNone(captured["body"]) + self.assertEqual(captured["content_length"], 0) + class TestContentLengthWiringAsync(unittest.IsolatedAsyncioTestCase): - """Async path: same contract as the sync test class, routed through - ``AsynchronousRequest`` → ``ExecuteAsync``.""" + """Async version of the sync class above. Same checks.""" async def _capture_outgoing_request(self, request_data): params = _DummyRequestParams() @@ -179,6 +160,19 @@ async def test_none_body_sets_content_length_zero(self): captured = await self._capture_outgoing_request(None) self.assertEqual(captured["content_length"], 0) + async def test_bytes_body_is_coerced_to_none_and_content_length_zero(self): + """Async version of the bytes-body check. Same contract: bytes + come out as None and Content-Length is set to 0.""" + captured = await self._capture_outgoing_request(b'{"x":1}') + self.assertIsNone(captured["body"]) + self.assertEqual(captured["content_length"], 0) + + async def test_bytearray_body_is_coerced_to_none_and_content_length_zero(self): + """Async version of the bytearray check.""" + captured = await self._capture_outgoing_request(bytearray(b'{"x":1}')) + self.assertIsNone(captured["body"]) + self.assertEqual(captured["content_length"], 0) + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_cosmos_paged_unit.py b/sdk/cosmos/azure-cosmos/tests/test_cosmos_paged_unit.py new file mode 100644 index 000000000000..d0048a4b8dca --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_cosmos_paged_unit.py @@ -0,0 +1,250 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Unit tests for the sync paged response wrappers in azure.cosmos. + +These tests exercise get_response_headers() and the wrapper types directly, +without requiring a live emulator or any network round-trip. +""" + +import unittest + +import pytest +from azure.core.utils import CaseInsensitiveDict + +from azure.cosmos._cosmos_responses import ( + CosmosAsyncItemPaged, + CosmosDict, + CosmosItemPaged, + CosmosList, +) + + +def _new_paged(**kwargs): + """Build a CosmosItemPaged without invoking the real query pipeline.""" + # No-op fetch callables are enough; these tests never iterate the pager. + return CosmosItemPaged( + get_next=lambda _continuation: {"value": [], "nextLink": None}, + extract_data=lambda _response: (None, []), + **kwargs, + ) + + +@pytest.mark.cosmosEmulator +class TestCosmosItemPagedUnit(unittest.TestCase): + """Pure unit tests for CosmosItemPaged.get_response_headers().""" + + def test_get_response_headers_is_empty_before_any_page_fetch(self): + # Before any page is fetched, callers must still be able to safely + # ask for response headers and get back an empty collection rather + # than a missing or null value that would crash their code. + pager = _new_paged() + headers = pager.get_response_headers() + self.assertIsInstance(headers, CaseInsensitiveDict) + self.assertEqual(len(headers), 0) + + def test_default_constructor_creates_fresh_header_dict(self): + # When the caller does not supply a place to store headers, the + # pager must create one on its own so the response-headers feature + # always works out of the box. + pager = _new_paged() + self.assertIsInstance(pager._response_headers, CaseInsensitiveDict) + + def test_explicit_none_response_headers_creates_fresh_dict(self): + # Explicitly opting out of providing a headers container must be + # treated the same as not providing one at all: the pager still + # gives the caller a working, empty headers collection. + pager = _new_paged(response_headers=None) + self.assertIsInstance(pager._response_headers, CaseInsensitiveDict) + self.assertEqual(len(pager._response_headers), 0) + + def test_response_headers_kwarg_is_the_same_instance_used_internally(self): + # When the caller supplies their own headers container, the pager + # must use that exact container (not a copy of it) so headers the + # query pipeline records later are visible to the caller. + shared = CaseInsensitiveDict() + pager = _new_paged(response_headers=shared) + self.assertIs(pager._response_headers, shared) + + def test_external_mutation_of_shared_dict_is_visible_via_getter(self): + # After each page is fetched, the query pipeline records the latest + # response headers. Reading headers back through the pager must + # reflect those updates so callers always see the freshest values. + shared = CaseInsensitiveDict() + pager = _new_paged(response_headers=shared) + + shared["x-ms-request-charge"] = "12.34" + shared["x-ms-activity-id"] = "abc-123" + + headers = pager.get_response_headers() + self.assertEqual(headers["x-ms-request-charge"], "12.34") + self.assertEqual(headers["x-ms-activity-id"], "abc-123") + + def test_get_response_headers_returns_a_copy_not_a_reference(self): + # Each read of the response headers must give the caller an + # independent snapshot they can freely modify without corrupting + # later reads or the pager's own internal state. + shared = CaseInsensitiveDict({"x-ms-request-charge": "1"}) + pager = _new_paged(response_headers=shared) + + first = pager.get_response_headers() + second = pager.get_response_headers() + + self.assertIsNot(first, second) + self.assertIsNot(first, shared) + + first["test-key"] = "test-value" + self.assertNotIn("test-key", second) + self.assertNotIn("test-key", shared) + + def test_returned_dict_is_case_insensitive(self): + # Service response headers can arrive in any letter casing, so + # callers must be able to look them up without having to guess or + # normalize the casing themselves. + shared = CaseInsensitiveDict() + pager = _new_paged(response_headers=shared) + shared["x-ms-request-charge"] = "5.0" + + headers = pager.get_response_headers() + self.assertEqual(headers["X-MS-Request-Charge"], "5.0") + self.assertEqual(headers["x-ms-request-charge"], "5.0") + self.assertEqual(headers["X-Ms-Request-Charge"], "5.0") + + def test_overwriting_simulates_pagination_and_keeps_only_latest_page(self): + # Reading response headers after many pages have been fetched must + # reflect only the most recent page's headers, never silently + # accumulate headers from every earlier page. + shared = CaseInsensitiveDict() + pager = _new_paged(response_headers=shared) + + for i in range(100): + shared.clear() + shared.update({ + "x-ms-request-charge": str(i), + "x-ms-activity-id": f"id-{i}", + "x-ms-item-count": str(i), + }) + + headers = pager.get_response_headers() + # Only the most recent page's headers should remain. + self.assertEqual(len(headers), 3) + self.assertEqual(headers["x-ms-request-charge"], "99") + self.assertEqual(headers["x-ms-activity-id"], "id-99") + + def test_return_type_is_caseinsensitivedict_not_list(self): + # Reading response headers must return the headers for a single + # page, not a collection of headers from many pages. + pager = _new_paged() + headers = pager.get_response_headers() + self.assertIsInstance(headers, CaseInsensitiveDict) + self.assertNotIsInstance(headers, list) + + def test_get_last_response_headers_attribute_does_not_exist(self): + # The old, removed way of reading response headers must not + # silently reappear on either pager type. + pager = _new_paged() + self.assertFalse(hasattr(pager, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosItemPaged, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosAsyncItemPaged, "get_last_response_headers")) + + def test_get_last_response_headers_raises_attribute_error_when_invoked(self): + # Make sure the old method is truly gone — not just hidden — so a + # caller who reaches for it gets a clear failure instead of a + # silent surprise, whether they go through an instance or the class. + pager = _new_paged() + with self.assertRaises(AttributeError): + getattr(pager, "get_last_response_headers")() + with self.assertRaises(AttributeError): + getattr(CosmosItemPaged, "get_last_response_headers") + with self.assertRaises(AttributeError): + getattr(CosmosAsyncItemPaged, "get_last_response_headers") + + def test_two_pagers_do_not_share_their_header_dicts(self): + # Each pager must own its own response headers so two queries + # running side by side never see each other's headers. + p1 = _new_paged() + p2 = _new_paged() + self.assertIsNot(p1._response_headers, p2._response_headers) + + # Updating one pager's headers must not show up on the other. + p1._response_headers["x-ms-request-charge"] = "9.0" + self.assertNotIn("x-ms-request-charge", p2.get_response_headers()) + + +@pytest.mark.cosmosEmulator +class TestCosmosDictAndListHeaders(unittest.TestCase): + """Header API guards for CosmosDict and CosmosList.""" + + def test_cosmos_dict_returns_copy_of_response_headers(self): + original = CaseInsensitiveDict({"x-ms-request-charge": "2.5"}) + wrapper = CosmosDict({"id": "x"}, response_headers=original) + + first = wrapper.get_response_headers() + second = wrapper.get_response_headers() + + self.assertIsNot(first, original) + self.assertIsNot(first, second) + self.assertEqual(first["x-ms-request-charge"], "2.5") + + first["mutated"] = "yes" + self.assertNotIn("mutated", second) + self.assertNotIn("mutated", original) + + def test_cosmos_dict_with_none_payload_behaves_like_empty_dict(self): + wrapper = CosmosDict(None, response_headers=CaseInsensitiveDict()) + self.assertEqual(len(wrapper), 0) + self.assertEqual(len(wrapper.get_response_headers()), 0) + + def test_cosmos_dict_headers_are_case_insensitive(self): + wrapper = CosmosDict( + {"id": "x"}, + response_headers=CaseInsensitiveDict({"x-ms-request-charge": "7.0"}), + ) + headers = wrapper.get_response_headers() + self.assertEqual(headers["X-MS-REQUEST-CHARGE"], "7.0") + + def test_cosmos_list_returns_copy_of_response_headers(self): + original = CaseInsensitiveDict({"x-ms-request-charge": "3.0"}) + wrapper = CosmosList([{"id": "a"}], response_headers=original) + + first = wrapper.get_response_headers() + second = wrapper.get_response_headers() + + self.assertIsNot(first, original) + self.assertIsNot(first, second) + self.assertEqual(first["x-ms-request-charge"], "3.0") + + first["mutated"] = "yes" + self.assertNotIn("mutated", second) + self.assertNotIn("mutated", original) + + def test_cosmos_list_with_none_payload_behaves_like_empty_list(self): + wrapper = CosmosList(None, response_headers=CaseInsensitiveDict()) + self.assertEqual(len(wrapper), 0) + self.assertEqual(len(wrapper.get_response_headers()), 0) + + def test_cosmos_dict_does_not_expose_get_last_response_headers(self): + # The old, removed way of reading response headers must not + # silently reappear on the single-item response wrapper, whether + # accessed through the wrapper itself or its class. + wrapper = CosmosDict({"id": "x"}, response_headers=CaseInsensitiveDict()) + self.assertFalse(hasattr(wrapper, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosDict, "get_last_response_headers")) + with self.assertRaises(AttributeError): + getattr(wrapper, "get_last_response_headers")() + + def test_cosmos_list_does_not_expose_get_last_response_headers(self): + # The same removed method must also stay absent on the list + # response wrapper. + wrapper = CosmosList([{"id": "a"}], response_headers=CaseInsensitiveDict()) + self.assertFalse(hasattr(wrapper, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosList, "get_last_response_headers")) + with self.assertRaises(AttributeError): + getattr(wrapper, "get_last_response_headers")() + + +if __name__ == "__main__": + unittest.main() + + + diff --git a/sdk/cosmos/azure-cosmos/tests/test_cosmos_paged_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_cosmos_paged_unit_async.py new file mode 100644 index 000000000000..4a1abc3b511a --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_cosmos_paged_unit_async.py @@ -0,0 +1,262 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Unit tests for the async paged response wrapper in azure.cosmos. + +These tests exercise get_response_headers() on CosmosAsyncItemPaged directly, +without requiring a live emulator or any network round-trip. +""" + +import unittest + +import pytest +from azure.core.utils import CaseInsensitiveDict + +from azure.cosmos._cosmos_responses import ( + CosmosAsyncItemPaged, + CosmosDict, + CosmosItemPaged, + CosmosList, +) + + +async def _async_get_next(_continuation): + return {"value": [], "nextLink": None} + + +async def _async_extract(_response): + return None, [] + + +def _new_async_paged(**kwargs): + """Build a CosmosAsyncItemPaged without invoking the real query pipeline.""" + return CosmosAsyncItemPaged( + get_next=_async_get_next, + extract_data=_async_extract, + **kwargs, + ) + + +@pytest.mark.cosmosEmulator +class TestCosmosAsyncItemPagedUnit(unittest.TestCase): + """Pure unit tests for CosmosAsyncItemPaged.get_response_headers(). + + The getter just reads a dict, so the tests do not need an event loop. + """ + + def test_get_response_headers_is_empty_before_any_page_fetch(self): + # Before any page is fetched, callers must still be able to safely + # ask for response headers and get back an empty collection rather + # than a missing or null value that would crash their code. + pager = _new_async_paged() + headers = pager.get_response_headers() + self.assertIsInstance(headers, CaseInsensitiveDict) + self.assertEqual(len(headers), 0) + + def test_default_constructor_creates_fresh_header_dict(self): + # When the caller does not supply a place to store headers, the + # pager must create one on its own so the response-headers + # feature always works out of the box. + pager = _new_async_paged() + self.assertIsInstance(pager._response_headers, CaseInsensitiveDict) + + def test_explicit_none_response_headers_creates_fresh_dict(self): + # Explicitly opting out of providing a headers container must be + # treated the same as not providing one at all: the pager still + # gives the caller a working, empty headers collection. + pager = _new_async_paged(response_headers=None) + self.assertIsInstance(pager._response_headers, CaseInsensitiveDict) + self.assertEqual(len(pager._response_headers), 0) + + def test_response_headers_kwarg_is_the_same_instance_used_internally(self): + # When the caller supplies their own headers container, the + # pager must use that exact container (not a copy of it) so + # headers the query pipeline records later are visible to the + # caller. + shared = CaseInsensitiveDict() + pager = _new_async_paged(response_headers=shared) + self.assertIs(pager._response_headers, shared) + + def test_external_mutation_of_shared_dict_is_visible_via_getter(self): + # After each page is fetched, the query pipeline records the + # latest response headers. Reading headers back through the + # pager must reflect those updates so callers always see the + # freshest values. + shared = CaseInsensitiveDict() + pager = _new_async_paged(response_headers=shared) + + shared["x-ms-request-charge"] = "12.34" + shared["x-ms-activity-id"] = "abc-123" + + headers = pager.get_response_headers() + self.assertEqual(headers["x-ms-request-charge"], "12.34") + self.assertEqual(headers["x-ms-activity-id"], "abc-123") + + def test_get_response_headers_returns_a_copy_not_a_reference(self): + # Each read of the response headers must give the caller an + # independent snapshot they can freely modify without corrupting + # later reads or the pager's own internal state. + shared = CaseInsensitiveDict({"x-ms-request-charge": "1"}) + pager = _new_async_paged(response_headers=shared) + + first = pager.get_response_headers() + second = pager.get_response_headers() + + self.assertIsNot(first, second) + self.assertIsNot(first, shared) + + first["test-key"] = "test-value" + self.assertNotIn("test-key", second) + self.assertNotIn("test-key", shared) + + def test_returned_dict_is_case_insensitive(self): + # Service response headers can arrive in any letter casing, so + # callers must be able to look them up without having to guess or + # normalize the casing themselves. + shared = CaseInsensitiveDict() + pager = _new_async_paged(response_headers=shared) + shared["x-ms-request-charge"] = "5.0" + + headers = pager.get_response_headers() + self.assertEqual(headers["X-MS-Request-Charge"], "5.0") + self.assertEqual(headers["x-ms-request-charge"], "5.0") + + def test_overwriting_simulates_pagination_and_keeps_only_latest_page(self): + # Reading response headers after many pages have been fetched must + # reflect only the most recent page's headers, never silently + # accumulate headers from every earlier page. + shared = CaseInsensitiveDict() + pager = _new_async_paged(response_headers=shared) + + for i in range(100): + shared.clear() + shared.update({ + "x-ms-request-charge": str(i), + "x-ms-activity-id": f"id-{i}", + "x-ms-item-count": str(i), + }) + + headers = pager.get_response_headers() + self.assertEqual(len(headers), 3) + self.assertEqual(headers["x-ms-request-charge"], "99") + self.assertEqual(headers["x-ms-activity-id"], "id-99") + + def test_return_type_is_caseinsensitivedict_not_list(self): + # Reading response headers must return the headers for a single + # page, not a collection of headers from many pages. + pager = _new_async_paged() + headers = pager.get_response_headers() + self.assertIsInstance(headers, CaseInsensitiveDict) + self.assertNotIsInstance(headers, list) + + def test_get_last_response_headers_attribute_does_not_exist(self): + # The old, removed way of reading response headers must not + # silently reappear on either pager type. + pager = _new_async_paged() + self.assertFalse(hasattr(pager, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosAsyncItemPaged, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosItemPaged, "get_last_response_headers")) + + def test_get_last_response_headers_raises_attribute_error_when_invoked_async(self): + # Make sure the old method is truly gone — not just hidden — so a + # caller who reaches for it gets a clear failure instead of a + # silent surprise, whether they go through an instance or the class. + pager = _new_async_paged() + with self.assertRaises(AttributeError): + getattr(pager, "get_last_response_headers")() + with self.assertRaises(AttributeError): + getattr(CosmosAsyncItemPaged, "get_last_response_headers") + with self.assertRaises(AttributeError): + getattr(CosmosItemPaged, "get_last_response_headers") + + def test_two_pagers_do_not_share_their_header_dicts(self): + # Each pager must own its own response headers so two queries + # running side by side never see each other's headers. + p1 = _new_async_paged() + p2 = _new_async_paged() + self.assertIsNot(p1._response_headers, p2._response_headers) + + p1._response_headers["x-ms-request-charge"] = "9.0" + self.assertNotIn("x-ms-request-charge", p2.get_response_headers()) + + +@pytest.mark.cosmosEmulator +class TestCosmosDictAndListHeadersAsync(unittest.TestCase): + """Header API guards for CosmosDict and CosmosList.""" + + def test_cosmos_dict_returns_copy_of_response_headers_async(self): + original = CaseInsensitiveDict({"x-ms-request-charge": "2.5"}) + wrapper = CosmosDict({"id": "x"}, response_headers=original) + + first = wrapper.get_response_headers() + second = wrapper.get_response_headers() + + self.assertIsNot(first, original) + self.assertIsNot(first, second) + self.assertEqual(first["x-ms-request-charge"], "2.5") + + first["mutated"] = "yes" + self.assertNotIn("mutated", second) + self.assertNotIn("mutated", original) + + def test_cosmos_dict_with_none_payload_behaves_like_empty_dict_async(self): + wrapper = CosmosDict(None, response_headers=CaseInsensitiveDict()) + self.assertEqual(len(wrapper), 0) + self.assertEqual(len(wrapper.get_response_headers()), 0) + + def test_cosmos_dict_headers_are_case_insensitive_async(self): + wrapper = CosmosDict( + {"id": "x"}, + response_headers=CaseInsensitiveDict({"x-ms-request-charge": "7.0"}), + ) + headers = wrapper.get_response_headers() + self.assertEqual(headers["X-MS-REQUEST-CHARGE"], "7.0") + + def test_cosmos_list_returns_copy_of_response_headers_async(self): + original = CaseInsensitiveDict({"x-ms-request-charge": "3.0"}) + wrapper = CosmosList([{"id": "a"}], response_headers=original) + + first = wrapper.get_response_headers() + second = wrapper.get_response_headers() + + self.assertIsNot(first, original) + self.assertIsNot(first, second) + self.assertEqual(first["x-ms-request-charge"], "3.0") + + first["mutated"] = "yes" + self.assertNotIn("mutated", second) + self.assertNotIn("mutated", original) + + def test_cosmos_list_with_none_payload_behaves_like_empty_list_async(self): + wrapper = CosmosList(None, response_headers=CaseInsensitiveDict()) + self.assertEqual(len(wrapper), 0) + self.assertEqual(len(wrapper.get_response_headers()), 0) + + def test_cosmos_dict_does_not_expose_get_last_response_headers_async(self): + # The old, removed way of reading response headers must not + # silently reappear on the single-item response wrapper, whether + # accessed through the wrapper or its class. + wrapper = CosmosDict({"id": "x"}, response_headers=CaseInsensitiveDict()) + self.assertFalse(hasattr(wrapper, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosDict, "get_last_response_headers")) + with self.assertRaises(AttributeError): + getattr(wrapper, "get_last_response_headers")() + with self.assertRaises(AttributeError): + getattr(CosmosDict, "get_last_response_headers") + + def test_cosmos_list_does_not_expose_get_last_response_headers_async(self): + # The same removed method must also stay absent on the list + # response wrapper. + wrapper = CosmosList([{"id": "a"}], response_headers=CaseInsensitiveDict()) + self.assertFalse(hasattr(wrapper, "get_last_response_headers")) + self.assertFalse(hasattr(CosmosList, "get_last_response_headers")) + with self.assertRaises(AttributeError): + getattr(wrapper, "get_last_response_headers")() + with self.assertRaises(AttributeError): + getattr(CosmosList, "get_last_response_headers") + + +if __name__ == "__main__": + unittest.main() + + diff --git a/sdk/cosmos/azure-cosmos/tests/test_encoding.py b/sdk/cosmos/azure-cosmos/tests/test_encoding.py index 35504d7b8c24..0fd15e0b5e58 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_encoding.py +++ b/sdk/cosmos/azure-cosmos/tests/test_encoding.py @@ -77,6 +77,54 @@ def test_create_stored_procedure_with_line_separator_para_seperator_next_line_un read_sp = self.key_container.scripts.get_stored_procedure(created_sp['id']) self.assertEqual(read_sp['body'], test_string_unicode) + # Round-trip tests for documents that contain 4-byte UTF-8 + # characters (emoji). These exercise both writing and reading + # of multi-byte content. + + def test_round_trip_emoji_document_through_full_sdk_stack(self): + """Writes a document containing emoji, reads it back, and + checks the read content matches the written content exactly.""" + # Mixes 2-byte (é), 3-byte (日), and 4-byte (emoji) characters. + emoji_payload = u'celebration 🎉🎊 — café 日本 🌍' # cspell:disable-line + doc_id = 'emoji-rt-' + str(uuid.uuid4()) + document = { + 'pk': 'pk', + 'id': doc_id, + 'multibyte_content': emoji_payload, + } + + created = self.created_container.create_item(body=document) + # Sanity check on the write response itself. + self.assertEqual(created['multibyte_content'], emoji_payload) + + read = self.created_container.read_item(item=doc_id, partition_key='pk') + self.assertEqual(read['multibyte_content'], emoji_payload) + # Also compare the raw UTF-8 bytes to be sure nothing changed. + self.assertEqual( + read['multibyte_content'].encode('utf-8'), + emoji_payload.encode('utf-8'), + ) + + def test_round_trip_emoji_document_via_query(self): + """Same content as the test above, but pulled back via a SQL + query instead of a point read.""" + emoji_payload = u'query 🎉 — café 日本' # cspell:disable-line + doc_id = 'emoji-q-' + str(uuid.uuid4()) + document = { + 'pk': 'pk', + 'id': doc_id, + 'multibyte_content': emoji_payload, + } + self.created_container.create_item(body=document) + + results = list(self.created_container.query_items( + query="SELECT * FROM c WHERE c.id = @id", + parameters=[{"name": "@id", "value": doc_id}], + partition_key='pk', + )) + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['multibyte_content'], emoji_payload) + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_encoding_async.py b/sdk/cosmos/azure-cosmos/tests/test_encoding_async.py new file mode 100644 index 000000000000..ba6df138af3b --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_encoding_async.py @@ -0,0 +1,127 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Async data-plane versions of the encoding round-trip tests. + +This mirrors the sync document/partition-key checks and emoji round-trips +using the async client. The stored-procedure control-plane check remains in +the sync file because it relies on key-auth script operations. +""" +import unittest +import uuid + +import pytest + +import test_config + + +@pytest.mark.cosmosEmulator +class TestEncodingAsync(unittest.IsolatedAsyncioTestCase): + """Async round-trips for non-ASCII document content.""" + + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + + @classmethod + def setUpClass(cls): + if (cls.masterKey == '[YOUR_KEY_HERE]' + or cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' (ACCOUNT_KEY env var) and 'host' (ACCOUNT_HOST " + "env var) to run the tests." + ) + + async def asyncSetUp(self): + # Open the async client and keep a handle to the test container. + # The client is closed in asyncTearDown so we don't leak sockets. + self.client = test_config.TestConfig.create_data_client_async() + await self.client.__aenter__() + self.created_db = self.client.get_database_client( + test_config.TestConfig.TEST_DATABASE_ID + ) + self.created_container = self.created_db.get_container_client( + test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID + ) + + async def asyncTearDown(self): + await self.client.close() + + async def test_unicode_characters_in_partition_key_async(self): + test_string = u'€€ کلید پارتیشن विभाजन कुंजी \t123' # cspell:disable-line + document_definition = { + 'pk': test_string, + 'id': 'myid' + str(uuid.uuid4()), + } + created_doc = await self.created_container.create_item(body=document_definition) + + read_doc = await self.created_container.read_item( + item=created_doc['id'], + partition_key=test_string, + ) + self.assertEqual(read_doc['pk'], test_string) + + async def test_create_document_with_line_separator_para_seperator_next_line_unicodes_async(self): + test_string = u'Line Separator (\u2028) & Paragraph Separator (\u2029) & Next Line (\x85) & نیم\u200cفاصله' # cspell:disable-line + document_definition = { + 'pk': 'pk', + 'id': 'myid' + str(uuid.uuid4()), + 'unicode_content': test_string, + } + created_doc = await self.created_container.create_item(body=document_definition) + + read_doc = await self.created_container.read_item( + item=created_doc['id'], + partition_key='pk', + ) + self.assertEqual(read_doc['unicode_content'], test_string) + + async def test_round_trip_emoji_document_through_full_sdk_stack_async(self): + """Writes a document containing emoji, reads it back, and checks + the read content matches the written content exactly.""" + emoji_payload = u'celebration 🎉🎊 — café 日本 🌍' # cspell:disable-line + doc_id = 'emoji-rt-async-' + str(uuid.uuid4()) + document = { + 'pk': 'pk', + 'id': doc_id, + 'multibyte_content': emoji_payload, + } + + created = await self.created_container.create_item(body=document) + self.assertEqual(created['multibyte_content'], emoji_payload) + + read = await self.created_container.read_item( + item=doc_id, partition_key='pk' + ) + self.assertEqual(read['multibyte_content'], emoji_payload) + self.assertEqual( + read['multibyte_content'].encode('utf-8'), + emoji_payload.encode('utf-8'), + ) + + async def test_round_trip_emoji_document_via_query_async(self): + """Same content as the test above, but pulled back via a SQL + query instead of a point read.""" + emoji_payload = u'query 🎉 — café 日本' # cspell:disable-line + doc_id = 'emoji-q-async-' + str(uuid.uuid4()) + document = { + 'pk': 'pk', + 'id': doc_id, + 'multibyte_content': emoji_payload, + } + await self.created_container.create_item(body=document) + + results = [] + async for item in self.created_container.query_items( + query="SELECT * FROM c WHERE c.id = @id", + parameters=[{"name": "@id", "value": doc_id}], + partition_key='pk'): + results.append(item) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]['multibyte_content'], emoji_payload) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py index b31894f3a1ef..3b3c2d588ca7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_continuation_token.py @@ -141,9 +141,6 @@ def test_known_input_produces_known_murmur_digest(self): ) -# ---------------------------------------------------------------------- # -# Token round-trip -# ---------------------------------------------------------------------- # class TestTokenRoundTrip: """``_encode_token`` -> ``_decode_token`` is structurally lossless and the wire form is base64-encoded JSON containing all seven required @@ -258,9 +255,6 @@ def test_none_and_empty_inputs_decode_to_none(self): assert _decode_token("") is None -# ---------------------------------------------------------------------- # -# Version-mismatch rejection -# ---------------------------------------------------------------------- # class TestVersionMismatchRejected: """A token that decodes as our shape but with a non-current ``v`` raises ``ValueError`` rather than being silently misinterpreted.""" @@ -372,9 +366,6 @@ def test_missing_per_entry_backend_continuation_raises(self): assert "bc" in str(excinfo.value) -# ---------------------------------------------------------------------- # -# Identity-fingerprint mismatch rejection -# ---------------------------------------------------------------------- # class TestIdentityFingerprintMismatch: """A valid v=1 token replayed against a different collection / query / feed_range produces a fingerprint mismatch the call site rejects. @@ -533,9 +524,6 @@ def test_call_site_replay_with_different_feed_range_raises(self): -# ---------------------------------------------------------------------- # -# Explode-on-multi-overlap - post-split fan-out unit contract -# ---------------------------------------------------------------------- # class TestExplodeOnMultiOverlap: """Post-split fan-out contract for the resume path. @@ -729,9 +717,6 @@ def test_three_child_split_slices_into_three(self): ] -# ---------------------------------------------------------------------- # -# max_item_count normalization -# ---------------------------------------------------------------------- # class TestNormalizeMaxItemCount: """``_normalize_max_item_count`` collapses unset / non-numeric / non-positive values to ``None`` (unbounded) and passes positive ints through unchanged. @@ -764,9 +749,6 @@ def test_object_is_treated_as_unbounded(self): assert _normalize_max_item_count(object()) is None -# ---------------------------------------------------------------------- # -# Request-header shaping -# ---------------------------------------------------------------------- # class TestApplyFeedrangeRequestHeaders: """``_apply_feedrange_request_headers`` sets and clears routing/page/token headers correctly for both full-partition and sub-range requests.""" @@ -926,6 +908,69 @@ def test_value_min_max_three_way_merge(self): merged_max = _base._merge_query_results(merged_max, {"Documents": [11]}, max_query) assert merged_max["Documents"] == [11] + # MIN and MAX should still pick the right value when one partition + # returns an int and another a float, or when values are negative. + + def test_value_min_merge_mixed_int_and_float(self): + query = "SELECT VALUE MIN(c.score) FROM c" + merged = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3.5]}, query, + ) + assert merged["Documents"] == [3.5] + + def test_value_max_merge_with_negative_values(self): + query = "SELECT VALUE MAX(c.score) FROM c" + merged = _base._merge_query_results( + {"Documents": [-1]}, {"Documents": [-5]}, query, + ) + assert merged["Documents"] == [-1] + + # Lowercase query text should still merge as a min/max, + # not by adding the values. + def test_value_min_max_merge_lowercase_keyword_still_merges(self): + min_query = "select value min(c.score) from c" + merged_min = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3]}, min_query, + ) + assert merged_min["Documents"] == [3] + + max_query = "select value max(c.score) from c" + merged_max = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3]}, max_query, + ) + assert merged_max["Documents"] == [7] + + # Object-shaped partials carry the aggregate inside an "_aggregate" + # key. Names starting with "min" or "max" (any case) should pick the + # smallest or largest, not add the values. + + def test_object_aggregate_min_branch_uses_min(self): + query = "SELECT MIN(c.score) AS min_score FROM c" + results = {"Documents": [{"_aggregate": {"min_score": 9}}]} + partial = {"Documents": [{"_aggregate": {"min_score": 5}}]} + merged = _base._merge_query_results(results, partial, query) + assert merged["Documents"] == [{"_aggregate": {"min_score": 5}}] + + def test_object_aggregate_max_branch_uses_max(self): + query = "SELECT MAX(c.score) AS max_score FROM c" + results = {"Documents": [{"_aggregate": {"max_score": 9}}]} + partial = {"Documents": [{"_aggregate": {"max_score": 12}}]} + merged = _base._merge_query_results(results, partial, query) + assert merged["Documents"] == [{"_aggregate": {"max_score": 12}}] + + def test_object_aggregate_min_max_branches_are_case_insensitive(self): + query = "SELECT MIN(c.score) AS Min_score, MAX(c.score) AS MAX_score FROM c" + results = { + "Documents": [{"_aggregate": {"Min_score": 9, "MAX_score": 1}}], + } + partial = { + "Documents": [{"_aggregate": {"Min_score": 5, "MAX_score": 7}}], + } + merged = _base._merge_query_results(results, partial, query) + assert merged["Documents"] == [ + {"_aggregate": {"Min_score": 5, "MAX_score": 7}}, + ] + def test_value_boolean_non_aggregate_fragments_are_concatenated(self): query = "SELECT VALUE c.flag FROM c" partial_result = {"Documents": [True]} @@ -1125,6 +1170,31 @@ def test_bare_aggregate_with_surrounding_whitespace_still_classifies(self): query = "SELECT VALUE COUNT(1) FROM c" assert _get_select_value_aggregate_function(query) == "COUNT" + # MIN and MAX should be detected for any case spelling, since the + # user's query text can arrive in any case. + + @pytest.mark.parametrize( + "query,expected_function", + [ + ("SELECT VALUE MIN(c.score) FROM c", "MIN"), + ("select value min(c.score) from c", "MIN"), + ("Select Value Min(c.score) From c", "MIN"), + ("SELECT VALUE MAX(c.score) FROM c", "MAX"), + ("select value max(c.score) from c", "MAX"), + ("Select Value Max(c.score) From c", "MAX"), + ("SELECT VALUE COUNT(1) FROM c", "COUNT"), + ("SELECT VALUE SUM(c.amount) FROM c", "SUM"), + ("SELECT VALUE AVG(c.score) FROM c", "AVG"), + ], + ) + def test_value_aggregate_detection_is_case_insensitive(self, query, expected_function): + assert _get_select_value_aggregate_function(query) == expected_function + + # A property named "min" or "max" is not an aggregate call. + def test_value_min_max_detection_distinguishes_from_column_named_min(self): + assert _get_select_value_aggregate_function("SELECT VALUE c.min FROM c") is None + assert _get_select_value_aggregate_function("SELECT VALUE c.max FROM c") is None + class TestAggregateClassificationHeuristics: def test_block_comment_prefix_does_not_drive_outer_select_value_detection(self): @@ -1217,6 +1287,23 @@ def test_classify_aggregate_partial_treats_non_aggregate_float_as_none(self): docs = [42.5] assert _classify_aggregate_partial(docs, query) == _AggregatePartialClassification.NONE + # Boolean rows should not be treated as numeric, even when the + # query wraps them in MIN, MAX, or SUM. + + def test_classify_aggregate_partial_excludes_boolean_value_rows_for_min(self): + query = "SELECT VALUE MIN(c.flag) FROM c" + assert _classify_aggregate_partial([True], query) == _AggregatePartialClassification.NONE + assert _classify_aggregate_partial([False], query) == _AggregatePartialClassification.NONE + + def test_classify_aggregate_partial_excludes_boolean_value_rows_for_max(self): + query = "SELECT VALUE MAX(c.flag) FROM c" + assert _classify_aggregate_partial([True], query) == _AggregatePartialClassification.NONE + assert _classify_aggregate_partial([False], query) == _AggregatePartialClassification.NONE + + def test_classify_aggregate_partial_excludes_boolean_value_rows_for_sum(self): + query = "SELECT VALUE SUM(c.flag) FROM c" + assert _classify_aggregate_partial([True], query) == _AggregatePartialClassification.NONE + class TestEmptyPageStallCounter: """No-progress guard only counts empty pages that still carry continuation.""" diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index af0861929cbb..d37ee322f10f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -9,6 +9,7 @@ import pytest from azure.cosmos import documents +from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager from azure.cosmos._service_request_retry_policy import ServiceRequestRetryPolicy from azure.cosmos.documents import DatabaseAccount, _OperationType @@ -42,8 +43,10 @@ def create_database_account(enable_multiple_writable_locations): canonical_location1_name = "East US 2" canonical_location2_name = "West US 3" +canonical_location3_name = "Central US" canonical_location1_endpoint = "https://eastus2.documents.azure.com" canonical_location2_endpoint = "https://westus3.documents.azure.com" +canonical_location3_endpoint = "https://centralus.documents.azure.com" def create_database_account_with_canonical_regions(enable_multiple_writable_locations): @@ -60,7 +63,23 @@ def create_database_account_with_canonical_regions(enable_multiple_writable_loca return db_acc -def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=documents.ConnectionPolicy()): +def create_database_account_with_three_canonical_regions(enable_multiple_writable_locations): + # Builds a three-region account for tests that need a longer preferred list. + db_acc = DatabaseAccount() + regions = [ + {"name": canonical_location1_name, "databaseAccountEndpoint": canonical_location1_endpoint}, + {"name": canonical_location2_name, "databaseAccountEndpoint": canonical_location2_endpoint}, + {"name": canonical_location3_name, "databaseAccountEndpoint": canonical_location3_endpoint}, + ] + db_acc._WritableLocations = list(regions) + db_acc._ReadableLocations = list(regions) + db_acc._EnableMultipleWritableLocations = enable_multiple_writable_locations + return db_acc + + +def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=None): + if connection_policy is None: + connection_policy = documents.ConnectionPolicy() connection_policy.PreferredLocations = preferred_locations connection_policy.UseMultipleWriteLocations = use_multiple_write_locations lc = LocationCache(default_endpoint=default_endpoint, @@ -714,11 +733,9 @@ def test_resolve_endpoint_without_preferred_locations_supports_normalized_exclus assert lc.resolve_service_endpoint(read_request) == canonical_location1_endpoint def test_preferred_locations_support_normalized_region_names(self, caplog): - # Preferred locations should match account region names even with case/spacing/separator variations. - # Also guards against drift between the two call sites that normalize region names: - # the routing path (get_preferred_regional_routing_contexts) and the diagnostic path - # (_emit_config_mismatch_warning_once). If they ever disagree on normalization, routing - # would still succeed while users got a misleading "did not match" warning. + # Preferred locations should match account region names even when the + # caller uses different case, spacing, hyphens, or underscores. No + # mismatch warning should appear when every entry matches a real region. with caplog.at_level(logging.WARNING): lc = refresh_location_cache(["east-us-2", " west_us_3 "], True) db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) @@ -735,10 +752,9 @@ def test_preferred_locations_support_normalized_region_names(self, caplog): assert read_contexts[1].get_primary() == canonical_location2_endpoint def test_excluded_locations_support_normalized_region_names(self, caplog): - # Excluded locations should filter regions even when normalized names are used. - # Same divergence guard as the preferred-locations test: excluded_locations also flows - # through both the routing path and the diagnostic warning emitter, so a normalization - # mismatch between them would produce correct filtering plus a spurious warning. + # Excluded locations should filter regions even when the caller spells + # them with different case, spacing, hyphens, or underscores. No + # mismatch warning should appear when every entry matches a real region. connection_policy = documents.ConnectionPolicy() connection_policy.ExcludedLocations = ["east-us-2"] @@ -767,20 +783,9 @@ def test_should_refresh_endpoints_handles_normalized_preferred_region(self): assert lc.should_refresh_endpoints() is False def test_should_refresh_endpoints_returns_true_for_normalized_non_primary(self): - # Companion to the False-branch test above: pins the True branch of - # should_refresh_endpoints() against a normalized preferred-location - # input. If normalization regresses on this path, the SDK silently - # stops scheduling background refreshes when the most-preferred - # region is no longer the primary — leaving customer traffic pinned - # to a region they tried to leave. - # - # Engineering the inequality: with two preferred regions - # ["east-us-2", "west-us-3"], the routing list normally puts East US 2 - # first (primary == most-preferred → False). Marking East US 2's read - # endpoint unavailable promotes West US 3 to primary, but the - # normalized lookup for "east-us-2" still resolves to the East US 2 - # context — which now != read_regional_routing_contexts[0], firing - # the True branch. + # When the caller's most preferred region (spelled with a hyphen here) + # is no longer the primary because its endpoint was marked unavailable, + # should_refresh_endpoints must return True so a refresh runs. lc = refresh_location_cache(["east-us-2", "west-us-3"], True) db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) lc.perform_on_database_account_read(db_acc) @@ -791,9 +796,8 @@ def test_should_refresh_endpoints_returns_true_for_normalized_non_primary(self): assert lc.should_refresh_endpoints() is True def test_get_locational_endpoint_normalizes_customer_region_string(self): - # GetLocationalEndpoint is used during bootstrap fallback with the customer-supplied - # preferred region string. It must produce the canonical regional URL for any - # accepted normalization variant. + # The static helper builds a region-specific URL from the account host. + # Any spelling variant of the same region should produce the same URL. default_endpoint_url = "https://contoso.documents.azure.com:443/" expected_endpoint = "https://contoso-eastus2.documents.azure.com:443/" @@ -834,10 +838,9 @@ def test_unmatched_preferred_locations_warning_is_deduped(self, caplog): assert len(unmatched_logs) == 1 def test_excluded_locations_ignore_none_and_empty_entries(self): - # Defensive: None / "" / whitespace-only entries in excluded_locations must - # NOT collide with the "" sentinel produced by _normalize_region_name(None) - # and silently filter out unrelated endpoints. Behavior must equal the - # clean-list equivalent. + # None, empty, and whitespace-only entries in excluded_locations should + # be ignored. They must not accidentally match real endpoints and they + # must not block the valid entries from filtering correctly. connection_policy = documents.ConnectionPolicy() connection_policy.ExcludedLocations = [None, "", " ", "east-us-2"] # type: ignore[list-item] @@ -857,22 +860,313 @@ def test_excluded_locations_ignore_none_and_empty_entries(self): # West US 3 is excluded on the request → writes route to East US 2. assert lc.resolve_service_endpoint(write_request) == canonical_location1_endpoint + def test_preferred_locations_handle_uppercase_and_pascalcase_variants(self): + # The caller may spell preferred regions in all caps or PascalCase. + # Both should resolve to the same canonical endpoints. + lc = refresh_location_cache(["EAST US 2", "WestUs3"], True) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + write_contexts = lc.get_write_regional_routing_contexts() + read_contexts = lc.get_read_regional_routing_contexts() + assert write_contexts[0].get_primary() == canonical_location1_endpoint + assert write_contexts[1].get_primary() == canonical_location2_endpoint + assert read_contexts[0].get_primary() == canonical_location1_endpoint + assert read_contexts[1].get_primary() == canonical_location2_endpoint + + def test_excluded_locations_handle_uppercase_and_mixed_punctuation(self): + # Client-level and per-request excluded entries that mix uppercase and + # different separators should still filter the right regions. + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = ["EAST-US_2"] + lc = refresh_location_cache( + [canonical_location1_name, canonical_location2_name], True, connection_policy, + ) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = ["WEST-US_3"] + + assert lc.resolve_service_endpoint(read_request) == canonical_location2_endpoint + assert lc.resolve_service_endpoint(write_request) == canonical_location1_endpoint + + def test_duplicate_normalized_entries_in_excluded_list_warn_once(self, caplog): + # When the same region is listed twice in different spellings, the + # filter should still work and no duplicate mismatch warning should fire. + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = ["East US 2", "east-us-2", "EAST_US_2"] + lc = refresh_location_cache( + [canonical_location1_name, canonical_location2_name], True, connection_policy, + ) + db_acc = create_database_account_with_canonical_regions(enable_multiple_writable_locations=True) + + with caplog.at_level("WARNING", logger="azure.cosmos.LocationCache"): + lc.perform_on_database_account_read(db_acc) + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + resolved = lc.resolve_service_endpoint(read_request) + + assert resolved == canonical_location2_endpoint + mismatch_warnings = [ + r for r in caplog.records if "did not match" in r.getMessage() + ] + assert mismatch_warnings == [] + + def test_preferred_locations_mix_normalized_and_canonical_forms(self): + # A three-region preferred list that mixes canonical and messy spellings + # should map every entry to the right account region in the right order. + lc = refresh_location_cache( + ["east-us-2", "WestUs3", " central us "], True, + ) + db_acc = create_database_account_with_three_canonical_regions(enable_multiple_writable_locations=True) + lc.perform_on_database_account_read(db_acc) + + write_endpoints = [ctx.get_primary() for ctx in lc.get_write_regional_routing_contexts()] + read_endpoints = [ctx.get_primary() for ctx in lc.get_read_regional_routing_contexts()] + expected = [ + canonical_location1_endpoint, + canonical_location2_endpoint, + canonical_location3_endpoint, + ] + assert write_endpoints == expected + assert read_endpoints == expected + + def test_global_endpoint_manager_normalizes_preferred_locations_from_policy(self): + # End-to-end check that messy region names set on ConnectionPolicy flow + # through the endpoint manager into the location cache and resolve correctly. + cp = documents.ConnectionPolicy() + cp.PreferredLocations = ["east-us-2", " WEST_US_3 "] + cp.UseMultipleWriteLocations = True + + mock_client = unittest.mock.Mock() + mock_client.connection_policy = cp + mock_client.url_connection = default_endpoint + + gem = _GlobalEndpointManager(mock_client) + gem.location_cache.perform_on_database_account_read( + create_database_account_with_canonical_regions(enable_multiple_writable_locations=True), + ) + + read_endpoints = [ + ctx.get_primary() for ctx in gem.location_cache.get_read_regional_routing_contexts() + ] + assert read_endpoints == [canonical_location1_endpoint, canonical_location2_endpoint] + + """ + Additional sync coverage for keeping unavailable endpoints as + fallback options. Covers the global-endpoint-manager wrapper, + single-write accounts, the health-check probe set, ordering, + circuit-breaker fallback, recovery, and account-refresh preservation. + """ + + def test_sync_global_endpoint_manager_returns_unavailable_as_last_resort(self): + # Sync wrapper around LocationCache should also keep an unavailable + # endpoint at the tail of the routing list so it can be used as fallback. + cp = documents.ConnectionPolicy() + cp.PreferredLocations = [location1_name, location2_name] + cp.UseMultipleWriteLocations = True + mock_client = unittest.mock.Mock() + mock_client.connection_policy = cp + mock_client.url_connection = default_endpoint + + gem = _GlobalEndpointManager(mock_client) + gem.location_cache.perform_on_database_account_read(create_database_account(True)) + + # Mark location1 unavailable for both reads and writes. + gem.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True, context="test") + gem.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_ctxs = gem.get_applicable_read_regional_routing_contexts(read_request) + assert [c.get_primary() for c in read_ctxs] == [location2_endpoint, location1_endpoint], \ + "Sync GEM should keep unavailable read endpoint at the tail, not drop it." + + # If the only healthy region is excluded, the unavailable region + # should still be returned by the sync wrapper, not the global default. + read_request.excluded_locations = [location2_name] + assert gem._resolve_service_endpoint(read_request) == location1_endpoint + + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_ctxs = gem.get_applicable_write_regional_routing_contexts(write_request) + assert [c.get_primary() for c in write_ctxs] == [location2_endpoint, location1_endpoint] + + write_request.excluded_locations = [location2_name] + assert gem._resolve_service_endpoint(write_request) == location1_endpoint + + def test_sync_single_write_account_read_unavailable_and_excluded(self): + # On a single-write account, an excluded healthy region should still + # fall back to the unavailable preferred region rather than the global default. + lc = refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=False + ) + lc.perform_on_database_account_read(create_database_account(False)) + assert lc.can_use_multiple_write_locations() is False, \ + "Test setup must be a single-write account." + + lc.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.excluded_locations = [location2_name] + + resolved = lc.resolve_service_endpoint(read_request) + assert resolved == location1_endpoint, \ + "Single-write read path returned the global default instead of " \ + "the unavailable preferred region." + + def test_sync_health_check_set_includes_unavailable_endpoints(self): + # Unavailable read endpoints must remain in the health-check probe set so + # they can be re-marked available once the prober finds them healthy. + # First mark location1 write-unavailable so it is no longer the primary + # write probe endpoint; this isolates the read-unavailable assertion from + # write-endpoint inclusion. + lc = refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True + ) + lc.perform_on_database_account_read(create_database_account(True)) + lc.mark_endpoint_unavailable_for_write( + location1_endpoint, refresh_cache=True, context="test" + ) + assert lc.get_write_regional_routing_contexts()[0].get_primary() == location2_endpoint, \ + "Test precondition failed: location1 must not be the primary write endpoint." + + lc.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True) + endpoints = lc.endpoints_to_health_check() + assert location1_endpoint in endpoints, \ + "Health-check probe set is missing the unavailable read endpoint." + assert location2_endpoint in endpoints + + @pytest.mark.parametrize("unavailable", [[], [location1_name], [location1_name, location2_name]]) + def test_sync_routing_list_has_no_duplicate_endpoints(self, unavailable): + # The routing list should never contain the same endpoint twice, + # regardless of how many regions are marked unavailable. + endpoint_by_loc = {location1_name: location1_endpoint, location2_name: location2_endpoint} + lc = refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True + ) + lc.perform_on_database_account_read(create_database_account(True)) + + for loc in unavailable: + lc.mark_endpoint_unavailable_for_read(endpoint_by_loc[loc], refresh_cache=True) + read_primaries = [c.get_primary() for c in lc.get_read_regional_routing_contexts()] + assert len(read_primaries) == len(set(read_primaries)), \ + f"Read routing list has duplicates: {read_primaries}" + assert set(read_primaries) == {location1_endpoint, location2_endpoint} + + for loc in unavailable: + lc.mark_endpoint_unavailable_for_write( + endpoint_by_loc[loc], refresh_cache=True, context="test" + ) + write_primaries = [c.get_primary() for c in lc.get_write_regional_routing_contexts()] + assert len(write_primaries) == len(set(write_primaries)), \ + f"Write routing list has duplicates: {write_primaries}" + assert set(write_primaries) == {location1_endpoint, location2_endpoint} + + def test_mark_endpoint_available_restores_head_position(self): + # After recovery, a previously-unavailable preferred endpoint should + # return to the head of the routing list, not stay at the tail. + lc = refresh_location_cache( + [location1_name, location2_name, location3_name], + use_multiple_write_locations=True, + ) + lc.perform_on_database_account_read(create_database_account(True)) + + # Initial state: most-preferred is location1. + assert lc.read_regional_routing_contexts[0].get_primary() == location1_endpoint + assert lc.write_regional_routing_contexts[0].get_primary() == location1_endpoint + + # Mark location1 unavailable for both lanes — it should slide to the tail. + lc.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True) + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + assert lc.read_regional_routing_contexts[-1].get_primary() == location1_endpoint + assert lc.write_regional_routing_contexts[-1].get_primary() == location1_endpoint + + # Simulate the health-probe rehabilitating the endpoint. + lc.mark_endpoint_available(location1_endpoint) + lc.update_location_cache() + + assert lc.is_endpoint_unavailable(location1_endpoint, "Read") is False + assert lc.is_endpoint_unavailable(location1_endpoint, "Write") is False + assert lc.read_regional_routing_contexts[0].get_primary() == location1_endpoint, \ + "Recovered endpoint should return to the head of the read routing list." + assert lc.write_regional_routing_contexts[0].get_primary() == location1_endpoint, \ + "Recovered endpoint should return to the head of the write routing list." + + def test_account_topology_refresh_preserves_unavailability_tail_order(self): + # A periodic account-topology refresh must not drop endpoints that + # were marked unavailable, and the tail ordering must be preserved. + lc = refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True + ) + db_acc = create_database_account(True) + lc.perform_on_database_account_read(db_acc) + + # Mark location1 unavailable for writes — it should move to the tail. + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + write_primaries_before = [c.get_primary() for c in lc.get_write_regional_routing_contexts()] + assert write_primaries_before == [location2_endpoint, location1_endpoint] + + # Simulate a periodic background refresh — the same topology comes back. + lc.perform_on_database_account_read(db_acc) + + # The unavailability mark must survive, AND the routing list must + # still contain location1 (as the tail), not drop it. + assert lc.is_endpoint_unavailable(location1_endpoint, "Write"), \ + "Unavailability mark must survive an account-topology refresh." + write_primaries_after = [c.get_primary() for c in lc.get_write_regional_routing_contexts()] + assert write_primaries_after == write_primaries_before, \ + "Account-topology refresh dropped the unavailable endpoint from the routing list." + + def test_circuit_breaker_excluded_read_falls_back_before_global_default(self): + # With the only healthy region user-excluded and the other region + # circuit-breaker-excluded, reads should still resolve to the + # circuit-breaker-excluded region instead of the global default. + lc = refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True + ) + lc.perform_on_database_account_read(create_database_account(True)) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.excluded_locations = [location1_name] + read_request.excluded_locations_circuit_breaker = [location2_name] + + resolved = lc.resolve_service_endpoint(read_request) + assert resolved == location2_endpoint, \ + "Read should fall back to the circuit-breaker-excluded region " \ + "instead of dropping to the global default." + + def test_master_resource_appends_user_excluded_to_tail(self): + # For master/metadata requests, user-excluded locations should be + # appended to the tail so the request still has a chance to succeed. + lc = refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True + ) + lc.perform_on_database_account_read(create_database_account(True)) + + # Both regions healthy and both user-excluded — for a metadata request, + # the SDK must still try them as a last resort. + master_request = RequestObject(ResourceType.Database, _OperationType.Read, None) + master_request.excluded_locations = [location1_name, location2_name] + + applicable = lc._get_applicable_read_regional_routing_contexts(master_request) + primaries = [c.get_primary() for c in applicable] + # The fix preserves user-excluded regions as a tail fallback for + # master requests; both location1 and location2 should appear. + assert location1_endpoint in primaries + assert location2_endpoint in primaries + # And neither should be the global default. + assert default_endpoint not in primaries + class TestNormalizeRegionName: - """Pin down the safety invariants of `_normalize_region_name`. - - The normalization rule must satisfy two opposing requirements: - 1. Forgive cosmetic differences in user-supplied region names - (case, whitespace, hyphens, underscores) so that lookups - match the service-reported canonical names. - 2. NEVER collapse two genuinely distinct Azure regions into - the same key — doing so would silently misroute traffic. - Prefix-sharing pairs like "East US" / "East US 2" are the - highest-risk case and are explicitly called out in the PR - description as the primary regression hazard. + """Unit tests for the _normalize_region_name helper. + + The helper must accept cosmetic differences (case, spacing, hyphens, + underscores) but never collapse two genuinely different regions like + "East US" and "East US 2". """ - # --- Collision-safety: distinct regions must stay distinct. --- + # Distinct regions must stay distinct after normalization. def test_does_not_collapse_prefix_sharing_regions(self): assert _normalize_region_name("East US") != _normalize_region_name("East US 2") assert _normalize_region_name("West US") != _normalize_region_name("West US 2") @@ -881,10 +1175,7 @@ def test_does_not_collapse_prefix_sharing_regions(self): assert _normalize_region_name("Central US") != _normalize_region_name("South Central US") assert _normalize_region_name("China East") != _normalize_region_name("China East 2") - # --- Positive normalization: cosmetic variants must collapse. --- - # These pin down that the function isn't a no-op — without them, - # a "fix" that just returned the input unchanged would still pass - # the collision-safety test above. + # Cosmetic differences should collapse to the same canonical form. def test_collapses_case_and_whitespace_variants(self): canonical = _normalize_region_name("East US 2") assert _normalize_region_name("east us 2") == canonical @@ -893,12 +1184,22 @@ def test_collapses_case_and_whitespace_variants(self): assert _normalize_region_name("eastus2") == canonical assert _normalize_region_name("east-us-2") == canonical assert _normalize_region_name("east_us_2") == canonical + # Extra internal whitespace and mixed punctuation should also collapse. + assert _normalize_region_name("East US 2") == canonical + assert _normalize_region_name("East-US_2") == canonical + assert _normalize_region_name("east -_ us -_ 2") == canonical def test_handles_none_and_empty(self): assert _normalize_region_name(None) == "" assert _normalize_region_name("") == "" assert _normalize_region_name(" ") == "" + # Calling the helper on its own output should be a no-op. + def test_is_idempotent(self): + for raw in ("East US 2", " East-US_2 ", "eastus2", "EAST US 2", ""): + once = _normalize_region_name(raw) + assert _normalize_region_name(once) == once + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache_async.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache_async.py new file mode 100644 index 000000000000..f54ccbd69132 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache_async.py @@ -0,0 +1,766 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""LocationCache parity checks in async test classes. + +Most tests in this module assert the same fallback and normalization +invariants as the sync suite while running under async unittest +classes. A small set of tests explicitly awaits async endpoint-manager +APIs to validate startup refresh/fetch behavior under coroutine +execution. All tests use mocks; no live account. +""" + +import asyncio +import logging +import unittest +import unittest.mock + +import pytest + +from azure.cosmos import exceptions +from azure.cosmos import documents +from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager as _AsyncGlobalEndpointManager +from azure.cosmos.documents import _OperationType +from azure.cosmos.http_constants import ResourceType +from azure.cosmos._location_cache import LocationCache, _normalize_region_name +from azure.cosmos._request_object import RequestObject +from azure.cosmos._service_request_retry_policy import ServiceRequestRetryPolicy + + +default_endpoint = "https://default.documents.azure.com" +location1_name = "location1" +location2_name = "location2" +location3_name = "location3" +location4_name = "location4" +location1_endpoint = "https://location1.documents.azure.com" +location2_endpoint = "https://location2.documents.azure.com" +location3_endpoint = "https://location3.documents.azure.com" +location4_endpoint = "https://location4.documents.azure.com" + + +# Canonical regions used by the normalization tests below. These mimic the +# region names the service returns so tests can pass spelling variants and +# confirm they still resolve to the canonical endpoint. +canonical_location1_name = "East US 2" +canonical_location2_name = "West US 3" +canonical_location3_name = "Central US" +canonical_location1_endpoint = "https://eastus2.documents.azure.com" +canonical_location2_endpoint = "https://westus3.documents.azure.com" +canonical_location3_endpoint = "https://centralus.documents.azure.com" + + +def _create_database_account_with_canonical_regions(enable_multiple_writable_locations, three_regions=False): + """Builds a DatabaseAccount whose region names match real Azure regions. + Set three_regions=True to include a third region (Central US).""" + db_acc = documents.DatabaseAccount() + regions = [ + {"name": canonical_location1_name, "databaseAccountEndpoint": canonical_location1_endpoint}, + {"name": canonical_location2_name, "databaseAccountEndpoint": canonical_location2_endpoint}, + ] + if three_regions: + regions.append( + {"name": canonical_location3_name, "databaseAccountEndpoint": canonical_location3_endpoint}, + ) + db_acc._WritableLocations = list(regions) + db_acc._ReadableLocations = list(regions) + db_acc._EnableMultipleWritableLocations = enable_multiple_writable_locations + return db_acc + + +def _refresh_location_cache_with_policy(preferred, excluded, use_multiple_write_locations=True): + """Builds a LocationCache with the given preferred and excluded lists.""" + cp = documents.ConnectionPolicy() + cp.PreferredLocations = list(preferred) + if excluded is not None: + cp.ExcludedLocations = list(excluded) + cp.UseMultipleWriteLocations = use_multiple_write_locations + return LocationCache(default_endpoint=default_endpoint, connection_policy=cp) + + +def _create_database_account(enable_multiple_writable_locations): + """Builds a DatabaseAccount with three write regions and three + read regions so tests can pick which one to mark unavailable.""" + db_acc = documents.DatabaseAccount() + db_acc._WritableLocations = [ + {"name": location1_name, "databaseAccountEndpoint": location1_endpoint}, + {"name": location2_name, "databaseAccountEndpoint": location2_endpoint}, + {"name": location3_name, "databaseAccountEndpoint": location3_endpoint}, + ] + db_acc._ReadableLocations = [ + {"name": location1_name, "databaseAccountEndpoint": location1_endpoint}, + {"name": location2_name, "databaseAccountEndpoint": location2_endpoint}, + {"name": location4_name, "databaseAccountEndpoint": location4_endpoint}, + ] + db_acc._EnableMultipleWritableLocations = enable_multiple_writable_locations + return db_acc + + +def _refresh_location_cache(preferred_locations, use_multiple_write_locations): + """Builds a LocationCache with the given preferred regions.""" + cp = documents.ConnectionPolicy() + cp.PreferredLocations = preferred_locations + cp.UseMultipleWriteLocations = use_multiple_write_locations + return LocationCache(default_endpoint=default_endpoint, connection_policy=cp) + + +@pytest.mark.cosmosEmulator +class TestLocationCacheAsync(unittest.IsolatedAsyncioTestCase): + """Async-context tests for the unavailable-region fallback behavior.""" + + async def test_unavailable_read_endpoint_remains_in_routing_list_async(self): + """Read path: if the only healthy region is excluded by the + caller, routing should still fall back to the unavailable + preferred region instead of dropping to the global default.""" + preferred_locations = [location1_name, location2_name] + lc = _refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + lc.perform_on_database_account_read(_create_database_account(True)) + + lc.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.excluded_locations = [location2_name] + + resolved = lc.resolve_service_endpoint(read_request) + self.assertEqual( + resolved, location1_endpoint, + "Expected the unavailable preferred region to be used as a " + "last-resort regional endpoint instead of the global default.", + ) + + async def test_unavailable_write_endpoint_remains_in_routing_list_async(self): + """Write path version of the read test above.""" + preferred_locations = [location1_name, location2_name] + lc = _refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + lc.perform_on_database_account_read(_create_database_account(True)) + + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = [location2_name] + + resolved = lc.resolve_service_endpoint(write_request) + self.assertEqual( + resolved, location1_endpoint, + "Expected the unavailable preferred region to be used as a " + "last-resort regional endpoint instead of the global default.", + ) + + async def test_async_global_endpoint_manager_returns_unavailable_as_last_resort(self): + """Drives the async endpoint-manager wrapper directly. The + wrapper is a thin pass-through to the shared cache, so this + test checks the wrapper does not lose or re-filter the + unavailable-as-last-resort ordering.""" + cp = documents.ConnectionPolicy() + cp.PreferredLocations = [location1_name, location2_name] + cp.UseMultipleWriteLocations = True + mock_client = unittest.mock.Mock() + mock_client.connection_policy = cp + mock_client.url_connection = default_endpoint + + gem = _AsyncGlobalEndpointManager(mock_client) + gem.location_cache.perform_on_database_account_read(_create_database_account(True)) + + # Mark location1 unavailable for both reads and writes. + gem.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True, context="test") + gem.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + # Read routing list should include both regions, unavailable one last. + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_ctxs = gem.get_applicable_read_regional_routing_contexts(read_request) + read_endpoints = [c.get_primary() for c in read_ctxs] + self.assertEqual( + read_endpoints, [location2_endpoint, location1_endpoint], + "Unavailable read endpoint should appear at the tail of the list.", + ) + + # If the only healthy region is excluded, the unavailable + # region should still be returned. + read_request.excluded_locations = [location2_name] + resolved = gem._resolve_service_endpoint(read_request) + self.assertEqual( + resolved, location1_endpoint, + "Expected the unavailable preferred region when the only healthy " + "region is excluded.", + ) + + # Same check for writes. + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_ctxs = gem.get_applicable_write_regional_routing_contexts(write_request) + write_endpoints = [c.get_primary() for c in write_ctxs] + self.assertEqual( + write_endpoints, [location2_endpoint, location1_endpoint], + "Unavailable write endpoint should appear at the tail of the list.", + ) + + write_request.excluded_locations = [location2_name] + resolved_write = gem._resolve_service_endpoint(write_request) + self.assertEqual( + resolved_write, location1_endpoint, + "Expected the unavailable preferred region when the only healthy " + "region is excluded.", + ) + + async def test_async_endpoint_manager_get_database_account_uses_preferred_fallback_async(self): + """_GetDatabaseAccount should await the default endpoint first, then + await preferred-region fallback endpoints when default fails.""" + cp = documents.ConnectionPolicy() + cp.PreferredLocations = [location1_name] + cp.UseMultipleWriteLocations = True + mock_client = unittest.mock.Mock() + mock_client.connection_policy = cp + mock_client.url_connection = default_endpoint + + gem = _AsyncGlobalEndpointManager(mock_client) + account = _create_database_account(True) + default_error = exceptions.CosmosHttpResponseError( + status_code=503, + message="Injected default-endpoint failure", + ) + gem._GetDatabaseAccountStub = unittest.mock.AsyncMock( + side_effect=[default_error, account], + ) + + resolved = await gem._GetDatabaseAccount() + self.assertIs(resolved, account) + + locational_endpoint = LocationCache.GetLocationalEndpoint( + default_endpoint, location1_name, + ) + gem._GetDatabaseAccountStub.assert_has_awaits( + [ + unittest.mock.call(default_endpoint), + unittest.mock.call(locational_endpoint), + ], + any_order=False, + ) + + async def test_async_refresh_endpoint_list_concurrent_calls_fetch_once_async(self): + """Concurrent refresh calls should serialize on the async lock and + avoid duplicate account fetches when startup refresh is in flight.""" + cp = documents.ConnectionPolicy() + cp.PreferredLocations = [location1_name, location2_name] + cp.UseMultipleWriteLocations = True + mock_client = unittest.mock.Mock() + mock_client.connection_policy = cp + mock_client.url_connection = default_endpoint + + gem = _AsyncGlobalEndpointManager(mock_client) + gem.startup = True + gem.refresh_needed = True + gem._aenter_used = True + + call_counter = {"count": 0} + account = _create_database_account(True) + + async def _get_account_once(**_kwargs): + call_counter["count"] += 1 + await asyncio.sleep(0.01) + return account + + gem._GetDatabaseAccount = unittest.mock.AsyncMock(side_effect=_get_account_once) + gem._endpoints_health_check = unittest.mock.AsyncMock(return_value=None) + + await asyncio.gather( + gem.refresh_endpoint_list(None), + gem.refresh_endpoint_list(None), + ) + + if gem.refresh_task: + await gem.refresh_task + + self.assertEqual(call_counter["count"], 1) + self.assertFalse(gem.startup) + self.assertGreater(len(gem.location_cache.get_ordered_read_locations()), 0) + + async def test_async_service_request_retry_policy_routes_through_unavailable_as_last_resort(self): # pylint: disable=line-too-long + """Drives the retry policy through the full retry-then-fallback + sequence for a write. After both preferred regions are marked + unavailable and the retry budget is exhausted, the final + resolution must still surface a regional endpoint, not the + global default.""" + preferred_locations = [location1_name, location2_name] + lc = _refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + lc.perform_on_database_account_read(_create_database_account(True)) + + mock_gem = unittest.mock.Mock() + mock_gem.location_cache = lc + mock_gem.resolve_service_endpoint_for_partition.side_effect = [location2_endpoint] + mock_gem.mark_endpoint_unavailable_for_write = lc.mark_endpoint_unavailable_for_write + + mock_connection_policy = unittest.mock.Mock() + mock_connection_policy.EnableEndpointDiscovery = True + mock_pk_range_wrapper = unittest.mock.Mock() + + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + resolved_endpoint = lc.resolve_service_endpoint(write_request) + self.assertEqual(resolved_endpoint, location1_endpoint) + + write_request.location_endpoint_to_route = location1_endpoint + retry_policy = ServiceRequestRetryPolicy( + mock_connection_policy, mock_gem, mock_pk_range_wrapper, write_request, + ) + + # First retry marks location1 unavailable and switches to location2. + self.assertTrue(retry_policy.ShouldRetry()) + self.assertEqual(write_request.location_endpoint_to_route, location2_endpoint) + self.assertTrue(lc.is_endpoint_unavailable(location1_endpoint, "Write")) + + # Second retry exhausts the budget. + self.assertFalse(retry_policy.ShouldRetry()) + self.assertTrue(lc.is_endpoint_unavailable(location2_endpoint, "Write")) + + # Final fallback should surface the unavailable preferred region, + # not the global default. + write_request.clear_route_to_location() + write_request.use_preferred_locations = False + + final_endpoint = lc.resolve_service_endpoint(write_request) + self.assertEqual( + final_endpoint, location1_endpoint, + "Final fallback returned the global default instead of an " + "unavailable preferred region.", + ) + + async def test_async_retry_policy_read_path_routes_through_unavailable_as_last_resort(self): # pylint: disable=line-too-long + """Read-path version of the retry-then-fallback test above.""" + preferred_locations = [location1_name, location2_name] + lc = _refresh_location_cache(preferred_locations, use_multiple_write_locations=True) + lc.perform_on_database_account_read(_create_database_account(True)) + + mock_gem = unittest.mock.Mock() + mock_gem.location_cache = lc + mock_gem.resolve_service_endpoint_for_partition.side_effect = [location2_endpoint] + mock_gem.mark_endpoint_unavailable_for_read = lc.mark_endpoint_unavailable_for_read + + mock_connection_policy = unittest.mock.Mock() + mock_connection_policy.EnableEndpointDiscovery = True + mock_pk_range_wrapper = unittest.mock.Mock() + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + resolved_endpoint = lc.resolve_service_endpoint(read_request) + self.assertEqual(resolved_endpoint, location1_endpoint) + + read_request.location_endpoint_to_route = location1_endpoint + retry_policy = ServiceRequestRetryPolicy( + mock_connection_policy, mock_gem, mock_pk_range_wrapper, read_request, + ) + + self.assertTrue(retry_policy.ShouldRetry()) + self.assertEqual(read_request.location_endpoint_to_route, location2_endpoint) + self.assertTrue(lc.is_endpoint_unavailable(location1_endpoint, "Read")) + + self.assertFalse(retry_policy.ShouldRetry()) + self.assertTrue(lc.is_endpoint_unavailable(location2_endpoint, "Read")) + + read_request.clear_route_to_location() + read_request.use_preferred_locations = False + final_endpoint = lc.resolve_service_endpoint(read_request) + self.assertEqual( + final_endpoint, location1_endpoint, + "Final fallback returned the global default instead of an " + "unavailable preferred region.", + ) + + # The tests below cover topologies and helpers that the existing + # tests don't touch: single-write accounts, the no-duplicates + # invariant, the health-check probe set, and the metadata routing + # path. Each one runs inside an async coroutine to catch any + # event-loop interaction with the shared cache. + + async def test_async_single_write_account_read_unavailable_and_excluded_async(self): + """Single-write account read path. This is the common + topology and the other tests only cover multi-write.""" + preferred_locations = [location1_name, location2_name] + # use_multiple_write_locations=False on the policy plus + # enable_multiple_writable_locations=False on the account = single-write. + lc = _refresh_location_cache(preferred_locations, use_multiple_write_locations=False) + lc.perform_on_database_account_read(_create_database_account(False)) + + self.assertFalse(lc.can_use_multiple_write_locations(), + "Test setup must be a single-write account.") + + lc.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.excluded_locations = [location2_name] + + resolved = lc.resolve_service_endpoint(read_request) + self.assertEqual( + resolved, location1_endpoint, + "Single-write read path returned the global default instead of " + "the unavailable preferred region.", + ) + + async def test_async_routing_list_has_no_duplicate_endpoints(self): + """The routing list should never contain the same endpoint + twice, regardless of which regions are marked unavailable.""" + endpoint_by_loc = {location1_name: location1_endpoint, location2_name: location2_endpoint} + for unavailable in ([], [location1_name], [location1_name, location2_name]): + with self.subTest(unavailable=unavailable): + lc = _refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True, + ) + lc.perform_on_database_account_read(_create_database_account(True)) + + for loc in unavailable: + lc.mark_endpoint_unavailable_for_read(endpoint_by_loc[loc], refresh_cache=True) + + read_primaries = [c.get_primary() for c in lc.get_read_regional_routing_contexts()] + self.assertEqual( + len(read_primaries), len(set(read_primaries)), + f"Read routing list has duplicates: {read_primaries}", + ) + self.assertEqual(set(read_primaries), {location1_endpoint, location2_endpoint}) + + # Read marks don't affect the write side, so mark again for writes. + for loc in unavailable: + lc.mark_endpoint_unavailable_for_write( + endpoint_by_loc[loc], refresh_cache=True, context="test", + ) + write_primaries = [c.get_primary() for c in lc.get_write_regional_routing_contexts()] + self.assertEqual( + len(write_primaries), len(set(write_primaries)), + f"Write routing list has duplicates: {write_primaries}", + ) + self.assertEqual(set(write_primaries), {location1_endpoint, location2_endpoint}) + + async def test_async_health_check_set_includes_unavailable_endpoints(self): + """An endpoint marked unavailable should stay in the set the + background health-check loop probes, so it can be re-marked + available once it recovers.""" + lc = _refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True, + ) + lc.perform_on_database_account_read(_create_database_account(True)) + lc.mark_endpoint_unavailable_for_write( + location1_endpoint, refresh_cache=True, context="test" + ) + self.assertEqual( + lc.get_write_regional_routing_contexts()[0].get_primary(), + location2_endpoint, + "Test precondition failed: location1 must not be the primary write endpoint.", + ) + + lc.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True) + endpoints = lc.endpoints_to_health_check() + self.assertIn( + location1_endpoint, endpoints, + "Health-check probe set is missing the unavailable read endpoint.", + ) + self.assertIn(location2_endpoint, endpoints) + + + async def test_async_master_resource_with_all_healthy_prefers_non_excluded(self): + """With every region healthy, a metadata request should still + prefer a healthy non-excluded region over a healthy excluded one.""" + lc = _refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True, + ) + lc.perform_on_database_account_read(_create_database_account(True)) + + # No mark_endpoint_unavailable calls; both regions stay healthy. + + master_request = RequestObject(ResourceType.Database, _OperationType.Read, None) + master_request.excluded_locations = [location2_name] + + resolved = lc.resolve_service_endpoint(master_request) + self.assertEqual( + resolved, location1_endpoint, + f"Expected the healthy non-excluded region ({location1_endpoint}) " + f"to come first, but got {resolved}.", + ) + + async def test_async_data_call_with_exclusion_and_unavailable_preserves_fallback(self): + # For a data request, excluded_locations is a hard filter. With one + # region unavailable and the other excluded, the SDK should still + # return the unavailable non-excluded region before falling back to + # the global default. + lc = _refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True, + ) + lc.perform_on_database_account_read(_create_database_account(True)) + + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + + data_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + data_request.excluded_locations = [location2_name] + + resolved = lc.resolve_service_endpoint(data_request) + self.assertEqual( + resolved, location1_endpoint, + f"Expected the unavailable non-excluded region ({location1_endpoint}) " + f"as a last-resort regional endpoint, but got {resolved}.", + ) + + """ + Additional async coverage for keeping unavailable endpoints as + fallback options. Recovery, account-topology refresh, and + circuit-breaker read fallback are exercised here. + """ + + async def test_async_mark_endpoint_available_restores_head_position_async(self): + # After recovery, a previously-unavailable preferred endpoint should + # return to the head of the routing list, not stay at the tail. + lc = _refresh_location_cache( + [location1_name, location2_name, location3_name], + use_multiple_write_locations=True, + ) + lc.perform_on_database_account_read(_create_database_account(True)) + + self.assertEqual( + lc.read_regional_routing_contexts[0].get_primary(), location1_endpoint + ) + self.assertEqual( + lc.write_regional_routing_contexts[0].get_primary(), location1_endpoint + ) + + # Mark location1 unavailable on both lanes; it should slide to the tail. + lc.mark_endpoint_unavailable_for_read(location1_endpoint, refresh_cache=True) + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + self.assertEqual(lc.read_regional_routing_contexts[-1].get_primary(), location1_endpoint) + self.assertEqual(lc.write_regional_routing_contexts[-1].get_primary(), location1_endpoint) + + # Health-probe rehabilitates the endpoint. + lc.mark_endpoint_available(location1_endpoint) + lc.update_location_cache() + + self.assertFalse(lc.is_endpoint_unavailable(location1_endpoint, "Read")) + self.assertFalse(lc.is_endpoint_unavailable(location1_endpoint, "Write")) + self.assertEqual( + lc.read_regional_routing_contexts[0].get_primary(), location1_endpoint, + "Recovered endpoint should return to the head of the read routing list." + ) + self.assertEqual( + lc.write_regional_routing_contexts[0].get_primary(), location1_endpoint, + "Recovered endpoint should return to the head of the write routing list." + ) + + async def test_account_topology_refresh_preserves_unavailability_tail_order_async(self): + # A periodic account-topology refresh must not drop endpoints that + # were marked unavailable, and the tail ordering must be preserved. + lc = _refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True, + ) + db_acc = _create_database_account(True) + lc.perform_on_database_account_read(db_acc) + + lc.mark_endpoint_unavailable_for_write(location1_endpoint, refresh_cache=True, context="test") + write_primaries_before = [c.get_primary() for c in lc.get_write_regional_routing_contexts()] + self.assertEqual(write_primaries_before, [location2_endpoint, location1_endpoint]) + + # Simulate periodic background refresh — same topology comes back. + lc.perform_on_database_account_read(db_acc) + + self.assertTrue( + lc.is_endpoint_unavailable(location1_endpoint, "Write"), + "Unavailability mark must survive an account-topology refresh.", + ) + write_primaries_after = [c.get_primary() for c in lc.get_write_regional_routing_contexts()] + self.assertEqual( + write_primaries_after, write_primaries_before, + "Account-topology refresh dropped the unavailable endpoint from the routing list.", + ) + + async def test_async_circuit_breaker_excluded_read_falls_back_before_global_default_async(self): + # With the only healthy region user-excluded and the other region + # circuit-breaker-excluded, reads should still resolve to the + # circuit-breaker-excluded region instead of the global default. + lc = _refresh_location_cache( + [location1_name, location2_name], use_multiple_write_locations=True, + ) + lc.perform_on_database_account_read(_create_database_account(True)) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.excluded_locations = [location1_name] + read_request.excluded_locations_circuit_breaker = [location2_name] + + resolved = lc.resolve_service_endpoint(read_request) + self.assertEqual( + resolved, location2_endpoint, + "Read should fall back to the circuit-breaker-excluded region " + "instead of dropping to the global default.", + ) + + +@pytest.mark.cosmosEmulator +class TestRegionNormalizationAsync(unittest.IsolatedAsyncioTestCase): + """Async-context coverage for region-name normalization. + + Drives the same matching behavior the sync tests cover from inside + coroutines so any event-loop interaction with the shared cache shows up. + """ + + async def test_preferred_locations_support_spelling_variants_async(self): + # Each preferred entry uses a different spelling style. All three + # should resolve to the canonical account endpoints in order. + lc = _refresh_location_cache_with_policy( + preferred=["east-us-2", " WEST_US_3 ", "Central US"], + excluded=None, + ) + lc.perform_on_database_account_read( + _create_database_account_with_canonical_regions(True, three_regions=True), + ) + + write_endpoints = [c.get_primary() for c in lc.get_write_regional_routing_contexts()] + read_endpoints = [c.get_primary() for c in lc.get_read_regional_routing_contexts()] + expected = [ + canonical_location1_endpoint, + canonical_location2_endpoint, + canonical_location3_endpoint, + ] + self.assertEqual(write_endpoints, expected) + self.assertEqual(read_endpoints, expected) + + async def test_excluded_locations_support_spelling_variants_async(self): + # Client-level excluded list uses one spelling, per-request list uses + # another. Both should filter the same region the canonical name does. + lc = _refresh_location_cache_with_policy( + preferred=[canonical_location1_name, canonical_location2_name], + excluded=["east-us-2"], + ) + lc.perform_on_database_account_read(_create_database_account_with_canonical_regions(True)) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = ["WEST_US_3"] + + self.assertEqual(lc.resolve_service_endpoint(read_request), canonical_location2_endpoint) + self.assertEqual(lc.resolve_service_endpoint(write_request), canonical_location1_endpoint) + + async def test_excluded_locations_ignore_none_and_empty_async(self): + # None, empty, and whitespace-only entries must not block valid + # exclusions and must not match real endpoints by accident. + lc = _refresh_location_cache_with_policy( + preferred=[canonical_location1_name, canonical_location2_name], + excluded=[None, "", " ", "east-us-2"], + ) + lc.perform_on_database_account_read(_create_database_account_with_canonical_regions(True)) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.excluded_locations = [None, "", "west_us_3"] + + self.assertEqual(lc.resolve_service_endpoint(read_request), canonical_location2_endpoint) + self.assertEqual(lc.resolve_service_endpoint(write_request), canonical_location1_endpoint) + + async def test_duplicate_normalized_entries_warn_once_async(self): + # The same region listed three times in three spellings should still + # filter correctly and should not produce a mismatch warning. + lc = _refresh_location_cache_with_policy( + preferred=[canonical_location1_name, canonical_location2_name], + excluded=["East US 2", "east-us-2", "EAST_US_2"], + ) + + with self.assertLogs("azure.cosmos.LocationCache", level=logging.WARNING) as captured: + lc.perform_on_database_account_read(_create_database_account_with_canonical_regions(True)) + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + resolved = lc.resolve_service_endpoint(read_request) + # assertLogs requires at least one record, so emit a marker. + logging.getLogger("azure.cosmos.LocationCache").warning("marker") + + self.assertEqual(resolved, canonical_location2_endpoint) + mismatch_messages = [m for m in captured.output if "did not match" in m] + self.assertEqual(mismatch_messages, []) + + async def test_resolve_endpoint_without_preferred_locations_supports_variants_async(self): + # Per-request exclusions should still apply when the request opts out + # of preferred-location routing. + lc = _refresh_location_cache_with_policy( + preferred=[], + excluded=None, + ) + lc.perform_on_database_account_read(_create_database_account_with_canonical_regions(True)) + + write_request = RequestObject(ResourceType.Document, _OperationType.Create, None) + write_request.use_preferred_locations = False + write_request.excluded_locations = ["east-us-2"] + self.assertEqual(lc.resolve_service_endpoint(write_request), canonical_location2_endpoint) + + read_request = RequestObject(ResourceType.Document, _OperationType.Read, None) + read_request.use_preferred_locations = False + read_request.excluded_locations = ["west_us_3"] + self.assertEqual(lc.resolve_service_endpoint(read_request), canonical_location1_endpoint) + + async def test_should_refresh_endpoints_handles_normalized_preferred_async(self): + # When the most preferred region uses a non-canonical spelling and is + # already the primary, no background refresh should be scheduled. + lc = _refresh_location_cache_with_policy(preferred=["east-us-2"], excluded=None) + lc.perform_on_database_account_read(_create_database_account_with_canonical_regions(True)) + self.assertFalse(lc.should_refresh_endpoints()) + + async def test_should_refresh_endpoints_true_when_normalized_preferred_is_not_primary_async(self): + # When the most preferred region is no longer the primary because its + # endpoint was marked unavailable, a refresh should be scheduled. + lc = _refresh_location_cache_with_policy( + preferred=["east-us-2", "west-us-3"], excluded=None, + ) + lc.perform_on_database_account_read(_create_database_account_with_canonical_regions(True)) + lc.mark_endpoint_unavailable_for_read(canonical_location1_endpoint, refresh_cache=True) + + self.assertEqual( + lc.read_regional_routing_contexts[0].get_primary(), canonical_location2_endpoint, + ) + self.assertTrue(lc.should_refresh_endpoints()) + + async def test_get_locational_endpoint_normalizes_customer_region_async(self): + # The static helper should produce the same regional URL for every + # spelling variant of the same region. + default_endpoint_url = "https://contoso.documents.azure.com:443/" + expected = "https://contoso-eastus2.documents.azure.com:443/" + for variant in ("East US 2", "east us 2", "eastus2", "east-us-2", "east_us_2", + " EastUs2 ", "EAST-US_2", "East US 2"): + self.assertEqual(LocationCache.GetLocationalEndpoint(default_endpoint_url, variant), expected) + + async def test_async_endpoint_manager_normalizes_preferred_locations_from_policy_async(self): + # End-to-end check that messy region names set on ConnectionPolicy + # flow through the async endpoint manager into the location cache. + cp = documents.ConnectionPolicy() + cp.PreferredLocations = ["east-us-2", " WEST_US_3 "] + cp.UseMultipleWriteLocations = True + + mock_client = unittest.mock.Mock() + mock_client.connection_policy = cp + mock_client.url_connection = default_endpoint + + gem = _AsyncGlobalEndpointManager(mock_client) + gem.location_cache.perform_on_database_account_read( + _create_database_account_with_canonical_regions(True), + ) + + read_endpoints = [ + c.get_primary() for c in gem.location_cache.get_read_regional_routing_contexts() + ] + self.assertEqual( + read_endpoints, [canonical_location1_endpoint, canonical_location2_endpoint], + ) + + +class TestNormalizeRegionNameAsync(unittest.IsolatedAsyncioTestCase): + """Unit tests for the helper, exercised inside coroutines for parity + with the rest of the async suite.""" + + async def test_does_not_collapse_prefix_sharing_regions_async(self): + self.assertNotEqual(_normalize_region_name("East US"), _normalize_region_name("East US 2")) + self.assertNotEqual(_normalize_region_name("West US"), _normalize_region_name("West US 2")) + self.assertNotEqual(_normalize_region_name("Central US"), _normalize_region_name("North Central US")) + self.assertNotEqual(_normalize_region_name("China East"), _normalize_region_name("China East 2")) + + async def test_collapses_case_and_whitespace_variants_async(self): + canonical = _normalize_region_name("East US 2") + for variant in ("east us 2", "EAST US 2", " East US 2 ", "eastus2", + "east-us-2", "east_us_2", "East US 2", "East-US_2"): + self.assertEqual(_normalize_region_name(variant), canonical) + + async def test_handles_none_and_empty_async(self): + self.assertEqual(_normalize_region_name(None), "") + self.assertEqual(_normalize_region_name(""), "") + self.assertEqual(_normalize_region_name(" "), "") + + async def test_is_idempotent_async(self): + for raw in ("East US 2", " East-US_2 ", "eastus2", "EAST US 2", ""): + once = _normalize_region_name(raw) + self.assertEqual(_normalize_region_name(once), once) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py index a1bf10d40b98..d4ab355ed253 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit.py @@ -17,7 +17,20 @@ from azure.cosmos import _retry_utility from azure.cosmos._cosmos_client_connection import CosmosClientConnection from azure.cosmos._execution_context.base_execution_context import _DefaultQueryExecutionContext -from azure.cosmos._routing.feed_range_continuation import _FIELD_VERSION, _TOKEN_VERSION, _decode_token +from azure.cosmos._routing import routing_range +from azure.cosmos._routing.feed_range_continuation import ( + _FIELD_BACKEND_CONTINUATION, + _FIELD_COLLECTION_RID, + _FIELD_CONTINUATIONS, + _FIELD_FEEDRANGE_HASH, + _FIELD_QUERY_HASH, + _FIELD_VERSION, + _TOKEN_VERSION, + _decode_token, + _encode_token, + _hash_feed_range, + _hash_query_spec, +) from azure.cosmos.http_constants import HttpHeaders, StatusCodes, SubStatusCodes # tracemalloc is not available in PyPy, so we import conditionally @@ -28,9 +41,6 @@ HAS_TRACEMALLOC = False -# ================================= -# Shared Test Helpers -# ================================= class MockGlobalEndpointManager: """Mock global endpoint manager for testing.""" @@ -101,9 +111,6 @@ def raise_410_partition_split_error(*args, **kwargs): raise create_410_partition_split_error() -# ========================== -# Test Class -# ========================== @pytest.mark.cosmosEmulator class TestPartitionSplitRetryUnit(unittest.TestCase): @@ -621,11 +628,8 @@ def test_memory_bounded_no_leak_on_410_retries(self, mock_execute): - Memory growth is minimal (no recursive accumulation) - No infinite recursion (max depth = 0 for PK range queries) """ - # tracemalloc.start() begins tracing memory allocations to detect leaks tracemalloc.start() - # gc.collect() forces garbage collection to get accurate baseline memory measurement gc.collect() - # take_snapshot() captures current memory state for comparison after test snapshot_before = tracemalloc.take_snapshot() start_time = time.time() @@ -643,31 +647,17 @@ def mock_fetch_function(options): context._fetch_items_helper_with_retries(mock_fetch_function) elapsed_time = time.time() - start_time - # gc.collect() before snapshot ensures we measure actual leaks, not pending garbage gc.collect() snapshot_after = tracemalloc.take_snapshot() - # compare_to() shows memory difference between snapshots to identify growth top_stats = snapshot_after.compare_to(snapshot_before, 'lineno') memory_growth = sum(stat.size_diff for stat in top_stats if stat.size_diff > 0) peak_memory = tracemalloc.get_traced_memory()[1] - # tracemalloc.stop() ends memory tracing and frees tracing overhead tracemalloc.stop() # Collect metrics execute_calls = mock_execute.call_count refresh_calls = mock_client.refresh_routing_map_provider_call_count - # Print metrics - print(f"\n{'=' * 60}") - print("MEMORY METRICS - Partition Split Memory Verification") - print(f"{'=' * 60}") - print(f"Metrics:") - print(f" - Execute calls: {execute_calls} (bounded)") - print(f" - Refresh calls: {refresh_calls}") - print(f" - Elapsed time: {elapsed_time:.2f}s") - print(f" - Memory growth: {memory_growth / 1024:.2f} KB") - print(f" - Peak memory: {peak_memory / 1024:.2f} KB") - print(f"{'=' * 60}") assert execute_calls == 4, \ f"Execute calls should be bounded to 4, got {execute_calls}" @@ -692,10 +682,6 @@ def mock_fetch_function(options): pk_execute_calls = mock_execute.call_count pk_refresh_calls = mock_client.refresh_routing_map_provider_call_count - print(f"\nPK Range Query:") - print(f" - Execute calls: {pk_execute_calls} (no retry)") - print(f" - Refresh calls: {pk_refresh_calls} (no recursion)") - print(f"{'=' * 60}\n") assert pk_execute_calls == 1, \ f"PK range query should have 1 execute call, got {pk_execute_calls}" @@ -1169,11 +1155,8 @@ def test_queryfeed_populates_capture_dict_from_options(self): `__QueryFeed` itself does the population. Catches the `options`-vs-`kwargs` extraction regression. """ - from unittest.mock import patch as _patch - - # Build a CosmosClientConnection without running __init__; we - # only need the attributes that the no-query (read-feed) branch - # of __QueryFeed touches. + # Build the connection without running __init__; only the attributes + # used by the no-query (read-feed) branch of __QueryFeed are needed. conn = object.__new__(CosmosClientConnection) conn.default_headers = {} conn.last_response_headers = {} @@ -1197,20 +1180,19 @@ def test_queryfeed_populates_capture_dict_from_options(self): headers={}, ) - # Patch the heavy collaborators inside __QueryFeed's no-query - # branch so we can drive it without a real pipeline. - with _patch( + # Patch the heavy collaborators so the no-query branch can run without a real pipeline. + with patch( "azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}, ), \ - _patch( + patch( "azure.cosmos._cosmos_client_connection.base.set_session_token_header" ), \ - _patch( + patch( "azure.cosmos._cosmos_client_connection.RequestObject", return_value=request_obj_mock, - ) as request_obj_ctor, \ - _patch.object( + ), \ + patch.object( CosmosClientConnection, "_CosmosClientConnection__Get", return_value=( @@ -1218,7 +1200,6 @@ def test_queryfeed_populates_capture_dict_from_options(self): canned_headers, ), ) as mock_get: - _ = request_obj_ctor # silence unused-warning # Invoke the name-mangled private method directly. result, headers = conn._CosmosClientConnection__QueryFeed( @@ -1240,8 +1221,8 @@ def test_queryfeed_populates_capture_dict_from_options(self): "'_internal_response_headers_capture' from options." ) - # the marker key must have been removed from options so it - # never leaks downstream into header construction or RequestObject. + # The marker key must have been removed from options so it never + # leaks downstream into header construction or RequestObject. assert "_internal_response_headers_capture" not in options, ( "__QueryFeed should pop the capture marker out of options" ) @@ -1488,6 +1469,151 @@ def post_side_effect(_path, _request_params, _query, _req_headers, **_kwargs): assert decoded[_FIELD_VERSION] == _TOKEN_VERSION assert len(decoded["c"]) == 2 + def test_queryfeed_v1_inbound_token_explodes_head_on_resume(self): + """Resume with a structured continuation token whose head range now spans + two partitions. The loop must send one request per child across two + pages and forward the parent's backend continuation to each child so + results are not replayed from the start. + """ + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + child_left = {"id": "1", "minInclusive": "00", "maxExclusive": "7F"} + child_right = {"id": "2", "minInclusive": "7F", "maxExclusive": "FF"} + + def overlap_side_effect(_rid, ranges, _opts): + requested = ranges[0] + if requested.min == "00" and requested.max == "FF": + return [child_left, child_right] + if requested.min == "00" and requested.max == "7F": + return [child_left] + if requested.min == "7F" and requested.max == "FF": + return [child_right] + return [] + + client._routing_map_provider.get_overlapping_ranges.side_effect = overlap_side_effect + + feed_range_dict = { + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + } + full_range = routing_range.Range( + range_min="00", + range_max="FF", + isMinInclusive=True, + isMaxInclusive=False, + ) + query_text = "SELECT * FROM c" + # Build a structured continuation token whose single entry covers the + # full range and carries a backend continuation that must be forwarded + # to both children once the head is split. + inbound_token_payload = { + "v": 1, + _FIELD_COLLECTION_RID: "rid-c1", + _FIELD_QUERY_HASH: _hash_query_spec(query_text), + _FIELD_FEEDRANGE_HASH: _hash_feed_range(full_range), + _FIELD_CONTINUATIONS: [ + { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + _FIELD_BACKEND_CONTINUATION: "parent-bc", + } + ], + } + inbound_token = _encode_token(inbound_token_payload) + + seen_request_continuations = [] + seen_pkr_ids = [] + post_call_count = {"n": 0} + + def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + post_call_count["n"] += 1 + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + seen_pkr_ids.append(req_headers.get(HttpHeaders.PartitionKeyRangeID)) + # Each child returns one document and signals "drained" by + # returning a None continuation so the next call targets the + # next child. + return ( + {"Documents": [{"id": f"doc-{post_call_count['n']}"}]}, + {HttpHeaders.Continuation: None}, + ) + + with patch("azure.cosmos._cosmos_client_connection.base.GetHeaders", return_value={}): + with patch("azure.cosmos._cosmos_client_connection.base.set_session_token_header", return_value=None): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + # First page: should send a request to the left child with + # the original backend continuation. + docs_p1, headers_p1 = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=query_text, + options={"continuation": inbound_token}, + feed_range=feed_range_dict, + ) + + outbound_p1 = headers_p1.get(HttpHeaders.Continuation) + decoded_p1 = _decode_token(outbound_p1) + assert decoded_p1 is not None, ( + "Page 1 outbound continuation must be a structured " + f"token after the head is split; got {outbound_p1!r}" + ) + assert decoded_p1[_FIELD_VERSION] == _TOKEN_VERSION + assert len(decoded_p1[_FIELD_CONTINUATIONS]) == 1, ( + "After draining the left child, only the right child " + "should remain in the outbound token; got " + f"{decoded_p1[_FIELD_CONTINUATIONS]!r}" + ) + surviving = decoded_p1[_FIELD_CONTINUATIONS][0] + assert surviving["min"] == "7F" and surviving["max"] == "FF", ( + "The remaining entry must be the right child's range; " + f"got {surviving!r}" + ) + assert surviving.get(_FIELD_BACKEND_CONTINUATION) == "parent-bc", ( + "The remaining child must carry the parent's backend " + f"continuation forward; got bc={surviving.get(_FIELD_BACKEND_CONTINUATION)!r}" + ) + + # Second page: should send a request to the right child + # with the same backend continuation. + docs_p2, headers_p2 = client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=query_text, + options={"continuation": outbound_p1}, + feed_range=feed_range_dict, + ) + + assert post_call_count["n"] == 2, ( + "Expected exactly one request per child across the two pages; " + f"got {post_call_count['n']} request(s)" + ) + assert seen_request_continuations == ["parent-bc", "parent-bc"], ( + "Each child must receive the parent's backend continuation; " + f"got {seen_request_continuations!r}" + ) + assert seen_pkr_ids == [child_left["id"], child_right["id"]], ( + "Expected one request per child range id in left-to-right order; " + f"got {seen_pkr_ids!r}" + ) + assert docs_p1 == [{"id": "doc-1"}], ( + f"Page 1 should return only the left child's document; got {docs_p1!r}" + ) + assert docs_p2 == [{"id": "doc-2"}], ( + f"Page 2 should return only the right child's document; got {docs_p2!r}" + ) + outbound_p2 = headers_p2.get(HttpHeaders.Continuation) + assert outbound_p2 in (None, "", b""), ( + "Expected an empty outbound continuation after both children are " + f"drained; got {outbound_p2!r}" + ) + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py index fec2d8208c57..417e141dafab 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_retry_unit_async.py @@ -8,6 +8,7 @@ import gc import time import unittest +from typing import List from unittest.mock import patch, MagicMock, AsyncMock import pytest @@ -20,7 +21,20 @@ from azure.cosmos.aio import CosmosClient # noqa: F401 - needed to resolve circular imports from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection from azure.cosmos._execution_context.aio.base_execution_context import _DefaultQueryExecutionContext -from azure.cosmos._routing.feed_range_continuation import _FIELD_VERSION, _TOKEN_VERSION, _decode_token +from azure.cosmos._routing import routing_range +from azure.cosmos._routing.feed_range_continuation import ( + _FIELD_BACKEND_CONTINUATION, + _FIELD_COLLECTION_RID, + _FIELD_CONTINUATIONS, + _FIELD_FEEDRANGE_HASH, + _FIELD_QUERY_HASH, + _FIELD_VERSION, + _TOKEN_VERSION, + _decode_token, + _encode_token, + _hash_feed_range, + _hash_query_spec, +) # tracemalloc is not available in PyPy, so we import conditionally try: @@ -30,9 +44,6 @@ HAS_TRACEMALLOC = False -# ==================================== -# Shared Test Helpers -# ==================================== class MockGlobalEndpointManager: """Mock global endpoint manager for testing.""" @@ -106,11 +117,6 @@ def raise_410_partition_split_error(*args, **kwargs): raise create_410_partition_split_error() -# =============================== -# Test Class -# =============================== - - @pytest.mark.cosmosEmulator class TestPartitionSplitRetryUnitAsync(unittest.IsolatedAsyncioTestCase): @@ -624,11 +630,8 @@ async def test_memory_bounded_no_leak_on_410_retries_async(self, mock_execute): - Memory growth is minimal (no recursive accumulation) - No infinite recursion (max depth = 0 for PK range queries) """ - # tracemalloc.start() begins tracing memory allocations to detect leaks tracemalloc.start() - # gc.collect() forces garbage collection to get accurate baseline memory measurement gc.collect() - # take_snapshot() captures current memory state for comparison after test snapshot_before = tracemalloc.take_snapshot() start_time = time.time() @@ -646,31 +649,17 @@ async def mock_fetch_function(options): await context._fetch_items_helper_with_retries(mock_fetch_function) elapsed_time = time.time() - start_time - # gc.collect() before snapshot ensures we measure actual leaks, not pending garbage gc.collect() snapshot_after = tracemalloc.take_snapshot() - # compare_to() shows memory difference between snapshots to identify growth top_stats = snapshot_after.compare_to(snapshot_before, 'lineno') memory_growth = sum(stat.size_diff for stat in top_stats if stat.size_diff > 0) peak_memory = tracemalloc.get_traced_memory()[1] - # tracemalloc.stop() ends memory tracing and frees tracing overhead tracemalloc.stop() # Collect metrics execute_calls = mock_execute.call_count refresh_calls = mock_client.refresh_routing_map_provider_call_count - # Print metrics - print(f"\n{'=' * 60}") - print("MEMORY METRICS (Async) - Partition Split Memory Verification") - print(f"{'=' * 60}") - print(f"Metrics:") - print(f" - Execute calls: {execute_calls} (bounded)") - print(f" - Refresh calls: {refresh_calls}") - print(f" - Elapsed time: {elapsed_time:.2f}s") - print(f" - Memory growth: {memory_growth / 1024:.2f} KB") - print(f" - Peak memory: {peak_memory / 1024:.2f} KB") - print(f"{'=' * 60}") assert execute_calls == 4, \ f"Execute calls should be bounded to 4, got {execute_calls}" @@ -695,10 +684,6 @@ async def mock_fetch_function(options): pk_execute_calls = mock_execute.call_count pk_refresh_calls = mock_client.refresh_routing_map_provider_call_count - print(f"\nPK Range Query:") - print(f" - Execute calls: {pk_execute_calls} (no retry)") - print(f" - Refresh calls: {pk_refresh_calls} (no recursion)") - print(f"{'=' * 60}\n") assert pk_execute_calls == 1, \ f"PK range query should have 1 execute call, got {pk_execute_calls}" @@ -1062,11 +1047,8 @@ async def test_queryfeed_populates_capture_dict_from_options_async(self): test-side injection. Catches the `options`-vs-`kwargs` extraction regression on the async path. """ - from unittest.mock import patch as _patch - - # Build a CosmosClientConnection without running __init__; we - # only need the attributes that the no-query (read-feed) branch - # of async __QueryFeed touches. + # Build the connection without running __init__; only the attributes + # used by the no-query (read-feed) branch of async __QueryFeed are needed. conn = object.__new__(CosmosClientConnection) conn.default_headers = {} conn.last_response_headers = {} @@ -1092,21 +1074,20 @@ async def test_queryfeed_populates_capture_dict_from_options_async(self): operation_type="ReadFeed", ) - # Patch the heavy collaborators inside async __QueryFeed's - # no-query branch so we can drive it without a real pipeline. - with _patch( + # Patch the heavy collaborators so the async no-query branch can run without a real pipeline. + with patch( "azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", return_value={}, ), \ - _patch( + patch( "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", new=AsyncMock(), ), \ - _patch( + patch( "azure.cosmos.aio._cosmos_client_connection_async._request_object.RequestObject", return_value=request_obj_mock, ), \ - _patch.object( + patch.object( CosmosClientConnection, "_CosmosClientConnection__Get", new=AsyncMock(return_value=( @@ -1405,3 +1386,214 @@ async def _noop_set_session(*args, **kwargs): assert decoded is not None assert decoded[_FIELD_VERSION] == _TOKEN_VERSION assert len(decoded["c"]) == 2 + + async def test_queryfeed_feed_range_routing_lookup_failure_stamps_checkpoint_async(self): + """If the routing lookup fails mid-page on the async path, a resumable + continuation must be saved on the client before the error is re-raised + so the caller can retry without losing progress. + """ + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + single_overlap = [{"id": "0", "minInclusive": "00", "maxExclusive": "FF"}] + routing_call_count = {"n": 0} + + async def overlap_side_effect(_rid, _ranges, _opts): + routing_call_count["n"] += 1 + # First call succeeds; the next routing lookup raises so the test + # exercises the failure path inside the mid-page loop. + if routing_call_count["n"] >= 2: + raise RuntimeError("routing-map-down") + return single_overlap + + client._routing_map_provider.get_overlapping_ranges = AsyncMock(side_effect=overlap_side_effect) + + async def _noop_set_session(*args, **kwargs): + return None + + with patch("azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", return_value={}): + with patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + side_effect=_noop_set_session, + ): + with patch.object(client, "_CosmosClientConnection__Post", new=AsyncMock()) as post_mock: + with pytest.raises(RuntimeError, match="routing-map-down"): + await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query="SELECT * FROM c", + options={"continuation": "legacy-inbound-token"}, + feed_range={ + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + }, + ) + # The routing failure must surface before any backend call. + assert post_mock.await_count == 0, ( + "__Post must not be awaited when the routing lookup " + f"raises (got {post_mock.await_count} awaits)" + ) + + # The original inbound continuation is saved unchanged so the caller + # can retry from the same position. + continuation = client.last_response_headers.get(HttpHeaders.Continuation) + assert continuation == "legacy-inbound-token", ( + "Expected the original inbound continuation to be saved on " + f"last_response_headers after the routing failure; got {continuation!r}" + ) + + async def test_queryfeed_v1_inbound_token_explodes_head_on_resume_async(self): + """Resume with a structured continuation token whose head range now spans + two partitions. The loop must send one request per child across two + pages and forward the parent's backend continuation to each child so + results are not replayed from the start. + """ + client = self._create_minimal_connection() + client._query_compatibility_mode = client._QueryCompatibilityMode.Default + client._routing_map_provider = MagicMock() + + child_left = {"id": "1", "minInclusive": "00", "maxExclusive": "7F"} + child_right = {"id": "2", "minInclusive": "7F", "maxExclusive": "FF"} + + async def overlap_side_effect(_rid, ranges, _opts): + requested = ranges[0] + if requested.min == "00" and requested.max == "FF": + return [child_left, child_right] + if requested.min == "00" and requested.max == "7F": + return [child_left] + if requested.min == "7F" and requested.max == "FF": + return [child_right] + return [] + + client._routing_map_provider.get_overlapping_ranges = AsyncMock(side_effect=overlap_side_effect) + + feed_range_dict = { + "Range": { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + } + } + full_range = routing_range.Range( + range_min="00", + range_max="FF", + isMinInclusive=True, + isMaxInclusive=False, + ) + query_text = "SELECT * FROM c" + # Build a structured continuation token whose single entry covers the + # full range and carries a backend continuation that must be forwarded + # to both children once the head is split. + inbound_token_payload = { + "v": 1, + _FIELD_COLLECTION_RID: "rid-c1", + _FIELD_QUERY_HASH: _hash_query_spec(query_text), + _FIELD_FEEDRANGE_HASH: _hash_feed_range(full_range), + _FIELD_CONTINUATIONS: [ + { + "min": "00", + "max": "FF", + "isMinInclusive": True, + "isMaxInclusive": False, + _FIELD_BACKEND_CONTINUATION: "parent-bc", + } + ], + } + inbound_token = _encode_token(inbound_token_payload) + + seen_request_continuations: List[str] = [] + seen_pkr_ids: List[str] = [] + post_call_count = {"n": 0} + + async def post_side_effect(_path, _request_params, _query, req_headers, **_kwargs): + post_call_count["n"] += 1 + seen_request_continuations.append(req_headers.get(HttpHeaders.Continuation)) + seen_pkr_ids.append(req_headers.get(HttpHeaders.PartitionKeyRangeID)) + return ( + {"Documents": [{"id": f"doc-{post_call_count['n']}"}]}, + {HttpHeaders.Continuation: None}, + ) + + async def _noop_set_session(*args, **kwargs): + return None + + with patch("azure.cosmos.aio._cosmos_client_connection_async.base.GetHeaders", return_value={}): + with patch( + "azure.cosmos.aio._cosmos_client_connection_async.base.set_session_token_header_async", + side_effect=_noop_set_session, + ): + with patch.object(client, "_CosmosClientConnection__Post", side_effect=post_side_effect): + # First page: should send a request to the left child with + # the original backend continuation. + docs_p1, headers_p1 = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=query_text, + options={"continuation": inbound_token}, + feed_range=feed_range_dict, + ) + + outbound_p1 = headers_p1.get(HttpHeaders.Continuation) + decoded_p1 = _decode_token(outbound_p1) + assert decoded_p1 is not None, ( + "Page 1 outbound continuation must be a structured " + f"token after the head is split; got {outbound_p1!r}" + ) + assert decoded_p1[_FIELD_VERSION] == _TOKEN_VERSION + assert len(decoded_p1[_FIELD_CONTINUATIONS]) == 1, ( + "After draining the left child, only the right child " + f"should remain in the outbound token; got {decoded_p1[_FIELD_CONTINUATIONS]!r}" + ) + surviving = decoded_p1[_FIELD_CONTINUATIONS][0] + assert surviving["min"] == "7F" and surviving["max"] == "FF", ( + "The remaining entry must be the right child's range; " + f"got {surviving!r}" + ) + assert surviving.get(_FIELD_BACKEND_CONTINUATION) == "parent-bc", ( + "The remaining child must carry the parent's backend " + f"continuation forward; got bc={surviving.get(_FIELD_BACKEND_CONTINUATION)!r}" + ) + + # Second page: should send a request to the right child + # with the same backend continuation. + docs_p2, headers_p2 = await client.QueryFeed( + path="/dbs/db/colls/c1/docs", + collection_id="rid-c1", + query=query_text, + options={"continuation": outbound_p1}, + feed_range=feed_range_dict, + ) + + assert post_call_count["n"] == 2, ( + "Expected exactly one request per child across the two pages; " + f"got {post_call_count['n']} request(s)" + ) + assert seen_request_continuations == ["parent-bc", "parent-bc"], ( + "Each child must receive the parent's backend continuation; " + f"got {seen_request_continuations!r}" + ) + assert seen_pkr_ids == [child_left["id"], child_right["id"]], ( + "Expected one POST per child PKR id in left-to-right order; " + f"got PKR ids {seen_pkr_ids!r}" + ) + assert docs_p1 == [{"id": "doc-1"}], ( + f"Page 1 must return only child_left's doc; got {docs_p1!r}" + ) + assert docs_p2 == [{"id": "doc-2"}], ( + f"Page 2 must return only child_right's doc; got {docs_p2!r}" + ) + outbound_p2 = headers_p2.get(HttpHeaders.Continuation) + assert outbound_p2 in (None, "", b""), ( + "Expected empty outbound continuation after fully draining " + f"both children; got {outbound_p2!r}" + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 6cc99aa6e074..32a2b080c7e5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -4,14 +4,16 @@ import os import unittest import uuid +from unittest.mock import patch import pytest +from azure.core.exceptions import ServiceResponseError import azure.cosmos._retry_utility as retry_utility import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.exceptions as exceptions import test_config -from azure.cosmos import http_constants, DatabaseProxy, _endpoint_discovery_retry_policy +from azure.cosmos import http_constants, DatabaseProxy, _endpoint_discovery_retry_policy, _synchronized_request from azure.cosmos._routing.feed_range_continuation import _decode_token from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo @@ -1024,6 +1026,539 @@ def test_query_items_parameters_none_with_options(self): self._delete_container_for_test(created_collection.id) + # Tests below verify that the per-request and client-level + # read_timeout reach every page fetch when results are walked + # one page at a time via by_page(). + + def _capture_pipeline_read_timeouts(self): + # Wraps the outgoing HTTP call so each request records its URL + # and the read_timeout the SDK passed along with it. + captured = [] + original = _synchronized_request._PipelineRunFunction + + def _wrapper(pipeline_client, request, **kwargs): + captured.append((str(request.url), kwargs.get("read_timeout"))) + return original(pipeline_client, request, **kwargs) + + return captured, patch.object( + _synchronized_request, "_PipelineRunFunction", side_effect=_wrapper + ) + + @staticmethod + def _doc_call_timeouts(captured): + # Keep only document page fetches. Other internal calls have + # their own timeout settings and are not part of these tests. + return [rt for (url, rt) in captured if "/docs" in url] + + def test_read_timeout_propagates_through_by_page_paging(self): + # Per-request read_timeout should reach every page fetch across + # single-partition queries, cross-partition queries, full reads, + # and change feeds when results are walked via by_page(). + container = self._create_container_for_test( + "by_page_read_timeout_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + # Seed across two partition keys so the cross-partition + # path actually fans out. + for i in range(6): + container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "data": i}) + + request_level_timeout = 25 + + # Single-partition query paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": 0}], + partition_key=0, + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0, "expected at least one page fetch") + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"single-partition by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + + # Cross-partition query paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = container.query_items( + query="SELECT * FROM c", + enable_cross_partition_query=True, + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"cross-partition by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + + # read_all_items paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = container.read_all_items( + max_item_count=2, + read_timeout=request_level_timeout, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"read_all_items by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + + # Change feed paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = container.query_items_change_feed( + start_time="Beginning", + max_item_count=2, + read_timeout=request_level_timeout, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"change feed by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container.id) + + def test_client_level_read_timeout_propagates_through_by_page_paging(self): + # When the client is built with a read_timeout and the caller + # does not pass a per-request one, that client value should + # still reach every page fetch. + container = self._create_container_for_test( + "by_page_client_read_timeout_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + ) + try: + for i in range(4): + container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": "p", "data": i}) + + client_timeout = 22 + + with cosmos_client.CosmosClient(self.host, self.credential, read_timeout=client_timeout) as ct_client: + ct_container = ct_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(container.id) + + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = ct_container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=1, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == client_timeout for rt in doc_timeouts), + f"client-level read_timeout did not reach all pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container.id) + + def test_client_level_short_read_timeout_fails_on_by_page(self): + # A tiny client-level read_timeout should cause every page + # fetch to fail, confirming that value actually reaches the + # network call. + container = self._create_container_for_test( + "by_page_short_client_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + ) + try: + for i in range(3): + container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": "p", "data": i}) + + with cosmos_client.CosmosClient( + self.host, self.credential, read_timeout=0.000000000001 + ) as short_client: + short_container = short_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(container.id) + with self.assertRaises((exceptions.CosmosClientTimeoutError, ServiceResponseError)): + pages = short_container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=1, + ).by_page() + for page in pages: + list(page) + finally: + self._delete_container_for_test(container.id) + + # Aggregate queries (COUNT, SUM, MAX) take a separate execution path + # from regular queries. The tests below confirm a per-request + # read_timeout still reaches every page fetch on that path. GROUP BY + # is excluded because the Python SDK does not declare it as + # supported and the gateway rejects those queries. ``c.amount`` is + # used instead of ``c.value`` because ``VALUE`` is a SQL keyword. + + def test_read_timeout_propagates_through_by_page_value_count_aggregate(self): + # Per-request read_timeout must reach every /docs/ page fetch + # when paging a ``SELECT VALUE COUNT(1) FROM c`` aggregate. + container = self._create_container_for_test( + "by_page_count_agg_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + for i in range(6): + container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + request_level_timeout = 25 + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + enable_cross_partition_query=True, + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0, "expected at least one page fetch") + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"VALUE COUNT aggregate by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container.id) + + def test_read_timeout_propagates_through_by_page_value_sum_aggregate(self): + # Per-request read_timeout must reach every /docs/ page fetch + # when paging a ``SELECT VALUE SUM(...) FROM c`` aggregate. + container = self._create_container_for_test( + "by_page_sum_agg_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + for i in range(6): + container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + request_level_timeout = 25 + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = container.query_items( + query="SELECT VALUE SUM(c.amount) FROM c WHERE IS_NUMBER(c.amount)", + enable_cross_partition_query=True, + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"VALUE SUM aggregate by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container.id) + + def test_read_timeout_propagates_through_by_page_value_max_aggregate(self): + # Per-request read_timeout must reach every /docs/ page fetch + # when paging a ``SELECT VALUE MAX(...) FROM c`` aggregate. MAX + # exercises the same MultiExecutionAggregator path that GROUP BY + # would, but is actually supported by the Python SDK. + container = self._create_container_for_test( + "by_page_max_agg_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + for i in range(6): + container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + request_level_timeout = 25 + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = container.query_items( + query="SELECT VALUE MAX(c.amount) FROM c", + enable_cross_partition_query=True, + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"VALUE MAX aggregate by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container.id) + + def test_client_level_read_timeout_propagates_through_aggregate_by_page(self): + # Mirror of test_client_level_read_timeout_propagates_through_by_page_paging + # but for the aggregate path. The client-level read_timeout must + # reach every /docs/ page fetch across the three aggregate shapes + # when no per-request override is set. + container = self._create_container_for_test( + "by_page_client_agg_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + for i in range(6): + container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + client_timeout = 22 + + with cosmos_client.CosmosClient( + self.host, self.credential, read_timeout=client_timeout + ) as ct_client: + ct_container = ct_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(container.id) + + for query in ( + "SELECT VALUE COUNT(1) FROM c", + "SELECT VALUE SUM(c.amount) FROM c WHERE IS_NUMBER(c.amount)", + "SELECT VALUE MAX(c.amount) FROM c", + ): + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + pages = ct_container.query_items( + query=query, + enable_cross_partition_query=True, + max_item_count=1, + ).by_page() + for page in pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater( + len(doc_timeouts), 0, + f"no /docs/ fetches for aggregate query {query!r}", + ) + # Tier-1 page fetches must carry the client value. + # The query-plan POST against /docs/ is a tier-2-ish + # call that the dispatch path strips kwargs from in + # some flows; tolerate a ``None`` there but require + # at least one /docs/ call to carry the value, and + # any non-None /docs/ call must match it exactly. + non_none = [rt for rt in doc_timeouts if rt is not None] + self.assertGreater( + len(non_none), 0, + f"every /docs/ call dropped read_timeout for aggregate {query!r}", + ) + self.assertTrue( + all(rt == client_timeout for rt in non_none), + f"client-level read_timeout did not reach all pages for " + f"aggregate {query!r}: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container.id) + + # by_page() coverage for the outer timeout, connection_timeout, + # resume from a continuation token, and the cross-partition AVG + # error. + + def _capture_pipeline_kwarg(self, kwarg_name): + # Wraps the outgoing HTTP call so each request records its URL + # and the value of the named kwarg the SDK passed along. + captured = [] + original = _synchronized_request._PipelineRunFunction + + def _wrapper(pipeline_client, request, **kwargs): + captured.append((str(request.url), kwargs.get(kwarg_name))) + return original(pipeline_client, request, **kwargs) + + return captured, patch.object( + _synchronized_request, "_PipelineRunFunction", side_effect=_wrapper, + ) + + def test_outer_timeout_propagates_through_by_page_paging(self): + # The outer wall-clock timeout must reach every page fetch when + # a cross-partition query is walked one page at a time. The + # remaining budget may decrement across attempts, but it must + # never be dropped and must never exceed the requested value. + container = self._create_container_for_test( + "by_page_outer_timeout_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + for i in range(6): + container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "data": i} + ) + + outer_timeout = 15 + captured, ctx = self._capture_pipeline_kwarg("timeout") + with ctx: + pages = container.query_items( + query="SELECT * FROM c", + enable_cross_partition_query=True, + max_item_count=1, + timeout=outer_timeout, + ).by_page() + for page in pages: + list(page) + doc_values = [t for (url, t) in captured if "/docs" in url] + self.assertGreater(len(doc_values), 0) + non_none = [t for t in doc_values if t is not None] + self.assertGreater(len(non_none), 0) + for t in non_none: + self.assertLessEqual(t, outer_timeout) + finally: + self._delete_container_for_test(container.id) + + def test_connection_timeout_propagates_through_by_page_paging(self): + # connection_timeout is per-attempt, so it must reach every + # page fetch as the same exact value the caller passed in. + container = self._create_container_for_test( + "by_page_conn_timeout_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + for i in range(6): + container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "data": i} + ) + + connection_timeout = 27 + captured, ctx = self._capture_pipeline_kwarg("connection_timeout") + with ctx: + pages = container.query_items( + query="SELECT * FROM c", + enable_cross_partition_query=True, + max_item_count=1, + connection_timeout=connection_timeout, + ).by_page() + for page in pages: + list(page) + doc_values = [t for (url, t) in captured if "/docs" in url] + self.assertGreater(len(doc_values), 0) + self.assertTrue(all(t == connection_timeout for t in doc_values)) + finally: + self._delete_container_for_test(container.id) + + def test_avg_aggregate_by_page_raises_value_error_cross_partition(self): + # A VALUE AVG query across multiple partitions cannot be merged + # to a correct result on the client, so the SDK must raise + # ValueError. A single-partition AVG is the supported workaround + # and is verified as a control at the end. + container = self._create_container_for_test( + "by_page_avg_cross_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + # Seed across many distinct partition key values so the + # query has data on multiple physical partitions to merge. + for i in range(20): + container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": f"pk_{i}", "amount": i + 10} + ) + + # A feed range that covers the full hash space forces the + # query to fan out across every physical partition. + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + with self.assertRaises(ValueError) as cm: + pages = container.query_items( + query="SELECT VALUE AVG(c.amount) FROM c", + feed_range=feed_range, + max_item_count=1, + ).by_page() + for page in pages: + list(page) + self.assertIn("AVG", str(cm.exception).upper()) + + # Control: a single-partition AVG must still succeed. + single_pages = container.query_items( + query="SELECT VALUE AVG(c.amount) FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "pk_0"}], + partition_key="pk_0", + max_item_count=1, + ).by_page() + single_result = [] + for page in single_pages: + single_result.extend(list(page)) + self.assertEqual(len(single_result), 1) + finally: + self._delete_container_for_test(container.id) + + def test_read_timeout_propagates_through_by_page_resume(self): + # When the caller resumes paging on a new pager built from a + # continuation token, the per-request read_timeout must reach + # every page fetch on the resumed pager too. + container = self._create_container_for_test( + "by_page_resume_timeout_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + try: + for i in range(10): + container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": "p", "data": i} + ) + + request_level_timeout = 19 + + # Pull one page and capture its continuation token. This + # initial call runs outside the capture context because we + # only want to assert on the resumed pager. + first_pages = container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=2, + ).by_page() + try: + first_page = next(first_pages) + except StopIteration: + self.fail("expected at least one page on the initial pull") + list(first_page) + continuation = first_pages.continuation_token + self.assertIsNotNone(continuation) + + captured, ctx = self._capture_pipeline_read_timeouts() + with ctx: + resumed_pages = container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=2, + read_timeout=request_level_timeout, + ).by_page(continuation_token=continuation) + for page in resumed_pages: + list(page) + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue(all(rt == request_level_timeout for rt in doc_timeouts)) + finally: + self._delete_container_for_test(container.id) + + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_aggregate_utils_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_aggregate_utils_unit_async.py new file mode 100644 index 000000000000..ea43e3341e52 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_query_aggregate_utils_unit_async.py @@ -0,0 +1,133 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +"""Async tests for SELECT VALUE aggregate classification and result merging. + +Runs the same checks the sync tests cover from inside an async test +class so the async client shows the same behavior when partial query +results from different partitions are classified and merged. +""" + +import unittest + +import pytest + +from azure.cosmos import _base +from azure.cosmos._query_aggregate_utils import ( + _AggregatePartialClassification, + _classify_aggregate_partial, + _get_select_value_aggregate_function, +) +from azure.cosmos._routing import feed_range_continuation as _frc +from azure.cosmos._routing.feed_range_continuation import ( + _count_page_items_from_partial_result, +) +from azure.cosmos.aio import _cosmos_client_connection_async as _async_conn + +pytestmark = pytest.mark.cosmosEmulator + + +class TestQueryAggregateUtilsAsync(unittest.IsolatedAsyncioTestCase): + """Async-side checks for the shared classifier and merge helpers. + + The helpers themselves are synchronous. Running them from an async + test class confirms the async client uses the same helpers and + behaves the same way as the sync client. + """ + + # Boolean partial rows should not be treated as numeric aggregates. + async def test_classify_aggregate_partial_excludes_boolean_value_rows_async(self): + query = "SELECT VALUE COUNT(1) FROM c" + assert _classify_aggregate_partial([True], query) == _AggregatePartialClassification.NONE + assert _classify_aggregate_partial([False], query) == _AggregatePartialClassification.NONE + + # Boolean rows from each partition should be joined into a list, + # not added together. + async def test_value_count_boolean_fragments_concatenate_via_base_merge_async(self): + query = "SELECT VALUE COUNT(1) > 0 FROM c" + assert _count_page_items_from_partial_result({"Documents": [True]}, query) == 1 + merged = _base._merge_query_results( + {"Documents": [True]}, {"Documents": [True]}, query, + ) + assert merged["Documents"] == [True, True] + + # A plain numeric projection without an aggregate function should + # return one row per document, not a summed total. + async def test_value_numeric_non_aggregate_concat_async(self): + query = "SELECT VALUE c.score FROM c" + assert _get_select_value_aggregate_function(query) is None + assert _classify_aggregate_partial([7], query) == _AggregatePartialClassification.NONE + merged = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3]}, query, + ) + assert merged["Documents"] == [7, 3] + + async def test_value_float_non_aggregate_concat_async(self): + query = "SELECT VALUE c.ratio FROM c" + merged = _base._merge_query_results( + {"Documents": [1.5]}, {"Documents": [2.25]}, query, + ) + assert merged["Documents"] == [1.5, 2.25] + + async def test_value_numeric_non_aggregate_three_way_concat_async(self): + query = "SELECT VALUE c.score FROM c" + merged = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3]}, query, + ) + merged = _base._merge_query_results(merged, {"Documents": [11]}, query) + assert merged["Documents"] == [7, 3, 11] + + # MIN keeps the smaller value and MAX keeps the larger value when + # combining one row from each partition. + async def test_value_min_max_merge_async(self): + min_query = "SELECT VALUE MIN(c.score) FROM c" + assert _get_select_value_aggregate_function(min_query) == "MIN" + merged_min = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3]}, min_query, + ) + assert merged_min["Documents"] == [3] + + max_query = "SELECT VALUE MAX(c.score) FROM c" + assert _get_select_value_aggregate_function(max_query) == "MAX" + merged_max = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3]}, max_query, + ) + assert merged_max["Documents"] == [7] + + # MIN and MAX should be detected regardless of keyword casing. + async def test_value_min_max_lowercase_keyword_detected_async(self): + assert _get_select_value_aggregate_function( + "select value min(c.score) from c" + ) == "MIN" + assert _get_select_value_aggregate_function( + "Select Value Max(c.score) From c" + ) == "MAX" + + # MIN and MAX should still pick the right value when one partition + # returns an int and another a float, or when values are negative. + async def test_value_min_max_merge_mixed_numeric_types_async(self): + min_query = "SELECT VALUE MIN(c.score) FROM c" + merged_min = _base._merge_query_results( + {"Documents": [7]}, {"Documents": [3.5]}, min_query, + ) + assert merged_min["Documents"] == [3.5] + + max_query = "SELECT VALUE MAX(c.score) FROM c" + merged_max = _base._merge_query_results( + {"Documents": [-1]}, {"Documents": [-5]}, max_query, + ) + assert merged_max["Documents"] == [-1] + + # The async client should call the same merge and counting helpers + # as the sync client, so identical behavior is guaranteed. + async def test_async_module_reuses_shared_classifier_and_merge_async(self): + assert _async_conn.base._merge_query_results is _base._merge_query_results + assert _async_conn._count_page_items_from_partial_result is ( + _frc._count_page_items_from_partial_result + ) + assert _async_conn._count_page_items_from_partial_result is ( + _count_page_items_from_partial_result + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_async.py index e63a222b3a0e..4c7aec788f60 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_async.py @@ -5,10 +5,13 @@ import unittest import uuid from asyncio import gather +from unittest.mock import patch import pytest +from azure.core.exceptions import ServiceResponseError import azure.cosmos.aio._retry_utility_async as retry_utility +import azure.cosmos.aio._asynchronous_request as _asynchronous_request import azure.cosmos.exceptions as exceptions import azure.cosmos.cosmos_client as sync_cosmos_client import test_config @@ -1007,6 +1010,489 @@ async def test_query_items_parameters_none_with_options_async(self): self._delete_container_for_test(container_id) + # Async variants of the by_page() read_timeout coverage. Verify + # that the per-request and client-level read_timeout reach every + # page fetch when results are walked one page at a time. + + def _capture_pipeline_read_timeouts_async(self): + # Wraps the outgoing async HTTP call so each request records + # its URL and the read_timeout the SDK passed along with it. + captured = [] + original = _asynchronous_request._PipelineRunFunction + + async def _wrapper(pipeline_client, request, **kwargs): + captured.append((str(request.url), kwargs.get("read_timeout"))) + return await original(pipeline_client, request, **kwargs) + + return captured, patch.object( + _asynchronous_request, "_PipelineRunFunction", side_effect=_wrapper + ) + + @staticmethod + def _doc_call_timeouts(captured): + # Keep only document page fetches. Other internal calls have + # their own timeout settings and are not part of these tests. + return [rt for (url, rt) in captured if "/docs" in url] + + async def test_read_timeout_propagates_through_by_page_paging_async(self): + # Per-request read_timeout should reach every page fetch across + # single-partition queries, cross-partition queries, full reads, + # and change feeds when results are walked via by_page(). + container_id = "by_page_read_timeout_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(6): + await container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "data": i}) + + request_level_timeout = 25 + + # Single-partition query paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": 0}], + partition_key=0, + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0, "expected at least one page fetch") + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"single-partition by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + + # Cross-partition query paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = container.query_items( + query="SELECT * FROM c", + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"cross-partition by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + + # read_all_items paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = container.read_all_items( + max_item_count=2, + read_timeout=request_level_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"read_all_items by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + + # Change feed paged with by_page(). + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = container.query_items_change_feed( + start_time="Beginning", + max_item_count=2, + read_timeout=request_level_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"change feed by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container_id) + + async def test_client_level_read_timeout_propagates_through_by_page_paging_async(self): + # When the async client is built with a read_timeout and the + # caller does not pass a per-request one, that client value + # should still reach every page fetch. + container_id = "by_page_client_read_timeout_async_" + str(uuid.uuid4()) + container = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + for i in range(4): + await container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": "p", "data": i}) + + client_timeout = 22 + + async with CosmosClient(self.host, self.masterKey, read_timeout=client_timeout) as ct_client: + ct_container = ct_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(container_id) + + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = ct_container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=1, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == client_timeout for rt in doc_timeouts), + f"client-level read_timeout did not reach all pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container_id) + + async def test_client_level_short_read_timeout_fails_on_by_page_async(self): + # A tiny client-level read_timeout should cause every page + # fetch to fail, confirming that value actually reaches the + # async network call. + container_id = "by_page_short_client_async_" + str(uuid.uuid4()) + container = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + for i in range(3): + await container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": "p", "data": i}) + + async with CosmosClient(self.host, self.masterKey, read_timeout=0.000000000001) as short_client: + short_container = short_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(container_id) + with self.assertRaises((exceptions.CosmosClientTimeoutError, ServiceResponseError)): + pages = short_container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=1, + ).by_page() + async for page in pages: + [item async for item in page] + finally: + self._delete_container_for_test(container_id) + + # Aggregate queries (COUNT, SUM, MAX) take a separate execution path + # from regular queries. The tests below confirm a per-request + # read_timeout still reaches every page fetch on that path. GROUP BY + # is excluded because the Python SDK does not declare it as + # supported and the gateway rejects those queries. ``c.amount`` is + # used instead of ``c.value`` because ``VALUE`` is a SQL keyword. + + async def test_read_timeout_propagates_through_by_page_value_count_aggregate_async(self): + container_id = "by_page_count_agg_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(6): + await container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + request_level_timeout = 25 + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = container.query_items( + query="SELECT VALUE COUNT(1) FROM c", + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0, "expected at least one page fetch") + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"VALUE COUNT aggregate by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container_id) + + async def test_read_timeout_propagates_through_by_page_value_sum_aggregate_async(self): + container_id = "by_page_sum_agg_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(6): + await container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + request_level_timeout = 25 + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = container.query_items( + query="SELECT VALUE SUM(c.amount) FROM c WHERE IS_NUMBER(c.amount)", + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"VALUE SUM aggregate by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container_id) + + async def test_read_timeout_propagates_through_by_page_value_max_aggregate_async(self): + # MAX exercises the same MultiExecutionAggregator path GROUP BY + # would, but is actually supported by the Python SDK. + container_id = "by_page_max_agg_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(6): + await container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + request_level_timeout = 25 + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = container.query_items( + query="SELECT VALUE MAX(c.amount) FROM c", + max_item_count=1, + read_timeout=request_level_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue( + all(rt == request_level_timeout for rt in doc_timeouts), + f"VALUE MAX aggregate by_page() dropped read_timeout on some pages: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container_id) + + async def test_client_level_read_timeout_propagates_through_aggregate_by_page_async(self): + # Async mirror of test_client_level_read_timeout_propagates_through_aggregate_by_page. + container_id = "by_page_client_agg_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(6): + await container.create_item({"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "amount": i + 10}) + + client_timeout = 22 + + async with CosmosClient(self.host, self.masterKey, read_timeout=client_timeout) as ct_client: + ct_container = ct_client.get_database_client(self.TEST_DATABASE_ID) \ + .get_container_client(container_id) + + for query in ( + "SELECT VALUE COUNT(1) FROM c", + "SELECT VALUE SUM(c.amount) FROM c WHERE IS_NUMBER(c.amount)", + "SELECT VALUE MAX(c.amount) FROM c", + ): + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + pages = ct_container.query_items( + query=query, + max_item_count=1, + ).by_page() + async for page in pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater( + len(doc_timeouts), 0, + f"no /docs/ fetches for aggregate query {query!r}", + ) + # See sync sibling test for why ``None`` is tolerated. + non_none = [rt for rt in doc_timeouts if rt is not None] + self.assertGreater( + len(non_none), 0, + f"every /docs/ call dropped read_timeout for aggregate {query!r}", + ) + self.assertTrue( + all(rt == client_timeout for rt in non_none), + f"client-level read_timeout did not reach all pages for " + f"aggregate {query!r}: {doc_timeouts}", + ) + finally: + self._delete_container_for_test(container_id) + + # by_page() coverage for the outer timeout, connection_timeout, + # resume from a continuation token, and the cross-partition AVG + # error. See the sync siblings in test_query.py for the rationale. + + def _capture_pipeline_kwarg_async(self, kwarg_name): + captured = [] + original = _asynchronous_request._PipelineRunFunction + + async def _wrapper(pipeline_client, request, **kwargs): + captured.append((str(request.url), kwargs.get(kwarg_name))) + return await original(pipeline_client, request, **kwargs) + + return captured, patch.object( + _asynchronous_request, "_PipelineRunFunction", side_effect=_wrapper, + ) + + async def test_outer_timeout_propagates_through_by_page_paging_async(self): + # The outer wall-clock timeout must reach every page fetch when + # a cross-partition query is walked one page at a time. + container_id = "by_page_outer_timeout_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(6): + await container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "data": i} + ) + + outer_timeout = 15 + captured, ctx = self._capture_pipeline_kwarg_async("timeout") + with ctx: + pages = container.query_items( + query="SELECT * FROM c", + max_item_count=1, + timeout=outer_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_values = [t for (url, t) in captured if "/docs" in url] + self.assertGreater(len(doc_values), 0) + non_none = [t for t in doc_values if t is not None] + self.assertGreater(len(non_none), 0) + for t in non_none: + self.assertLessEqual(t, outer_timeout) + finally: + self._delete_container_for_test(container_id) + + async def test_connection_timeout_propagates_through_by_page_paging_async(self): + # connection_timeout is per-attempt; it must reach every page + # fetch as the same exact value the caller passed in. + container_id = "by_page_conn_timeout_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(6): + await container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": i % 2, "data": i} + ) + + connection_timeout = 27 + captured, ctx = self._capture_pipeline_kwarg_async("connection_timeout") + with ctx: + pages = container.query_items( + query="SELECT * FROM c", + max_item_count=1, + connection_timeout=connection_timeout, + ).by_page() + async for page in pages: + [item async for item in page] + doc_values = [t for (url, t) in captured if "/docs" in url] + self.assertGreater(len(doc_values), 0) + self.assertTrue(all(t == connection_timeout for t in doc_values)) + finally: + self._delete_container_for_test(container_id) + + async def test_avg_aggregate_by_page_raises_value_error_cross_partition_async(self): + # A VALUE AVG query across multiple partitions must raise + # ValueError. A single-partition AVG is verified as a control. + container_id = "by_page_avg_cross_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(20): + await container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": f"pk_{i}", "amount": i + 10} + ) + + # A feed range that covers the full hash space forces the + # query to fan out across every physical partition. + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + with self.assertRaises(ValueError) as cm: + pages = container.query_items( + query="SELECT VALUE AVG(c.amount) FROM c", + feed_range=feed_range, + max_item_count=1, + ).by_page() + async for page in pages: + [item async for item in page] + self.assertIn("AVG", str(cm.exception).upper()) + + # Control: a single-partition AVG must still succeed. + single_pages = container.query_items( + query="SELECT VALUE AVG(c.amount) FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "pk_0"}], + partition_key="pk_0", + max_item_count=1, + ).by_page() + single_result = [] + async for page in single_pages: + async for item in page: + single_result.append(item) + self.assertEqual(len(single_result), 1) + finally: + self._delete_container_for_test(container_id) + + async def test_read_timeout_propagates_through_by_page_resume_async(self): + # When the caller resumes paging on a new pager built from a + # continuation token, the per-request read_timeout must reach + # every page fetch on the resumed pager too. + container_id = "by_page_resume_timeout_async_" + str(uuid.uuid4()) + container = self._create_container_for_test( + container_id, PartitionKey(path="/pk"), offer_throughput=11000, + ) + try: + for i in range(10): + await container.create_item( + {"id": f"item_{i}_{uuid.uuid4()}", "pk": "p", "data": i} + ) + + request_level_timeout = 19 + + # Pull one page and capture its continuation token. This + # initial call runs outside the capture context because we + # only want to assert on the resumed pager. + first_pages = container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=2, + ).by_page() + first_page = await first_pages.__anext__() + _ = [item async for item in first_page] + continuation = first_pages.continuation_token + self.assertIsNotNone(continuation) + + captured, ctx = self._capture_pipeline_read_timeouts_async() + with ctx: + resumed_pages = container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "p"}], + partition_key="p", + max_item_count=2, + read_timeout=request_level_timeout, + ).by_page(continuation_token=continuation) + async for page in resumed_pages: + [item async for item in page] + doc_timeouts = self._doc_call_timeouts(captured) + self.assertGreater(len(doc_timeouts), 0) + self.assertTrue(all(rt == request_level_timeout for rt in doc_timeouts)) + finally: + self._delete_container_for_test(container_id) + + if __name__ == '__main__': unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range.py index b7f65c7619bc..4ad409e5f2b4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range.py @@ -41,6 +41,8 @@ def _build_lane_suffix(): TEST_OFFER_THROUGHPUTS = [CONFIG.THROUGHPUT_FOR_1_PARTITION, CONFIG.THROUGHPUT_FOR_5_PARTITIONS] PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY PK_VALUES = ('pk1', 'pk2', 'pk3') +RUN_MARKER_FIELD = "_run_marker" +RUN_MARKER_VALUE = str(uuid.uuid4()) def add_all_pk_values_to_set(items: List[Mapping[str, str]], pk_value_set: Set[str]) -> None: if len(items) == 0: return @@ -51,7 +53,15 @@ def add_all_pk_values_to_set(items: List[Mapping[str, str]], pk_value_set: Set[s @pytest.fixture(scope="class", autouse=True) def setup_and_teardown(): print("Setup: This runs before any tests") - document_definitions = [{PARTITION_KEY: pk, 'id': str(uuid.uuid4()), 'value': 100} for pk in PK_VALUES] + document_definitions = [ + { + PARTITION_KEY: pk, + 'id': str(uuid.uuid4()), + 'value': 100, + RUN_MARKER_FIELD: RUN_MARKER_VALUE, + } + for pk in PK_VALUES + ] key_db = CosmosClient(HOST, KEY).get_database_client(DATABASE_ID) data_db = test_config.TestConfig.create_data_client().get_database_client(DATABASE_ID) @@ -164,7 +174,10 @@ def test_query_with_feed_range_for_a_full_range(self, setup, container_id): def test_query_with_avg_aggregate_across_full_feed_range_raises(self, setup): """AVG over a feed_range spanning multiple partitions must raise.""" container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) - query = 'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"])' + query = ( + f'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) # Full hash range covers every physical partition of the container. full_range = test_config.create_range( @@ -176,7 +189,11 @@ def test_query_with_avg_aggregate_across_full_feed_range_raises(self, setup): feed_range = test_config.create_feed_range_in_dict(full_range) with pytest.raises(ValueError) as excinfo: - list(container.query_items(query=query, feed_range=feed_range)) + list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) message = str(excinfo.value) assert "Unsupported query shape for range-scoped pagination" in message @@ -186,10 +203,17 @@ def test_query_with_avg_aggregate_single_partition_feed_range_succeeds(self, set """AVG scoped to a single-partition feed_range must still succeed.""" # Multi-partition container, but the feed_range maps to one partition. container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) - query = 'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"])' + query = ( + f'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) feed_range = container.feed_range_from_partition_key(PK_VALUES[0]) - items = list(container.query_items(query=query, feed_range=feed_range)) + items = list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) # Seed data has value=100 for every document. assert items, "Single-partition AVG must return at least one result row" @@ -199,11 +223,183 @@ def test_query_with_avg_aggregate_single_partition_feed_range_succeeds(self, set single_container = get_container(setup, SINGLE_PARTITION_CONTAINER_ID) single_feed_range = single_container.feed_range_from_partition_key(PK_VALUES[0]) single_items = list(single_container.query_items( - query=query, feed_range=single_feed_range, + query=query, + feed_range=single_feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], )) assert single_items, "Single-partition container AVG must return a row" assert single_items[0] == 100 + # The next few tests run against a multi-partition container and check + # that combining partial results returns lists for plain projections and + # single values for MIN, MAX, SUM, and COUNT. + + def test_query_value_numeric_field_across_full_feed_range_returns_list(self, setup): + """Numeric VALUE projections should return one row per document, not a sum.""" + container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE c["value"] FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) + + # All seeded docs have value=100; we expect one entry per doc. + assert len(items) == len(PK_VALUES), ( + f"Expected one value per seeded doc ({len(PK_VALUES)}); got {len(items)}." + ) + assert all(item == 100 for item in items), ( + f"Expected every value to be 100; got {items}" + ) + assert sum(items) == 100 * len(PK_VALUES) + + def test_query_value_boolean_expression_across_full_feed_range_returns_list(self, setup): + """Boolean VALUE projections should return one boolean per document.""" + container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE c["value"] > 0 FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) + + assert len(items) == len(PK_VALUES), ( + f"Expected one boolean per seeded doc ({len(PK_VALUES)}); got {len(items)}." + ) + assert all(item is True for item in items), ( + f"All seeded values are 100 > 0; expected every row to be True. Got {items}" + ) + + def test_query_value_min_across_full_feed_range_returns_scalar(self, setup): + """MIN over a multi-partition feed_range should return one value (the smallest).""" + container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE MIN(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) + + assert len(items) == 1, ( + f"MIN should return a single value across partitions; got {len(items)} rows: {items}" + ) + assert items[0] == 100 + + def test_query_value_max_across_full_feed_range_returns_scalar(self, setup): + """MAX over a multi-partition feed_range should return one value (the largest).""" + container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE MAX(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) + + assert len(items) == 1, ( + f"MAX should return a single value across partitions; got {len(items)} rows: {items}" + ) + assert items[0] == 100 + + def test_query_value_sum_across_full_feed_range_still_sums(self, setup): + """SUM should still add per-partition totals together.""" + container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE SUM(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) + + assert len(items) == 1, ( + f"SUM should return a single value across partitions; got {len(items)} rows: {items}" + ) + assert items[0] == 100 * len(PK_VALUES) + + def test_query_value_count_across_full_feed_range_still_counts(self, setup): + """COUNT should still return a single total across partitions.""" + container = get_container(setup, MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE COUNT(1) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = list(container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )) + + assert len(items) == 1 + assert items[0] == len(PK_VALUES) + @pytest.mark.parametrize('container_id', TEST_CONTAINERS_IDS) @pytest.mark.cosmosSplit def test_query_with_feed_range_during_partition_split_combined(self, setup, container_id): @@ -463,4 +659,3 @@ def test_query_with_continuation(self, setup): if __name__ == "__main__": unittest.main() - diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_async.py index 6e31586100b1..76e2452115e7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_async.py @@ -40,6 +40,8 @@ def _build_lane_suffix(): TEST_OFFER_THROUGHPUTS = [CONFIG.THROUGHPUT_FOR_1_PARTITION, CONFIG.THROUGHPUT_FOR_5_PARTITIONS] PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY PK_VALUES = ('pk1', 'pk2', 'pk3') +RUN_MARKER_FIELD = "_run_marker" +RUN_MARKER_VALUE = str(uuid.uuid4()) async def add_all_pk_values_to_set_async(items: List[Mapping[str, str]], pk_value_set: Set[str]) -> None: if len(items) == 0: @@ -58,7 +60,15 @@ async def setup_and_teardown_async(request): token acquisition + endpoint discovery cost by the number of tests. """ print("Setup: This runs once per test class") - document_definitions = [{PARTITION_KEY: pk, 'id': str(uuid.uuid4()), 'value': 100} for pk in PK_VALUES] + document_definitions = [ + { + PARTITION_KEY: pk, + 'id': str(uuid.uuid4()), + 'value': 100, + RUN_MARKER_FIELD: RUN_MARKER_VALUE, + } + for pk in PK_VALUES + ] # Key-auth client for control-plane (container creation) key_client = CosmosClient(HOST, KEY) @@ -189,7 +199,10 @@ async def test_query_with_feed_range_for_a_full_range_async(self, container_id): async def test_query_with_avg_aggregate_across_full_feed_range_raises_async(self): """AVG over a feed_range spanning multiple partitions must raise.""" container = self.get_container(MULTI_PARTITION_CONTAINER_ID) - query = 'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"])' + query = ( + f'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) # Full hash range covers every physical partition of the container. full_range = test_config.create_range( @@ -202,7 +215,9 @@ async def test_query_with_avg_aggregate_across_full_feed_range_raises_async(self with pytest.raises(ValueError) as excinfo: _ = [item async for item in container.query_items( - query=query, feed_range=feed_range, + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], )] message = str(excinfo.value) @@ -213,11 +228,16 @@ async def test_query_with_avg_aggregate_single_partition_feed_range_succeeds_asy """AVG scoped to a single-partition feed_range must still succeed.""" # Multi-partition container, but the feed_range maps to one partition. container = self.get_container(MULTI_PARTITION_CONTAINER_ID) - query = 'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"])' + query = ( + f'SELECT VALUE AVG(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) feed_range = await container.feed_range_from_partition_key(PK_VALUES[0]) items = [item async for item in container.query_items( - query=query, feed_range=feed_range, + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], )] # Seed data has value=100 for every document. @@ -228,11 +248,183 @@ async def test_query_with_avg_aggregate_single_partition_feed_range_succeeds_asy single_container = self.get_container(SINGLE_PARTITION_CONTAINER_ID) single_feed_range = await single_container.feed_range_from_partition_key(PK_VALUES[0]) single_items = [item async for item in single_container.query_items( - query=query, feed_range=single_feed_range, + query=query, + feed_range=single_feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], )] assert single_items, "Single-partition container AVG must return a row" assert single_items[0] == 100 + # The next few tests run against a multi-partition container + # (async variant) and confirm that combining partial results + # returns the right shape: lists for non-aggregate projections, + # single values for MIN, MAX, SUM, and COUNT. + + async def test_query_value_numeric_field_across_full_feed_range_returns_list_async(self): + """Numeric VALUE projections should return one row per document, not a sum.""" + container = self.get_container(MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE c["value"] FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = [item async for item in container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )] + + assert len(items) == len(PK_VALUES), ( + f"Expected one value per seeded doc ({len(PK_VALUES)}); got {len(items)}." + ) + assert all(item == 100 for item in items), ( + f"Expected every value to be 100; got {items}" + ) + assert sum(items) == 100 * len(PK_VALUES) + + async def test_query_value_boolean_expression_across_full_feed_range_returns_list_async(self): + """Boolean VALUE projections should return one boolean per document.""" + container = self.get_container(MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE c["value"] > 0 FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = [item async for item in container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )] + + assert len(items) == len(PK_VALUES), ( + f"Expected one boolean per seeded doc ({len(PK_VALUES)}); got {len(items)}." + ) + assert all(item is True for item in items), ( + f"All seeded values are 100 > 0; expected every row to be True. Got {items}" + ) + + async def test_query_value_min_across_full_feed_range_returns_scalar_async(self): + """MIN over a multi-partition feed_range should return one value (the smallest).""" + container = self.get_container(MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE MIN(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = [item async for item in container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )] + + assert len(items) == 1, ( + f"MIN should return a single value across partitions; got {len(items)} rows: {items}" + ) + assert items[0] == 100 + + async def test_query_value_max_across_full_feed_range_returns_scalar_async(self): + """MAX over a multi-partition feed_range should return one value (the largest).""" + container = self.get_container(MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE MAX(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = [item async for item in container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )] + + assert len(items) == 1, ( + f"MAX should return a single value across partitions; got {len(items)} rows: {items}" + ) + assert items[0] == 100 + + async def test_query_value_sum_across_full_feed_range_still_sums_async(self): + """SUM should still add per-partition totals together.""" + container = self.get_container(MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE SUM(c["value"]) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = [item async for item in container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )] + + assert len(items) == 1, ( + f"SUM should return a single value across partitions; got {len(items)} rows: {items}" + ) + assert items[0] == 100 * len(PK_VALUES) + + async def test_query_value_count_across_full_feed_range_still_counts_async(self): + """COUNT should still return a single total across partitions.""" + container = self.get_container(MULTI_PARTITION_CONTAINER_ID) + query = ( + f'SELECT VALUE COUNT(1) FROM c WHERE IS_DEFINED(c["value"]) ' + f'AND c["{RUN_MARKER_FIELD}"] = @run_marker' + ) + + full_range = test_config.create_range( + range_min="", + range_max="FF", + is_min_inclusive=True, + is_max_inclusive=False, + ) + feed_range = test_config.create_feed_range_in_dict(full_range) + + items = [item async for item in container.query_items( + query=query, + feed_range=feed_range, + parameters=[{"name": "@run_marker", "value": RUN_MARKER_VALUE}], + )] + + assert len(items) == 1 + assert items[0] == len(PK_VALUES) + @pytest.mark.skip(reason="will be moved to a new pipeline") @pytest.mark.parametrize('container_id', TEST_CONTAINERS_IDS) async def test_query_with_feed_range_async_during_back_to_back_partition_splits_async(self, container_id): @@ -536,4 +728,3 @@ async def test_query_with_continuation_async(self): if __name__ == "__main__": unittest.main() - diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py index ba6d5800e21a..de4279fa12eb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_feed_range_multipartition_async.py @@ -11,15 +11,13 @@ import pytest_asyncio import test_config -from azure.cosmos import _base +from azure.cosmos import _base, documents from azure.cosmos import http_constants from azure.cosmos.aio import CosmosClient from azure.cosmos._routing.feed_range_continuation import _decode_token from azure.cosmos.partition_key import PartitionKey CONFIG = test_config.TestConfig() -HOST = CONFIG.host -KEY = CONFIG.masterKey DATABASE_ID = CONFIG.TEST_DATABASE_ID REPRO_CONTAINER_ID = "FeedRangeMultiPartitionAsync-" + str(uuid.uuid4()) @@ -31,7 +29,9 @@ def _client() -> CosmosClient: - return CosmosClient(HOST, KEY) + """Return a data-plane client. Uses AAD when COSMOS_TEST_DATA_AUTH_MODE=aad, + key auth otherwise, so the AAD lane actually exercises AAD.""" + return test_config.TestConfig.create_data_client_async() def _get_container(client: CosmosClient): @@ -113,15 +113,13 @@ async def setup_and_teardown_async(): @pytest.mark.cosmosQuery +@pytest.mark.cosmosAADQuery @pytest.mark.asyncio @pytest.mark.usefixtures("setup_and_teardown_async") class TestFeedRangeMultiPartitionAsync: """Async end-to-end tests for feed_range queries that overlap multiple physical partitions.""" - # ------------------------------------------------------------------ # - # Single-partition control - # ------------------------------------------------------------------ # async def test_single_partition_feed_range_async(self): """Single-partition regression guard.""" client = _client() @@ -174,9 +172,6 @@ async def test_single_partition_feed_range_async(self): await client.close() - # ------------------------------------------------------------------ # - # Two-partition feed_range - # ------------------------------------------------------------------ # async def test_two_partition_feed_range_async(self): client = _client() try: @@ -511,9 +506,6 @@ async def _failing_post(*args, **kwargs): finally: await client.close() - # ------------------------------------------------------------------ # - # Three-way overlap - # ------------------------------------------------------------------ # async def test_three_way_overlap_async(self): client = _client() try: @@ -566,9 +558,6 @@ async def test_three_way_overlap_async(self): finally: await client.close() - # ------------------------------------------------------------------ # - # Post-split resume (slow) - # ------------------------------------------------------------------ # @pytest.mark.cosmosSplit @pytest.mark.cosmosAADSplit async def test_post_split_resume_async(self): @@ -607,10 +596,19 @@ async def test_post_split_resume_async(self): # Step 2 — trigger a real split. target_throughput = max(REPRO_THROUGHPUT * 2, 60000) + # Split trigger is a control-plane throughput operation; run it via + # key-auth container even in AAD lanes (same pattern as other split tests). + key_client_for_split = CosmosClient(CONFIG.host, CONFIG.masterKey) try: - await test_config.TestConfig.trigger_split_async(container, target_throughput) + key_container_for_split = _get_container(key_client_for_split) + await test_config.TestConfig.trigger_split_async( + key_container_for_split, + target_throughput, + ) except unittest.SkipTest: raise + finally: + await key_client_for_split.close() await asyncio.sleep(10) _ = [fr async for fr in container.read_feed_ranges(force_refresh=True)] @@ -658,9 +656,6 @@ async def test_post_split_resume_async(self): finally: await client.close() - # ------------------------------------------------------------------ # - # Legacy opaque token compatibility - # ------------------------------------------------------------------ # async def test_legacy_opaque_token_compat_async(self, caplog): """Use an opaque continuation token and verify restart behavior.""" client = _client() @@ -725,9 +720,6 @@ async def test_legacy_opaque_token_compat_async(self, caplog): finally: await client.close() - # ------------------------------------------------------------------ # - # Identity-fingerprint mismatch rejection (live half) - # ------------------------------------------------------------------ # async def test_token_identity_mismatch_rejected_async(self): """Live identity-mismatch rejection test.""" client = _client() @@ -806,7 +798,286 @@ async def _drain(p): finally: await client.close() + async def test_full_partition_key_query_pagination_resume_async(self): + """Query with a full partition key on a hierarchical container, drain + page 1, resume from the returned continuation token, and confirm the + resumed pages match the remaining pages of a fresh iterator and + cover the same documents as a baseline scan in the same order. + """ + client = _client() + try: + db = client.get_database_client(DATABASE_ID) + container_id = "FeedRangeMultiPartitionAsyncFullPK-" + str(uuid.uuid4()) + created_container = await db.create_container_if_not_exists( + id=container_id, + partition_key=PartitionKey( + path=['/state', '/city', '/zipcode'], + kind=documents.PartitionKind.MultiHash), + offer_throughput=400, + ) + try: + full_key = ['CA', 'Oxnard', '93033'] + for i in range(25): + await created_container.upsert_item({ + 'id': f'full-pk-doc-async-{i:03d}', + 'state': full_key[0], + 'city': full_key[1], + 'zipcode': full_key[2], + 'value': i, + }) + for i in range(5): + await created_container.upsert_item({ + 'id': f'other-doc-async-{i:03d}', + 'state': 'WA', + 'city': 'Seattle', + 'zipcode': f'98{i:03d}', + 'value': i, + }) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_container.query_items( + query=query, + partition_key=full_key, + max_item_count=7, + ) + + pager = query_iterable.by_page() + first_page_iter = await pager.__anext__() + first_page = [it async for it in first_page_iter] + assert first_page, ( + "first page must contain at least one item to exercise resume") + token = pager.continuation_token + assert token, ( + "expected a non-empty continuation token after page 1 to " + "exercise the resume path") + + expected_remaining_ids: List[str] = [] + async for page in pager: + expected_remaining_ids.extend( + [it['id'] async for it in page]) + + resumed_remaining_ids: List[str] = [] + async for page in query_iterable.by_page(token): + resumed_remaining_ids.extend( + [it['id'] async for it in page]) + + assert expected_remaining_ids == resumed_remaining_ids, ( + "Pages returned after resuming from the continuation token " + "must match the pages returned by draining a fresh iterator") + + baseline_ids: List[str] = [] + async for item in created_container.query_items( + query=query, partition_key=full_key): + baseline_ids.append(item['id']) + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids, ( + "Page 1 plus the resumed pages must equal the baseline " + "documents for this partition key in the same order") + finally: + try: + await db.delete_container(created_container.id) + except Exception: # pylint: disable=broad-except + pass + finally: + await client.close() + + async def test_prefix_partition_key_query_pagination_resume_async(self): + """Query with a partition key prefix on a hierarchical container, + drain page 1, resume from the returned token, and confirm the + resumed pages match a fresh iterator and cover the baseline + documents in the same order. + """ + client = _client() + try: + db = client.get_database_client(DATABASE_ID) + container_id = "FeedRangeMultiPartitionAsyncPrefixPK-" + str(uuid.uuid4()) + created_container = await db.create_container_if_not_exists( + id=container_id, + partition_key=PartitionKey( + path=['/state', '/city', '/zipcode'], + kind=documents.PartitionKind.MultiHash), + offer_throughput=400, + ) + try: + for i in range(30): + await created_container.upsert_item({ + 'id': f'ca-doc-async-{i:03d}', + 'state': 'CA', + 'city': f'city-{i % 5}', + 'zipcode': f'zip-{i:03d}', + 'value': i, + }) + for i in range(6): + await created_container.upsert_item({ + 'id': f'wa-doc-async-{i:03d}', + 'state': 'WA', + 'city': f'city-{i % 2}', + 'zipcode': f'zip-{i:03d}', + 'value': i, + }) + + query = 'SELECT c.id FROM c ORDER BY c.id' + query_iterable = created_container.query_items( + query=query, + partition_key=['CA'], + max_item_count=7, + ) + + pager = query_iterable.by_page() + first_page_iter = await pager.__anext__() + first_page = [it async for it in first_page_iter] + assert first_page + token = pager.continuation_token + assert token, ( + "Expected a non-empty continuation token after page 1 " + "to exercise the resume path") + + expected_remaining_ids: List[str] = [] + async for page in pager: + expected_remaining_ids.extend( + [it['id'] async for it in page]) + + resumed_remaining_ids: List[str] = [] + async for page in query_iterable.by_page(token): + resumed_remaining_ids.extend( + [it['id'] async for it in page]) + + assert expected_remaining_ids == resumed_remaining_ids, ( + "Pages returned after resuming from the continuation token " + "must match the pages returned by draining a fresh iterator") + + baseline_ids: List[str] = [] + async for item in created_container.query_items( + query=query, partition_key=['CA']): + baseline_ids.append(item['id']) + fetched_ids = [item['id'] for item in first_page] + resumed_remaining_ids + assert baseline_ids == fetched_ids, ( + "Page 1 plus the resumed pages must equal the baseline " + "documents for this partition key prefix in the same order") + finally: + try: + await db.delete_container(created_container.id) + except Exception: # pylint: disable=broad-except + pass + finally: + await client.close() + + async def test_explode_iteration_guard_raises_in_query_loop_async(self, monkeypatch): + """If the routing lookup keeps returning multiple children for the same + range, the async query loop must give up after a fixed number of + retries and raise instead of looping forever. + """ + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + p0, p1 = partitions[0], partitions[1] + crossing = _crossing_feed_range(p0[0], p1[1]) + client_conn = container.client_connection + + # Always return two children whose ranges equal the requested head + # so the loop never makes progress and the guard must trip. + async def _always_multi_overlap(_rid, feed_ranges, _opts): + head = feed_ranges[0] + return [ + {"id": "left", "minInclusive": head.min, "maxExclusive": head.max}, + {"id": "right", "minInclusive": head.min, "maxExclusive": head.max}, + ] + + monkeypatch.setattr( + client_conn._routing_map_provider, + "get_overlapping_ranges", + _always_multi_overlap, + ) + monkeypatch.setattr( + "azure.cosmos._routing.feed_range_continuation." + "_MAX_MULTI_OVERLAP_EXPLODE_ITERATIONS", + 2, + ) + + with pytest.raises(RuntimeError) as excinfo: + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + async for page in pager: + _ = [it async for it in page] + assert "split re-resolution" in str(excinfo.value), ( + "Expected the safety-guard error message; " + f"got: {excinfo.value!r}") + finally: + await client.close() + + async def test_no_progress_guard_logs_warning_in_query_loop_async( + self, monkeypatch, caplog + ): + """If the async query loop keeps receiving empty pages with the same + continuation token, a warning must be logged so the situation is + visible to operators. + """ + client = _client() + try: + container = _get_container(client) + partitions = await _sorted_partition_ranges(container) + if len(partitions) < 2: + pytest.skip("Need a container with >= 2 physical partitions") + + p0, p1 = partitions[0], partitions[1] + crossing = _crossing_feed_range(p0[0], p1[1]) + client_conn = container.client_connection + + post_call_count = 0 + + async def _stalled_post(*_args, **_kwargs): + nonlocal post_call_count + post_call_count += 1 + continuation = "stalled-token-async" if post_call_count <= 3 else None + return ( + {"Documents": []}, + {http_constants.HttpHeaders.Continuation: continuation}, + ) + + monkeypatch.setattr( + client_conn, + "_CosmosClientConnection__Post", + _stalled_post, + ) + monkeypatch.setattr( + "azure.cosmos.aio._cosmos_client_connection_async." + "_MAX_CONSECUTIVE_NO_PROGRESS_PAGES", + 2, + ) + + with caplog.at_level( + "WARNING", + logger="azure.cosmos.aio._cosmos_client_connection_async", + ): + pager = container.query_items( + query="SELECT * FROM c", + feed_range=crossing, + max_item_count=PAGE_SIZE, + ).by_page() + async for page in pager: + _ = [it async for it in page] + + assert post_call_count >= 3, ( + "Expected at least 3 stalled __Post calls to trigger the " + f"no-progress guard; got {post_call_count}") + assert any( + "same continuation token" in record.getMessage() + for record in caplog.records + ), ( + "Expected warning log from the async no-progress guard " + "mentioning the unchanged continuation token") + finally: + await client.close() + if __name__ == "__main__": unittest.main() + diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_response_headers.py b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers.py index 692ec85c8df7..f10cf4a67deb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_response_headers.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers.py @@ -1,17 +1,22 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import gc import os import threading +import tracemalloc import unittest import uuid +from collections.abc import Mapping from concurrent.futures import ThreadPoolExecutor, as_completed import pytest +from azure.core.utils import CaseInsensitiveDict import azure.cosmos.cosmos_client as cosmos_client import test_config from azure.cosmos import DatabaseProxy +from azure.cosmos._cosmos_responses import CosmosItemPaged from azure.cosmos.partition_key import PartitionKey @@ -235,6 +240,96 @@ def test_query_response_headers_by_page_iteration(self): finally: self._delete_container_for_test(container_id) + def test_query_response_headers_long_pagination_bounded_memory(self): + """Paging through many pages keeps response-header state bounded and + keeps overall iterator memory growth under a linear safety ceiling. + Document payloads are not retained during measurement so growth reflects + headers and iterator overhead only.""" + container_id = "test_headers_long_pagination_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + num_items = 200 + for i in range(num_items): + created_collection.create_item( + body={"pk": "test", "id": f"item_{i:04d}", "value": i} + ) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=2, + ) + + # Read the first page outside the measurement window so one-time + # client setup is not counted. Count items only; never retain pages. + page_iter = query_iterable.by_page() + first_page = next(page_iter) + first_page_count = sum(1 for _ in first_page) + baseline_headers = query_iterable.get_response_headers() + self.assertIsNotNone(baseline_headers) + + # Collect transient setup objects so they don't show up as growth. + gc.collect() + tracemalloc.start() + try: + snapshot_before = tracemalloc.take_snapshot() + + page_count = 1 + items_total = first_page_count + header_sizes = [len(baseline_headers)] + all_keys_seen = set(baseline_headers.keys()) + + for page in page_iter: + # Count items without keeping a reference to any of them. + items_total += sum(1 for _ in page) + page_count += 1 + headers = query_iterable.get_response_headers() + self.assertIsNotNone(headers) + header_sizes.append(len(headers)) + all_keys_seen.update(headers.keys()) + # Drop the per-page header copy before the next iteration. + del headers + + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + top_stats = snapshot_after.compare_to(snapshot_before, "lineno") + memory_growth = sum(stat.size_diff for stat in top_stats if stat.size_diff > 0) + finally: + tracemalloc.stop() + + # We really paginated and read every item back. + self.assertGreaterEqual(page_count, 20, f"Expected many pages, got {page_count}.") + self.assertEqual(items_total, num_items) + + # The per-page headers dict stays close to the first page's size. + max_header_size = max(header_sizes) + self.assertLessEqual( + max_header_size, len(baseline_headers) + 8, + f"Headers dict grew across pagination (max={max_header_size}, baseline={len(baseline_headers)}).", + ) + + # The set of header names seen across all pages stays bounded. + self.assertLessEqual( + len(all_keys_seen), len(baseline_headers) + 16, + f"Header name set grew across pagination (seen={len(all_keys_seen)}, baseline={len(baseline_headers)}).", + ) + + # Linear safety ceiling so the check scales if num_items changes. + # This is a "no catastrophic leak" guard, not a strict O(1) proof. + # Observed per-page overhead on a live account is around 24-29 KiB; + # 48 KiB gives roughly 2x headroom and still catches a real leak. + max_per_page_bytes = 48 * 1024 + ceiling_bytes = max_per_page_bytes * page_count + self.assertLess( + memory_growth, ceiling_bytes, + f"Iterator memory grew by {memory_growth} bytes over {page_count} pages; " + f"exceeded linear safety ceiling {ceiling_bytes} bytes.", + ) + + finally: + self._delete_container_for_test(container_id) + def test_query_response_headers_returns_copies(self): """Test that get_response_headers returns copies, not references.""" container_id = "test_headers_copies_" + str(uuid.uuid4()) @@ -267,15 +362,14 @@ def test_query_response_headers_returns_copies(self): self._delete_container_for_test(container_id) def test_query_response_headers_thread_safety(self): - """Test that response headers are captured correctly when multiple queries run concurrently. - - This test verifies that each query operation captures its own headers independently, - without interference from concurrent queries. This is the key thread-safety guarantee. - """ + """Each concurrent query must see only its own response headers. + Each worker installs a response_hook that records every page it + sees; the iterator's final headers must match that worker's own + last hook payload.""" container_id = "test_headers_thread_" + str(uuid.uuid4()) created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) try: - # Create items with different partition keys to ensure different queries + # Different partition keys so different threads run different queries. num_partitions = 5 items_per_partition = 10 for pk_idx in range(num_partitions): @@ -284,7 +378,6 @@ def test_query_response_headers_thread_safety(self): body={"pk": f"partition_{pk_idx}", "id": f"item_{pk_idx}_{item_idx}", "value": item_idx} ) - # Results storage - each thread will store its query results here results = {} errors = [] lock = threading.Lock() @@ -292,16 +385,22 @@ def test_query_response_headers_thread_safety(self): def run_query(partition_key: str, thread_id: int): """Run a query and capture its headers.""" try: + # Per-thread hook: records what this iterator received. + captured_pages = [] + + def hook(headers, _result): + captured_pages.append(dict(headers)) + query = "SELECT * FROM c WHERE c.pk = @pk" query_iterable = created_collection.query_items( query=query, parameters=[{"name": "@pk", "value": partition_key}], partition_key=partition_key, - max_item_count=2, # Small page size to ensure multiple pages - populate_query_metrics=True + max_item_count=2, + populate_query_metrics=True, + response_hook=hook, ) - # Consume all items items = list(query_iterable) headers = query_iterable.get_response_headers() @@ -309,96 +408,128 @@ def run_query(partition_key: str, thread_id: int): results[thread_id] = { "partition_key": partition_key, "item_count": len(items), - "headers": headers + "headers": headers, + "captured_pages": captured_pages, } except Exception as e: with lock: errors.append((thread_id, str(e))) - # Run multiple queries concurrently num_threads = 10 with ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [] for i in range(num_threads): partition_key = f"partition_{i % num_partitions}" futures.append(executor.submit(run_query, partition_key, i)) - - # Wait for all to complete for future in as_completed(futures): - future.result() # This will raise if the thread raised + future.result() - # Verify no errors occurred self.assertEqual(len(errors), 0, f"Errors occurred: {errors}") - - # Verify all threads got results self.assertEqual(len(results), num_threads) - # Verify each thread captured headers correctly for thread_id, result in results.items(): self.assertEqual(result["item_count"], items_per_partition, f"Thread {thread_id} got wrong item count") self.assertIn("x-ms-request-charge", result["headers"], f"Thread {thread_id} headers missing x-ms-request-charge") - # Verify that different threads have independent header dicts + # The iterator's final headers must match this thread's own + # last hook payload, otherwise header state is shared. + self.assertGreater( + len(result["captured_pages"]), 0, + f"Thread {thread_id} hook never fired", + ) + last_hook = result["captured_pages"][-1] + self.assertEqual( + result["headers"].get("x-ms-activity-id"), + last_hook.get("x-ms-activity-id"), + f"Thread {thread_id} got headers that did not come from its own response.", + ) + self.assertEqual( + result["headers"].get("x-ms-request-charge"), + last_hook.get("x-ms-request-charge"), + f"Thread {thread_id} got a request charge that did not come from its own response.", + ) + + # Each thread holds its own headers dict, not a shared object. thread_ids = list(results.keys()) if len(thread_ids) >= 2: self.assertIsNot(results[thread_ids[0]]["headers"], results[thread_ids[1]]["headers"]) + # Different partitions, so activity ids must not all be the same. + activity_ids = { + r["headers"].get("x-ms-activity-id") for r in results.values() + } + self.assertGreater( + len(activity_ids), 1, + "All threads got the same activity id, which means header state is shared.", + ) + finally: self._delete_container_for_test(container_id) def test_query_response_headers_concurrent_same_container(self): - """Test concurrent queries on the same container with overlapping execution. - - This test specifically targets the race condition that would occur if headers - were captured from a shared client.last_response_headers after fetch_next_block(). - """ + """All threads run the same query against the same partition. Item counts + and request charges look identical across threads, so isolation is checked + on x-ms-activity-id, which the service assigns per request.""" container_id = "test_headers_concurrent_" + str(uuid.uuid4()) created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) try: - # Create enough items to ensure multiple pages for i in range(50): created_collection.create_item(body={"pk": "shared", "id": f"item_{i}", "value": i}) - barrier = threading.Barrier(5) # Synchronize 5 threads + barrier = threading.Barrier(5) results = {} + errors = [] lock = threading.Lock() def run_synchronized_query(thread_id: int): """Run a query with synchronization to maximize overlap.""" - query_iterable = created_collection.query_items( - query="SELECT * FROM c WHERE c.pk = @pk", - parameters=[{"name": "@pk", "value": "shared"}], - partition_key="shared", - max_item_count=5, # Small pages = more fetches - populate_query_metrics=True - ) + try: + # Per-thread hook: records what this iterator received. + captured_pages = [] - # Wait for all threads to be ready - barrier.wait() + def hook(headers, _result): + captured_pages.append(dict(headers)) - # Now all threads fetch concurrently - items = list(query_iterable) - headers = query_iterable.get_response_headers() + query_iterable = created_collection.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "shared"}], + partition_key="shared", + max_item_count=5, + populate_query_metrics=True, + response_hook=hook, + ) + + # Wait so all threads start fetching at the same moment. + barrier.wait() - with lock: - results[thread_id] = { - "item_count": len(items), - "request_charge": float(headers.get("x-ms-request-charge", 0)) - } + items = list(query_iterable) + headers = query_iterable.get_response_headers() + + with lock: + results[thread_id] = { + "item_count": len(items), + "request_charge": float(headers.get("x-ms-request-charge", 0)), + "headers": headers, + "captured_pages": captured_pages, + } + except Exception as e: + with lock: + errors.append((thread_id, str(e))) threads = [] for i in range(5): t = threading.Thread(target=run_synchronized_query, args=(i,)) threads.append(t) t.start() - for t in threads: t.join(timeout=60) - # Verify all threads completed and got correct results + # Surface any worker exceptions in the main thread. + self.assertEqual(len(errors), 0, f"Worker errors: {errors}") + self.assertEqual(len(results), 5) for thread_id, result in results.items(): self.assertEqual(result["item_count"], 50, @@ -406,6 +537,420 @@ def run_synchronized_query(thread_id: int): self.assertGreater(result["request_charge"], 0, f"Thread {thread_id} should have positive request charge") + # The iterator's final headers must match this thread's own + # last hook payload (compared on activity id, unique per request). + self.assertGreater( + len(result["captured_pages"]), 0, + f"Thread {thread_id} hook never fired", + ) + last_hook = result["captured_pages"][-1] + self.assertEqual( + result["headers"].get("x-ms-activity-id"), + last_hook.get("x-ms-activity-id"), + f"Thread {thread_id} got headers that did not come from its own response.", + ) + + # Each thread sent its own requests, so all activity ids must be distinct. + final_ids = [r["headers"].get("x-ms-activity-id") for r in results.values()] + self.assertEqual( + len(set(final_ids)), len(final_ids), + f"Two or more threads got the same activity id, which means header state is shared. " + f"Ids: {final_ids}", + ) + + finally: + self._delete_container_for_test(container_id) + + def test_query_response_headers_before_iteration_returns_empty(self): + """Headers must be an empty dict when no page has been fetched yet.""" + container_id = "test_headers_preiter_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + created_collection.create_item(body={"pk": "test", "id": "item_1"}) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c", + partition_key="test", + ) + + # No iteration yet, so the dict must be empty (and not None). + headers = query_iterable.get_response_headers() + self.assertIsInstance(headers, CaseInsensitiveDict) + self.assertEqual(len(headers), 0) + finally: + self._delete_container_for_test(container_id) + + def test_query_response_headers_match_last_response_hook_invocation(self): + """The headers returned after iteration must match the headers handed + to the last response_hook call.""" + container_id = "test_headers_hookparity_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + for i in range(12): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + captured_pages = [] + + def hook(headers, _result): + # Snapshot the headers handed to the hook for every page. + captured_pages.append(dict(headers)) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=4, + response_hook=hook, + ) + + items = list(query_iterable) + self.assertEqual(len(items), 12) + + # At least one page was fetched, so the hook fired at least once. + self.assertGreater(len(captured_pages), 0) + + # The getter must return what the last hook invocation saw. + final_headers = query_iterable.get_response_headers() + self.assertEqual( + final_headers["x-ms-request-charge"], + captured_pages[-1]["x-ms-request-charge"], + ) + self.assertEqual( + final_headers["x-ms-activity-id"], + captured_pages[-1]["x-ms-activity-id"], + ) + finally: + self._delete_container_for_test(container_id) + + def test_query_response_headers_return_type_is_dict_not_list(self): + """The getter returns a single dict (not a list), and the old + get_last_response_headers method is not available.""" + container_id = "test_headers_returntype_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + created_collection.create_item(body={"pk": "test", "id": "item_1"}) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c", + partition_key="test", + ) + self.assertIsInstance(query_iterable, CosmosItemPaged) + + list(query_iterable) + headers = query_iterable.get_response_headers() + + self.assertIsInstance(headers, CaseInsensitiveDict) + self.assertNotIsInstance(headers, list) + self.assertFalse(hasattr(query_iterable, "get_last_response_headers")) + finally: + self._delete_container_for_test(container_id) + + def test_read_all_items_response_headers(self): + """read_all_items pagers expose the same headers contract as queries.""" + container_id = "test_headers_readall_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + for i in range(8): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}"}) + + paged = created_collection.read_all_items(max_item_count=3) + items = list(paged) + self.assertEqual(len(items), 8) + + self.assertTrue( + hasattr(paged, "get_response_headers"), + "read_all_items pager must expose get_response_headers", + ) + headers = paged.get_response_headers() + self.assertIsInstance(headers, CaseInsensitiveDict) + self.assertIn("x-ms-request-charge", headers) + finally: + self._delete_container_for_test(container_id) + + + def test_response_hook_parity_query_items_change_feed(self): + """Headers handed to the response_hook on query_items_change_feed + must match the pager's get_response_headers() when the pager + exposes one (some change-feed pager flavors don't).""" + container_id = "test_hookparity_cf_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + for i in range(6): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}"}) + + captured = [] + + def hook(headers, _result): + captured.append(dict(headers)) + + paged = created_collection.query_items_change_feed( + start_time="Beginning", + max_item_count=2, + response_hook=hook, + ) + items = list(paged) + self.assertGreaterEqual(len(items), 6) + self.assertGreater(len(captured), 0, "response_hook must fire at least once") + + # change_feed pages always carry a request charge in the hook payload + self.assertIn("x-ms-request-charge", captured[-1]) + + # When the change-feed pager exposes get_response_headers(), + # the last page's headers must match the last hook invocation. + if hasattr(paged, "get_response_headers"): + final_headers = paged.get_response_headers() + self.assertIn("x-ms-request-charge", final_headers) + self.assertEqual( + final_headers["x-ms-request-charge"], + captured[-1]["x-ms-request-charge"], + ) + finally: + self._delete_container_for_test(container_id) + + def test_response_hook_parity_point_ops(self): + """For every point CRUD method, the headers handed to the + response_hook must match the returned wrapper's + get_response_headers(). delete_item returns no body.""" + container_id = "test_hookparity_point_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + captured = [] + + def hook(headers, _body): + captured.append(dict(headers)) + + # create_item + captured.clear() + created = created_collection.create_item( + body={"pk": "p", "id": "doc1", "value": 1}, + response_hook=hook, + ) + self.assertEqual(len(captured), 1) + self.assertEqual( + created.get_response_headers()["x-ms-request-charge"], + captured[0]["x-ms-request-charge"], + ) + + # read_item + captured.clear() + read = created_collection.read_item( + item="doc1", partition_key="p", response_hook=hook, + ) + self.assertEqual(len(captured), 1) + self.assertEqual( + read.get_response_headers()["x-ms-request-charge"], + captured[0]["x-ms-request-charge"], + ) + + # replace_item + captured.clear() + replaced = created_collection.replace_item( + item="doc1", + body={"pk": "p", "id": "doc1", "value": 2}, + response_hook=hook, + ) + self.assertEqual(len(captured), 1) + self.assertEqual( + replaced.get_response_headers()["x-ms-request-charge"], + captured[0]["x-ms-request-charge"], + ) + + # upsert_item + captured.clear() + upserted = created_collection.upsert_item( + body={"pk": "p", "id": "doc1", "value": 3}, + response_hook=hook, + ) + self.assertEqual(len(captured), 1) + self.assertEqual( + upserted.get_response_headers()["x-ms-request-charge"], + captured[0]["x-ms-request-charge"], + ) + + # delete_item: returns None, but the hook still fires once with + # the response headers from the DELETE round-trip. + captured.clear() + created_collection.delete_item( + item="doc1", partition_key="p", response_hook=hook, + ) + self.assertEqual(len(captured), 1) + self.assertIn("x-ms-request-charge", captured[0]) + finally: + self._delete_container_for_test(container_id) + + def test_response_hook_parity_query_databases_and_query_containers(self): + """query_databases and query_containers fire their response_hook + exactly once after the iterable is returned (the hook fires before + the caller actually walks the iterable). ``query_databases`` passes + a one-arg hook (just headers); ``query_containers`` passes two args + (headers, paged-iterable). Accept both via ``*args``. + + We pin: the hook fires exactly once for each surface, and the + captured payload is a Mapping. We deliberately do *not* pin a + specific header like ``x-ms-request-charge`` because the hook + fires with ``client_connection.last_response_headers`` at the + moment the pager is constructed, which can be a stale value from + an earlier request (e.g. the account probe) when nothing has yet + flowed through this specific query.""" + + captured_db = [] + + def hook_db(*args): + # query_databases: hook(headers); query_containers: hook(headers, paged) + captured_db.append(args[0]) + + db_pager = self.client.query_databases( + query="SELECT * FROM root r", + response_hook=hook_db, + ) + _ = list(db_pager) + self.assertEqual( + len(captured_db), 1, + "query_databases response_hook must fire exactly once", + ) + self.assertIsInstance(captured_db[0], Mapping) + + captured_c = [] + + def hook_c(*args): + captured_c.append(args[0]) + + c_pager = self.created_db.query_containers( + query="SELECT * FROM root r", + response_hook=hook_c, + ) + _ = list(c_pager) + self.assertEqual( + len(captured_c), 1, + "query_containers response_hook must fire exactly once", + ) + self.assertIsInstance(captured_c[0], Mapping) + + def test_response_hook_fires_at_least_once_for_every_paged_surface(self): + """A response_hook attached to any paged surface must fire at least + once, and every captured payload must be a Mapping.""" + + container_id = "test_hookfires_paged_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test(container_id, PartitionKey(path="/pk")) + try: + for i in range(6): + created_collection.create_item(body={"pk": "test", "id": f"item_{i}"}) + + def _run(surface_name, build_pager): + captured = [] + + def hook(headers, _result): + captured.append(dict(headers)) + + pager = build_pager(hook) + list(pager) + self.assertGreater( + len(captured), + 0, + f"{surface_name} response_hook never fired", + ) + for payload in captured: + self.assertIsInstance( + payload, + Mapping, + f"{surface_name} hook received non-Mapping payload", + ) + + _run( + "query_items", + lambda h: created_collection.query_items( + query="SELECT * FROM c", + partition_key="test", + response_hook=h, + ), + ) + _run( + "query_items_change_feed", + lambda h: created_collection.query_items_change_feed( + start_time="Beginning", + max_item_count=2, + response_hook=h, + ), + ) + finally: + self._delete_container_for_test(container_id) + + + def test_response_hook_parity_patch_item(self): + # patch_item must fire its response_hook once with headers that + # match the returned wrapper on request charge and activity id. + container_id = "test_hookparity_patch_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test( + container_id, PartitionKey(path="/pk") + ) + try: + created_collection.create_item( + body={"pk": "p", "id": "doc-patch", "value": 1} + ) + + captured = [] + + def hook(headers, _body): + captured.append(dict(headers)) + + patched = created_collection.patch_item( + item="doc-patch", + partition_key="p", + patch_operations=[ + {"op": "replace", "path": "/value", "value": 99} + ], + response_hook=hook, + ) + + self.assertEqual(len(captured), 1) + patched_headers = patched.get_response_headers() + self.assertEqual( + patched_headers["x-ms-request-charge"], + captured[0]["x-ms-request-charge"], + ) + self.assertEqual( + patched_headers["x-ms-activity-id"], + captured[0]["x-ms-activity-id"], + ) + self.assertEqual(patched["value"], 99) + finally: + self._delete_container_for_test(container_id) + + def test_response_hook_parity_execute_item_batch(self): + # execute_item_batch must fire its response_hook once with + # headers that match the returned wrapper. + container_id = "test_hookparity_batch_" + str(uuid.uuid4()) + created_collection = self._create_container_for_test( + container_id, PartitionKey(path="/pk") + ) + try: + captured = [] + + def hook(headers, _body): + captured.append(dict(headers)) + + batch_ops = [ + ("create", ({"pk": "p", "id": "batch-doc-1", "value": 1},)), + ("upsert", ({"pk": "p", "id": "batch-doc-2", "value": 2},)), + ] + + result = created_collection.execute_item_batch( + batch_operations=batch_ops, + partition_key="p", + response_hook=hook, + ) + + self.assertEqual(len(captured), 1) + result_headers = result.get_response_headers() + self.assertEqual( + result_headers["x-ms-request-charge"], + captured[0]["x-ms-request-charge"], + ) + self.assertEqual( + result_headers["x-ms-activity-id"], + captured[0]["x-ms-activity-id"], + ) + self.assertEqual(len(result), 2) finally: self._delete_container_for_test(container_id) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_response_headers_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers_async.py index 7554dc722b7f..c2525b6d513c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_response_headers_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_response_headers_async.py @@ -2,13 +2,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import asyncio +import gc import os +import tracemalloc import unittest import uuid +from collections.abc import Mapping import pytest +from azure.core.utils import CaseInsensitiveDict import test_config +from azure.cosmos._cosmos_responses import CosmosAsyncItemPaged from azure.cosmos.aio import CosmosClient, DatabaseProxy from azure.cosmos.partition_key import PartitionKey @@ -251,6 +256,98 @@ async def test_query_response_headers_by_page_iteration_async(self): finally: await self._delete_container_for_test(cid) + async def test_query_response_headers_long_pagination_bounded_memory_async(self): + """Paging through many pages keeps response-header state bounded and + keeps overall iterator memory growth under a linear safety ceiling. + Document payloads are not retained during measurement so growth reflects + headers and iterator overhead only.""" + cid = "test_headers_long_pagination_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + num_items = 200 + for i in range(num_items): + await created_collection.create_item( + body={"pk": "test", "id": f"item_{i:04d}", "value": i} + ) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=2, + ) + + # Read the first page outside the measurement window so one-time + # client setup is not counted. Count items only; never retain pages. + page_iter = query_iterable.by_page() + first_page = await page_iter.__anext__() + first_page_count = 0 + async for _ in first_page: + first_page_count += 1 + baseline_headers = query_iterable.get_response_headers() + assert baseline_headers is not None + + # Collect transient setup objects so they don't show up as growth. + gc.collect() + tracemalloc.start() + try: + snapshot_before = tracemalloc.take_snapshot() + + page_count = 1 + items_total = first_page_count + header_sizes = [len(baseline_headers)] + all_keys_seen = set(baseline_headers.keys()) + + async for page in page_iter: + # Count items without keeping a reference to any of them. + async for _ in page: + items_total += 1 + page_count += 1 + headers = query_iterable.get_response_headers() + assert headers is not None + header_sizes.append(len(headers)) + all_keys_seen.update(headers.keys()) + # Drop the per-page header copy before the next iteration. + del headers + + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + top_stats = snapshot_after.compare_to(snapshot_before, "lineno") + memory_growth = sum(stat.size_diff for stat in top_stats if stat.size_diff > 0) + finally: + tracemalloc.stop() + + # We really paginated and read every item back. + assert page_count >= 20, f"Expected many pages, got {page_count}." + assert items_total == num_items + + # The per-page headers dict stays close to the first page's size. + max_header_size = max(header_sizes) + assert max_header_size <= len(baseline_headers) + 8, ( + f"Headers dict grew across pagination (max={max_header_size}, " + f"baseline={len(baseline_headers)})." + ) + + # The set of header names seen across all pages stays bounded. + assert len(all_keys_seen) <= len(baseline_headers) + 16, ( + f"Header name set grew across pagination (seen={len(all_keys_seen)}, " + f"baseline={len(baseline_headers)})." + ) + + # Linear safety ceiling so the check scales if num_items changes. + # This is a "no catastrophic leak" guard, not a strict O(1) proof. + # Observed per-page overhead on a live account is around 24-29 KiB; + # 48 KiB gives roughly 2x headroom and still catches a real leak. + max_per_page_bytes = 48 * 1024 + ceiling_bytes = max_per_page_bytes * page_count + assert memory_growth < ceiling_bytes, ( + f"Iterator memory grew by {memory_growth} bytes over {page_count} pages; " + f"exceeded linear safety ceiling {ceiling_bytes} bytes." + ) + + finally: + await self._delete_container_for_test(cid) + async def test_query_response_headers_returns_copies_async(self): """Test that get_response_headers returns copies, not references.""" cid = "test_headers_copies_async_" + str(uuid.uuid4()) @@ -284,16 +381,14 @@ async def test_query_response_headers_returns_copies_async(self): await self._delete_container_for_test(cid) async def test_query_response_headers_concurrent_async(self): - """Test that response headers are captured correctly when multiple async queries run concurrently. - - This test verifies that each query operation captures its own headers independently, - without interference from concurrent queries. This is the key thread-safety guarantee. - """ + """Each concurrent query must see only its own response headers. + Each task installs a response_hook that records every page it sees; + the iterator's final headers must match that task's own last hook payload.""" cid = "test_headers_concurrent_async_" + str(uuid.uuid4()) created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) try: - # Create items with different partition keys + # Different partition keys so different tasks run different queries. num_partitions = 5 items_per_partition = 10 for pk_idx in range(num_partitions): @@ -304,16 +399,22 @@ async def test_query_response_headers_concurrent_async(self): async def run_query(partition_key: str, query_id: int): """Run a query and capture its headers.""" + # Per-task hook: records what this iterator received. + captured_pages = [] + + def hook(headers, _result): + captured_pages.append(dict(headers)) + query = "SELECT * FROM c WHERE c.pk = @pk" query_iterable = created_collection.query_items( query=query, parameters=[{"name": "@pk", "value": partition_key}], partition_key=partition_key, - max_item_count=2, # Small page size to ensure multiple pages - populate_query_metrics=True + max_item_count=2, + populate_query_metrics=True, + response_hook=hook, ) - # Consume all items items = [item async for item in query_iterable] headers = query_iterable.get_response_headers() @@ -321,10 +422,10 @@ async def run_query(partition_key: str, query_id: int): "query_id": query_id, "partition_key": partition_key, "item_count": len(items), - "headers": headers + "headers": headers, + "captured_pages": captured_pages, } - # Run multiple queries concurrently using asyncio.gather num_queries = 10 tasks = [] for i in range(num_queries): @@ -333,82 +434,96 @@ async def run_query(partition_key: str, query_id: int): results = await asyncio.gather(*tasks) - # Verify all queries got results assert len(results) == num_queries - # Verify each query captured headers correctly for result in results: assert result["item_count"] == items_per_partition, \ f"Query {result['query_id']} got wrong item count" assert "x-ms-request-charge" in result["headers"], \ f"Query {result['query_id']} headers missing x-ms-request-charge" - # Verify that different queries have independent header dicts + # The iterator's final headers must match this task's own + # last hook payload, otherwise header state is shared. + assert len(result["captured_pages"]) > 0, \ + f"Query {result['query_id']} hook never fired" + last_hook = result["captured_pages"][-1] + assert result["headers"].get("x-ms-activity-id") == last_hook.get("x-ms-activity-id"), ( + f"Query {result['query_id']} got headers that did not come from its own response." + ) + assert result["headers"].get("x-ms-request-charge") == last_hook.get("x-ms-request-charge"), ( + f"Query {result['query_id']} got a request charge that did not come from its own response." + ) + + # Each task holds its own headers dict, not a shared object. if len(results) >= 2: assert results[0]["headers"] is not results[1]["headers"] + # Different partitions, so activity ids must not all be the same. + activity_ids = {r["headers"].get("x-ms-activity-id") for r in results} + assert len(activity_ids) > 1, ( + "All tasks got the same activity id, which means header state is shared." + ) + finally: await self._delete_container_for_test(cid) async def test_query_response_headers_high_concurrency_async(self): - """Test with high concurrency to stress-test the thread-safety. - - This test specifically targets the race condition that would occur if headers - were captured from a shared client.last_response_headers after fetch operations. - """ + """Many tasks run the same query against the same partition at the same time. + Item counts and request charges look identical across tasks, so isolation is + checked on x-ms-activity-id, which the service assigns per request.""" cid = "test_headers_stress_async_" + str(uuid.uuid4()) created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) try: - # Create enough items to ensure multiple pages for i in range(50): await created_collection.create_item( body={"pk": "shared", "id": f"item_{i}", "value": i} ) - # Use an event to synchronize all coroutines start_event = asyncio.Event() async def run_synchronized_query(query_id: int): """Run a query with synchronization to maximize overlap.""" + # Per-task hook: records what this iterator received. + captured_pages = [] + + def hook(headers, _result): + captured_pages.append(dict(headers)) + query_iterable = created_collection.query_items( query="SELECT * FROM c WHERE c.pk = @pk", parameters=[{"name": "@pk", "value": "shared"}], partition_key="shared", - max_item_count=5, # Small pages = more fetches - populate_query_metrics=True + max_item_count=5, + populate_query_metrics=True, + response_hook=hook, ) - # Wait for the start signal + # Wait so all tasks start fetching at the same moment. await start_event.wait() - # Now all coroutines fetch concurrently items = [item async for item in query_iterable] headers = query_iterable.get_response_headers() return { "query_id": query_id, "item_count": len(items), - "request_charge": float(headers.get("x-ms-request-charge", 0)) + "request_charge": float(headers.get("x-ms-request-charge", 0)), + "headers": headers, + "captured_pages": captured_pages, } - # Create tasks but don't start fetching yet num_concurrent = 20 tasks = [run_synchronized_query(i) for i in range(num_concurrent)] - # Schedule all tasks gathered = asyncio.gather(*tasks) - # Give tasks time to reach the wait point + # Give tasks time to reach the wait point. await asyncio.sleep(0.1) - - # Signal all to start simultaneously start_event.set() - # Wait for all to complete results = await gathered - # Verify all queries completed correctly assert len(results) == num_concurrent for result in results: assert result["item_count"] == 50, \ @@ -416,6 +531,386 @@ async def run_synchronized_query(query_id: int): assert result["request_charge"] > 0, \ f"Query {result['query_id']} should have positive request charge" + # The iterator's final headers must match this task's own + # last hook payload (compared on activity id, unique per request). + assert len(result["captured_pages"]) > 0, \ + f"Query {result['query_id']} hook never fired" + last_hook = result["captured_pages"][-1] + assert result["headers"].get("x-ms-activity-id") == last_hook.get("x-ms-activity-id"), ( + f"Query {result['query_id']} got headers that did not come from its own response." + ) + + # Each task sent its own requests, so all activity ids must be distinct. + final_ids = [r["headers"].get("x-ms-activity-id") for r in results] + assert len(set(final_ids)) == len(final_ids), ( + f"Two or more tasks got the same activity id, which means header state is shared. " + f"Ids: {final_ids}" + ) + + finally: + await self._delete_container_for_test(cid) + + async def test_query_response_headers_before_iteration_returns_empty_async(self): + """Headers must be an empty dict when no async page has been fetched yet.""" + cid = "test_headers_preiter_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + await created_collection.create_item(body={"pk": "test", "id": "item_1"}) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c", + partition_key="test", + ) + + # No iteration yet, so the dict must be empty (and not None). + headers = query_iterable.get_response_headers() + assert isinstance(headers, CaseInsensitiveDict) + assert len(headers) == 0 + finally: + await self._delete_container_for_test(cid) + + async def test_query_response_headers_match_last_response_hook_invocation_async(self): + """The headers returned after async iteration must match the headers + handed to the last response_hook call.""" + cid = "test_headers_hookparity_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + for i in range(12): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}", "value": i}) + + captured_pages = [] + + def hook(headers, _result): + captured_pages.append(dict(headers)) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "test"}], + partition_key="test", + max_item_count=4, + response_hook=hook, + ) + + items = [item async for item in query_iterable] + assert len(items) == 12 + assert len(captured_pages) > 0 + + final_headers = query_iterable.get_response_headers() + assert final_headers["x-ms-request-charge"] == captured_pages[-1]["x-ms-request-charge"] + assert final_headers["x-ms-activity-id"] == captured_pages[-1]["x-ms-activity-id"] + finally: + await self._delete_container_for_test(cid) + + async def test_query_response_headers_return_type_is_dict_not_list_async(self): + """The getter returns a single dict (not a list), and the old + get_last_response_headers method is not available.""" + cid = "test_headers_returntype_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + await created_collection.create_item(body={"pk": "test", "id": "item_1"}) + + query_iterable = created_collection.query_items( + query="SELECT * FROM c", + partition_key="test", + ) + assert isinstance(query_iterable, CosmosAsyncItemPaged) + + _ = [item async for item in query_iterable] + headers = query_iterable.get_response_headers() + + assert isinstance(headers, CaseInsensitiveDict) + assert not isinstance(headers, list) + assert not hasattr(query_iterable, "get_last_response_headers") + finally: + await self._delete_container_for_test(cid) + + async def test_read_all_items_response_headers_async(self): + """read_all_items pagers expose the same headers contract as queries.""" + cid = "test_headers_readall_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + for i in range(8): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}"}) + + paged = created_collection.read_all_items(max_item_count=3) + items = [item async for item in paged] + assert len(items) == 8 + + assert hasattr( + paged, "get_response_headers" + ), "read_all_items pager must expose get_response_headers" + headers = paged.get_response_headers() + assert isinstance(headers, CaseInsensitiveDict) + assert "x-ms-request-charge" in headers + finally: + await self._delete_container_for_test(cid) + + + async def test_response_hook_parity_query_items_change_feed_async(self): + """Headers handed to the response_hook on async change feed must + match the pager's get_response_headers() when the pager exposes + one (some change-feed pager flavors don't).""" + cid = "test_hookparity_cf_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + for i in range(6): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}"}) + + captured = [] + + def hook(headers, _result): + captured.append(dict(headers)) + + paged = created_collection.query_items_change_feed( + start_time="Beginning", + max_item_count=2, + response_hook=hook, + ) + items = [item async for item in paged] + assert len(items) >= 6 + assert len(captured) > 0, "response_hook must fire at least once" + + assert "x-ms-request-charge" in captured[-1] + + if hasattr(paged, "get_response_headers"): + final_headers = paged.get_response_headers() + assert "x-ms-request-charge" in final_headers + assert ( + final_headers["x-ms-request-charge"] + == captured[-1]["x-ms-request-charge"] + ) + finally: + await self._delete_container_for_test(cid) + + async def test_response_hook_parity_point_ops_async(self): + """For every async point CRUD method, the headers handed to the + response_hook must match the returned wrapper's + get_response_headers(). delete_item returns no body.""" + cid = "test_hookparity_point_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + captured = [] + + def hook(headers, _body): + captured.append(dict(headers)) + + # create_item + captured.clear() + created = await created_collection.create_item( + body={"pk": "p", "id": "doc1", "value": 1}, + response_hook=hook, + ) + assert len(captured) == 1 + assert ( + created.get_response_headers()["x-ms-request-charge"] + == captured[0]["x-ms-request-charge"] + ) + + # read_item + captured.clear() + read = await created_collection.read_item( + item="doc1", partition_key="p", response_hook=hook, + ) + assert len(captured) == 1 + assert ( + read.get_response_headers()["x-ms-request-charge"] + == captured[0]["x-ms-request-charge"] + ) + + # replace_item + captured.clear() + replaced = await created_collection.replace_item( + item="doc1", + body={"pk": "p", "id": "doc1", "value": 2}, + response_hook=hook, + ) + assert len(captured) == 1 + assert ( + replaced.get_response_headers()["x-ms-request-charge"] + == captured[0]["x-ms-request-charge"] + ) + + # upsert_item + captured.clear() + upserted = await created_collection.upsert_item( + body={"pk": "p", "id": "doc1", "value": 3}, + response_hook=hook, + ) + assert len(captured) == 1 + assert ( + upserted.get_response_headers()["x-ms-request-charge"] + == captured[0]["x-ms-request-charge"] + ) + + # delete_item: returns None, but the hook still fires once. + captured.clear() + await created_collection.delete_item( + item="doc1", partition_key="p", response_hook=hook, + ) + assert len(captured) == 1 + assert "x-ms-request-charge" in captured[0] + finally: + await self._delete_container_for_test(cid) + + async def test_response_hook_parity_query_databases_and_query_containers_async(self): + """Async query_databases and query_containers fire their hook once + after the iterable is returned (the hook fires before the caller + walks the iterable). ``query_databases`` passes one arg (just + headers); ``query_containers`` passes two (headers, paged); accept + both via ``*args``. + + We pin: the hook fires exactly once for each surface, and the + captured payload is a Mapping. We deliberately do *not* pin a + specific header like ``x-ms-request-charge`` because the hook + fires with ``client_connection.last_response_headers`` at the + moment the pager is constructed, which can be a stale value from + an earlier request when nothing has yet flowed through this + specific query.""" + + captured_db = [] + + def hook_db(*args): + captured_db.append(args[0]) + + db_pager = self.client.query_databases( + query="SELECT * FROM root r", + response_hook=hook_db, + ) + _ = [item async for item in db_pager] + assert len(captured_db) == 1, "query_databases hook must fire exactly once" + assert isinstance(captured_db[0], Mapping) + + captured_c = [] + + def hook_c(*args): + captured_c.append(args[0]) + + c_pager = self.created_db.query_containers( + query="SELECT * FROM root r", + response_hook=hook_c, + ) + _ = [item async for item in c_pager] + assert len(captured_c) == 1, "query_containers hook must fire exactly once" + assert isinstance(captured_c[0], Mapping) + + async def test_response_hook_fires_at_least_once_for_every_paged_surface_async(self): + """A response_hook attached to any async paged surface must fire at + least once, and every captured payload must be a Mapping.""" + + cid = "test_hookfires_paged_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test(cid, PartitionKey(path="/pk")) + try: + for i in range(6): + await created_collection.create_item(body={"pk": "test", "id": f"item_{i}"}) + + async def _run(surface_name, build_pager): + captured = [] + + def hook(headers, _result): + captured.append(dict(headers)) + + pager = build_pager(hook) + _ = [item async for item in pager] + assert len(captured) > 0, f"{surface_name} response_hook never fired" + for payload in captured: + assert isinstance(payload, Mapping), \ + f"{surface_name} hook received non-Mapping payload" + + await _run( + "query_items", + lambda h: created_collection.query_items( + query="SELECT * FROM c", + partition_key="test", + response_hook=h, + ), + ) + await _run( + "query_items_change_feed", + lambda h: created_collection.query_items_change_feed( + start_time="Beginning", + max_item_count=2, + response_hook=h, + ), + ) + finally: + await self._delete_container_for_test(cid) + + + async def test_response_hook_parity_patch_item_async(self): + # Async patch_item must fire its response_hook once with headers + # that match the returned wrapper. + cid = "test_hookparity_patch_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test( + cid, PartitionKey(path="/pk") + ) + try: + await created_collection.create_item( + body={"pk": "p", "id": "doc-patch", "value": 1} + ) + + captured = [] + + def hook(headers, _body): + captured.append(dict(headers)) + + patched = await created_collection.patch_item( + item="doc-patch", + partition_key="p", + patch_operations=[ + {"op": "replace", "path": "/value", "value": 99} + ], + response_hook=hook, + ) + + assert len(captured) == 1 + patched_headers = patched.get_response_headers() + assert ( + patched_headers["x-ms-request-charge"] + == captured[0]["x-ms-request-charge"] + ) + assert ( + patched_headers["x-ms-activity-id"] + == captured[0]["x-ms-activity-id"] + ) + assert patched["value"] == 99 + finally: + await self._delete_container_for_test(cid) + + async def test_response_hook_parity_execute_item_batch_async(self): + # Async execute_item_batch must fire its response_hook once with + # headers that match the returned wrapper. + cid = "test_hookparity_batch_async_" + str(uuid.uuid4()) + created_collection = await self._create_container_for_test( + cid, PartitionKey(path="/pk") + ) + try: + captured = [] + + def hook(headers, _body): + captured.append(dict(headers)) + + batch_ops = [ + ("create", ({"pk": "p", "id": "batch-doc-1", "value": 1},)), + ("upsert", ({"pk": "p", "id": "batch-doc-2", "value": 2},)), + ] + + result = await created_collection.execute_item_batch( + batch_operations=batch_ops, + partition_key="p", + response_hook=hook, + ) + + assert len(captured) == 1 + result_headers = result.get_response_headers() + assert ( + result_headers["x-ms-request-charge"] + == captured[0]["x-ms-request-charge"] + ) + assert ( + result_headers["x-ms-activity-id"] + == captured[0]["x-ms-activity-id"] + ) + assert len(result) == 2 finally: await self._delete_container_for_test(cid) diff --git a/sdk/cosmos/azure-cosmos/tests/test_read_timeout_propagation.py b/sdk/cosmos/azure-cosmos/tests/test_read_timeout_propagation.py new file mode 100644 index 000000000000..be0c01b34202 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_read_timeout_propagation.py @@ -0,0 +1,271 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests that ``read_timeout`` is honored by sync query operations. + +A caller can pass ``read_timeout`` per call (e.g. ``container.query_items(... +read_timeout=X)``) or set it on the client constructor / connection policy. +These tests use a capturing transport to confirm the value actually reaches +the wire for each query entry point and that per-call beats client-level. +""" + +import unittest +from typing import Any, Optional + +import pytest +from azure.core.pipeline.transport import RequestsTransport + +import test_config +from azure.cosmos import CosmosClient, documents + + +def _is_query_document_fetch(url: str) -> bool: + """True for query/result fetches against ``/docs``. + + Filters out SDK-internal calls (partition-range fetches, container + properties reads, account-info probes) as well as non-document query + surfaces that are asserted separately in dedicated tests. + """ + stripped = url.rstrip("/") + return stripped.endswith("/docs") + + +class _CaptureTransport(RequestsTransport): + """Forwards every request and records the timeout values used on the wire.""" + + def __init__(self): + super().__init__() + self.captured: list[dict[str, Any]] = [] + + def send(self, request, **kwargs): + self.captured.append({ + "url": request.url, + "method": request.method, + "read_timeout": kwargs.get("read_timeout"), + "connection_timeout": kwargs.get("connection_timeout"), + }) + return super().send(request, **kwargs) + + def query_document_fetch_read_timeouts(self) -> list[Optional[float]]: + return [c["read_timeout"] for c in self.captured if _is_query_document_fetch(c["url"])] + + +@pytest.mark.cosmosEmulator +@pytest.mark.cosmosAADQuery +class TestReadTimeoutPropagation(unittest.TestCase): + """``read_timeout`` propagation tests for the sync query surface.""" + + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + SINGLE_PARTITION_ID = configs.TEST_SINGLE_PARTITION_CONTAINER_ID + MULTI_PARTITION_ID = configs.TEST_MULTI_PARTITION_CONTAINER_ID + + @classmethod + def setUpClass(cls): + if cls.masterKey == "[YOUR_KEY_HERE]" or cls.host == "[YOUR_ENDPOINT_HERE]": + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' to run the tests.") + # Seed a few items with the master key so the queries below have + # something to return regardless of test ordering or auth mode. + cls._seed_client = CosmosClient(cls.host, cls.masterKey) + db = cls._seed_client.get_database_client(cls.TEST_DATABASE_ID) + for cid in (cls.SINGLE_PARTITION_ID, cls.MULTI_PARTITION_ID): + c = db.get_container_client(cid) + seeded_count = 0 + seed_failures = [] + for i in range(3): + try: + c.upsert_item({"id": f"read_timeout_seed_{cid}_{i}", + "pk": f"pk{i}"}) + seeded_count += 1 + except Exception as exc: + seed_failures.append(f"{type(exc).__name__}: {exc}") + if seeded_count == 0: + raise RuntimeError( + f"Failed to seed any items into {cid}. " + f"Seed errors: {seed_failures}" + ) + + @classmethod + def tearDownClass(cls): + try: + cls._seed_client.close() + except Exception: + pass + + def _build_client(self, capture: _CaptureTransport, **kw) -> CosmosClient: + """Builds a client routed through the capturing transport.""" + return self.configs.create_data_client(transport=capture, **kw) + + def _assert_all_query_document_fetch_read_timeouts_equal( + self, capture: _CaptureTransport, expected: float + ) -> None: + observed = capture.query_document_fetch_read_timeouts() + assert observed, ( + "no /docs query requests captured; the test did not run a document query. " + "captured={}".format(capture.captured) + ) + bad = [v for v in observed if v != expected] + assert not bad, ( + "document query request used read_timeout={} (expected {}). " + "all /docs query read_timeouts: {}. all captures: {}".format( + bad, expected, observed, capture.captured + ) + ) + + def test_per_call_read_timeout_propagates_to_single_partition_query(self): + """Per-call ``read_timeout`` reaches the wire for a single-partition query.""" + capture = _CaptureTransport() + client = self._build_client(capture) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + list(container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + read_timeout=17.0, + )) + finally: + client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 17.0) + + def test_per_call_read_timeout_propagates_to_cross_partition_query(self): + """Per-call ``read_timeout`` reaches the wire for a cross-partition query.""" + capture = _CaptureTransport() + client = self._build_client(capture) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.MULTI_PARTITION_ID) + ) + list(container.query_items( + query="SELECT * FROM c", + enable_cross_partition_query=True, + read_timeout=18.0, + )) + finally: + client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 18.0) + + def test_per_call_read_timeout_propagates_to_change_feed(self): + """Per-call ``read_timeout`` reaches the wire for change-feed reads.""" + capture = _CaptureTransport() + client = self._build_client(capture) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + list(container.query_items_change_feed( + start_time="Beginning", + read_timeout=19.0, + )) + finally: + client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 19.0) + + def test_per_call_read_timeout_propagates_to_database_query_containers(self): + """Per-call ``read_timeout`` reaches the wire when querying containers.""" + capture = _CaptureTransport() + client = self._build_client(capture) + try: + db = client.get_database_client(self.TEST_DATABASE_ID) + list(db.query_containers( + query="SELECT * FROM c", + read_timeout=20.0, + )) + finally: + client.close() + # query_containers hits the /colls endpoint. + observed = [c["read_timeout"] for c in capture.captured if "/colls" in c["url"]] + assert observed, "no /colls request captured: {}".format(capture.captured) + assert all(v == 20.0 for v in observed), ( + "db.query_containers dropped per-call read_timeout. captured={}".format( + capture.captured + ) + ) + + def test_per_call_read_timeout_propagates_to_client_query_databases(self): + """Per-call ``read_timeout`` reaches the wire when querying databases.""" + capture = _CaptureTransport() + client = self._build_client(capture) + try: + list(client.query_databases( + query="SELECT * FROM c", + read_timeout=21.0, + )) + finally: + client.close() + # query_databases hits the /dbs endpoint with no collection suffix. + observed = [c["read_timeout"] for c in capture.captured if "/dbs" in c["url"]] + assert observed, "no /dbs request captured: {}".format(capture.captured) + assert all(v == 21.0 for v in observed), ( + "client.query_databases dropped per-call read_timeout. captured={}".format( + capture.captured + ) + ) + + def test_per_call_read_timeout_overrides_client_for_query(self): + """Per-call ``read_timeout`` wins over the value set on the client.""" + capture = _CaptureTransport() + client = self._build_client(capture, read_timeout=33.0) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + list(container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + read_timeout=5.0, + )) + finally: + client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 5.0) + + def test_client_level_read_timeout_kwarg_propagates_to_queries(self): + """``read_timeout`` set on the client constructor reaches the wire for queries.""" + capture = _CaptureTransport() + client = self._build_client(capture, read_timeout=22.0) + try: + assert client.client_connection.connection_policy.ReadTimeout == 22.0 + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + list(container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + )) + finally: + client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 22.0) + + def test_connection_policy_read_timeout_propagates_to_queries(self): + """``ReadTimeout`` set on the connection policy reaches the wire for queries.""" + cp = documents.ConnectionPolicy() + cp.DisableSSLVerification = self.configs.is_emulator + cp.ReadTimeout = 23.0 + capture = _CaptureTransport() + client = self._build_client(capture, connection_policy=cp) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + list(container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + )) + finally: + client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 23.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_read_timeout_propagation_async.py b/sdk/cosmos/azure-cosmos/tests/test_read_timeout_propagation_async.py new file mode 100644 index 000000000000..76e635933016 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_read_timeout_propagation_async.py @@ -0,0 +1,270 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests that ``read_timeout`` is honored by async query operations. + +A caller can pass ``read_timeout`` per call (e.g. ``container.query_items(... +read_timeout=X)``) or set it on the client constructor / connection policy. +These tests use a capturing transport to confirm the value actually reaches +the wire for each query entry point and that per-call beats client-level. +""" + +import unittest +from typing import Any, Optional + +import pytest +from azure.core.pipeline.transport import AioHttpTransport + +import test_config +from azure.cosmos import documents +from azure.cosmos.aio import CosmosClient + + +def _is_query_document_fetch(url: str) -> bool: + """True for query/result fetches against ``/docs``. + + Filters out SDK-internal calls (partition-range fetches, container + properties reads, account-info probes) as well as non-document query + surfaces that are asserted separately in dedicated tests. + """ + stripped = url.rstrip("/") + return stripped.endswith("/docs") + + +class _AsyncCaptureTransport(AioHttpTransport): + """Forwards every request and records the timeout values used on the wire.""" + + def __init__(self): + super().__init__() + self.captured: list[dict[str, Any]] = [] + + async def send(self, request, **kwargs): + self.captured.append({ + "url": request.url, + "method": request.method, + "read_timeout": kwargs.get("read_timeout"), + "connection_timeout": kwargs.get("connection_timeout"), + }) + return await super().send(request, **kwargs) + + def query_document_fetch_read_timeouts(self) -> list[Optional[float]]: + return [c["read_timeout"] for c in self.captured if _is_query_document_fetch(c["url"])] + + +@pytest.mark.cosmosEmulator +@pytest.mark.cosmosAADQuery +class TestReadTimeoutPropagationAsync(unittest.IsolatedAsyncioTestCase): + """``read_timeout`` propagation tests for the async query surface.""" + + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + SINGLE_PARTITION_ID = configs.TEST_SINGLE_PARTITION_CONTAINER_ID + MULTI_PARTITION_ID = configs.TEST_MULTI_PARTITION_CONTAINER_ID + + @classmethod + def setUpClass(cls): + if cls.masterKey == "[YOUR_KEY_HERE]" or cls.host == "[YOUR_ENDPOINT_HERE]": + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' to run the tests.") + + async def asyncSetUp(self): + """Seeds a few items once per class with the master key.""" + if getattr(self.__class__, "_seeded", False): + return + async with CosmosClient(self.host, self.masterKey) as seed: + db = seed.get_database_client(self.TEST_DATABASE_ID) + for cid in (self.SINGLE_PARTITION_ID, self.MULTI_PARTITION_ID): + c = db.get_container_client(cid) + seeded_count = 0 + seed_failures = [] + for i in range(3): + try: + await c.upsert_item({ + "id": f"read_timeout_seed_{cid}_{i}_async", + "pk": f"pk{i}", + }) + seeded_count += 1 + except Exception as exc: + seed_failures.append(f"{type(exc).__name__}: {exc}") + if seeded_count == 0: + raise RuntimeError( + f"Failed to seed any items into {cid}. " + f"Seed errors: {seed_failures}" + ) + self.__class__._seeded = True + + def _build_client(self, capture: _AsyncCaptureTransport, **kw) -> CosmosClient: + """Builds a client routed through the capturing transport.""" + return self.configs.create_data_client_async(transport=capture, **kw) + + def _assert_all_query_document_fetch_read_timeouts_equal( + self, capture: _AsyncCaptureTransport, expected: float + ) -> None: + observed = capture.query_document_fetch_read_timeouts() + assert observed, ( + "no /docs query requests captured; the test did not run a document query. " + "captured={}".format(capture.captured) + ) + bad = [v for v in observed if v != expected] + assert not bad, ( + "document query request used read_timeout={} (expected {}). " + "all /docs query read_timeouts: {}. all captures: {}".format( + bad, expected, observed, capture.captured + ) + ) + + async def test_per_call_read_timeout_propagates_to_single_partition_query_async(self): + """Per-call ``read_timeout`` reaches the wire for a single-partition query.""" + capture = _AsyncCaptureTransport() + client = self._build_client(capture) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + _ = [item async for item in container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + read_timeout=17.0, + )] + finally: + await client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 17.0) + + async def test_per_call_read_timeout_propagates_to_cross_partition_query_async(self): + """Per-call ``read_timeout`` reaches the wire for a cross-partition query.""" + capture = _AsyncCaptureTransport() + client = self._build_client(capture) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.MULTI_PARTITION_ID) + ) + _ = [item async for item in container.query_items( + query="SELECT * FROM c", + read_timeout=18.0, + )] + finally: + await client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 18.0) + + async def test_per_call_read_timeout_propagates_to_change_feed_async(self): + """Per-call ``read_timeout`` reaches the wire for change-feed reads.""" + capture = _AsyncCaptureTransport() + client = self._build_client(capture) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + _ = [item async for item in container.query_items_change_feed( + start_time="Beginning", + read_timeout=19.0, + )] + finally: + await client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 19.0) + + async def test_per_call_read_timeout_propagates_to_database_query_containers_async(self): + """Per-call ``read_timeout`` reaches the wire when querying containers.""" + capture = _AsyncCaptureTransport() + client = self._build_client(capture) + try: + db = client.get_database_client(self.TEST_DATABASE_ID) + _ = [c async for c in db.query_containers( + query="SELECT * FROM c", + read_timeout=20.0, + )] + finally: + await client.close() + # query_containers hits the /colls endpoint. + observed = [c["read_timeout"] for c in capture.captured if "/colls" in c["url"]] + assert observed, "no /colls request captured: {}".format(capture.captured) + assert all(v == 20.0 for v in observed), ( + "db.query_containers dropped per-call read_timeout. captured={}".format( + capture.captured + ) + ) + + async def test_per_call_read_timeout_propagates_to_client_query_databases_async(self): + """Per-call ``read_timeout`` reaches the wire when querying databases.""" + capture = _AsyncCaptureTransport() + client = self._build_client(capture) + try: + _ = [d async for d in client.query_databases( + query="SELECT * FROM c", + read_timeout=21.0, + )] + finally: + await client.close() + # query_databases hits the /dbs endpoint with no collection suffix. + observed = [c["read_timeout"] for c in capture.captured if "/dbs" in c["url"]] + assert observed, "no /dbs request captured: {}".format(capture.captured) + assert all(v == 21.0 for v in observed), ( + "client.query_databases dropped per-call read_timeout. captured={}".format( + capture.captured + ) + ) + + async def test_per_call_read_timeout_overrides_client_for_query_async(self): + """Per-call ``read_timeout`` wins over the value set on the client.""" + capture = _AsyncCaptureTransport() + client = self._build_client(capture, read_timeout=33.0) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + _ = [item async for item in container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + read_timeout=5.0, + )] + finally: + await client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 5.0) + + async def test_client_level_read_timeout_kwarg_propagates_to_queries_async(self): + """``read_timeout`` set on the client constructor reaches the wire for queries.""" + capture = _AsyncCaptureTransport() + client = self._build_client(capture, read_timeout=22.0) + try: + assert client.client_connection.connection_policy.ReadTimeout == 22.0 + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + _ = [item async for item in container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + )] + finally: + await client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 22.0) + + async def test_connection_policy_read_timeout_propagates_to_queries_async(self): + """``ReadTimeout`` set on the connection policy reaches the wire for queries.""" + cp = documents.ConnectionPolicy() + cp.DisableSSLVerification = self.configs.is_emulator + cp.ReadTimeout = 23.0 + capture = _AsyncCaptureTransport() + client = self._build_client(capture, connection_policy=cp) + try: + container = ( + client.get_database_client(self.TEST_DATABASE_ID) + .get_container_client(self.SINGLE_PARTITION_ID) + ) + _ = [item async for item in container.query_items( + query="SELECT * FROM c", + partition_key="pk0", + )] + finally: + await client.close() + self._assert_all_query_document_fetch_read_timeouts_equal(capture, 23.0) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py b/sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py index 82143223b72c..534155bc4162 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py +++ b/sdk/cosmos/azure-cosmos/tests/test_request_response_decoding.py @@ -1,32 +1,9 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -"""Wiring tests for the response-body decode call in the core sync/async -``_Request()`` paths. - -Background: ``_synchronized_request._Request`` and -``_asynchronous_request._Request`` are the highest-traffic code paths in -the SDK — every CRUD, query, and change-feed read flows through them. -Both call ``decode_response_body_for_status`` to decode the HTTP -response body before the status-code branching that builds typed -``CosmosResourceNotFoundError`` / ``CosmosHttpResponseError`` exceptions. - -These tests lock in two contracts: - -1. **Wiring** — the call sites actually invoke the shared decoder. If - someone reverts the call back to ``data.decode("utf-8")`` we want a - unit test to fail immediately. Mirrors the wiring tests already in - place for the inference service in ``test_semantic_reranker_unit``. - -2. **Behavior** — when an HTTP error response body contains invalid - UTF-8, ``_Request`` surfaces the real typed exception - (``CosmosResourceNotFoundError`` etc.) instead of letting a - ``UnicodeDecodeError`` escape. This is the property the - ``decode_response_body_for_status`` helper was introduced to - guarantee; without an end-to-end test on ``_Request``, the helper - could quietly fall out of use and the regression would not be - caught at unit-test time. -""" +"""Tests that check the sync and async request functions both call the +shared decode helper, and that a response body with invalid bytes +still surfaces the right typed exception based on the HTTP status.""" import asyncio import unittest from unittest.mock import MagicMock, patch, AsyncMock @@ -38,7 +15,6 @@ from azure.cosmos.http_constants import ResourceType -# Same invalid UTF-8 used in test_response_decoding.py _INVALID_UTF8 = b'{"note":"hello \xc3\x28 world"}' _VALID_UTF8 = b'{"ok":true}' @@ -46,11 +22,7 @@ def _build_request_args(status_code: int, body: bytes): - """Build the minimal set of mocked dependencies ``_Request`` needs to - reach the decode call site. Returns (args_tuple, mock_response).""" - # ``endpoint_override`` short-circuits endpoint resolution so we do - # not need a real GlobalEndpointManager. ``DatabaseAccount`` skips - # ``refresh_endpoint_list`` for the same reason. + """Builds the smallest set of mocks the request function needs.""" request_params = MagicMock() request_params.healthy_tentative_location = False request_params.resource_type = ResourceType.DatabaseAccount @@ -77,7 +49,7 @@ def _build_request_args(status_code: int, body: bytes): request.url = _FAKE_ENDPOINT + "dbs" request.headers = {} - # The fake pipeline response that _PipelineRunFunction will return. + # Fake HTTP response with the given body and status. mock_response = MagicMock() mock_response.http_response.status_code = status_code mock_response.http_response.headers = {} @@ -90,11 +62,12 @@ def _build_request_args(status_code: int, body: bytes): class TestSyncRequestUsesSharedDecoder(unittest.TestCase): - """Wiring + behavioral tests for ``_synchronized_request._Request``.""" + """Sync request function: uses the shared decoder, and turns + error responses into typed exceptions.""" def test_request_invokes_shared_response_decoder(self): - """Reverting the call site back to ``data.decode('utf-8')`` would - make this test fail. Locks in the wiring.""" + """Checks the sync request function actually calls the shared + decode helper with the response bytes and status.""" args, mock_response = _build_request_args(status_code=200, body=_VALID_UTF8) with patch( @@ -109,11 +82,8 @@ def test_request_invokes_shared_response_decoder(self): mock_decode.assert_called_once_with(_VALID_UTF8, 200, "Read") def test_invalid_utf8_on_404_surfaces_resource_not_found(self): - """Behavioral guarantee: a 404 carrying a malformed-UTF-8 body - must surface as ``CosmosResourceNotFoundError``, not as a - ``UnicodeDecodeError``. Customer error handlers branch on the - typed exception; a decode error here would skip those handlers - entirely.""" + """A 404 with invalid bytes in the body should still come + out as the typed not-found exception, not a decode error.""" args, mock_response = _build_request_args(status_code=404, body=_INVALID_UTF8) with patch( @@ -124,9 +94,8 @@ def test_invalid_utf8_on_404_surfaces_resource_not_found(self): _synchronized_request._Request(*args) def test_invalid_utf8_on_503_surfaces_http_response_error(self): - """Same guarantee for the generic ``status_code >= 400`` branch. - 503 specifically matters: it drives cross-region retry; masking - it with a decode error would stop failover from happening.""" + """A 503 with invalid bytes still comes out as the generic + HTTP error with the right status, not a decode error.""" args, mock_response = _build_request_args(status_code=503, body=_INVALID_UTF8) with patch( @@ -139,7 +108,7 @@ def test_invalid_utf8_on_503_surfaces_http_response_error(self): class TestAsyncRequestUsesSharedDecoder(unittest.TestCase): - """Wiring + behavioral tests for ``_asynchronous_request._Request``.""" + """Async request function: same checks as the sync class.""" def test_request_invokes_shared_response_decoder(self): async def run_test(): @@ -187,21 +156,9 @@ async def run_test(): class TestRequestWrapsResidualUnicodeDecodeErrorAsDecodeError(unittest.TestCase): - """For a successful (2xx) response carrying invalid UTF-8 in default - strict mode, ``_Request`` must surface the failure as - ``azure.core.exceptions.DecodeError`` — not as a stdlib - ``UnicodeDecodeError``. This contract matches the existing JSON-parse - failure path (which also raises ``DecodeError``) and keeps the wire - truth intact: - - * ``e.response.status_code`` is the real wire status (e.g. 200) — - no synthetic 400, no faked sub-status. - * ``e.__cause__`` is the original ``UnicodeDecodeError`` so - operators can still see the byte offset and the env-var hint. - - Customer middleware keyed on ``HttpResponseError`` / - ``CosmosHttpResponseError`` continues to work because ``DecodeError`` - is a subclass of ``HttpResponseError``.""" + """A 200 response with invalid bytes (in default strict mode) + should be raised as DecodeError, keeping the real wire status and + the original cause attached.""" def test_sync_2xx_with_invalid_utf8_raises_decode_error(self): args, mock_response = _build_request_args(status_code=200, body=_INVALID_UTF8) diff --git a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py index 174273515e0f..facc1f704771 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py +++ b/sdk/cosmos/azure-cosmos/tests/test_response_decoding.py @@ -1,15 +1,9 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -"""Tests for the response-body UTF-8 decode helper and its env-var-driven -fallback behavior. Covers the healthy path (strict decode succeeds), the -default behavior when the env var is unset (strict decode fails with an -actionable hint), and the opt-in REPLACE / IGNORE modes. - -The helper reads the env var per-call on the decode-failure path, so -tests just mutate ``os.environ`` (under ``mock.patch.dict``) and call -``decode_response_body`` directly — no cache reset needed. -""" +"""Tests for the response-body UTF-8 decode helper. Covers the default +behavior, the opt-in REPLACE / IGNORE modes, and how the env var that +controls them is parsed.""" # cspell:ignore ufffd import json import os @@ -19,9 +13,9 @@ from azure.cosmos import _response_decoding -# A small payload containing one valid 2-byte UTF-8 sequence followed by a -# byte (0xC3 followed by 0x28) that is not a valid UTF-8 continuation byte. -# `\xC3\x28` is the textbook example of an invalid UTF-8 sequence. +# Small payload with an invalid byte sequence: 0xC3 marks the start of +# a two-byte UTF-8 character, but the next byte (0x28) is not a valid +# continuation byte, so this can't be decoded as UTF-8. _INVALID_UTF8 = b'{"note":"hello \xc3\x28 world"}' _VALID_UTF8 = b'{"note":"hello world"}' @@ -29,29 +23,23 @@ class _DecoderEnvIsolatedTestCase(unittest.TestCase): - """Base class that isolates each test from the surrounding process - environment by rolling back any env mutations the test makes.""" + """Saves and restores the env var so tests don't leak settings.""" def setUp(self): self._env_patch = mock.patch.dict(os.environ, {}, clear=False) self._env_patch.start() - # Strip the env var for the duration of the test so the helper's - # default behavior (no env -> strict) is the explicit baseline. - # Tests that need a specific env value set it themselves. + # Clear the env var so each test starts from the default state. os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) def tearDown(self): - # `mock.patch.dict` rolls back any env mutations the test made, - # including the pop in setUp. self._env_patch.stop() class TestStrictDecodingHealthyPath(_DecoderEnvIsolatedTestCase): def test_valid_utf8_decodes_unchanged(self): - """The healthy path must produce exactly the same string as - ``bytes.decode('utf-8')``. This is the regression guard for - the 99.99% case where the body is well-formed.""" + """A well-formed payload decodes to the same string as a plain + UTF-8 decode would.""" result = _response_decoding.decode_response_body(_VALID_UTF8) self.assertEqual(result, '{"note":"hello world"}') @@ -62,30 +50,25 @@ def test_empty_bytes_decodes_to_empty_string(self): class TestStrictDecodingRaisesActionableError(_DecoderEnvIsolatedTestCase): def test_invalid_utf8_without_env_var_raises_with_hint(self): - """When the env var is unset (the historical default) the helper - must raise ``UnicodeDecodeError`` so existing call sites continue - to behave the same way. The hint in ``reason`` points the - operator at the env var name so they can self-serve.""" - # setUp already cleared the env var; assert it for the reader. + """With the env var unset, invalid bytes raise. The error + message mentions the env var so users know how to opt in.""" self.assertNotIn(_MALFORMED_INPUT_ENV_VAR, os.environ) with self.assertRaises(UnicodeDecodeError) as ctx: _response_decoding.decode_response_body(_INVALID_UTF8, operation_context="read_item") self.assertIn(_MALFORMED_INPUT_ENV_VAR, ctx.exception.reason) - # Original exception must be chained so callers and log readers - # can still see the underlying decoder error. + # The original error stays chained so callers can still see it. self.assertIsInstance(ctx.exception.__cause__, UnicodeDecodeError) class TestPermissiveFallback(_DecoderEnvIsolatedTestCase): - """Exercises the decode behavior in each fallback mode by setting - the env var and calling ``decode_response_body`` directly.""" + """Checks decode behavior in each fallback mode.""" def test_replace_mode_substitutes_replacement_character(self): os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" result = _response_decoding.decode_response_body(_INVALID_UTF8) - # The bad byte is replaced by U+FFFD; the surrounding text is preserved. + # The bad byte becomes the replacement character; the rest stays. self.assertIn("\ufffd", result) self.assertIn("hello", result) self.assertIn("world", result) @@ -93,16 +76,14 @@ def test_replace_mode_substitutes_replacement_character(self): def test_ignore_mode_drops_bad_bytes(self): os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" result = _response_decoding.decode_response_body(_INVALID_UTF8) - # No replacement character; the bad byte is silently dropped. + # The bad byte is dropped instead of replaced. self.assertNotIn("\ufffd", result) self.assertIn("hello", result) self.assertIn("world", result) class TestEnvVarParser(_DecoderEnvIsolatedTestCase): - """Unit tests for ``_resolve_fallback_mode_from_env`` in isolation. - Each test sets the env var and asserts the parsed mode matches the - documented mapping.""" + """Tests for how the env var value is read and interpreted.""" def test_replace_env_value_resolves_to_replace_mode(self): os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" @@ -113,8 +94,7 @@ def test_ignore_env_value_resolves_to_ignore_mode(self): self.assertEqual(_response_decoding._resolve_fallback_mode_from_env(), "ignore") def test_unknown_env_value_resolves_to_strict(self): - """Anything other than REPLACE / IGNORE (case-insensitive) must - leave strict decoding in effect.""" + """Any value other than REPLACE or IGNORE keeps strict decoding.""" os.environ[_MALFORMED_INPUT_ENV_VAR] = "BOGUS" self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) @@ -127,106 +107,125 @@ def test_unset_env_resolves_to_strict(self): self.assertNotIn(_MALFORMED_INPUT_ENV_VAR, os.environ) self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) + # The cases below pin down what happens for env-var values that + # are neither REPLACE nor IGNORE. Each one must resolve to strict + # so accidental or typo'd values don't silently change behavior. + + def test_empty_string_env_value_resolves_to_strict(self): + """An empty value (common shell typo: VAR= ) must stay strict.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "" + self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) + + def test_whitespace_only_env_value_resolves_to_strict(self): + """Values that are only spaces or newlines must stay strict.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = " \t\n" + self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) + + def test_report_env_value_resolves_to_strict(self): + """REPORT isn't one of the accepted values, so it stays strict.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPORT" + self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) + + def test_comma_separated_env_value_resolves_to_strict(self): + """The parser does not split on commas. The whole value is + treated as one token and stays strict if it isn't recognized.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE,IGNORE" + self.assertIsNone(_response_decoding._resolve_fallback_mode_from_env()) + + def test_mixed_case_replace_resolves_to_replace_mode(self): + """Values are matched case-insensitively.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "Replace" + self.assertEqual(_response_decoding._resolve_fallback_mode_from_env(), "replace") + class TestEnvVarToBehaviorEndToEnd(_DecoderEnvIsolatedTestCase): - """Verifies the full contract: setting the env var actually changes - what ``decode_response_body`` does, and clearing it returns to - strict. Catches regressions where the env parser and the per-call - read drift apart.""" + """Checks that setting the env var actually changes what the + decode helper does, and clearing it goes back to the default.""" def test_setting_replace_env_var_makes_invalid_utf8_decode_succeed(self): - # Baseline: with no env var, the same input raises. + # With no env var the same input raises. with self.assertRaises(UnicodeDecodeError): _response_decoding.decode_response_body(_INVALID_UTF8) - # Opt in via the env var and prove the same input now decodes to - # a replacement-character-bearing string instead of raising. + # Set REPLACE and the same input now decodes successfully. os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" result = _response_decoding.decode_response_body(_INVALID_UTF8) self.assertIn("\ufffd", result) def test_clearing_env_var_returns_to_strict(self): - # Opt in. os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" self.assertEqual(_response_decoding._resolve_fallback_mode_from_env(), "replace") - # Opt out by removing the var; next decode raises again. + # Remove the var and the next decode raises again. del os.environ[_MALFORMED_INPUT_ENV_VAR] with self.assertRaises(UnicodeDecodeError): _response_decoding.decode_response_body(_INVALID_UTF8) class TestDecodeForStatus(_DecoderEnvIsolatedTestCase): - """Tests ``decode_response_body_for_status`` — the wrapper that - HTTP request paths use so a malformed-UTF-8 error body does not - mask the real status-code exception. The SDK's retry/refresh logic - and customer error handlers branch on status code, so a 404, 410 - (partition split), 429 (throttle), or 503 must surface as the - correct typed exception even when the body has invalid bytes.""" + """Tests the wrapper that lets error responses still produce a + typed exception even when the response body has invalid bytes.""" def test_valid_utf8_success_passes_through(self): - """Healthy path: 2xx with well-formed body decodes normally.""" + """A 200 with a well-formed body decodes normally.""" result = _response_decoding.decode_response_body_for_status( _VALID_UTF8, status_code=200 ) self.assertEqual(result, '{"note":"hello world"}') def test_invalid_utf8_on_2xx_still_raises(self): - """A successful response with malformed bytes is a data-integrity - problem the caller needs to see — do not silently paper over it.""" + """A successful response with invalid bytes still raises so + the caller is informed about the data problem.""" with self.assertRaises(UnicodeDecodeError): _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=200 ) def test_invalid_utf8_on_404_does_not_raise(self): - """404 with malformed body must decode best-effort so callers - receive ``CosmosResourceNotFoundError`` instead of a confusing - ``UnicodeDecodeError``.""" + """A 404 with invalid bytes decodes best-effort so the + caller still gets the typed not-found error.""" result = _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=404 ) - # The bad byte is replaced; surrounding text is preserved so the - # error message remains human-readable. + # The bad byte is replaced; the rest of the text is preserved. self.assertIn("\ufffd", result) self.assertIn("hello", result) self.assertIn("world", result) def test_invalid_utf8_on_throttle_does_not_raise(self): - """429 carries the retry-after signal the SDK's throttle handler - depends on; it must not be masked by a decode error.""" + """A 429 with invalid bytes still decodes so the throttle + handler can see it.""" result = _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=429 ) self.assertIn("\ufffd", result) def test_invalid_utf8_on_partition_gone_does_not_raise(self): - """410 is the partition-split signal that triggers partition-map - refresh; masking it would break split recovery.""" + """A 410 with invalid bytes still decodes so the partition + refresh logic can run.""" result = _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=410 ) self.assertIn("\ufffd", result) def test_invalid_utf8_on_service_unavailable_does_not_raise(self): - """503 drives cross-region retry; masking it makes the SDK give - up instead of failing over.""" + """A 503 with invalid bytes still decodes so the retry logic + can run.""" result = _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=503 ) self.assertIn("\ufffd", result) def test_boundary_399_still_raises(self): - """The wrapper opens up best-effort decode at exactly 400 and - above. 399 (unused in HTTP today, but covers 3xx redirects) - must still raise — same reason as 2xx.""" + """Best-effort decode only kicks in at 400 and above. A 399 + status still raises like a successful response would.""" with self.assertRaises(UnicodeDecodeError): _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=399 ) def test_boundary_400_does_not_raise(self): - """Confirms the threshold is inclusive at 400 (Bad Request).""" + """Confirms 400 is included in the best-effort range.""" result = _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=400 ) @@ -240,10 +239,8 @@ def test_empty_body_decodes_to_empty_string_regardless_of_status(self): ) def test_fallback_env_var_handles_2xx_before_status_check_kicks_in(self): - """When the operator has opted in via the env var, the inner - ``decode_response_body`` already succeeds (with replacement), so - a 2xx with malformed bytes also decodes successfully. The - wrapper's status-code branch never runs in that case.""" + """When the env var is set, a 200 with invalid bytes decodes + too because the env-var path runs before the status check.""" os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" result = _response_decoding.decode_response_body_for_status( _INVALID_UTF8, status_code=200 @@ -252,42 +249,23 @@ def test_fallback_env_var_handles_2xx_before_status_check_kicks_in(self): class TestPermissiveFallbackJsonPipeline(_DecoderEnvIsolatedTestCase): - """End-to-end tests covering the decode-then-``json.loads`` pipeline - every caller of ``decode_response_body`` runs. - - These tests document an important operator-facing trade-off of - enabling permissive fallback (``REPLACE``): malformed bytes inside - a JSON *string value* become silently-corrupted Python str values - after parsing (``"\\ufffd"`` ends up in the data), while malformed - bytes that land on JSON *structural* characters cause the parse - step to fail with ``json.JSONDecodeError`` — which the SDK's - request path catches and surfaces as a typed ``DecodeError``. - - Assertions are intentionally outcome-shaped (does parse succeed? - does the value contain U+FFFD?) and avoid asserting exact error - messages or byte offsets, so CPython upgrades do not break us. - Mirrors the coverage in the Java SDK's MalformedResponseTests.""" - - # Bad bytes (`\xc3\x28`) inside the JSON string value `"caf??"`. - # After REPLACE decode the body is well-formed JSON whose value - # happens to contain U+FFFD — json.loads succeeds. + """Checks what happens when the decoded string is then parsed as + JSON. Bad bytes inside a string value parse cleanly; bad bytes + that land where a JSON structural character was expected do not. + """ + + # Bad bytes inside a JSON string value. After REPLACE the body is + # well-formed JSON whose string value contains the replacement + # character. _BAD_BYTES_IN_STRING_VALUE = b'{"name":"caf\xc3\x28 dining"}' - # Bad bytes (`\xc3\x28`) placed where the JSON colon delimiter - # should be. After REPLACE decode the colon position contains - # U+FFFD instead — json.loads cannot parse this as an object. + # Bad bytes where the JSON colon should be. After REPLACE the + # body is not parseable JSON. _BAD_BYTES_IN_STRUCTURE = b'{"name"\xc3\x28"value"}' def test_replace_mode_corrupts_string_values_silently(self): - """REPLACE + parse on bad bytes inside a string value: parse - succeeds, the resulting Python str contains U+FFFD. This is - the case operators need to be aware of when enabling REPLACE - — application code receives data with substituted characters - and no signal that the substitution happened. - - Cross-SDK parity note: matches the Java - MalformedResponseTests scenario where a corrupted character - inside a JSON string is silently preserved through parsing.""" + """When the bad bytes were inside a JSON string value, parsing + succeeds and the replacement character ends up in the value.""" os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" decoded = _response_decoding.decode_response_body(self._BAD_BYTES_IN_STRING_VALUE) @@ -296,24 +274,19 @@ def test_replace_mode_corrupts_string_values_silently(self): self.assertIsInstance(parsed, dict) self.assertIn("name", parsed) self.assertIsInstance(parsed["name"], str) - # The replacement character is present in the parsed value; - # the rest of the value text is preserved verbatim. + # The replacement character is in the parsed value; the rest + # of the text is preserved. self.assertIn("\ufffd", parsed["name"]) self.assertIn("caf", parsed["name"]) self.assertIn("dining", parsed["name"]) def test_replace_mode_structural_corruption_raises_json_decode_error(self): - """REPLACE + parse on bad bytes in JSON structure: decode - succeeds, parse raises ``json.JSONDecodeError``. The SDK's - ``_Request`` path catches that and surfaces it as - ``azure.core.exceptions.DecodeError`` (covered by the - ``_Request`` wiring tests). Here we just lock in the seam: - decode produces a string, parse rejects it.""" + """When the bad bytes broke the JSON structure, decode still + succeeds but JSON parsing fails.""" os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" decoded = _response_decoding.decode_response_body(self._BAD_BYTES_IN_STRUCTURE) - # Decode itself does not raise — the byte that broke JSON - # structure has become U+FFFD in the decoded string. + # Decode itself does not raise. self.assertIsInstance(decoded, str) self.assertIn("\ufffd", decoded) diff --git a/sdk/cosmos/azure-cosmos/tests/test_response_decoding_paged.py b/sdk/cosmos/azure-cosmos/tests/test_response_decoding_paged.py new file mode 100644 index 000000000000..273054566af4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_response_decoding_paged.py @@ -0,0 +1,241 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests that paging through one response after another keeps working +when one page has bytes that aren't valid UTF-8 and the user has opted +in to the permissive decode mode via the env var.""" +import asyncio +import os +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from azure.core.exceptions import DecodeError + +from azure.cosmos import _synchronized_request +from azure.cosmos.aio import _asynchronous_request +from azure.cosmos.http_constants import ResourceType + + +_MALFORMED_INPUT_ENV_VAR = "AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" +pytestmark = pytest.mark.cosmosEmulator + +# Page 1 is valid JSON but a string value contains an invalid UTF-8 +# byte sequence. Page 2 is fully well-formed. +_PAGE_WITH_BAD_UTF8_IN_STRING_VALUE = ( + b'{"_rid":"rid","Documents":[{"id":"doc1","x":"caf\xc3\x28 dining"}],"_count":1}' +) +_PAGE_VALID_UTF8 = ( + b'{"_rid":"rid","Documents":[{"id":"doc2","x":"hello"}],"_count":1}' +) + +_FAKE_ENDPOINT = "https://example.documents.azure.com:443/" + + +def _build_request_args(): + """Builds the minimum set of mocks needed for one request call.""" + request_params = MagicMock() + request_params.healthy_tentative_location = False + request_params.resource_type = ResourceType.DatabaseAccount + request_params.read_timeout_override = None + request_params.endpoint_override = _FAKE_ENDPOINT + request_params.should_cancel_request.return_value = False + request_params.operation_type = "ReadFeed" + request_params.availability_strategy = None + + connection_policy = MagicMock() + connection_policy.RequestTimeout = 30 + connection_policy.ReadTimeout = 30 + connection_policy.RecoveryReadTimeout = 5 + connection_policy.DBAReadTimeout = 5 + connection_policy.DBAConnectionTimeout = 5 + connection_policy.SSLConfiguration = None + connection_policy.DisableSSLVerification = False + + global_endpoint_manager = MagicMock() + pipeline_client = MagicMock() + + request = MagicMock() + request.url = _FAKE_ENDPOINT + "dbs/db/colls/c/docs" + request.headers = {} + + return global_endpoint_manager, request_params, connection_policy, pipeline_client, request + + +def _mock_response(body: bytes, status_code: int = 200): + """Builds a fake HTTP response with the given body and status.""" + mock_response = MagicMock() + mock_response.http_response.status_code = status_code + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = body + return mock_response + + +class _DecoderEnvIsolatedTestCase(unittest.TestCase): + """Saves and restores the env var so tests don't leak settings.""" + + def setUp(self): + self._saved = os.environ.get(_MALFORMED_INPUT_ENV_VAR) + os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) + + def tearDown(self): + if self._saved is not None: + os.environ[_MALFORMED_INPUT_ENV_VAR] = self._saved + else: + os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) + + +class TestSyncPagedIterationWithReplace(_DecoderEnvIsolatedTestCase): + """Calls the sync request function once per page, the way an + iterator would pull one page at a time.""" + + def _drive_pages(self, page_bodies): + """Calls the request function once per body and returns the + list of results. A fresh args tuple is built each time because + the function mutates the request object in place.""" + results = [] + for body in page_bodies: + args = _build_request_args() + with patch( + "azure.cosmos._synchronized_request._PipelineRunFunction", + return_value=_mock_response(body, status_code=200), + ): + results.append(_synchronized_request._Request(*args)) + return results + + def test_strict_mode_iteration_aborts_on_corrupt_page(self): + """With the env var unset, the bad page should fail right + away. This is the default behavior.""" + with self.assertRaises(DecodeError): + self._drive_pages([_PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, _PAGE_VALID_UTF8]) + + def test_replace_mode_iteration_completes_past_corrupt_page(self): + """With REPLACE set, both pages should decode and parse. The + bad byte on page 1 becomes a replacement character.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + + results = self._drive_pages([_PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, _PAGE_VALID_UTF8]) + + self.assertEqual(len(results), 2) + + page1_body, _ = results[0] + page2_body, _ = results[1] + + # Page 1 came through with a replacement character in place of + # the bad byte. Surrounding text is preserved. + self.assertIn("Documents", page1_body) + self.assertEqual(page1_body["Documents"][0]["id"], "doc1") + self.assertIn("\ufffd", page1_body["Documents"][0]["x"]) + self.assertIn("caf", page1_body["Documents"][0]["x"]) + self.assertIn("dining", page1_body["Documents"][0]["x"]) + + # Page 2 is normal text. The setting only kicks in when the + # bytes are actually invalid. + self.assertEqual(page2_body["Documents"][0]["id"], "doc2") + self.assertEqual(page2_body["Documents"][0]["x"], "hello") + + def test_ignore_mode_iteration_completes_past_corrupt_page(self): + """With IGNORE set, both pages decode. The bad byte is dropped + instead of being replaced.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" + + results = self._drive_pages([_PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, _PAGE_VALID_UTF8]) + + self.assertEqual(len(results), 2) + page1_body, _ = results[0] + page2_body, _ = results[1] + + # No replacement character; the surrounding text stays intact. + self.assertNotIn("\ufffd", page1_body["Documents"][0]["x"]) + self.assertIn("caf", page1_body["Documents"][0]["x"]) + self.assertEqual(page2_body["Documents"][0]["x"], "hello") + + def test_replace_mode_corrupt_page_does_not_poison_next_request_headers(self): + """Three pages in a row (bad, good, bad) should each decode + based on their own bytes, so any state accidentally shared + between calls would show up here.""" + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + + results = self._drive_pages([ + _PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, + _PAGE_VALID_UTF8, + _PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, + ]) + + self.assertEqual(len(results), 3) + page1_body, _ = results[0] + page2_body, _ = results[1] + page3_body, _ = results[2] + + self.assertIn("\ufffd", page1_body["Documents"][0]["x"]) + self.assertNotIn("\ufffd", page2_body["Documents"][0]["x"]) + self.assertIn("\ufffd", page3_body["Documents"][0]["x"]) + + +class TestAsyncPagedIterationWithReplace(_DecoderEnvIsolatedTestCase): + """Async version of the paged-iteration checks above.""" + + def _drive_pages(self, page_bodies): + async def _run(): + results = [] + for body in page_bodies: + args = _build_request_args() + with patch( + "azure.cosmos.aio._asynchronous_request._PipelineRunFunction", + new=AsyncMock(return_value=_mock_response(body, status_code=200)), + ): + results.append(await _asynchronous_request._Request(*args)) + return results + return asyncio.run(_run()) + + def test_strict_mode_iteration_aborts_on_corrupt_page(self): + with self.assertRaises(DecodeError): + self._drive_pages([_PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, _PAGE_VALID_UTF8]) + + def test_replace_mode_iteration_completes_past_corrupt_page(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + + results = self._drive_pages([_PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, _PAGE_VALID_UTF8]) + + self.assertEqual(len(results), 2) + page1_body, _ = results[0] + page2_body, _ = results[1] + self.assertIn("\ufffd", page1_body["Documents"][0]["x"]) + self.assertEqual(page2_body["Documents"][0]["x"], "hello") + + def test_ignore_mode_iteration_completes_past_corrupt_page(self): + os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" + + results = self._drive_pages([_PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, _PAGE_VALID_UTF8]) + + self.assertEqual(len(results), 2) + page1_body, _ = results[0] + page2_body, _ = results[1] + self.assertNotIn("\ufffd", page1_body["Documents"][0]["x"]) + self.assertIn("caf", page1_body["Documents"][0]["x"]) + self.assertEqual(page2_body["Documents"][0]["x"], "hello") + + def test_replace_mode_corrupt_page_does_not_poison_next_request_headers(self): + """Three async pages in a row (bad, good, bad) must each decode based on + their own bytes; no decoder state may leak across async requests.""" + # patch.dict scopes the env var to this test and restores the prior + # value when the block exits, even if an assertion below raises. + with patch.dict(os.environ, {_MALFORMED_INPUT_ENV_VAR: "REPLACE"}): + results = self._drive_pages([ + _PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, + _PAGE_VALID_UTF8, + _PAGE_WITH_BAD_UTF8_IN_STRING_VALUE, + ]) + + self.assertEqual(len(results), 3) + page1_body, _ = results[0] + page2_body, _ = results[1] + page3_body, _ = results[2] + + self.assertIn("\ufffd", page1_body["Documents"][0]["x"]) + self.assertNotIn("\ufffd", page2_body["Documents"][0]["x"]) + self.assertIn("\ufffd", page3_body["Documents"][0]["x"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py index 17c83535799f..5b3b885667ac 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit.py @@ -27,27 +27,19 @@ ) from azure.cosmos._routing.routing_map_provider import ( PartitionKeyRangeCache, + SmartRoutingMapProvider, ) +from azure.cosmos._routing import routing_range from azure.cosmos._routing.collection_routing_map import CollectionRoutingMap -from azure.cosmos import http_constants +from azure.cosmos import http_constants, _base from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos._gone_retry_policy_base import _PartitionKeyRangeGoneRetryPolicyBase -# ========================================================= -# Test-only tolerant shim for evaluate_drain_page -# ========================================================= -# Production wires ``_internal_response_status_capture`` via ``_Request`` so -# ``evaluate_drain_page`` always receives a concrete HTTP status. These unit -# tests use lightweight MagicMock side_effects that bypass ``_Request`` and -# therefore leave the sidecar at ``[None]``. Rather than retrofit every mock -# to populate the sidecar, default an unknown status to ``304`` (Not Modified) -# so the drain terminates after the first page -- which is exactly the -# termination signal each existing mock relies on (data on the data path, -# ``iter([])`` on the INM-match path). -# -# This shim is the *only* test-side concession to the strict status contract -# introduced in commit a1e27a57bd; production code is unchanged. +# Test-only shim: production wires status via the request pipeline so the drain +# helper always sees a real HTTP status. The lightweight mocks in this file +# bypass that pipeline, so an unknown status defaults to 304 (Not Modified) and +# the drain terminates after the first mocked page. Production code is unchanged. # pylint: disable=wrong-import-position import azure.cosmos._routing._routing_map_provider_common as _drain_common # noqa: E402 import azure.cosmos._routing.routing_map_provider as _sync_provider_module # noqa: E402 @@ -74,9 +66,6 @@ def _tolerant_evaluate_drain_page(*, page_new_etag, current_if_none_match, _async_provider_module.evaluate_drain_page = _tolerant_evaluate_drain_page -# ========================================================= -# Helpers -# ========================================================= def _make_complete_routing_map(collection_id="coll1", etag='"etag-1"'): """Create a minimal but complete CollectionRoutingMap for testing.""" @@ -105,9 +94,6 @@ def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): return client -# ========================================================= -# Test Class -# ========================================================= @pytest.mark.cosmosEmulator class TestRoutingMapProviderUnit(unittest.TestCase): @@ -683,13 +669,9 @@ def read_pk_ranges_cascading(collection_link, options, response_hook=None, **kwa self.assertEqual(result.change_feed_etag, '"etag-old"') - # ========================================================================== - # Helper-level retry-policy unit tests. - # - # These target only the pure helper that computes retry backoff / 503 - # escalation (no cache object, no fetch loop). They check that backoff - # stays within the deterministic upper bound and that jitter is applied. - # ========================================================================== + # Helper-level retry-policy tests: target the pure helper that computes + # retry backoff and 503 escalation (no cache, no fetch loop). They check + # that backoff stays within the deterministic upper bound and jitters. def test_overlap_retry_backoff_is_within_deterministic_upper_bound(self): """Each non-terminal attempt's backoff must lie in @@ -842,14 +824,9 @@ def test_overlap_retry_raises_503_at_attempt_budget_exhaustion(self): http_constants.StatusCodes.SERVICE_UNAVAILABLE, ) - # ========================================================================== - # Provider retry-loop behavior tests (mocked integration path). - # - # These exercise the sync provider's fetch/retry loop with mocked - # ``/pkranges`` payloads: transient inconsistencies either recover on - # retry or surface as HTTP 503; ``ValueError("Ranges overlap")`` never - # leaks to callers. - # ========================================================================== + # Provider retry-loop tests: exercise the sync provider's fetch and retry + # loop with mocked /pkranges payloads. Inconsistent snapshots either recover + # on retry or surface as HTTP 503; overlap errors never leak to callers. def test_fetch_routing_map_recovers_after_transient_overlap(self): """An inconsistent ``/pkranges`` snapshot followed by a consistent @@ -950,6 +927,11 @@ def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): "Persistent overlap must surface as HTTP 503 (transient), not as a bare ValueError " "or as a silent empty-result return." ) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "503 from a transient pkranges inconsistency must set sub_status to 21015.", + ) self.assertEqual( call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS, "Should have made exactly _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS fetch attempts before giving up." @@ -1044,6 +1026,11 @@ def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): cache.get_routing_map("dbs/db1/colls/coll1", feed_options={}) self.assertEqual(ctx.exception.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "503 from a persistent gap must set sub_status to 21015.", + ) self.assertEqual( call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS, "Should have made exactly _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS fetch attempts before giving up." @@ -1140,6 +1127,11 @@ def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): "Alternating overlap/gap signals must still surface as HTTP 503 once " "the shared budget is exhausted." ) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "Sub-status must be 21015 whether the budget was exhausted by overlaps, gaps, or both.", + ) self.assertEqual( call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS, "Overlap and gap signals must share one retry budget; alternating " @@ -1190,6 +1182,11 @@ def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): ) self.assertEqual(ctx.exception.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "503 from a forced refresh that exhausts the retry budget must set sub_status to 21015.", + ) # Critical invariant: the previously-cached map must still be reachable # via the same key. A 503 from a forced refresh must never evict good @@ -1206,6 +1203,201 @@ def fake_read_pk_ranges(collection_link, options, response_hook=None, **kwargs): "Cached ETag must remain the pre-503 value (no partial overwrite)." ) + # End-to-end tests that go through SmartRoutingMapProvider.get_overlapping_ranges + # to confirm overlap/gap errors from the cache never reach the caller as + # bare ValueError or AssertionError. They surface as 503 instead. + + _OVERLAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10', 'minInclusive': '80', 'maxExclusive': 'A0'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90'}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GOOD_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90', 'parents': ['10']}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0', 'parents': ['10']}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + + @staticmethod + def _make_sequenced_pk_ranges_client(response_sequence): + """Return a mock client that returns the next payload from + response_sequence on each fresh read, and an empty page when the + If-None-Match matches the last etag (acts like a 304 reply). + """ + call_count = {'n': 0} + last_etag = {'v': None} + client = MagicMock() + + def fake_read_pk_ranges(_collection_link, _options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return iter([]) + payload = (response_sequence[call_count['n']] + if call_count['n'] < len(response_sequence) + else response_sequence[-1]) + call_count['n'] += 1 + etag = '"etag-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag + if response_hook: + response_hook(headers, None) + capture_headers = kwargs.get('_internal_response_headers_capture') + if capture_headers is not None: + capture_headers.update(headers) + return iter(payload) + + client._ReadPartitionKeyRanges = MagicMock(side_effect=fake_read_pk_ranges) + return client, call_count, last_etag + + def _assert_is_routing_map_snapshot_503(self, exc): + """Assert ``exc`` is a CosmosHttpResponseError with status 503 and + sub_status 21015, and is not a ValueError or AssertionError.""" + self.assertIsInstance(exc, CosmosHttpResponseError) + self.assertNotIsInstance(exc, AssertionError) + self.assertFalse(isinstance(exc, ValueError)) + self.assertEqual(exc.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) + self.assertEqual( + exc.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + ) + + def test_smart_get_overlapping_ranges_no_bare_value_error_on_persistent_overlap(self): + """A persistent overlap response must raise 503 from + SmartRoutingMapProvider.get_overlapping_ranges, not a ValueError.""" + client, call_count, _ = self._make_sequenced_pk_ranges_client( + [self._OVERLAP_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch('azure.cosmos._routing.routing_map_provider.time.sleep', return_value=None): + with self.assertRaises(CosmosHttpResponseError) as ctx: + provider.get_overlapping_ranges("dbs/db1/colls/coll1", [full_range]) + + self._assert_is_routing_map_snapshot_503(ctx.exception) + self.assertEqual(call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS) + + def test_smart_get_overlapping_ranges_no_bare_assertion_error_on_persistent_gap(self): + """A persistent gap response must raise 503 from + SmartRoutingMapProvider.get_overlapping_ranges, not an AssertionError.""" + client, call_count, _ = self._make_sequenced_pk_ranges_client( + [self._GAP_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch('azure.cosmos._routing.routing_map_provider.time.sleep', return_value=None): + with self.assertRaises(CosmosHttpResponseError) as ctx: + provider.get_overlapping_ranges("dbs/db1/colls/coll1", [full_range]) + + self._assert_is_routing_map_snapshot_503(ctx.exception) + self.assertEqual(call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS) + + def test_smart_get_overlapping_ranges_recovers_after_transient_overlap(self): + """One bad overlap response followed by a good one must return the + expected ranges from get_overlapping_ranges.""" + client, call_count, _ = self._make_sequenced_pk_ranges_client( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch('azure.cosmos._routing.routing_map_provider.time.sleep', return_value=None): + overlapping = provider.get_overlapping_ranges( + "dbs/db1/colls/coll1", [full_range] + ) + + self.assertEqual(call_count['n'], 2) + ids = [r['id'] for r in overlapping] + self.assertEqual(ids, ['L', '10/0', '10/1', 'R']) + + def test_smart_get_overlapping_ranges_recovers_after_transient_gap(self): + """One bad gap response followed by a good one must return the + expected ranges from get_overlapping_ranges.""" + client, call_count, _ = self._make_sequenced_pk_ranges_client( + [self._GAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch('azure.cosmos._routing.routing_map_provider.time.sleep', return_value=None): + overlapping = provider.get_overlapping_ranges( + "dbs/db1/colls/coll1", [full_range] + ) + + self.assertEqual(call_count['n'], 2) + ids = [r['id'] for r in overlapping] + self.assertEqual(ids, ['L', '10/0', '10/1', 'R']) + + def test_cache_etag_advances_to_good_response_after_overlap_recovery(self): + """After recovery, the cached ETag matches the good response, and a + second call returns the same cached object without re-fetching.""" + client, call_count, _ = self._make_sequenced_pk_ranges_client( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + cache = PartitionKeyRangeCache(client) + collection_link = "dbs/db1/colls/coll1" + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + + with patch('azure.cosmos._routing.routing_map_provider.time.sleep', return_value=None): + first = cache.get_routing_map(collection_link, feed_options={}) + + self.assertIsNotNone(first) + self.assertEqual(first.change_feed_etag, '"etag-2"') + self.assertEqual(call_count['n'], 2) + + second = cache.get_routing_map(collection_link, feed_options={}) + self.assertIs(second, first) + self.assertEqual(call_count['n'], 2) + self.assertIs(cache._collection_routing_map_by_item[collection_id], first) + + def test_concurrent_callers_see_single_recovery_not_multiple_503s(self): + """With several threads calling at the same time, only one drives the + bad-then-good recovery and the others read the recovered map from cache.""" + client, call_count, _ = self._make_sequenced_pk_ranges_client( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + n_workers = 5 + barrier = threading.Barrier(n_workers, timeout=10) + results = [None] * n_workers + errors = [] + + def worker(idx): + try: + barrier.wait() + results[idx] = provider.get_overlapping_ranges( + "dbs/db1/colls/coll1", [full_range] + ) + except Exception as e: # pylint: disable=broad-except + errors.append(e) + + with patch('azure.cosmos._routing.routing_map_provider.time.sleep', return_value=None): + threads = [threading.Thread(target=worker, args=(i,)) for i in range(n_workers)] + for t in threads: + t.start() + for t in threads: + t.join(timeout=15) + + self.assertEqual(errors, [], f"Workers must not raise; got: {errors}") + self.assertEqual(call_count['n'], 2) + for i, r in enumerate(results): + self.assertIsNotNone(r, f"Worker {i} returned None.") + first_ids = [pkr['id'] for pkr in results[0]] + self.assertEqual(first_ids, ['L', '10/0', '10/1', 'R']) + for i in range(1, n_workers): + self.assertEqual([pkr['id'] for pkr in results[i]], first_ids) + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py index 107d00bef165..4582b0b89923 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_routing_map_provider_unit_async.py @@ -15,32 +15,24 @@ from azure.cosmos._routing.aio.routing_map_provider import ( PartitionKeyRangeCache, + SmartRoutingMapProvider, ) +from azure.cosmos._routing import routing_range from azure.cosmos._routing.collection_routing_map import CollectionRoutingMap from azure.cosmos._routing._routing_map_provider_common import ( process_fetched_ranges, _IncrementalMergeFailed, _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS, ) -from azure.cosmos import http_constants +from azure.cosmos import http_constants, _base from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos._gone_retry_policy_base import _PartitionKeyRangeGoneRetryPolicyBase -# ========================================================= -# Test-only tolerant shim for evaluate_drain_page -# ========================================================= -# Production wires ``_internal_response_status_capture`` via ``_Request`` so -# ``evaluate_drain_page`` always receives a concrete HTTP status. These unit -# tests use lightweight MagicMock side_effects that bypass ``_Request`` and -# therefore leave the sidecar at ``[None]``. Rather than retrofit every mock -# to populate the sidecar, default an unknown status to ``304`` (Not Modified) -# so the drain terminates after the first page -- which is exactly the -# termination signal each existing mock relies on (data on the data path, -# ``iter([])`` on the INM-match path). -# -# This shim is the *only* test-side concession to the strict status contract -# introduced in commit a1e27a57bd; production code is unchanged. +# Test-only shim: production wires status via the request pipeline so the drain +# helper always sees a real HTTP status. The lightweight mocks in this file +# bypass that pipeline, so an unknown status defaults to 304 (Not Modified) and +# the drain terminates after the first mocked page. Production code is unchanged. # pylint: disable=wrong-import-position import azure.cosmos._routing._routing_map_provider_common as _drain_common # noqa: E402 import azure.cosmos._routing.routing_map_provider as _sync_provider_module # noqa: E402 @@ -556,14 +548,9 @@ async def async_gen(): self.assertEqual(ids, ['4', '5', '3', '1']) self.assertEqual(result.change_feed_etag, '"etag-old"') - # ========================================================================== - # Provider retry-loop behavior tests (mocked integration path). - # - # These exercise the async provider's fetch/retry loop with mocked - # ``/pkranges`` payloads: transient inconsistencies either recover on - # retry or surface as HTTP 503; ``ValueError("Ranges overlap")`` never - # leaks to callers. - # ========================================================================== + # Provider retry-loop tests: exercise the async provider's fetch and retry + # loop with mocked /pkranges payloads. Inconsistent snapshots either recover + # on retry or surface as HTTP 503; overlap errors never leak to callers. async def test_fetch_routing_map_recovers_after_transient_overlap_async(self): """An inconsistent ``/pkranges`` snapshot followed by a consistent @@ -692,6 +679,11 @@ async def _no_sleep(_seconds): "Persistent overlap must surface as HTTP 503 (transient), not as a bare ValueError " "or as a silent empty-result return." ) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "503 from a transient pkranges inconsistency must set sub_status to 21015.", + ) # We should have exhausted the full retry budget (3 attempts by default). self.assertEqual( call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS, @@ -806,6 +798,11 @@ async def _no_sleep(_seconds): await cache.get_routing_map("dbs/db1/colls/coll1", feed_options={}) self.assertEqual(ctx.exception.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "503 from a persistent gap must set sub_status to 21015.", + ) self.assertEqual(call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS) async def test_incremental_overlap_converts_to_incremental_merge_failed_async(self): @@ -910,6 +907,11 @@ async def _no_sleep(_seconds): "Alternating overlap/gap signals must still surface as HTTP 503 once " "the shared budget is exhausted." ) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "Sub-status must be 21015 whether the budget was exhausted by overlaps, gaps, or both.", + ) self.assertEqual( call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS, "Overlap and gap signals must share one retry budget; alternating " @@ -967,6 +969,11 @@ async def _no_sleep(_seconds): ) self.assertEqual(ctx.exception.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) + self.assertEqual( + ctx.exception.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + "503 from a forced refresh that exhausts the retry budget must set sub_status to 21015.", + ) # Critical invariant: the previously-cached map must still be reachable # via the same key. A 503 from a forced refresh must never evict good @@ -983,6 +990,215 @@ async def _no_sleep(_seconds): "Cached ETag must remain the pre-503 value (no partial overwrite)." ) + # End-to-end tests that go through SmartRoutingMapProvider.get_overlapping_ranges + # to confirm overlap/gap errors from the cache never reach the caller as + # bare ValueError or AssertionError. They surface as 503 instead. + + _OVERLAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10', 'minInclusive': '80', 'maxExclusive': 'A0'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90'}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GAP_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + _GOOD_PAYLOAD = [ + {'id': 'L', 'minInclusive': '', 'maxExclusive': '80'}, + {'id': '10/0', 'minInclusive': '80', 'maxExclusive': '90', 'parents': ['10']}, + {'id': '10/1', 'minInclusive': '90', 'maxExclusive': 'A0', 'parents': ['10']}, + {'id': 'R', 'minInclusive': 'A0', 'maxExclusive': 'FF'}, + ] + + @staticmethod + def _make_sequenced_pk_ranges_async_client(response_sequence): + """Return a mock async client that returns the next payload from + response_sequence on each fresh read, and an empty async generator + when the If-None-Match matches the last etag (acts like a 304 reply). + """ + call_count = {'n': 0} + last_etag = {'v': None} + client = MagicMock() + + def fake_read_pk_ranges(_collection_link, _options, response_hook=None, **kwargs): + headers_in = kwargs.get('headers') or {} + inm = headers_in.get(http_constants.HttpHeaders.IfNoneMatch) + if inm is not None and inm == last_etag['v']: + return _empty_async_gen() + payload = (response_sequence[call_count['n']] + if call_count['n'] < len(response_sequence) + else response_sequence[-1]) + call_count['n'] += 1 + etag = '"etag-{}"'.format(call_count['n']) + headers = {http_constants.HttpHeaders.ETag: etag} + last_etag['v'] = etag + if response_hook: + response_hook(headers, None) + capture_headers = kwargs.get('_internal_response_headers_capture') + if capture_headers is not None: + capture_headers.update(headers) + + async def async_gen(): + for r in payload: + yield r + + return async_gen() + + client._ReadPartitionKeyRanges = MagicMock(side_effect=fake_read_pk_ranges) + return client, call_count, last_etag + + @staticmethod + async def _no_sleep(_seconds): + return None + + def _assert_is_routing_map_snapshot_503(self, exc): + """Assert ``exc`` is a CosmosHttpResponseError with status 503 and + sub_status 21015, and is not a ValueError or AssertionError.""" + self.assertIsInstance(exc, CosmosHttpResponseError) + self.assertNotIsInstance(exc, AssertionError) + self.assertFalse(isinstance(exc, ValueError)) + self.assertEqual(exc.status_code, http_constants.StatusCodes.SERVICE_UNAVAILABLE) + self.assertEqual( + exc.sub_status, + http_constants.SubStatusCodes.ROUTING_MAP_SNAPSHOT_INCONSISTENT, + ) + + async def test_smart_get_overlapping_ranges_no_bare_value_error_on_persistent_overlap_async(self): + """A persistent overlap response must raise 503 from + SmartRoutingMapProvider.get_overlapping_ranges, not a ValueError.""" + client, call_count, _ = self._make_sequenced_pk_ranges_async_client( + [self._OVERLAP_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + with self.assertRaises(CosmosHttpResponseError) as ctx: + await provider.get_overlapping_ranges("dbs/db1/colls/coll1", [full_range]) + + self._assert_is_routing_map_snapshot_503(ctx.exception) + self.assertEqual(call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS) + + async def test_smart_get_overlapping_ranges_no_bare_assertion_error_on_persistent_gap_async(self): + """A persistent gap response must raise 503 from + SmartRoutingMapProvider.get_overlapping_ranges, not an AssertionError.""" + client, call_count, _ = self._make_sequenced_pk_ranges_async_client( + [self._GAP_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + with self.assertRaises(CosmosHttpResponseError) as ctx: + await provider.get_overlapping_ranges("dbs/db1/colls/coll1", [full_range]) + + self._assert_is_routing_map_snapshot_503(ctx.exception) + self.assertEqual(call_count['n'], _TRANSIENT_SNAPSHOT_RETRY_MAX_ATTEMPTS) + + async def test_smart_get_overlapping_ranges_recovers_after_transient_overlap_async(self): + """One bad overlap response followed by a good one must return the + expected ranges from get_overlapping_ranges.""" + client, call_count, _ = self._make_sequenced_pk_ranges_async_client( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + overlapping = await provider.get_overlapping_ranges( + "dbs/db1/colls/coll1", [full_range] + ) + + self.assertEqual(call_count['n'], 2) + ids = [r['id'] for r in overlapping] + self.assertEqual(ids, ['L', '10/0', '10/1', 'R']) + + async def test_smart_get_overlapping_ranges_recovers_after_transient_gap_async(self): + """One bad gap response followed by a good one must return the + expected ranges from get_overlapping_ranges.""" + client, call_count, _ = self._make_sequenced_pk_ranges_async_client( + [self._GAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + overlapping = await provider.get_overlapping_ranges( + "dbs/db1/colls/coll1", [full_range] + ) + + self.assertEqual(call_count['n'], 2) + ids = [r['id'] for r in overlapping] + self.assertEqual(ids, ['L', '10/0', '10/1', 'R']) + + async def test_cache_etag_advances_to_good_response_after_overlap_recovery_async(self): + """After recovery, the cached ETag matches the good response, and a + second call returns the same cached object without re-fetching.""" + client, call_count, _ = self._make_sequenced_pk_ranges_async_client( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + cache = PartitionKeyRangeCache(client) + collection_link = "dbs/db1/colls/coll1" + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + first = await cache.get_routing_map(collection_link, feed_options={}) + + self.assertIsNotNone(first) + self.assertEqual(first.change_feed_etag, '"etag-2"') + self.assertEqual(call_count['n'], 2) + + second = await cache.get_routing_map(collection_link, feed_options={}) + self.assertIs(second, first) + self.assertEqual(call_count['n'], 2) + self.assertIs(cache._collection_routing_map_by_item[collection_id], first) + + async def test_concurrent_callers_see_single_recovery_not_multiple_503s_async(self): + """With several coroutines calling at the same time, only one drives the + bad-then-good recovery and the others read the recovered map from cache.""" + client, call_count, _ = self._make_sequenced_pk_ranges_async_client( + [self._OVERLAP_PAYLOAD, self._GOOD_PAYLOAD] + ) + provider = SmartRoutingMapProvider(client) + full_range = routing_range.Range("", "FF", True, False) + n_workers = 5 + + async def worker(): + return await provider.get_overlapping_ranges( + "dbs/db1/colls/coll1", [full_range] + ) + + with patch( + 'azure.cosmos._routing.aio.routing_map_provider.asyncio.sleep', + new=self._no_sleep, + ): + results = await asyncio.gather(*[worker() for _ in range(n_workers)]) + + self.assertEqual(call_count['n'], 2) + for i, r in enumerate(results): + self.assertIsNotNone(r, f"Worker {i} returned None.") + first_ids = [pkr['id'] for pkr in results[0]] + self.assertEqual(first_ids, ['L', '10/0', '10/1', 'R']) + for i in range(1, n_workers): + self.assertEqual([pkr['id'] for pkr in results[i]], first_ids) + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py index e0b63ba30dcd..29bc93137af5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_semantic_reranker_unit.py @@ -3,15 +3,25 @@ # cspell:ignore rerank reranker reranking """Unit tests for semantic reranker inference service timeout policy.""" import asyncio +import json import os +import threading import unittest from unittest.mock import MagicMock, patch -from azure.core.exceptions import ServiceRequestError, ServiceResponseError +from azure.core.exceptions import DecodeError, ServiceRequestError, ServiceResponseError import azure.cosmos.exceptions as exceptions +from azure.cosmos._cosmos_client_connection import CosmosClientConnection as _SyncCosmosClientConnection +from azure.cosmos._inference_service import _InferenceService as _SyncInferenceService +from azure.cosmos.aio._cosmos_client_connection_async import ( + CosmosClientConnection as _AsyncCosmosClientConnection, +) +from azure.cosmos.aio._inference_service_async import _InferenceService as _AsyncInferenceService +from azure.cosmos.documents import ConnectionPolicy _INFERENCE_ENDPOINT_ENV_VAR = "AZURE_COSMOS_SEMANTIC_RERANKER_INFERENCE_ENDPOINT" +_MALFORMED_INPUT_ENV_VAR = "AZURE_COSMOS_CHARSET_DECODER_ERROR_ACTION_ON_MALFORMED_INPUT" class TestInferenceServiceTimeout(unittest.TestCase): @@ -42,10 +52,8 @@ def _create_mock_connection(self, inference_request_timeout=5): def test_sync_inference_timeout_raises_408(self): """Test that sync inference service converts ServiceRequestError to 408.""" - from azure.cosmos._inference_service import _InferenceService - mock_connection = self._create_mock_connection() - service = _InferenceService(mock_connection) + service = _SyncInferenceService(mock_connection) with patch.object( service._inference_pipeline_client._pipeline, "run", @@ -62,11 +70,9 @@ def test_sync_inference_timeout_raises_408(self): def test_async_inference_timeout_raises_408(self): """Test that async inference service converts ServiceRequestError to 408.""" async def run_test(): - from azure.cosmos.aio._inference_service_async import _InferenceService - mock_connection = self._create_mock_connection() mock_connection.connection_policy.DisableSSLVerification = False - service = _InferenceService(mock_connection) + service = _AsyncInferenceService(mock_connection) with patch.object( service._inference_pipeline_client._pipeline, "run", @@ -84,29 +90,23 @@ async def run_test(): def test_sync_inference_timeout_value_from_connection_policy(self): """Test that sync inference service reads timeout from connection policy.""" - from azure.cosmos._inference_service import _InferenceService - mock_connection = self._create_mock_connection(inference_request_timeout=10) - service = _InferenceService(mock_connection) + service = _SyncInferenceService(mock_connection) self.assertEqual(service._inference_request_timeout, 10) def test_async_inference_timeout_value_from_connection_policy(self): """Test that async inference service reads timeout from connection policy.""" - from azure.cosmos.aio._inference_service_async import _InferenceService - mock_connection = self._create_mock_connection(inference_request_timeout=15) mock_connection.connection_policy.DisableSSLVerification = False - service = _InferenceService(mock_connection) + service = _AsyncInferenceService(mock_connection) self.assertEqual(service._inference_request_timeout, 15) def test_sync_inference_passes_timeout_to_pipeline(self): """Test that sync inference service passes timeout kwargs to pipeline.run().""" - from azure.cosmos._inference_service import _InferenceService - mock_connection = self._create_mock_connection(inference_request_timeout=7) - service = _InferenceService(mock_connection) + service = _SyncInferenceService(mock_connection) mock_response = MagicMock() mock_response.http_response.status_code = 200 @@ -129,11 +129,9 @@ def test_sync_inference_passes_timeout_to_pipeline(self): def test_async_inference_passes_timeout_to_pipeline(self): """Test that async inference service passes timeout kwargs to pipeline.run().""" async def run_test(): - from azure.cosmos.aio._inference_service_async import _InferenceService - mock_connection = self._create_mock_connection(inference_request_timeout=12) mock_connection.connection_policy.DisableSSLVerification = False - service = _InferenceService(mock_connection) + service = _AsyncInferenceService(mock_connection) mock_response = MagicMock() mock_response.http_response.status_code = 200 @@ -156,13 +154,10 @@ async def run_test(): asyncio.run(run_test()) def test_sync_inference_uses_shared_response_decoder(self): - """Test that sync inference service decodes response bytes via the - shared decode_response_body_for_status helper. Locks in the wiring - so a regression that reverts to inline data.decode("utf-8") would fail.""" - from azure.cosmos._inference_service import _InferenceService - + """Checks the sync inference service decodes response bytes + through the shared decode helper, not an inline decode call.""" mock_connection = self._create_mock_connection() - service = _InferenceService(mock_connection) + service = _SyncInferenceService(mock_connection) raw_response_data = b'{"Scores": []}' mock_response = MagicMock() @@ -184,14 +179,11 @@ def test_sync_inference_uses_shared_response_decoder(self): mock_decode.assert_called_once_with(raw_response_data, 200, "inference_request") def test_async_inference_uses_shared_response_decoder(self): - """Test that async inference service decodes response bytes via the - shared decode_response_body_for_status helper.""" + """Async version of the wiring check above.""" async def run_test(): - from azure.cosmos.aio._inference_service_async import _InferenceService - mock_connection = self._create_mock_connection() mock_connection.connection_policy.DisableSSLVerification = False - service = _InferenceService(mock_connection) + service = _AsyncInferenceService(mock_connection) raw_response_data = b'{"Scores": []}' mock_response = MagicMock() @@ -215,19 +207,11 @@ async def run_test(): asyncio.run(run_test()) def test_sync_inference_2xx_with_invalid_utf8_raises_decode_error(self): - """A successful (2xx) inference response carrying invalid UTF-8 - in default strict mode must surface as - ``azure.core.exceptions.DecodeError`` with the original wire - status (200) preserved on ``e.response``, not as a stdlib - ``UnicodeDecodeError``. Mirrors the contract enforced for the - core request paths in ``test_request_response_decoding``.""" - from azure.core.exceptions import DecodeError - from azure.cosmos._inference_service import _InferenceService - + """A 200 inference response with invalid bytes should come + out as DecodeError, keeping the real status and original cause.""" mock_connection = self._create_mock_connection() - service = _InferenceService(mock_connection) + service = _SyncInferenceService(mock_connection) - # Same textbook-invalid UTF-8 used in the core decode tests. invalid_utf8 = b'{"Scores": "caf\xc3\x28"}' mock_response = MagicMock() mock_response.http_response.status_code = 200 @@ -248,15 +232,11 @@ def test_sync_inference_2xx_with_invalid_utf8_raises_decode_error(self): self.assertIsInstance(ctx.exception.__cause__, UnicodeDecodeError) def test_async_inference_2xx_with_invalid_utf8_raises_decode_error(self): - """Async counterpart to the sync 2xx-malformed-UTF-8 test.""" - from azure.core.exceptions import DecodeError - + """Async version of the 2xx invalid-bytes check above.""" async def run_test(): - from azure.cosmos.aio._inference_service_async import _InferenceService - mock_connection = self._create_mock_connection() mock_connection.connection_policy.DisableSSLVerification = False - service = _InferenceService(mock_connection) + service = _AsyncInferenceService(mock_connection) invalid_utf8 = b'{"Scores": "caf\xc3\x28"}' mock_response = MagicMock() @@ -279,12 +259,160 @@ async def run_test(): asyncio.run(run_test()) + # Tests below check that the permissive decode env var also works + # through the inference service path, not just the regular request + # path. Without these, a regression that only breaks one path + # would not be caught. + + def test_sync_inference_replace_env_var_lets_2xx_with_invalid_utf8_succeed(self): + """With REPLACE set, a 200 response containing invalid UTF-8 + in a string value should parse successfully and not raise.""" + mock_connection = self._create_mock_connection() + service = _SyncInferenceService(mock_connection) + + # Valid JSON envelope; the bad byte sits inside a string value. + invalid_utf8 = b'{"Scores":[{"index":0,"score":0.5,"label":"caf\xc3\x28"}]}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = invalid_utf8 + + saved = os.environ.get(_MALFORMED_INPUT_ENV_VAR) + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + try: + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response, + ): + result = service.rerank( + reranking_context="test query", + documents=["doc1"], + ) + # The replacement character ends up inside the string value; + # surrounding text is preserved. + self.assertIn("Scores", result) + self.assertEqual(len(result["Scores"]), 1) + self.assertIn("\ufffd", result["Scores"][0]["label"]) + self.assertIn("caf", result["Scores"][0]["label"]) + finally: + if saved is None: + os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) + else: + os.environ[_MALFORMED_INPUT_ENV_VAR] = saved + # Confirm the result is a parsed object, not a raw string. + self.assertNotIsInstance(result, str) + _ = json.dumps(result) + + def test_async_inference_replace_env_var_lets_2xx_with_invalid_utf8_succeed(self): + """Async version of the REPLACE check above.""" + async def run_test(): + mock_connection = self._create_mock_connection() + mock_connection.connection_policy.DisableSSLVerification = False + service = _AsyncInferenceService(mock_connection) + + invalid_utf8 = b'{"Scores":[{"index":0,"score":0.5,"label":"caf\xc3\x28"}]}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = invalid_utf8 + + saved = os.environ.get(_MALFORMED_INPUT_ENV_VAR) + os.environ[_MALFORMED_INPUT_ENV_VAR] = "REPLACE" + try: + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response, + ): + result = await service.rerank( + reranking_context="test query", + documents=["doc1"], + ) + finally: + if saved is None: + os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) + else: + os.environ[_MALFORMED_INPUT_ENV_VAR] = saved + + self.assertIn("Scores", result) + self.assertEqual(len(result["Scores"]), 1) + self.assertIn("\ufffd", result["Scores"][0]["label"]) + self.assertIn("caf", result["Scores"][0]["label"]) + + asyncio.run(run_test()) + + def test_sync_inference_ignore_env_var_lets_2xx_with_invalid_utf8_succeed(self): + """With IGNORE set, the bad byte is dropped and the response + still parses cleanly.""" + mock_connection = self._create_mock_connection() + service = _SyncInferenceService(mock_connection) + + invalid_utf8 = b'{"Scores":[{"index":0,"score":0.5,"label":"caf\xc3\x28"}]}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = invalid_utf8 + + saved = os.environ.get(_MALFORMED_INPUT_ENV_VAR) + os.environ[_MALFORMED_INPUT_ENV_VAR] = "IGNORE" + try: + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response, + ): + result = service.rerank( + reranking_context="test query", + documents=["doc1"], + ) + finally: + if saved is None: + os.environ.pop(_MALFORMED_INPUT_ENV_VAR, None) + else: + os.environ[_MALFORMED_INPUT_ENV_VAR] = saved + + self.assertIn("Scores", result) + self.assertEqual(len(result["Scores"]), 1) + # IGNORE drops the bad byte instead of replacing it. + self.assertNotIn("\ufffd", result["Scores"][0]["label"]) + self.assertIn("caf", result["Scores"][0]["label"]) + + def test_async_inference_ignore_env_var_lets_2xx_with_invalid_utf8_succeed(self): + """Async equivalent of the sync IGNORE test above: with IGNORE set, the + bad byte is dropped from the async inference response and parsing succeeds.""" + async def run_test(): + mock_connection = self._create_mock_connection() + mock_connection.connection_policy.DisableSSLVerification = False + service = _AsyncInferenceService(mock_connection) + + invalid_utf8 = b'{"Scores":[{"index":0,"score":0.5,"label":"caf\xc3\x28"}]}' + mock_response = MagicMock() + mock_response.http_response.status_code = 200 + mock_response.http_response.headers = {} + mock_response.http_response.body.return_value = invalid_utf8 + + # patch.dict scopes the env var to this test and restores the prior + # value when the block exits, even if rerank below raises. + with patch.dict(os.environ, {_MALFORMED_INPUT_ENV_VAR: "IGNORE"}): + with patch.object( + service._inference_pipeline_client._pipeline, "run", + return_value=mock_response, + ): + result = await service.rerank( + reranking_context="test query", + documents=["doc1"], + ) + + self.assertIn("Scores", result) + self.assertEqual(len(result["Scores"]), 1) + # IGNORE drops the bad byte instead of replacing it. + self.assertNotIn("\ufffd", result["Scores"][0]["label"]) + self.assertIn("caf", result["Scores"][0]["label"]) + + asyncio.run(run_test()) + def test_sync_inference_response_timeout_raises_408(self): """Test that sync inference service converts ServiceResponseError to 408.""" - from azure.cosmos._inference_service import _InferenceService - mock_connection = self._create_mock_connection() - service = _InferenceService(mock_connection) + service = _SyncInferenceService(mock_connection) with patch.object( service._inference_pipeline_client._pipeline, "run", @@ -301,11 +429,9 @@ def test_sync_inference_response_timeout_raises_408(self): def test_async_inference_response_timeout_raises_408(self): """Test that async inference service converts ServiceResponseError to 408.""" async def run_test(): - from azure.cosmos.aio._inference_service_async import _InferenceService - mock_connection = self._create_mock_connection() mock_connection.connection_policy.DisableSSLVerification = False - service = _InferenceService(mock_connection) + service = _AsyncInferenceService(mock_connection) with patch.object( service._inference_pipeline_client._pipeline, "run", @@ -323,15 +449,11 @@ async def run_test(): def test_connection_policy_default_inference_timeout(self): """Test that ConnectionPolicy defaults InferenceRequestTimeout to 5 seconds.""" - from azure.cosmos.documents import ConnectionPolicy - policy = ConnectionPolicy() self.assertEqual(policy.InferenceRequestTimeout, 5) def test_connection_policy_custom_inference_timeout(self): """Test that ConnectionPolicy InferenceRequestTimeout can be set.""" - from azure.cosmos.documents import ConnectionPolicy - policy = ConnectionPolicy() policy.InferenceRequestTimeout = 30 self.assertEqual(policy.InferenceRequestTimeout, 30) @@ -339,41 +461,33 @@ def test_connection_policy_custom_inference_timeout(self): def test_sync_lazy_init_raises_error_without_env_var(self): """Test that _InferenceService raises ValueError when env var is missing. With lazy init, this error is deferred from client construction to first use.""" - from azure.cosmos._inference_service import _InferenceService - os.environ.pop(_INFERENCE_ENDPOINT_ENV_VAR, None) mock_connection = self._create_mock_connection() with self.assertRaises(ValueError) as ctx: - _InferenceService(mock_connection) + _SyncInferenceService(mock_connection) self.assertIn(_INFERENCE_ENDPOINT_ENV_VAR, str(ctx.exception)) def test_async_lazy_init_raises_error_without_env_var(self): """Test that async _InferenceService raises ValueError when env var is missing. With lazy init, this error is deferred from client construction to first use.""" - from azure.cosmos.aio._inference_service_async import _InferenceService - os.environ.pop(_INFERENCE_ENDPOINT_ENV_VAR, None) mock_connection = self._create_mock_connection() mock_connection.connection_policy.DisableSSLVerification = False with self.assertRaises(ValueError) as ctx: - _InferenceService(mock_connection) + _AsyncInferenceService(mock_connection) self.assertIn(_INFERENCE_ENDPOINT_ENV_VAR, str(ctx.exception)) def test_sync_inference_service_created_with_env_var(self): """Test that sync _InferenceService can be created when env var is set.""" - from azure.cosmos._inference_service import _InferenceService - mock_connection = self._create_mock_connection() - service = _InferenceService(mock_connection) + service = _SyncInferenceService(mock_connection) self.assertIsNotNone(service) def test_async_inference_service_created_with_env_var(self): """Test that async _InferenceService can be created when env var is set.""" - from azure.cosmos.aio._inference_service_async import _InferenceService - mock_connection = self._create_mock_connection() mock_connection.connection_policy.DisableSSLVerification = False - service = _InferenceService(mock_connection) + service = _AsyncInferenceService(mock_connection) self.assertIsNotNone(service) # ── _get_inference_service() direct call tests ── @@ -381,111 +495,91 @@ def test_async_inference_service_created_with_env_var(self): def test_sync_get_inference_service_returns_service_with_aad_and_env_var(self): """Test that _get_inference_service() returns an _InferenceService when AAD credentials are present and the env var is set.""" - from azure.cosmos._cosmos_client_connection import CosmosClientConnection - import threading - mock_conn = self._create_mock_connection() mock_conn._inference_service = None mock_conn._inference_service_lock = threading.Lock() - result = CosmosClientConnection._get_inference_service(mock_conn) + result = _SyncCosmosClientConnection._get_inference_service(mock_conn) self.assertIsNotNone(result) def test_async_get_inference_service_returns_service_with_aad_and_env_var(self): """Test that async _get_inference_service() returns an _InferenceService when AAD credentials are present and the env var is set.""" - from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection - mock_conn = self._create_mock_connection() mock_conn._inference_service = None mock_conn.connection_policy.DisableSSLVerification = False - result = CosmosClientConnection._get_inference_service(mock_conn) + result = _AsyncCosmosClientConnection._get_inference_service(mock_conn) self.assertIsNotNone(result) def test_sync_get_inference_service_raises_error_without_env_var(self): """Test that _get_inference_service() wraps ValueError into CosmosHttpResponseError when the env var is missing.""" - from azure.cosmos._cosmos_client_connection import CosmosClientConnection - import threading - os.environ.pop(_INFERENCE_ENDPOINT_ENV_VAR, None) mock_conn = self._create_mock_connection() mock_conn._inference_service = None mock_conn._inference_service_lock = threading.Lock() with self.assertRaises(exceptions.CosmosHttpResponseError) as ctx: - CosmosClientConnection._get_inference_service(mock_conn) + _SyncCosmosClientConnection._get_inference_service(mock_conn) self.assertEqual(ctx.exception.status_code, 400) self.assertIn("Failed to initialize inference service", str(ctx.exception)) def test_async_get_inference_service_raises_error_without_env_var(self): """Test that async _get_inference_service() wraps ValueError into CosmosHttpResponseError when the env var is missing.""" - from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection - os.environ.pop(_INFERENCE_ENDPOINT_ENV_VAR, None) mock_conn = self._create_mock_connection() mock_conn._inference_service = None mock_conn.connection_policy.DisableSSLVerification = False with self.assertRaises(exceptions.CosmosHttpResponseError) as ctx: - CosmosClientConnection._get_inference_service(mock_conn) + _AsyncCosmosClientConnection._get_inference_service(mock_conn) self.assertEqual(ctx.exception.status_code, 400) self.assertIn("Failed to initialize inference service", str(ctx.exception)) def test_sync_get_inference_service_returns_none_without_aad(self): """Test that _get_inference_service() returns None when no AAD credentials are present (master key auth).""" - from azure.cosmos._cosmos_client_connection import CosmosClientConnection - import threading - mock_conn = self._create_mock_connection() mock_conn.aad_credentials = None mock_conn._inference_service = None mock_conn._inference_service_lock = threading.Lock() - result = CosmosClientConnection._get_inference_service(mock_conn) + result = _SyncCosmosClientConnection._get_inference_service(mock_conn) self.assertIsNone(result) def test_async_get_inference_service_returns_none_without_aad(self): """Test that async _get_inference_service() returns None when no AAD credentials are present (master key auth).""" - from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection - mock_conn = self._create_mock_connection() mock_conn.aad_credentials = None mock_conn._inference_service = None - result = CosmosClientConnection._get_inference_service(mock_conn) + result = _AsyncCosmosClientConnection._get_inference_service(mock_conn) self.assertIsNone(result) def test_sync_get_inference_service_caches_instance(self): """Test that _get_inference_service() returns the same cached instance on repeated calls.""" - from azure.cosmos._cosmos_client_connection import CosmosClientConnection - import threading - mock_conn = self._create_mock_connection() mock_conn._inference_service = None mock_conn._inference_service_lock = threading.Lock() - first = CosmosClientConnection._get_inference_service(mock_conn) - second = CosmosClientConnection._get_inference_service(mock_conn) + first = _SyncCosmosClientConnection._get_inference_service(mock_conn) + second = _SyncCosmosClientConnection._get_inference_service(mock_conn) self.assertIsNotNone(first) self.assertIs(first, second) def test_async_get_inference_service_caches_instance(self): """Test that async _get_inference_service() returns the same cached instance on repeated calls.""" - from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection - mock_conn = self._create_mock_connection() mock_conn._inference_service = None mock_conn.connection_policy.DisableSSLVerification = False - first = CosmosClientConnection._get_inference_service(mock_conn) - second = CosmosClientConnection._get_inference_service(mock_conn) + first = _AsyncCosmosClientConnection._get_inference_service(mock_conn) + second = _AsyncCosmosClientConnection._get_inference_service(mock_conn) self.assertIsNotNone(first) self.assertIs(first, second) diff --git a/sdk/cosmos/azure-cosmos/tests/test_session.py b/sdk/cosmos/azure-cosmos/tests/test_session.py index d88cc108e4a5..367beed0713b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session.py @@ -266,6 +266,156 @@ def capture_session_token(request): finally: self.key_db.delete_container(test_container_ref) # control-plane + # The three tests below check that a query targeted at a single partition + # sends only that partition's session token on every request, including + # each page and the partition_key entry point. + + def test_feed_range_query_session_token_matches_partition_lsn(self): + # The token sent for a single feed-range query must be the per-partition + # token for that partition, not a comma-joined token across partitions. + test_container_ref = self.key_db.create_container( + "Container fr token value " + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + test_container = self.created_db.get_container_client(test_container_ref.id) + try: + for i in range(60): + test_container.create_item({"id": str(uuid.uuid4()), "pk": f"pk_{i:04d}"}) + + feed_ranges = list(test_container.read_feed_ranges()) + self.assertGreater(len(feed_ranges), 1, "Expected multiple feed ranges") + + captured = {"token": None} + + def capture(request): + captured["token"] = request.http_request.headers.get(HttpHeaders.SessionToken) + + list(test_container.query_items( + query="SELECT * FROM c", + feed_range=feed_ranges[0], + raw_request_hook=capture, + )) + + wire_token = captured["token"] + self.assertIsNotNone(wire_token, "Expected a SessionToken header to be sent.") + self.assertNotIn(",", wire_token, f"Got compound token on the wire: {wire_token!r}") + + # Token should be a single 'partitionRangeId:vector' pair. + self.assertIn(":", wire_token, f"Expected single-partition token shape: {wire_token!r}") + pk_range_id, _, vector_part = wire_token.partition(":") + self.assertNotEqual( + pk_range_id, "", f"Partition range id prefix should not be empty: {wire_token!r}" + ) + self.assertTrue(vector_part, f"Vector portion should not be empty: {wire_token!r}") + + finally: + self.key_db.delete_container(test_container_ref) + + def test_feed_range_query_session_token_single_partition_across_pages(self): + # Every page of a paginated single-partition query must send a single + # partition token, not a compound one. + test_container_ref = self.key_db.create_container( + "Container fr pagination " + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + test_container = self.created_db.get_container_client(test_container_ref.id) + try: + # Build multi-partition session state first so this test can catch + # regressions that accidentally send a compound token. + for i in range(60): + test_container.create_item({"id": str(uuid.uuid4()), "pk": f"pk_{i:04d}"}) + + # Then concentrate items on one key so the query spans several + # pages within one target physical partition. + for i in range(20): + test_container.create_item({ + "id": str(uuid.uuid4()), + "pk": "pinned_pk", + "i": i, + }) + + feed_ranges = list(test_container.read_feed_ranges()) + self.assertGreater(len(feed_ranges), 1, "Expected multiple feed ranges") + + target_fr = test_container.feed_range_from_partition_key("pinned_pk") + + captured_tokens = [] + + def capture(request): + token = request.http_request.headers.get(HttpHeaders.SessionToken) + if token: + captured_tokens.append(token) + + pager = test_container.query_items( + query="SELECT * FROM c", + feed_range=target_fr, + max_item_count=2, + raw_request_hook=capture, + ).by_page() + page_count = 0 + for page in pager: + _ = list(page) + page_count += 1 + + self.assertGreaterEqual(page_count, 1) + self.assertGreater(len(captured_tokens), 0, + "Expected at least one captured SessionToken header across pages.") + # Every page must carry a single-partition token, and all tokens + # must reference the same partition range id. + prefixes = set() + for token in captured_tokens: + self.assertNotIn( + ",", token, + f"Compound token leaked on a paginated request: {token!r}", + ) + self.assertIn(":", token, f"Bad token shape on paginated request: {token!r}") + prefixes.add(token.split(":", 1)[0]) + self.assertEqual( + len(prefixes), 1, + f"Single-partition query produced tokens for multiple partition range ids: {prefixes}", + ) + finally: + self.key_db.delete_container(test_container_ref) + + def test_query_with_partition_key_only_sends_single_partition_token(self): + # A query scoped by partition_key (not feed_range) must also send only + # the relevant partition's token. + test_container_ref = self.key_db.create_container( + "Container pk-only session " + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + test_container = self.created_db.get_container_client(test_container_ref.id) + try: + for i in range(60): + test_container.create_item({"id": str(uuid.uuid4()), "pk": f"pk_{i:04d}"}) + + captured_tokens = [] + + def capture(request): + token = request.http_request.headers.get(HttpHeaders.SessionToken) + if token: + captured_tokens.append(token) + + list(test_container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "pk_0001"}], + partition_key="pk_0001", + raw_request_hook=capture, + )) + + self.assertGreater(len(captured_tokens), 0, + "Expected at least one request with a SessionToken header.") + for token in captured_tokens: + self.assertNotIn( + ",", token, + f"partitionKey-only query leaked a compound token: {token!r}", + ) + finally: + self.key_db.delete_container(test_container_ref) + def test_session_token_with_space_in_container_name(self): # Session token should not be sent for control plane operations diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_async.py b/sdk/cosmos/azure-cosmos/tests/test_session_async.py index 003bb80fb5cc..cf18fb01ac8e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session_async.py @@ -200,6 +200,143 @@ def capture_session_token(request): finally: await self.key_db.delete_container(test_container_ref) # control-plane + # The three async tests below check that a query targeted at a single + # partition sends only that partition's session token on every request, + # including each page and the partition_key entry point. + + async def test_feed_range_query_session_token_matches_partition_lsn_async(self): + # Async twin: a single feed-range query must send only that partition's token. + test_container_ref = await self.key_db.create_container( + "Container fr token value async " + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + test_container = self.created_db.get_container_client(test_container_ref.id) + try: + for i in range(60): + await test_container.create_item({"id": str(uuid.uuid4()), "pk": f"pk_{i:04d}"}) + + feed_ranges = [fr async for fr in test_container.read_feed_ranges()] + self.assertGreater(len(feed_ranges), 1) + + captured = {"token": None} + + def capture(request): + captured["token"] = request.http_request.headers.get(HttpHeaders.SessionToken) + + _ = [item async for item in test_container.query_items( + query="SELECT * FROM c", + feed_range=feed_ranges[0], + raw_request_hook=capture, + )] + + wire_token = captured["token"] + self.assertIsNotNone(wire_token) + self.assertNotIn(",", wire_token, f"Got compound token on the wire: {wire_token!r}") + self.assertIn(":", wire_token, f"Expected single-partition token shape: {wire_token!r}") + pk_range_id, _, vector_part = wire_token.partition(":") + self.assertNotEqual(pk_range_id, "", f"Partition range id should not be empty: {wire_token!r}") + self.assertTrue(vector_part, f"Vector portion should not be empty: {wire_token!r}") + finally: + await self.key_db.delete_container(test_container_ref) + + async def test_feed_range_query_session_token_single_partition_across_pages_async(self): + # Async twin: every page must send a single-partition token. + test_container_ref = await self.key_db.create_container( + "Container fr pagination async " + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + test_container = self.created_db.get_container_client(test_container_ref.id) + try: + # Build multi-partition session state first so this test can catch + # regressions that accidentally send a compound token. + for i in range(60): + await test_container.create_item({"id": str(uuid.uuid4()), "pk": f"pk_{i:04d}"}) + + for i in range(20): + await test_container.create_item({ + "id": str(uuid.uuid4()), + "pk": "pinned_pk", + "i": i, + }) + + feed_ranges = [fr async for fr in test_container.read_feed_ranges()] + self.assertGreater(len(feed_ranges), 1, "Expected multiple feed ranges") + + target_fr = await test_container.feed_range_from_partition_key("pinned_pk") + + captured_tokens = [] + + def capture(request): + token = request.http_request.headers.get(HttpHeaders.SessionToken) + if token: + captured_tokens.append(token) + + pager = test_container.query_items( + query="SELECT * FROM c", + feed_range=target_fr, + max_item_count=2, + raw_request_hook=capture, + ).by_page() + page_count = 0 + async for page in pager: + _ = [item async for item in page] + page_count += 1 + + self.assertGreaterEqual(page_count, 1) + self.assertGreater(len(captured_tokens), 0) + prefixes = set() + for token in captured_tokens: + self.assertNotIn( + ",", token, + f"Compound token leaked on a paginated request: {token!r}", + ) + self.assertIn(":", token, f"Bad token shape on paginated request: {token!r}") + prefixes.add(token.split(":", 1)[0]) + self.assertEqual( + len(prefixes), 1, + f"Single-partition query produced tokens for multiple partition range ids: {prefixes}", + ) + finally: + await self.key_db.delete_container(test_container_ref) + + async def test_query_with_partition_key_only_sends_single_partition_token_async(self): + # Async twin: a query scoped by partition_key must also send only + # the relevant partition's token. + test_container_ref = await self.key_db.create_container( + "Container pk-only session async " + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000, + ) + test_container = self.created_db.get_container_client(test_container_ref.id) + try: + for i in range(60): + await test_container.create_item({"id": str(uuid.uuid4()), "pk": f"pk_{i:04d}"}) + + captured_tokens = [] + + def capture(request): + token = request.http_request.headers.get(HttpHeaders.SessionToken) + if token: + captured_tokens.append(token) + + _ = [item async for item in test_container.query_items( + query="SELECT * FROM c WHERE c.pk = @pk", + parameters=[{"name": "@pk", "value": "pk_0001"}], + partition_key="pk_0001", + raw_request_hook=capture, + )] + + self.assertGreater(len(captured_tokens), 0) + for token in captured_tokens: + self.assertNotIn( + ",", token, + f"partitionKey-only query leaked a compound token: {token!r}", + ) + finally: + await self.key_db.delete_container(test_container_ref) + async def test_manual_session_token_override_async(self): # Create an item to get a valid session token from the response created_document = await self.created_container.create_item( diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_token_unit.py b/sdk/cosmos/azure-cosmos/tests/test_session_token_unit.py index d16c6d0a4395..be368ef03321 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session_token_unit.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session_token_unit.py @@ -8,9 +8,13 @@ import pytest -from azure.cosmos import _session, http_constants +from azure.cosmos import _session, documents, http_constants +from azure.cosmos._base import set_session_token_header, set_session_token_header_async +from azure.cosmos._request_object import RequestObject from azure.cosmos._vector_session_token import VectorSessionToken +from azure.cosmos.documents import _OperationType from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.cosmos.http_constants import HttpHeaders, ResourceType class _DummyCollectionRanges: @@ -256,7 +260,6 @@ def validate_different_session_token_false_progress_merge_scenarios(self, false_ unittest.main() - class TestResolvePartitionLocalSessionTokenRegression(unittest.TestCase): """Regression tests for ``_resolve_partition_local_session_token``. @@ -317,3 +320,152 @@ def __init__(self, t): (pkr,), token_dict={"child": _Wrap(token)}) self.assertEqual(result, token.session_token) + +# Unit tests for set_session_token_header. When a request targets a single +# partition, the helper must send only that partition's token, not a +# comma-joined token covering every cached partition. + +class _SessionTokenGemStub: + """Stub for the global endpoint manager used by session token helpers.""" + + def can_use_multiple_write_locations(self, request): # noqa: D401 + return False + + +class _SessionTokenClientStub: + """Minimal client connection stub for session token header tests.""" + + def __init__(self, collection_link, collection_rid, partition_tokens): + self.session = _session.Session("https://stub.documents.azure.com") + # Seed per-partition tokens that a real client would populate after writes. + self.session.session_container.collection_name_to_rid[collection_link] = collection_rid + self.session.session_container.rid_to_session_token[collection_rid] = { + pk_range_id: VectorSessionToken.create(tok) for pk_range_id, tok in partition_tokens.items() + } + self._container_properties_cache = { + collection_link: {"_rid": collection_rid}, + } + self._routing_map_provider = _DummyRoutingMapProvider(collection_link) + self._global_endpoint_manager = _SessionTokenGemStub() + + +def _build_session_request(collection_link): + """Build a session-consistency document read request for the given path.""" + request_object = RequestObject(ResourceType.Document, _OperationType.Read, None) + headers = {HttpHeaders.ConsistencyLevel: documents.ConsistencyLevel.Session} + return request_object, headers, collection_link + + +@pytest.mark.cosmosEmulator +class TestSetSessionTokenHeaderFeedRange(unittest.TestCase): + """Sync coverage for single-partition session token selection.""" + + COLLECTION_LINK = "dbs/db1/colls/c1" + COLLECTION_RID = "rid_session_token_sync" + PARTITION_TOKENS = { + "0": "1#5", + "1": "1#10", + "2": "1#7", + } + + def test_per_partition_id_selects_single_partition_token_from_compound_state(self): + # When a partition id is supplied, only that partition's token should be sent. + client = _SessionTokenClientStub(self.COLLECTION_LINK, self.COLLECTION_RID, self.PARTITION_TOKENS) + request, headers, path = _build_session_request(self.COLLECTION_LINK) + + set_session_token_header( + client, headers, path, request, {}, partition_key_range_id="1", + ) + + token = headers.get(HttpHeaders.SessionToken) + self.assertEqual( + token, "1:1#10", + "Expected single-partition token '1:1#10', got {!r}.".format(token), + ) + self.assertNotIn(",", token or "", "Per-id selection should never emit a compound token.") + + def test_no_partition_id_returns_compound_for_cross_partition_query(self): + # Cross-partition queries (no partition id, no partition key) keep the joined token. + client = _SessionTokenClientStub(self.COLLECTION_LINK, self.COLLECTION_RID, self.PARTITION_TOKENS) + request, headers, path = _build_session_request(self.COLLECTION_LINK) + + set_session_token_header(client, headers, path, request, {}, partition_key_range_id=None) + + token = headers.get(HttpHeaders.SessionToken, "") + self.assertIn(",", token, "Cross-partition branch must emit a compound token.") + parts = sorted(token.split(",")) + self.assertEqual( + parts, sorted(["0:1#5", "1:1#10", "2:1#7"]), + "Compound token entries do not match the seeded state: {!r}".format(token), + ) + + def test_per_partition_id_returns_no_compound_when_no_state_for_partition(self): + # If we have no cached token for the targeted partition, the helper must not + # silently fall back to a compound token. + client = _SessionTokenClientStub(self.COLLECTION_LINK, self.COLLECTION_RID, self.PARTITION_TOKENS) + request, headers, path = _build_session_request(self.COLLECTION_LINK) + + set_session_token_header( + client, headers, path, request, {}, partition_key_range_id="999", + ) + + token = headers.get(HttpHeaders.SessionToken) + self.assertIsNone( + token, + "Missing per-id state must not set a session token header.", + ) + + +@pytest.mark.cosmosEmulator +class TestSetSessionTokenHeaderFeedRangeAsync(unittest.IsolatedAsyncioTestCase): + """Async coverage for single-partition session token selection.""" + + COLLECTION_LINK = "dbs/db1/colls/c1" + COLLECTION_RID = "rid_session_token_async" + PARTITION_TOKENS = { + "0": "1#5", + "1": "1#10", + "2": "1#7", + } + + async def test_per_partition_id_selects_single_partition_token_from_compound_state_async(self): + client = _SessionTokenClientStub(self.COLLECTION_LINK, self.COLLECTION_RID, self.PARTITION_TOKENS) + request, headers, path = _build_session_request(self.COLLECTION_LINK) + + await set_session_token_header_async( + client, headers, path, request, {}, partition_key_range_id="1", + ) + + token = headers.get(HttpHeaders.SessionToken) + self.assertEqual(token, "1:1#10") + self.assertNotIn(",", token or "") + + async def test_no_partition_id_returns_compound_for_cross_partition_query_async(self): + client = _SessionTokenClientStub(self.COLLECTION_LINK, self.COLLECTION_RID, self.PARTITION_TOKENS) + request, headers, path = _build_session_request(self.COLLECTION_LINK) + + await set_session_token_header_async( + client, headers, path, request, {}, partition_key_range_id=None, + ) + + token = headers.get(HttpHeaders.SessionToken, "") + self.assertIn(",", token) + parts = sorted(token.split(",")) + self.assertEqual(parts, sorted(["0:1#5", "1:1#10", "2:1#7"])) + + async def test_per_partition_id_returns_no_compound_when_no_state_for_partition_async(self): + # Async equivalent of the sync test above: if there is no cached token + # for the targeted partition, the async helper must not fall back to a + # compound (cross-partition) token. + client = _SessionTokenClientStub(self.COLLECTION_LINK, self.COLLECTION_RID, self.PARTITION_TOKENS) + request, headers, path = _build_session_request(self.COLLECTION_LINK) + + await set_session_token_header_async( + client, headers, path, request, {}, partition_key_range_id="999", + ) + + token = headers.get(HttpHeaders.SessionToken) + self.assertIsNone( + token, + "Missing per-id state must not set a session token header.", + ) diff --git a/sdk/cosmos/azure-cosmos/tests/test_user_agent_overwrite_regression.py b/sdk/cosmos/azure-cosmos/tests/test_user_agent_overwrite_regression.py new file mode 100644 index 000000000000..f0c680dad981 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_user_agent_overwrite_regression.py @@ -0,0 +1,193 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests that constructing a sync client with ``user_agent_overwrite`` does not crash. + +Covers combinations of the flag with other construction-time options +(connection policy, consistency level, timeouts, logger, connection string, +AAD credential) and checks that the user-supplied user-agent prefix shows up +on outbound requests. +""" + +import unittest + +import pytest +from azure.core.pipeline.transport import RequestsTransport + +import test_config +from azure.cosmos import CosmosClient, documents +from test_aad import CosmosEmulatorCredential +from test_cosmos_http_logging_policy import create_logger + + +class _UserAgentCaptureTransport(RequestsTransport): + """Forwards every request and records the outbound User-Agent header.""" + + def __init__(self): + super().__init__() + self.user_agents: list[str] = [] + + def send(self, request, **kwargs): + ua = request.headers.get("User-Agent") + if ua is not None: + self.user_agents.append(ua) + return super().send(request, **kwargs) + + +@pytest.mark.cosmosEmulator +@pytest.mark.cosmosAADLong +class TestUserAgentOverwriteRegression(unittest.TestCase): + """User-agent overwrite construction tests for the sync client.""" + + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + connection_str = configs.connection_str + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + TEST_SINGLE_PARTITION_CONTAINER_ID = configs.TEST_SINGLE_PARTITION_CONTAINER_ID + + _skip_on_non_emulator = pytest.mark.skipif( + not configs.is_emulator, + reason="Uses the emulator credential; only valid against the local emulator.", + ) + + @classmethod + def setUpClass(cls): + if cls.masterKey == "[YOUR_KEY_HERE]" or cls.host == "[YOUR_ENDPOINT_HERE]": + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' to run the tests.") + + def _smoke_data_plane_call(self, client: CosmosClient) -> None: + """Reads a container so the test fails if the client cannot actually be used.""" + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_SINGLE_PARTITION_CONTAINER_ID) + assert container.read()["id"] == container.id + + def _assert_user_agent_headers(self, capture: _UserAgentCaptureTransport, scenario: str) -> None: + assert capture.user_agents, f"no outbound requests captured for {scenario}" + assert all(ua.startswith("MyApp/1.0") for ua in capture.user_agents), ( + "user-agent prefix dropped for {}. captured={}".format(scenario, capture.user_agents) + ) + assert all("azsdk-python-cosmos/" in ua for ua in capture.user_agents), ( + "sync SDK user-agent missing for {}. captured={}".format(scenario, capture.user_agents) + ) + + def test_overwrite_with_connection_policy(self): + """Builds a client with a custom connection policy and the overwrite flag.""" + cp = documents.ConnectionPolicy() + cp.DisableSSLVerification = self.configs.is_emulator + capture = _UserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + connection_policy=cp, + transport=capture, + ) + try: + self._smoke_data_plane_call(client) + finally: + client.close() + self._assert_user_agent_headers(capture, "connection_policy") + + def test_overwrite_with_consistency_and_timeouts(self): + """Builds a client combining the overwrite flag with consistency and timeout options.""" + capture = _UserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + consistency_level="Session", + connection_timeout=10, + read_timeout=10, + transport=capture, + ) + try: + self._smoke_data_plane_call(client) + finally: + client.close() + self._assert_user_agent_headers(capture, "consistency_and_timeouts") + + def test_overwrite_with_logger_and_diagnostics(self): + """Builds a client combining the overwrite flag with a user-provided logger.""" + mock_handler = test_config.MockHandler() + logger = create_logger("test_ua_overwrite_diag_sync", mock_handler) + capture = _UserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + logger=logger, + enable_diagnostics_logging=True, + transport=capture, + ) + try: + self._smoke_data_plane_call(client) + finally: + client.close() + self._assert_user_agent_headers(capture, "logger_and_diagnostics") + + def test_overwrite_header_contains_user_prefix_under_both_flag_values(self): + """User-supplied user-agent prefix appears on the wire with the flag on or off. + + The Cosmos SDK always keeps its base user-agent in the header, so the + overwrite flag does not actually replace it. This pins down that + observable behavior and confirms neither flag value crashes the client. + """ + for overwrite_value in (True, False): + capture = _UserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=overwrite_value, + transport=capture, + ) + try: + self._smoke_data_plane_call(client) + finally: + client.close() + self._assert_user_agent_headers(capture, f"overwrite={overwrite_value}") + + def test_overwrite_via_from_connection_string(self): + """Builds a client from a connection string while also passing the overwrite flag.""" + capture = _UserAgentCaptureTransport() + client = CosmosClient.from_connection_string( + self.connection_str, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + transport=capture, + ) + try: + self._smoke_data_plane_call(client) + finally: + client.close() + self._assert_user_agent_headers(capture, "from_connection_string") + + @_skip_on_non_emulator + def test_overwrite_with_aad_emulator_credential(self): + """Builds a client with an AAD credential while also passing the overwrite flag.""" + credential = CosmosEmulatorCredential() + capture = _UserAgentCaptureTransport() + client = CosmosClient( + self.host, + credential, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + transport=capture, + ) + try: + self._smoke_data_plane_call(client) + finally: + client.close() + self._assert_user_agent_headers(capture, "aad_emulator_credential") + + +if __name__ == "__main__": + unittest.main() + + diff --git a/sdk/cosmos/azure-cosmos/tests/test_user_agent_overwrite_regression_async.py b/sdk/cosmos/azure-cosmos/tests/test_user_agent_overwrite_regression_async.py new file mode 100644 index 000000000000..2063f2705e7a --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_user_agent_overwrite_regression_async.py @@ -0,0 +1,195 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""Tests that constructing an async client with ``user_agent_overwrite`` does not crash. + +Covers combinations of the flag with other construction-time options +(connection policy, consistency level, timeouts, logger, connection string, +AAD credential) and checks that the user-supplied user-agent prefix shows up +on outbound requests. +""" + +import unittest + +import pytest +from azure.core.pipeline.transport import AioHttpTransport + +import test_config +from azure.cosmos import documents +from azure.cosmos.aio import CosmosClient +from test_aad_async import CosmosEmulatorCredential +from test_cosmos_http_logging_policy import create_logger + + +class _AsyncUserAgentCaptureTransport(AioHttpTransport): + """Forwards every request and records the outbound User-Agent header.""" + + def __init__(self): + super().__init__() + self.user_agents: list[str] = [] + + async def send(self, request, **kwargs): + ua = request.headers.get("User-Agent") + if ua is not None: + self.user_agents.append(ua) + return await super().send(request, **kwargs) + + +@pytest.mark.cosmosEmulator +@pytest.mark.cosmosAADLong +class TestUserAgentOverwriteRegressionAsync(unittest.IsolatedAsyncioTestCase): + """User-agent overwrite construction tests for the async client.""" + + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + connection_str = configs.connection_str + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + TEST_SINGLE_PARTITION_CONTAINER_ID = configs.TEST_SINGLE_PARTITION_CONTAINER_ID + + _skip_on_non_emulator = pytest.mark.skipif( + not configs.is_emulator, + reason="Uses the emulator credential; only valid against the local emulator.", + ) + + @classmethod + def setUpClass(cls): + if cls.masterKey == "[YOUR_KEY_HERE]" or cls.host == "[YOUR_ENDPOINT_HERE]": + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' to run the tests.") + + async def _smoke_data_plane_call(self, client: CosmosClient) -> None: + """Reads a container so the test fails if the client cannot actually be used.""" + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_SINGLE_PARTITION_CONTAINER_ID) + result = await container.read() + assert result["id"] == container.id + + def _assert_user_agent_headers(self, capture: _AsyncUserAgentCaptureTransport, scenario: str) -> None: + assert capture.user_agents, f"no outbound requests captured for {scenario}" + assert all(ua.startswith("MyApp/1.0") for ua in capture.user_agents), ( + "user-agent prefix dropped for {}. captured={}".format(scenario, capture.user_agents) + ) + assert all("azsdk-python-cosmos-async/" in ua for ua in capture.user_agents), ( + "async SDK user-agent missing for {}. captured={}".format(scenario, capture.user_agents) + ) + + async def test_overwrite_with_connection_policy_async(self): + """Builds a client with a custom connection policy and the overwrite flag.""" + cp = documents.ConnectionPolicy() + cp.DisableSSLVerification = self.configs.is_emulator + capture = _AsyncUserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + connection_policy=cp, + transport=capture, + ) + try: + await self._smoke_data_plane_call(client) + finally: + await client.close() + self._assert_user_agent_headers(capture, "connection_policy_async") + + async def test_overwrite_with_consistency_and_timeouts_async(self): + """Builds a client combining the overwrite flag with consistency and timeout options.""" + capture = _AsyncUserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + consistency_level="Session", + connection_timeout=10, + read_timeout=10, + transport=capture, + ) + try: + await self._smoke_data_plane_call(client) + finally: + await client.close() + self._assert_user_agent_headers(capture, "consistency_and_timeouts_async") + + async def test_overwrite_with_logger_and_diagnostics_async(self): + """Builds a client combining the overwrite flag with a user-provided logger.""" + mock_handler = test_config.MockHandler() + logger = create_logger("test_ua_overwrite_diag_async", mock_handler) + capture = _AsyncUserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + logger=logger, + enable_diagnostics_logging=True, + transport=capture, + ) + try: + await self._smoke_data_plane_call(client) + finally: + await client.close() + self._assert_user_agent_headers(capture, "logger_and_diagnostics_async") + + async def test_overwrite_header_contains_user_prefix_under_both_flag_values_async(self): + """User-supplied user-agent prefix appears on the wire with the flag on or off. + + The Cosmos SDK always keeps its base user-agent in the header, so the + overwrite flag does not actually replace it. This pins down that + observable behavior and confirms neither flag value crashes the client. + """ + for overwrite_value in (True, False): + capture = _AsyncUserAgentCaptureTransport() + client = CosmosClient( + self.host, + self.masterKey, + user_agent="MyApp/1.0", + user_agent_overwrite=overwrite_value, + transport=capture, + ) + try: + await self._smoke_data_plane_call(client) + finally: + await client.close() + self._assert_user_agent_headers(capture, f"overwrite={overwrite_value}") + + async def test_overwrite_via_from_connection_string_async(self): + """Builds a client from a connection string while also passing the overwrite flag.""" + capture = _AsyncUserAgentCaptureTransport() + client = CosmosClient.from_connection_string( + self.connection_str, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + transport=capture, + ) + try: + await self._smoke_data_plane_call(client) + finally: + await client.close() + self._assert_user_agent_headers(capture, "from_connection_string_async") + + @_skip_on_non_emulator + async def test_overwrite_with_aad_emulator_credential_async(self): + """Builds a client with an AAD credential while also passing the overwrite flag.""" + credential = CosmosEmulatorCredential() + capture = _AsyncUserAgentCaptureTransport() + client = CosmosClient( + self.host, + credential, + user_agent="MyApp/1.0", + user_agent_overwrite=True, + transport=capture, + ) + try: + await self._smoke_data_plane_call(client) + finally: + await client.close() + self._assert_user_agent_headers(capture, "aad_emulator_credential_async") + + +if __name__ == "__main__": + unittest.main() + +