diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index a2167eb..c3bef3a 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -42,8 +42,8 @@ env: CBCI_SUPPORTED_ARM64_PLATFORMS: "linux macos" CBCI_DEFAULT_LINUX_X86_64_PLATFORM: "ubuntu-22.04" CBCI_DEFAULT_LINUX_ARM64_PLATFORM: "ubuntu-22.04-arm" - CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-13" - CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-14" + CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-15-intel" + CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-15" CBCI_DEFAULT_WINDOWS_PLATFORM: "windows-2022" CBCI_DEFAULT_LINUX_CONTAINER: "slim-bookworm" CBCI_DEFAULT_ALPINE_CONTAINER: "alpine" diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 6366d78..b89ef77 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -64,8 +64,8 @@ env: CBCI_SUPPORTED_ARM64_PLATFORMS: "linux macos" CBCI_DEFAULT_LINUX_X86_64_PLATFORM: "ubuntu-22.04" CBCI_DEFAULT_LINUX_ARM64_PLATFORM: "ubuntu-22.04-arm" - CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-13" - CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-14" + CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-15-intel" + CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-15" CBCI_DEFAULT_WINDOWS_PLATFORM: "windows-2022" CBCI_DEFAULT_LINUX_CONTAINER: "slim-bookworm" CBCI_DEFAULT_ALPINE_CONTAINER: "alpine" diff --git a/.github/workflows/verify_release.yml b/.github/workflows/verify_release.yml index 81d7c65..7c8ceb6 100644 --- a/.github/workflows/verify_release.yml +++ b/.github/workflows/verify_release.yml @@ -57,8 +57,8 @@ env: CBCI_SUPPORTED_ARM64_PLATFORMS: "linux macos" CBCI_DEFAULT_LINUX_X86_64_PLATFORM: "ubuntu-22.04" CBCI_DEFAULT_LINUX_ARM64_PLATFORM: "ubuntu-22.04-arm" - CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-13" - CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-14" + CBCI_DEFAULT_MACOS_X86_64_PLATFORM: "macos-15-intel" + CBCI_DEFAULT_MACOS_ARM64_PLATFORM: "macos-15" CBCI_DEFAULT_WINDOWS_PLATFORM: "windows-2022" CBCI_DEFAULT_LINUX_CONTAINER: "slim-bookworm" CBCI_DEFAULT_ALPINE_CONTAINER: "alpine" diff --git a/.gitignore b/.gitignore index 2511da8..164c3c3 100644 --- a/.gitignore +++ b/.gitignore @@ -176,5 +176,8 @@ gocaves* .pytest_cache/ test_scripts/ -# rff +# ruff .ruff_cache/ + +# other +.DS_Store diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6edfbdc..b9feb9f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,6 +45,8 @@ repos: - pytest~=8.3.5 - httpx~=0.28.1 - aiohttp~=3.11.10 + - sniffio~=1.3.1 + - anyio~=4.9.0 types: - python require_serial: true diff --git a/acouchbase_analytics/cluster.py b/acouchbase_analytics/cluster.py index 2cc0caf..8617aba 100644 --- a/acouchbase_analytics/cluster.py +++ b/acouchbase_analytics/cluster.py @@ -25,7 +25,8 @@ from typing import TypeAlias from acouchbase_analytics.database import AsyncDatabase -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult if TYPE_CHECKING: from couchbase_analytics.credential import Credential @@ -143,6 +144,22 @@ def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awai """ # noqa: E501 return self._impl.execute_query(statement, *args, **kwargs) + def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryHandle]: + """Executes a query against an Analytics cluster in async mode. + + .. seealso:: + :meth:`acouchbase_analytics.Scope.start_query`: For how to execute scope-level queries. + + Args: + statement: The SQL++ statement to execute. + options (:class:`~acouchbase_analytics.options.StartQueryOptions`): Optional parameters for the query operation. + **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~acouchbase_analytics.options.StartQueryOptions` + + Returns: + :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle`: An instance of a :class:`~acouchbase_analytics.query_handle.AsyncQueryHandle` + """ # noqa: E501 + return self._impl.start_query(statement, *args, **kwargs) + async def shutdown(self) -> None: """Shuts down this cluster instance. Cleaning up all resources associated with it. diff --git a/acouchbase_analytics/cluster.pyi b/acouchbase_analytics/cluster.pyi index bea6643..4da9fff 100644 --- a/acouchbase_analytics/cluster.pyi +++ b/acouchbase_analytics/cluster.pyi @@ -21,10 +21,19 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.credential import Credential from acouchbase_analytics.database import AsyncDatabase -from couchbase_analytics.credential import Credential -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) +from acouchbase_analytics.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncCluster: @overload @@ -54,14 +63,34 @@ class AsyncCluster: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: Unpack[QueryOptionsKwargs] + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: Unpack[QueryOptionsKwargs] ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: str + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: str ) -> Awaitable[AsyncQueryResult]: ... @overload - def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + def execute_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> AsyncQueryHandle: ... def shutdown(self) -> Awaitable[None]: ... @overload @classmethod diff --git a/acouchbase_analytics/options.py b/acouchbase_analytics/options.py index bc8f846..47c432b 100644 --- a/acouchbase_analytics/options.py +++ b/acouchbase_analytics/options.py @@ -16,9 +16,13 @@ from couchbase_analytics.common.options import ClusterOptions as ClusterOptions # noqa: F401 from couchbase_analytics.common.options import ClusterOptionsKwargs as ClusterOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptions as FetchResultsOptions # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptionsKwargs as FetchResultsOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import QueryOptions as QueryOptions # noqa: F401 from couchbase_analytics.common.options import QueryOptionsKwargs as QueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import SecurityOptions as SecurityOptions # noqa: F401 from couchbase_analytics.common.options import SecurityOptionsKwargs as SecurityOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptions as StartQueryOptions # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptionsKwargs as StartQueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import TimeoutOptions as TimeoutOptions # noqa: F401 from couchbase_analytics.common.options import TimeoutOptionsKwargs as TimeoutOptionsKwargs # noqa: F401 diff --git a/acouchbase_analytics/protocol/_core/anyio_utils.py b/acouchbase_analytics/protocol/_core/anyio_utils.py index ce7a751..5a2e211 100644 --- a/acouchbase_analytics/protocol/_core/anyio_utils.py +++ b/acouchbase_analytics/protocol/_core/anyio_utils.py @@ -66,7 +66,7 @@ def current_async_library() -> Optional[AsyncBackend]: try: import sniffio except ImportError: - async_lib = 'asyncio' + return AsyncBackend('asyncio') try: async_lib = sniffio.current_async_library() diff --git a/acouchbase_analytics/protocol/_core/client_adapter.py b/acouchbase_analytics/protocol/_core/client_adapter.py index 6a6ac60..a24c1fa 100644 --- a/acouchbase_analytics/protocol/_core/client_adapter.py +++ b/acouchbase_analytics/protocol/_core/client_adapter.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import Optional, Union, cast from uuid import uuid4 from httpx import URL, AsyncClient, BasicAuth, Response @@ -25,12 +25,10 @@ from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.deserializer import Deserializer from couchbase_analytics.common.logging import LogLevel, log_message +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import _ConnectionDetails from couchbase_analytics.protocol.options import OptionsBuilder -if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request import QueryRequest - class _AsyncClientAdapter: """ @@ -164,7 +162,9 @@ async def create_client(self) -> None: def log_message(self, message: str, log_level: LogLevel) -> None: log_message(logger, f'{self.log_prefix} {message}', log_level) - async def send_request(self, request: QueryRequest) -> Response: + async def send_request( + self, request: Union[CancelRequest, HttpRequest, QueryRequest, StartQueryRequest], stream: Optional[bool] = True + ) -> Response: """ **INTERNAL** """ @@ -177,8 +177,19 @@ async def send_request(self, request: QueryRequest) -> Response: port=request.url.port, path=request.url.path, ) - req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) - return await self._client.send(req, stream=True) + if isinstance(request, (QueryRequest, StartQueryRequest)): + req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) + else: + headers = request.headers if request.headers is not None else None + data = request.data if isinstance(request, CancelRequest) else None + req = self._client.build_request( + request.method, url, data=data, headers=headers, extensions=request.extensions + ) + + if stream is None: + stream = True + + return await self._client.send(req, stream=stream) def reset_client(self) -> None: """ diff --git a/acouchbase_analytics/protocol/_core/request_context.py b/acouchbase_analytics/protocol/_core/request_context.py index f01eec1..bdddf76 100644 --- a/acouchbase_analytics/protocol/_core/request_context.py +++ b/acouchbase_analytics/protocol/_core/request_context.py @@ -36,20 +36,20 @@ from couchbase_analytics.common.errors import AnalyticsError from couchbase_analytics.common.logging import LogLevel from couchbase_analytics.common.request import RequestState +from couchbase_analytics.protocol._core.request import FetchResultsRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import DEFAULT_TIMEOUTS from couchbase_analytics.protocol.errors import ErrorMapper if TYPE_CHECKING: from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter - from couchbase_analytics.protocol._core.request import QueryRequest class AsyncRequestContext: def __init__( self, client_adapter: _AsyncClientAdapter, - request: QueryRequest, - stream_config: Optional[JsonStreamConfig] = None, + request: Union[FetchResultsRequest, HttpRequest, QueryRequest, StartQueryRequest], + supports_cancellation: Optional[bool] = None, backend: Optional[AsyncBackend] = None, ) -> None: self._id = str(uuid4()) @@ -57,11 +57,9 @@ def __init__( self._request = request self._backend = backend or current_async_library() self._backoff_calc = DefaultBackoffCalculator() - self._error_ctx = ErrorContext(num_attempts=0, method=request.method, statement=request.get_request_statement()) + self._error_context = ErrorContext(num_attempts=0, method=request.method) self._request_state = RequestState.NotStarted - self._stream_config = stream_config or JsonStreamConfig() - self._json_stream: AsyncJsonStream - self._stage_completed: Optional[anyio.Event] = None + self._supports_cancellation = False if supports_cancellation is None else supports_cancellation self._request_error: Optional[Union[BaseException, Exception]] = None connect_timeout = self._client_adapter.connection_details.get_connect_timeout() self._connect_deadline = get_time() + connect_timeout @@ -71,31 +69,19 @@ def __init__( @property def cancelled(self) -> bool: + if not self._supports_cancellation: + return False self._check_timed_out() return self._request_state in [RequestState.Cancelled, RequestState.AsyncCancelledPriorToTimeout] @property def error_context(self) -> ErrorContext: - return self._error_ctx - - @property - def has_stage_completed(self) -> bool: - return self._stage_completed is not None and self._stage_completed.is_set() + return self._error_context @property def is_shutdown(self) -> bool: return self._shutdown - @property - def okay_to_iterate(self) -> bool: - self._check_timed_out() - return RequestState.okay_to_iterate(self._request_state) - - @property - def okay_to_stream(self) -> bool: - self._check_timed_out() - return RequestState.okay_to_stream(self._request_state) - @property def request_error(self) -> Optional[Union[BaseException, Exception]]: return self._request_error @@ -106,139 +92,15 @@ def request_state(self) -> RequestState: @property def retry_limit_exceeded(self) -> bool: - return self.error_context.num_attempts > self._request.max_retries - - @property - def results_or_errors_type(self) -> ParsedResultType: - return self._json_stream.results_or_errors_type + return self._error_context.num_attempts > self._request.max_retries @property def timed_out(self) -> bool: self._check_timed_out() return self._request_state == RequestState.Timeout - def _check_timed_out(self) -> None: - if self._request_state in [RequestState.Timeout, RequestState.Cancelled, RequestState.Error]: - return - - if hasattr(self, '_request_deadline') is False: - return - - current_time = get_time() - timed_out = current_time >= self._request_deadline - if timed_out: - message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} - self.log_message('Request has timed out', LogLevel.DEBUG, message_data=message_data) - if self._request_state == RequestState.Cancelled: - self._request_state = RequestState.AsyncCancelledPriorToTimeout - else: - self._request_state = RequestState.Timeout - - async def _execute(self, fn: Callable[..., Awaitable[Any]], *args: object) -> None: - await fn(*args) - if self._stage_completed is not None: - self._stage_completed.set() - - def _maybe_set_request_error( - self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None - ) -> None: - self._check_timed_out() - if exc_val is None: - return - if not RequestState.is_timeout_or_cancelled(self._request_state): - # This handles httpx timeouts - if exc_type is not None and issubclass(exc_type, TimeoutException): - self._request_state = RequestState.Timeout - elif issubclass(type(exc_val), TimeoutException): - self._request_state = RequestState.Timeout - elif isinstance(exc_val, CancelledError): - self._request_state = RequestState.Cancelled - else: - self._request_state = RequestState.Error - self._request_error = exc_val - - async def _process_error( - self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False - ) -> None: - self._request_state = RequestState.Error - if isinstance(json_data, str): - self._request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_ctx) - elif not isinstance(json_data, list): - self._request_error = AnalyticsError( - 'Cannot parse error response; expected JSON array', context=str(self._error_ctx) - ) - else: - self._request_error = ErrorMapper.build_error_from_json(json_data, self._error_ctx) - if handle_context_shutdown is True: - await self.reraise_after_shutdown(self._request_error) - - raise self._request_error - - def _reset_stream(self) -> None: - if hasattr(self, '_json_stream'): - del self._json_stream - self._request_state = RequestState.ResetAndNotStarted - self._stage_completed = None - self._cancel_scope_deadline_updated = False - - def _start_next_stage( - self, fn: Callable[..., Awaitable[Any]], *args: object, reset_previous_stage: Optional[bool] = False - ) -> None: - if self._stage_completed is not None: - if reset_previous_stage is True: - self._stage_completed = None - else: - raise RuntimeError('Task already running in this context.') - - self._stage_completed = anyio.Event() - self._taskgroup.start_soon(self._execute, fn, *args) - - async def _trace_handler(self, event_name: str, _: str) -> None: - if event_name == 'connection.connect_tcp.complete': - # after connection is established, we need to update the cancel_scope deadline to match the query_timeout - self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) - self._cancel_scope_deadline_updated = True - elif self._cancel_scope_deadline_updated is False and event_name.endswith('send_request_headers.started'): - # if the socket is reused, we won't get the connect_tcp.complete event, - # so the deadline at the next closest event - self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) - self._cancel_scope_deadline_updated = True - - def _update_cancel_scope_deadline(self, deadline: float, is_absolute: Optional[bool] = False) -> None: - new_deadline = deadline if is_absolute else get_time() + deadline - current_time = get_time() - if current_time >= new_deadline: - self.log_message( - 'Deadline already exceeded, cancelling request', - LogLevel.DEBUG, - message_data={ - 'current_time': f'{current_time}', - 'new_deadline': f'{new_deadline}', - }, - ) - self._taskgroup.cancel_scope.cancel() - else: - self.log_message( - f'Updating cancel scope deadline: {self._taskgroup.cancel_scope.deadline} -> {new_deadline}', - LogLevel.DEBUG, - ) - self._taskgroup.cancel_scope.deadline = new_deadline - - async def _wait_for_stage_to_complete(self) -> None: - if self._stage_completed is None: - return - await self._stage_completed.wait() - def calculate_backoff(self) -> float: - return self._backoff_calc.calculate_backoff(self._error_ctx.num_attempts) / 1000 - - def cancel_request(self, fn: Optional[Callable[..., Awaitable[Any]]] = None, *args: object) -> None: - if fn is not None: - self._taskgroup.start_soon(fn, *args) - if self._request_state == RequestState.Timeout: - return - self._taskgroup.cancel_scope.cancel() - self._request_state = RequestState.Cancelled + return self._backoff_calc.calculate_backoff(self._error_context.num_attempts) / 1000 def create_response_task(self, fn: Callable[..., Coroutine[Any, Any, Any]], *args: object) -> Task[Any]: if self._backend is None or self._backend.backend_lib != 'asyncio': @@ -250,20 +112,6 @@ def create_response_task(self, fn: Callable[..., Coroutine[Any, Any, Any]], *arg self._response_task = task return task - def deserialize_result(self, result: bytes) -> Any: - return self._request.deserializer.deserialize(result) - - async def finish_processing_stream(self) -> None: - if not self.has_stage_completed: - await self._wait_for_stage_to_complete() - - while not self._json_stream.token_stream_exhausted: - self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) - await self._wait_for_stage_to_complete() - - async def get_result_from_stream(self) -> ParsedResult: - return await self._json_stream.get_result() - async def initialize(self) -> None: if self._request_state == RequestState.ResetAndNotStarted: current_time = get_time() @@ -297,18 +145,11 @@ def log_message( message = f'{message}, {message_data_str}' self._client_adapter.log_message(message, log_level) - def maybe_continue_to_process_stream(self) -> None: - if not self.has_stage_completed: - return - - if self._json_stream.token_stream_exhausted: - return - - self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) - def okay_to_delay_and_retry(self, delay: float) -> bool: - self._check_timed_out() - if self._request_state in [RequestState.Timeout, RequestState.Cancelled]: + # calling self.timed_out will call _check_timed_out, so we don't need to call it again + if self.timed_out: + return False + elif self._supports_cancellation and self._request_state == RequestState.Cancelled: return False current_time = get_time() @@ -331,37 +172,28 @@ def okay_to_delay_and_retry(self, delay: float) -> bool: } self.log_message('Request has exceeded max retries', LogLevel.DEBUG, message_data=message_data) return False - else: - self._reset_stream() - return True + elif self._supports_cancellation: + # _reset_stream() _should_ exist, but surround w/ try/except just in case + try: + self._reset_stream() # type: ignore[attr-defined] + except AttributeError: + pass # nosec + + return True async def process_response( self, - close_handler: Callable[[], Coroutine[Any, Any, None]], - raw_response: Optional[ParsedResult] = None, + core_response: HttpCoreResponse, + close_handler: Callable[[], Awaitable[None]], handle_context_shutdown: Optional[bool] = False, ) -> Any: - if raw_response is None: - raw_response = await self._json_stream.get_result() - if raw_response is None: - await close_handler() - raise AnalyticsError( - message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx) - ) - - if raw_response.value is None: - await close_handler() - raise AnalyticsError( - message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx) - ) - # we have all the data, close the core response/stream await close_handler() try: - json_response = json.loads(raw_response.value) + json_response = core_response.json() except json.JSONDecodeError: - await self._process_error(str(raw_response.value), handle_context_shutdown=handle_context_shutdown) + await self._process_error(str(core_response.text), handle_context_shutdown=handle_context_shutdown) else: if 'errors' in json_response: await self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) @@ -375,29 +207,35 @@ async def reraise_after_shutdown(self, err: Exception) -> None: raise ex from None async def send_request(self, enable_trace_handling: Optional[bool] = False) -> HttpCoreResponse: - self._error_ctx.update_num_attempts() + self._error_context.update_num_attempts() ip = await get_request_ip_async(self._request.url.host, self._request.url.port, self.log_message) - if enable_trace_handling is True: - ( - self._request.update_url(ip, self._client_adapter.analytics_path).add_trace_to_extensions( - self._trace_handler - ) - ) + if self._request.path and not self._request.path.isspace(): + req_path = f'{self._request.path}' + else: + req_path = self._client_adapter.analytics_path + if enable_trace_handling is True and hasattr(self, '_trace_handler'): + self._request.update_url(ip, req_path).add_trace_to_extensions(self._trace_handler) else: - self._request.update_url(ip, self._client_adapter.analytics_path) - self._error_ctx.update_request_context(self._request) + self._request.update_url(ip, req_path) + self._error_context.update_request_context(self._request, path=self._request.path) message_data = { 'url': f'{self._request.url.get_formatted_url()}', - 'body': f'{self._request.body}', 'request_deadline': f'{self._request_deadline}', } + + if isinstance(self._request, (QueryRequest, StartQueryRequest)): + message_data['body'] = f'{self._request.body}' + + stream = hasattr(self._request, 'should_stream') and self._request.should_stream is True + message_data['streaming'] = str(stream) + self.log_message('HTTP request', LogLevel.DEBUG, message_data=message_data) - response = await self._client_adapter.send_request(self._request) - self._error_ctx.update_response_context(response) + response = await self._client_adapter.send_request(self._request, stream=stream) + self._error_context.update_response_context(response) message_data = { 'status_code': f'{response.status_code}', - 'last_dispatched_to': f'{self._error_ctx.last_dispatched_to}', - 'last_dispatched_from': f'{self._error_ctx.last_dispatched_from}', + 'last_dispatched_to': f'{self._error_context.last_dispatched_to}', + 'last_dispatched_from': f'{self._error_context.last_dispatched_from}', 'request_deadline': f'{self._request_deadline}', } self.log_message('HTTP response', LogLevel.DEBUG, message_data=message_data) @@ -422,21 +260,60 @@ async def shutdown( self._shutdown = True self.log_message('Request context shutdown complete', LogLevel.INFO) - def start_stream(self, core_response: HttpCoreResponse) -> None: - if hasattr(self, '_json_stream'): - self.log_message('JSON stream already exists', LogLevel.WARNING) + def _check_timed_out(self) -> None: + if self._request_state in (RequestState.Timeout, RequestState.Error): return - self._json_stream = AsyncJsonStream( - core_response.aiter_bytes(), stream_config=self._stream_config, logger_handler=self.log_message - ) - self._start_next_stage(self._json_stream.start_parsing) + if self._supports_cancellation and self._request_state == RequestState.Cancelled: + return - async def wait_for_results_or_errors(self) -> None: - await self._json_stream.has_results_or_errors.wait() - if self._json_stream.results_or_errors_type == ParsedResultType.ROW: - # we move to iterating rows - self._request_state = RequestState.StreamingResults + if hasattr(self, '_request_deadline') is False: + return + + current_time = get_time() + timed_out = current_time >= self._request_deadline + if timed_out: + message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} + self.log_message('Request has timed out', LogLevel.DEBUG, message_data=message_data) + if self._supports_cancellation and self._request_state == RequestState.Cancelled: + self._request_state = RequestState.AsyncCancelledPriorToTimeout + else: + self._request_state = RequestState.Timeout + + def _maybe_set_request_error( + self, exc_type: Optional[Type[BaseException]] = None, exc_val: Optional[BaseException] = None + ) -> None: + self._check_timed_out() + if exc_val is None: + return + if not RequestState.is_timeout_or_cancelled(self._request_state): + # This handles httpx timeouts + if exc_type is not None and issubclass(exc_type, TimeoutException): + self._request_state = RequestState.Timeout + elif issubclass(type(exc_val), TimeoutException): + self._request_state = RequestState.Timeout + elif isinstance(exc_val, CancelledError): + self._request_state = RequestState.Cancelled + else: + self._request_state = RequestState.Error + self._request_error = exc_val + + async def _process_error( + self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False + ) -> None: + self._request_state = RequestState.Error + if isinstance(json_data, str): + self._request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_context) + elif not isinstance(json_data, list): + self._request_error = AnalyticsError( + 'Cannot parse error response; expected JSON array', context=str(self._error_context) + ) + else: + self._request_error = ErrorMapper.build_error_from_json(json_data, self._error_context) + if handle_context_shutdown is True: + await self.reraise_after_shutdown(self._request_error) + + raise self._request_error async def __aenter__(self) -> AsyncRequestContext: self._taskgroup = anyio.create_task_group() @@ -457,3 +334,179 @@ async def __aexit__( self._maybe_set_request_error(exc_type, exc_val) del self._taskgroup return None # noqa: B012 + + +class AsyncStreamingRequestContext(AsyncRequestContext): + def __init__( + self, + client_adapter: _AsyncClientAdapter, + request: Union[FetchResultsRequest, QueryRequest], + stream_config: Optional[JsonStreamConfig] = None, + backend: Optional[AsyncBackend] = None, + ) -> None: + super().__init__(client_adapter, request, supports_cancellation=True, backend=backend) + if isinstance(request, QueryRequest): + self._error_context.set_statement(request.get_request_statement()) + self._stream_config = stream_config or JsonStreamConfig() + self._json_stream: AsyncJsonStream + self._stage_completed: Optional[anyio.Event] = None + self._deserializer = request.deserializer + + @property + def has_stage_completed(self) -> bool: + return self._stage_completed is not None and self._stage_completed.is_set() + + @property + def okay_to_iterate(self) -> bool: + self._check_timed_out() + return RequestState.okay_to_iterate(self._request_state) + + @property + def okay_to_stream(self) -> bool: + self._check_timed_out() + return RequestState.okay_to_stream(self._request_state) + + @property + def results_or_errors_type(self) -> ParsedResultType: + return self._json_stream.results_or_errors_type + + def cancel_request(self, fn: Optional[Callable[..., Awaitable[Any]]] = None, *args: object) -> None: + if fn is not None: + self._taskgroup.start_soon(fn, *args) + if self._request_state == RequestState.Timeout: + return + self._taskgroup.cancel_scope.cancel() + self._request_state = RequestState.Cancelled + + def deserialize_result(self, result: bytes) -> Any: + if not self._deserializer: + raise RuntimeError('No deserializer found for this request context.') + return self._deserializer.deserialize(result) + + async def finish_processing_stream(self) -> None: + if not self.has_stage_completed: + await self._wait_for_stage_to_complete() + + while not self._json_stream.token_stream_exhausted: + self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) + await self._wait_for_stage_to_complete() + + async def get_result_from_stream(self) -> ParsedResult: + return await self._json_stream.get_result() + + def maybe_continue_to_process_stream(self) -> None: + if not self.has_stage_completed: + return + + if self._json_stream.token_stream_exhausted: + return + + self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) + + async def process_streaming_response( + self, + close_handler: Callable[[], Awaitable[None]], + raw_response: Optional[ParsedResult] = None, + handle_context_shutdown: Optional[bool] = False, + ) -> Any: + if raw_response is None: + raw_response = await self._json_stream.get_result() + if raw_response is None: + await close_handler() + raise AnalyticsError( + message='Received unexpected empty result from JsonStream.', context=str(self._error_context) + ) + + if raw_response.value is None: + await close_handler() + raise AnalyticsError( + message='Received unexpected empty result from JsonStream.', context=str(self._error_context) + ) + + # we have all the data, close the core response/stream + await close_handler() + + try: + json_response = json.loads(raw_response.value) + except json.JSONDecodeError: + await self._process_error(str(raw_response.value), handle_context_shutdown=handle_context_shutdown) + else: + if 'errors' in json_response: + await self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) + return json_response + + def start_stream(self, core_response: HttpCoreResponse) -> None: + if hasattr(self, '_json_stream'): + self.log_message('JSON stream already exists', LogLevel.WARNING) + return + + self._json_stream = AsyncJsonStream( + core_response.aiter_bytes(), stream_config=self._stream_config, logger_handler=self.log_message + ) + self._start_next_stage(self._json_stream.start_parsing) + + async def wait_for_results_or_errors(self) -> None: + await self._json_stream.has_results_or_errors.wait() + if self._json_stream.results_or_errors_type == ParsedResultType.ROW: + # we move to iterating rows + self._request_state = RequestState.StreamingResults + + async def _execute(self, fn: Callable[..., Awaitable[Any]], *args: object) -> None: + await fn(*args) + if self._stage_completed is not None: + self._stage_completed.set() + + def _reset_stream(self) -> None: + if hasattr(self, '_json_stream'): + del self._json_stream + self._request_state = RequestState.ResetAndNotStarted + self._stage_completed = None + self._cancel_scope_deadline_updated = False + + def _start_next_stage( + self, fn: Callable[..., Awaitable[Any]], *args: object, reset_previous_stage: Optional[bool] = False + ) -> None: + if self._stage_completed is not None: + if reset_previous_stage is True: + self._stage_completed = None + else: + raise RuntimeError('Task already running in this context.') + + self._stage_completed = anyio.Event() + self._taskgroup.start_soon(self._execute, fn, *args) + + async def _trace_handler(self, event_name: str, _: str) -> None: + if event_name == 'connection.connect_tcp.complete': + # after connection is established, we need to update the cancel_scope deadline to match the query_timeout + self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) + self._cancel_scope_deadline_updated = True + elif self._cancel_scope_deadline_updated is False and event_name.endswith('send_request_headers.started'): + # if the socket is reused, we won't get the connect_tcp.complete event, + # so the deadline at the next closest event + self._update_cancel_scope_deadline(self._request_deadline, is_absolute=True) + self._cancel_scope_deadline_updated = True + + def _update_cancel_scope_deadline(self, deadline: float, is_absolute: Optional[bool] = False) -> None: + new_deadline = deadline if is_absolute else get_time() + deadline + current_time = get_time() + if current_time >= new_deadline: + self.log_message( + 'Deadline already exceeded, cancelling request', + LogLevel.DEBUG, + message_data={ + 'current_time': f'{current_time}', + 'new_deadline': f'{new_deadline}', + }, + ) + self._taskgroup.cancel_scope.cancel() + else: + self.log_message( + f'Updating cancel scope deadline: {self._taskgroup.cancel_scope.deadline} -> {new_deadline}', + LogLevel.DEBUG, + ) + self._taskgroup.cancel_scope.deadline = new_deadline + + async def _wait_for_stage_to_complete(self) -> None: + if self._stage_completed is None: + return + await self._stage_completed.wait() diff --git a/acouchbase_analytics/protocol/_core/response.py b/acouchbase_analytics/protocol/_core/response.py new file mode 100644 index 0000000..2d54df1 --- /dev/null +++ b/acouchbase_analytics/protocol/_core/response.py @@ -0,0 +1,115 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import Any, Optional + +from httpx import Response as HttpCoreResponse + +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.retries import AsyncRetryHandler +from couchbase_analytics.common._core.query import build_query_metadata +from couchbase_analytics.common.errors import AnalyticsError, InternalSDKError +from couchbase_analytics.common.logging import LogLevel +from couchbase_analytics.common.query import QueryMetadata + + +class AsyncHttpResponse: + def __init__( + self, + request_context: AsyncRequestContext, + skip_process_response: Optional[bool] = None, + request_id: Optional[str] = None, + ) -> None: + # Goal is to treat the AsyncHttpStreamingResponse as a "task group" + self._request_context = request_context + self._metadata: Optional[QueryMetadata] = None + self._core_response: HttpCoreResponse + self._json_response: Optional[Any] = None + self._skip_process_response = skip_process_response + self._request_id = request_id + + @property + def json_response(self) -> Optional[Any]: + return self._json_response + + async def close(self) -> None: + """ + **INTERNAL** + """ + if hasattr(self, '_core_response'): + await self._core_response.aclose() + self._request_context.log_message('HTTP core response closed', LogLevel.INFO) + del self._core_response + + def get_metadata(self) -> QueryMetadata: + """ + **INTERNAL** + """ + if self._metadata is None: + raise RuntimeError('Query metadata is only available after all rows have been iterated.') + return self._metadata + + async def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> None: + """ + **INTERNAL** + """ + try: + self._metadata = QueryMetadata( + build_query_metadata(json_data=json_data, raw_metadata=raw_metadata, request_id=self._request_id) + ) + await self._request_context.shutdown() + except (AnalyticsError, ValueError) as err: + await self._request_context.reraise_after_shutdown(err) + except Exception as ex: + internal_err = InternalSDKError(cause=ex, message=str(ex), context=str(self._request_context.error_context)) + await self._request_context.reraise_after_shutdown(internal_err) + finally: + await self.close() + + @AsyncRetryHandler.with_retries + async def send_request(self) -> None: + """ + **INTERNAL** + """ + await self._request_context.initialize() + self._core_response = await self._request_context.send_request() + if self._skip_process_response is True: + return + await self._process_response() + + async def shutdown(self) -> None: + """ + **INTERNAL** + """ + await self.close() + await self._request_context.shutdown() + + async def _close_in_background(self) -> None: + """ + **INTERNAL** + """ + await self.close() + + async def _process_response(self) -> None: + """ + **INTERNAL** + """ + self._json_response = await self._request_context.process_response( + self._core_response, self.close, handle_context_shutdown=True + ) + await self.set_metadata(json_data=self._json_response) diff --git a/acouchbase_analytics/protocol/_core/retries.py b/acouchbase_analytics/protocol/_core/retries.py index 83423d1..4570225 100644 --- a/acouchbase_analytics/protocol/_core/retries.py +++ b/acouchbase_analytics/protocol/_core/retries.py @@ -18,7 +18,7 @@ from asyncio import CancelledError from functools import wraps -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Optional, TypeVar, Union from httpx import ConnectError, ConnectTimeout, CookieConflict, HTTPError, InvalidURL, ReadTimeout, StreamError @@ -29,19 +29,22 @@ from couchbase_analytics.protocol.errors import WrappedError if TYPE_CHECKING: - from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext + from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext + from acouchbase_analytics.protocol._core.response import AsyncHttpResponse from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse +AsyncReqContext = Union['AsyncRequestContext', 'AsyncStreamingRequestContext'] +T = TypeVar('T', bound=Union['AsyncHttpResponse', 'AsyncHttpStreamingResponse']) + + class AsyncRetryHandler: """ **INTERNAL** """ @staticmethod - async def handle_httpx_retry( - ex: Union[ConnectError, ConnectTimeout], ctx: AsyncRequestContext - ) -> Optional[Exception]: + async def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: AsyncReqContext) -> Optional[Exception]: err_str = str(ex) if 'SSL:' in err_str: message = 'TLS connection error occurred.' @@ -64,7 +67,7 @@ async def handle_httpx_retry( return None @staticmethod - async def handle_retry(ex: WrappedError, ctx: AsyncRequestContext) -> Optional[Union[BaseException, Exception]]: + async def handle_retry(ex: WrappedError, ctx: AsyncReqContext) -> Optional[Union[BaseException, Exception]]: if ex.retriable is True: delay = ctx.calculate_backoff() err: Optional[Union[BaseException, Exception]] = None @@ -94,10 +97,10 @@ async def handle_retry(ex: WrappedError, ctx: AsyncRequestContext) -> Optional[U @staticmethod def with_retries( # noqa: C901 - fn: Callable[[AsyncHttpStreamingResponse], Coroutine[Any, Any, None]], - ) -> Callable[[AsyncHttpStreamingResponse], Coroutine[Any, Any, None]]: + fn: Callable[[T], Coroutine[Any, Any, None]], + ) -> Callable[[T], Coroutine[Any, Any, None]]: @wraps(fn) - async def wrapped_fn(self: AsyncHttpStreamingResponse) -> None: # noqa: C901 + async def wrapped_fn(self: T) -> None: # noqa: C901 while True: try: await fn(self) diff --git a/acouchbase_analytics/protocol/cluster.py b/acouchbase_analytics/protocol/cluster.py index f08519f..c28661b 100644 --- a/acouchbase_analytics/protocol/cluster.py +++ b/acouchbase_analytics/protocol/cluster.py @@ -27,15 +27,18 @@ from acouchbase_analytics.protocol._core.anyio_utils import current_async_library from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter -from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext +from acouchbase_analytics.protocol._core.response import AsyncHttpResponse +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse from couchbase_analytics.common.logging import LogLevel from couchbase_analytics.common.result import AsyncQueryResult from couchbase_analytics.protocol._core.request import _RequestBuilder if TYPE_CHECKING: + from acouchbase_analytics.options import ClusterOptions + from couchbase_analytics.common._core import JsonStreamConfig from couchbase_analytics.common.credential import Credential - from couchbase_analytics.options import ClusterOptions class AsyncCluster: @@ -106,16 +109,36 @@ async def _execute_query(self, http_resp: AsyncHttpStreamingResponse) -> AsyncQu return AsyncQueryResult(http_resp) def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryResult]: - base_req = self._request_builder.build_base_query_request(statement, *args, is_async=True, **kwargs) - stream_config = base_req.options.pop('stream_config', None) - request_context = AsyncRequestContext( - client_adapter=self.client_adapter, request=base_req, stream_config=stream_config, backend=self._backend + req = self._request_builder.build_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncStreamingRequestContext( + self.client_adapter, req, stream_config=stream_config, backend=self._backend ) resp = AsyncHttpStreamingResponse(request_context) if self._backend.backend_lib == 'asyncio': return request_context.create_response_task(self._execute_query, resp) return self._execute_query(resp) + async def _start_query( + self, http_resp: AsyncHttpResponse, stream_config: Optional[JsonStreamConfig] + ) -> AsyncQueryHandle: + if not self.has_client: + self.client_adapter.log_message( + 'Cluster does not have a connection. Creating the client.', LogLevel.WARNING + ) + await self._create_client() + await http_resp.send_request() + return AsyncQueryHandle(self._client_adapter, self._request_builder, http_resp, stream_config=stream_config) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryHandle]: + req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncRequestContext(self.client_adapter, req, backend=self._backend) + resp = AsyncHttpResponse(request_context) + if self._backend.backend_lib == 'asyncio': + return request_context.create_response_task(self._start_query, resp, stream_config) + return self._start_query(resp, stream_config) + @classmethod def create_instance( cls, endpoint: str, credential: Credential, options: Optional[ClusterOptions] = None, **kwargs: object diff --git a/acouchbase_analytics/protocol/cluster.pyi b/acouchbase_analytics/protocol/cluster.pyi index 87cb2a1..d9757e6 100644 --- a/acouchbase_analytics/protocol/cluster.pyi +++ b/acouchbase_analytics/protocol/cluster.pyi @@ -21,11 +21,20 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.credential import Credential +from acouchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter from acouchbase_analytics.protocol.database import AsyncDatabase -from couchbase_analytics.common.credential import Credential -from couchbase_analytics.common.result import AsyncQueryResult -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncCluster: @overload @@ -56,14 +65,34 @@ class AsyncCluster: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: Unpack[QueryOptionsKwargs] + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: Unpack[QueryOptionsKwargs] ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: str + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: str ) -> Awaitable[AsyncQueryResult]: ... @overload - def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + def execute_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> AsyncQueryHandle: ... @overload @classmethod def create_instance(cls, endpoint: str, credential: Credential) -> AsyncCluster: ... diff --git a/acouchbase_analytics/protocol/query_handle.py b/acouchbase_analytics/protocol/query_handle.py new file mode 100644 index 0000000..1459b70 --- /dev/null +++ b/acouchbase_analytics/protocol/query_handle.py @@ -0,0 +1,142 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext +from acouchbase_analytics.protocol._core.response import AsyncHttpResponse +from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse +from couchbase_analytics.common._core.query_handle import QueryHandleStatusResponse +from couchbase_analytics.common.errors import AnalyticsError +from couchbase_analytics.common.query_handle import AsyncQueryHandle as _CoreAsyncQueryHandle +from couchbase_analytics.common.query_handle import AsyncQueryResultHandle as _CoreAsyncQueryResultHandle +from couchbase_analytics.common.result import AsyncQueryResult +from couchbase_analytics.protocol._core.request import _RequestBuilder + +if TYPE_CHECKING: + from couchbase_analytics.common._core import JsonStreamConfig + + +class AsyncQueryHandle(_CoreAsyncQueryHandle): + def __init__( + self, + client_adapter: _AsyncClientAdapter, + request_builder: _RequestBuilder, + http_response: AsyncHttpResponse, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._http_response = http_response + self._stream_config = stream_config + self._request_id: str = '' + self._handle: str = '' + self._get_status_handle() + + @property + def handle(self) -> str: + return self._handle + + @property + def request_id(self) -> str: + return self._request_id + + async def fetch_result_handle(self) -> Optional[AsyncQueryResultHandle]: + server_req = self._request_builder.build_request_from_handle(self._handle) + request_context = AsyncRequestContext(self._client_adapter, server_req) + resp = AsyncHttpResponse(request_context) + await resp.send_request() + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + if 'handle' not in resp.json_response: + return None + + status_response = self._get_handle_status_response(resp) + return AsyncQueryResultHandle( + self._client_adapter, self._request_builder, status_response, stream_config=self._stream_config + ) + + async def cancel(self) -> None: + cancel_req = self._request_builder.build_cancel_request(self._request_id) + request_context = AsyncRequestContext(self._client_adapter, cancel_req) + resp = AsyncHttpResponse(request_context, skip_process_response=True, request_id=self._request_id) + await resp.send_request() + + def _get_status_handle(self) -> None: + if self._http_response.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + request_id = self._http_response.json_response.get('requestID', None) + if request_id is None: + raise AnalyticsError(message='Server response is missing "requestID" field.') + handle = self._http_response.json_response.get('handle', None) + if handle is None: + raise AnalyticsError(message='Server response is missing "handle" field.') + + self._request_id = request_id + self._handle = handle + # expected -> "handle": "/api/v1/request/status/062fd2f0-4b48-45f8-b494-d458b9a751e0/2-0" + # handle_tokens = self._handle.split(request_id) + # self._handle_id = handle_tokens[-1].replace('/', '') + + def _get_handle_status_response(self, resp: AsyncHttpResponse) -> QueryHandleStatusResponse: + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + handle = resp.json_response.get('handle', None) + if handle is None: + raise AnalyticsError(message='Server response is missing "handle" field.') + + return QueryHandleStatusResponse.from_server(self._request_id, resp.json_response) + + +class AsyncQueryResultHandle(_CoreAsyncQueryResultHandle): + def __init__( + self, + client_adapter: _AsyncClientAdapter, + request_builder: _RequestBuilder, + status_resp: QueryHandleStatusResponse, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._status_resp = status_resp + self._stream_config = stream_config + + @property + def request_id(self) -> str: + return self._status_resp.request_id + + async def fetch_results(self) -> AsyncQueryResult: + server_req = self._request_builder.build_fetch_results_request(self._status_resp.handle) + request_context = AsyncStreamingRequestContext( + self._client_adapter, server_req, stream_config=self._stream_config + ) + resp = AsyncHttpStreamingResponse(request_context, request_id=self._status_resp.request_id) + await resp.send_request() + return AsyncQueryResult(resp) + + async def discard_results(self) -> None: + req = self._request_builder.build_discard_results_request(self._status_resp.handle) + request_context = AsyncRequestContext(self._client_adapter, req) + resp = AsyncHttpResponse(request_context, skip_process_response=True, request_id=self._status_resp.request_id) + await resp.send_request() diff --git a/acouchbase_analytics/protocol/scope.py b/acouchbase_analytics/protocol/scope.py index cd97f5f..c72aa4a 100644 --- a/acouchbase_analytics/protocol/scope.py +++ b/acouchbase_analytics/protocol/scope.py @@ -17,7 +17,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Awaitable +from typing import TYPE_CHECKING, Awaitable, Optional if sys.version_info < (3, 10): from typing_extensions import TypeAlias @@ -26,14 +26,17 @@ from acouchbase_analytics.protocol._core.anyio_utils import current_async_library from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter -from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext, AsyncStreamingRequestContext +from acouchbase_analytics.protocol._core.response import AsyncHttpResponse +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle from acouchbase_analytics.protocol.streaming import AsyncHttpStreamingResponse +from acouchbase_analytics.result import AsyncQueryResult from couchbase_analytics.common.logging import LogLevel -from couchbase_analytics.common.result import AsyncQueryResult from couchbase_analytics.protocol._core.request import _RequestBuilder if TYPE_CHECKING: from acouchbase_analytics.protocol.database import AsyncDatabase + from couchbase_analytics.common._core import JsonStreamConfig class AsyncScope: @@ -73,15 +76,35 @@ async def _execute_query(self, http_resp: AsyncHttpStreamingResponse) -> AsyncQu return AsyncQueryResult(http_resp) def execute_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryResult]: - base_req = self._request_builder.build_base_query_request(statement, *args, is_async=True, **kwargs) - stream_config = base_req.options.pop('stream_config', None) - request_context = AsyncRequestContext( - client_adapter=self.client_adapter, request=base_req, stream_config=stream_config, backend=self._backend + req = self._request_builder.build_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncStreamingRequestContext( + client_adapter=self.client_adapter, request=req, stream_config=stream_config, backend=self._backend ) resp = AsyncHttpStreamingResponse(request_context) if self._backend.backend_lib == 'asyncio': return request_context.create_response_task(self._execute_query, resp) return self._execute_query(resp) + async def _start_query( + self, http_resp: AsyncHttpResponse, stream_config: Optional[JsonStreamConfig] + ) -> AsyncQueryHandle: + if not self.has_client: + self.client_adapter.log_message( + 'Cluster does not have a connection. Creating the client.', LogLevel.WARNING + ) + await self._create_client() + await http_resp.send_request() + return AsyncQueryHandle(self._client_adapter, self._request_builder, http_resp, stream_config=stream_config) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> Awaitable[AsyncQueryHandle]: + req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = req.options.pop('stream_config', None) + request_context = AsyncRequestContext(self.client_adapter, req, backend=self._backend) + resp = AsyncHttpResponse(request_context) + if self._backend.backend_lib == 'asyncio': + return request_context.create_response_task(self._start_query, resp, stream_config) + return self._start_query(resp, stream_config) + Scope: TypeAlias = AsyncScope diff --git a/acouchbase_analytics/protocol/scope.pyi b/acouchbase_analytics/protocol/scope.pyi index 87b1a52..74929b8 100644 --- a/acouchbase_analytics/protocol/scope.pyi +++ b/acouchbase_analytics/protocol/scope.pyi @@ -21,10 +21,12 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter from acouchbase_analytics.protocol.database import AsyncDatabase as AsyncDatabase -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.protocol.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncScope: def __init__(self, database: AsyncDatabase, scope_name: str) -> None: ... @@ -52,3 +54,23 @@ class AsyncScope: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> AsyncQueryHandle: ... diff --git a/acouchbase_analytics/protocol/streaming.py b/acouchbase_analytics/protocol/streaming.py index 48d7e58..b2af332 100644 --- a/acouchbase_analytics/protocol/streaming.py +++ b/acouchbase_analytics/protocol/streaming.py @@ -20,7 +20,7 @@ from httpx import Response as HttpCoreResponse -from acouchbase_analytics.protocol._core.request_context import AsyncRequestContext +from acouchbase_analytics.protocol._core.request_context import AsyncStreamingRequestContext from acouchbase_analytics.protocol._core.retries import AsyncRetryHandler from couchbase_analytics.common._core import ParsedResult, ParsedResultType from couchbase_analytics.common._core.query import build_query_metadata @@ -30,11 +30,12 @@ class AsyncHttpStreamingResponse: - def __init__(self, request_context: AsyncRequestContext) -> None: + def __init__(self, request_context: AsyncStreamingRequestContext, request_id: Optional[str] = None) -> None: self._metadata: Optional[QueryMetadata] = None self._core_response: HttpCoreResponse # Goal is to treat the AsyncHttpStreamingResponse as a "task group" self._request_context = request_context + self._request_id = request_id async def _close_in_background(self) -> None: """ @@ -68,7 +69,7 @@ async def _process_response( """ **INTERNAL** """ - json_response = await self._request_context.process_response( + json_response = await self._request_context.process_streaming_response( self.close, raw_response=raw_response, handle_context_shutdown=handle_context_shutdown ) await self.set_metadata(json_data=json_response) @@ -111,7 +112,9 @@ async def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Opti **INTERNAL** """ try: - self._metadata = QueryMetadata(build_query_metadata(json_data=json_data, raw_metadata=raw_metadata)) + self._metadata = QueryMetadata( + build_query_metadata(json_data=json_data, raw_metadata=raw_metadata, request_id=self._request_id) + ) await self._request_context.shutdown() except (AnalyticsError, ValueError) as err: await self._request_context.reraise_after_shutdown(err) diff --git a/couchbase_analytics/protocol/result.py b/acouchbase_analytics/query_handle.py similarity index 73% rename from couchbase_analytics/protocol/result.py rename to acouchbase_analytics/query_handle.py index 7165b68..bb3f4cb 100644 --- a/couchbase_analytics/protocol/result.py +++ b/acouchbase_analytics/query_handle.py @@ -13,5 +13,5 @@ # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import annotations +from couchbase_analytics.common.query_handle import AsyncQueryHandle as AsyncQueryHandle # noqa: F401 +from couchbase_analytics.common.query_handle import AsyncQueryResultHandle as AsyncQueryResultHandle # noqa: F401 diff --git a/acouchbase_analytics/scope.py b/acouchbase_analytics/scope.py index 6d12ac3..936aa74 100644 --- a/acouchbase_analytics/scope.py +++ b/acouchbase_analytics/scope.py @@ -25,7 +25,7 @@ else: from typing import TypeAlias -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.result import AsyncQueryResult if TYPE_CHECKING: from acouchbase_analytics.protocol.database import AsyncDatabase diff --git a/acouchbase_analytics/scope.pyi b/acouchbase_analytics/scope.pyi index b02fa4d..906ca1f 100644 --- a/acouchbase_analytics/scope.pyi +++ b/acouchbase_analytics/scope.pyi @@ -21,9 +21,11 @@ if sys.version_info < (3, 11): else: from typing import Unpack +from acouchbase_analytics import JSONType +from acouchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from acouchbase_analytics.protocol.database import AsyncDatabase as AsyncDatabase -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs -from couchbase_analytics.result import AsyncQueryResult +from acouchbase_analytics.query_handle import AsyncQueryHandle +from acouchbase_analytics.result import AsyncQueryResult class AsyncScope: def __init__(self, database: AsyncDatabase, scope_name: str) -> None: ... @@ -41,11 +43,31 @@ class AsyncScope: ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: Unpack[QueryOptionsKwargs] + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: Unpack[QueryOptionsKwargs] ) -> Awaitable[AsyncQueryResult]: ... @overload def execute_query( - self, statement: str, options: QueryOptions, *args: str, **kwargs: str + self, statement: str, options: QueryOptions, *args: JSONType, **kwargs: str ) -> Awaitable[AsyncQueryResult]: ... @overload - def execute_query(self, statement: str, *args: str, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + def execute_query(self, statement: str, *args: JSONType, **kwargs: str) -> Awaitable[AsyncQueryResult]: ... + @overload + def start_query(self, statement: str) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> AsyncQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> AsyncQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> AsyncQueryHandle: ... diff --git a/acouchbase_analytics/tests/connection_t.py b/acouchbase_analytics/tests/connection_t.py index 1567cfb..bf62ee1 100644 --- a/acouchbase_analytics/tests/connection_t.py +++ b/acouchbase_analytics/tests/connection_t.py @@ -67,7 +67,7 @@ def test_connstr_options_max_retries(self) -> None: connstr = f'https://localhost?max_retries={max_retries}' client = _AsyncClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') assert req.max_retries == max_retries @pytest.mark.parametrize( @@ -100,7 +100,7 @@ def test_connstr_options_timeout(self, duration: str, expected_seconds: str) -> connstr = f'https://localhost?{to_query_str(opts)}' client = _AsyncClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') expected = float(expected_seconds) returned_timeout_opts = req.get_request_timeouts() assert isinstance(returned_timeout_opts, dict) diff --git a/acouchbase_analytics/tests/query_options_t.py b/acouchbase_analytics/tests/query_options_t.py index dea8e14..efd4d7b 100644 --- a/acouchbase_analytics/tests/query_options_t.py +++ b/acouchbase_analytics/tests/query_options_t.py @@ -76,7 +76,7 @@ def test_options_deserializer( deserializer = DefaultJsonDeserializer() q_opts = QueryOptions(deserializer=deserializer) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer @@ -89,7 +89,7 @@ def test_options_deserializer_kwargs( deserializer = DefaultJsonDeserializer() kwargs: QueryOptionsKwargs = {'deserializer': deserializer} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer @@ -101,9 +101,9 @@ def test_options_max_retries( ) -> None: if max_retries is not None: q_opts = QueryOptions(max_retries=max_retries) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) @@ -115,9 +115,9 @@ def test_options_max_retries_kwargs( ) -> None: if max_retries is not None: kwargs: QueryOptionsKwargs = {'max_retries': max_retries} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) @@ -128,7 +128,7 @@ def test_options_named_parameters( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} q_opts = QueryOptions(named_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -138,7 +138,7 @@ def test_options_named_parameters_kwargs( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} kwargs: QueryOptionsKwargs = {'named_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -148,7 +148,7 @@ def test_options_positional_parameters( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] q_opts = QueryOptions(positional_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -158,7 +158,7 @@ def test_options_positional_parameters_kwargs( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] kwargs: QueryOptionsKwargs = {'positional_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -167,7 +167,7 @@ def test_options_raw(self, query_statment: str, request_builder: _RequestBuilder pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} q_opts = QueryOptions(raw=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -178,7 +178,7 @@ def test_options_raw_kwargs( pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} kwargs: QueryOptionsKwargs = {'raw': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -187,7 +187,7 @@ def test_options_readonly( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(readonly=True) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -196,7 +196,7 @@ def test_options_readonly_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'readonly': True} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -207,7 +207,7 @@ def test_options_scan_consistency( from couchbase_analytics.query import QueryScanConsistency q_opts = QueryOptions(scan_consistency=QueryScanConsistency.REQUEST_PLUS) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -218,7 +218,7 @@ def test_options_scan_consistency_kwargs( from couchbase_analytics.query import QueryScanConsistency kwargs: QueryOptionsKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -227,7 +227,7 @@ def test_options_timeout( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=20)) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -238,7 +238,7 @@ def test_options_timeout_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=20)} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -248,14 +248,14 @@ def test_options_timeout_kwargs( def test_options_timeout_must_be_positive(self, query_statment: str, request_builder: _RequestBuilder) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=-1)) with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, q_opts) + request_builder.build_query_request(query_statment, q_opts) def test_options_timeout_must_be_positive_kwargs( self, query_statment: str, request_builder: _RequestBuilder ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=-1)} with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, **kwargs) + request_builder.build_query_request(query_statment, **kwargs) class ClusterQueryOptionsTests(QueryOptionsTestSuite): diff --git a/couchbase_analytics/cluster.py b/couchbase_analytics/cluster.py index 8ca0784..9c133de 100644 --- a/couchbase_analytics/cluster.py +++ b/couchbase_analytics/cluster.py @@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Optional, Union from couchbase_analytics.database import Database +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult if TYPE_CHECKING: @@ -139,6 +140,22 @@ def execute_query( """ # noqa: E501 return self._impl.execute_query(statement, *args, **kwargs) + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + """Executes a query against an Analytics cluster in async mode. + + .. seealso:: + :meth:`couchbase_analytics.Scope.start_query`: For how to execute scope-level queries. + + Args: + statement: The SQL++ statement to execute. + options (:class:`~couchbase_analytics.options.StartQueryOptions`): Optional parameters for the query operation. + **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~couchbase_analytics.options.StartQueryOptions` + + Returns: + :class:`~couchbase_analytics.query_handle.BlockingQueryHandle`: An instance of a :class:`~couchbase_analytics.query_handle.BlockingQueryHandle` + """ # noqa: E501 + return self._impl.start_query(statement, *args, **kwargs) + def shutdown(self) -> None: """Shuts down this cluster instance. Cleaning up all resources associated with it. diff --git a/couchbase_analytics/cluster.pyi b/couchbase_analytics/cluster.pyi index 38f44ff..0b32759 100644 --- a/couchbase_analytics/cluster.pyi +++ b/couchbase_analytics/cluster.pyi @@ -25,7 +25,15 @@ else: from couchbase_analytics import JSONType from couchbase_analytics.credential import Credential from couchbase_analytics.database import Database -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult class Cluster: @@ -114,6 +122,26 @@ class Cluster: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... def shutdown(self) -> None: ... @overload @classmethod diff --git a/couchbase_analytics/common/_core/error_context.py b/couchbase_analytics/common/_core/error_context.py index 1356bc0..b43c1c6 100644 --- a/couchbase_analytics/common/_core/error_context.py +++ b/couchbase_analytics/common/_core/error_context.py @@ -17,11 +17,11 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from httpx import Response as HttpCoreResponse -from couchbase_analytics.protocol._core.request import QueryRequest +from couchbase_analytics.protocol._core.request import FetchResultsRequest, HttpRequest, QueryRequest @dataclass @@ -42,6 +42,9 @@ def set_errors(self, errors: List[Dict[str, Any]]) -> None: def set_first_error(self, error: Dict[str, Any]) -> None: self.first_error = error + def set_statement(self, statement: Optional[str]) -> None: + self.statement = statement + def maybe_update_errors(self) -> None: if self.errors is not None and len(self.errors) > 0: return @@ -51,8 +54,10 @@ def maybe_update_errors(self) -> None: def update_num_attempts(self) -> None: self.num_attempts += 1 - def update_request_context(self, request: QueryRequest) -> None: - self.path = request.url.path + def update_request_context( + self, request: Union[HttpRequest, FetchResultsRequest, QueryRequest], path: Optional[str] = None + ) -> None: + self.path = path or request.url.path def update_response_context(self, response: HttpCoreResponse) -> None: network_stream = response.extensions.get('network_stream', None) diff --git a/couchbase_analytics/common/_core/query.py b/couchbase_analytics/common/_core/query.py index 93c18d7..9bb40fa 100644 --- a/couchbase_analytics/common/_core/query.py +++ b/couchbase_analytics/common/_core/query.py @@ -59,7 +59,9 @@ class QueryMetadataCore(TypedDict, total=False): status: Optional[str] -def build_query_metadata(json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> QueryMetadataCore: +def build_query_metadata( + json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None, request_id: Optional[str] = None +) -> QueryMetadataCore: """ Builds the query metadata from the raw bytes. @@ -83,7 +85,7 @@ def build_query_metadata(json_data: Optional[Any] = None, raw_metadata: Optional warnings.append({'code': warning.get('code', 0), 'message': warning.get('msg', '')}) metadata: QueryMetadataCore = { - 'request_id': json_data.get('requestID', ''), + 'request_id': json_data.get('requestID', request_id or ''), 'client_context_id': json_data.get('clientContextID', ''), 'warnings': warnings, } diff --git a/couchbase_analytics/common/_core/query_handle.py b/couchbase_analytics/common/_core/query_handle.py new file mode 100644 index 0000000..70597ea --- /dev/null +++ b/couchbase_analytics/common/_core/query_handle.py @@ -0,0 +1,110 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Awaitable, List, Mapping, Optional, TypedDict, Union + +from couchbase_analytics.common._core.result import QueryResult + + +class QueryHandle(ABC): + @property + @abstractmethod + def handle(self) -> str: + """ """ + raise NotImplementedError + + @property + @abstractmethod + def request_id(self) -> str: + """ """ + raise NotImplementedError + + @abstractmethod + def cancel(self) -> Union[Awaitable[None], None]: + """ + Cancel the query associated with the QueryHandle. + """ + raise NotImplementedError + + @abstractmethod + def fetch_result_handle(self) -> Union[Awaitable[Optional[QueryResultHandle]], Optional[QueryResultHandle]]: + raise NotImplementedError + + +class QueryResultHandle(ABC): + """Abstract base class for query result handle.""" + + @property + @abstractmethod + def request_id(self) -> str: + """ """ + raise NotImplementedError + + @abstractmethod + def fetch_results(self) -> Union[Awaitable[QueryResult], QueryResult]: + """ + Get all the results. + """ + raise NotImplementedError + + @abstractmethod + def discard_results(self) -> Union[Awaitable[None], None]: + """ + Discard the query results associated with the QueryResultHandle. + """ + raise NotImplementedError + + +class ResultPartition(TypedDict): + handle: str + result_count: Optional[int] + + +@dataclass +class QueryHandleStatusResponse: + """**INTERNAL**""" + + request_id: str + status: str + handle: str + result_count: Optional[int] = None + partitions: Optional[List[ResultPartition]] = None + result_set_ordered: Optional[bool] = None + metrics: Optional[Mapping[str, Union[str, int]]] = None + created_at: Optional[str] = None + + @classmethod + def from_server(cls, request_id: str, raw_json: Any) -> QueryHandleStatusResponse: + raw_partitions = raw_json.get('partitions', []) + partitions: list[ResultPartition] = [] + for partition in raw_partitions: + partitions.append( + {'handle': partition.get('handle', None), 'result_count': partition.get('resultCount', None)} + ) + return cls( + request_id, + raw_json.get('status', None), + raw_json.get('handle', None), + result_count=raw_json.get('resultCount', None), + partitions=partitions, + result_set_ordered=raw_json.get('resultSetOrdered', None), + metrics=raw_json.get('metrics', None), + created_at=raw_json.get('createdAt', None), + ) diff --git a/couchbase_analytics/common/_core/result.py b/couchbase_analytics/common/_core/result.py index 0146122..0dbfef5 100644 --- a/couchbase_analytics/common/_core/result.py +++ b/couchbase_analytics/common/_core/result.py @@ -18,7 +18,7 @@ import sys from abc import ABC, abstractmethod -from typing import Any, Coroutine, List, Optional, Union +from typing import Any, Awaitable, List, Optional, Union if sys.version_info < (3, 9): from typing import AsyncIterator as PyAsyncIterator @@ -34,7 +34,7 @@ class QueryResult(ABC): """Abstract base class for query results.""" @abstractmethod - def cancel(self) -> Union[Coroutine[Any, Any, None], None]: + def cancel(self) -> Union[Awaitable[None], None]: """ Cancel streaming the query results. @@ -43,7 +43,7 @@ def cancel(self) -> Union[Coroutine[Any, Any, None], None]: raise NotImplementedError @abstractmethod - def get_all_rows(self) -> Union[Coroutine[Any, Any, List[Any]], List[Any]]: + def get_all_rows(self) -> Union[Awaitable[List[Any]], List[Any]]: """Convenience method to load all query results into memory.""" raise NotImplementedError diff --git a/couchbase_analytics/common/logging.py b/couchbase_analytics/common/logging.py index d599174..09253fd 100644 --- a/couchbase_analytics/common/logging.py +++ b/couchbase_analytics/common/logging.py @@ -16,6 +16,7 @@ import logging from enum import Enum +from typing import Optional LOG_FORMAT_ARR = [ '[%(asctime)s.%(msecs)03d]', @@ -36,10 +37,26 @@ class LogLevel(Enum): CRITICAL = logging.CRITICAL +def _has_open_handlers(logger: logging.Logger) -> bool: + current: Optional[logging.Logger] = logger + while current is not None: + for handler in current.handlers: + if isinstance(handler, logging.StreamHandler): + if hasattr(handler.stream, 'closed') and handler.stream.closed: + return False + if not current.propagate: + break + current = current.parent + return True + + def log_message(logger: logging.Logger, message: str, log_level: LogLevel) -> None: if not logger or not logger.hasHandlers(): return + if not _has_open_handlers(logger): + return + if log_level == LogLevel.DEBUG: logger.debug(message) elif log_level == LogLevel.INFO: diff --git a/couchbase_analytics/common/options.py b/couchbase_analytics/common/options.py index 5a680cc..d282b9b 100644 --- a/couchbase_analytics/common/options.py +++ b/couchbase_analytics/common/options.py @@ -26,13 +26,17 @@ from couchbase_analytics.common.options_base import ( ClusterOptionsBase, + FetchResultsOptionsBase, QueryOptionsBase, SecurityOptionsBase, + StartQueryOptionsBase, TimeoutOptionsBase, ) from couchbase_analytics.common.options_base import ClusterOptionsKwargs as ClusterOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options_base import FetchResultsOptionsKwargs as FetchResultsOptionsKwargs # noqa: F401 from couchbase_analytics.common.options_base import QueryOptionsKwargs as QueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options_base import SecurityOptionsKwargs as SecurityOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options_base import StartQueryOptionsKwargs as StartQueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options_base import TimeoutOptionsKwargs as TimeoutOptionsKwargs # noqa: F401 """ @@ -57,6 +61,14 @@ class ClusterOptions(ClusterOptionsBase): """ # noqa: E501 +class FetchResultsOptions(FetchResultsOptionsBase): + """Available options for Analytics asynchronous server query fetch results operation. + + Args: + deserializer (Optional[Deserializer]): Specifies a :class:`~couchbase_analytics.deserializer.Deserializer` to apply to results. Defaults to `None` (:class:`~couchbase_analytics.deserializer.DefaultJsonDeserializer`). + """ # noqa: E501 + + class SecurityOptions(SecurityOptionsBase): """Available security options to set when creating a cluster. @@ -149,7 +161,29 @@ class QueryOptions(QueryOptionsBase): Args: client_context_id (Optional[str]): Set to configure a unique identifier for this query request. Defaults to `None` (autogenerated by client). deserializer (Optional[Deserializer]): Specifies a :class:`~couchbase_analytics.deserializer.Deserializer` to apply to results. Defaults to `None` (:class:`~couchbase_analytics.deserializer.DefaultJsonDeserializer`). - lazy_execute (Optional[bool]): **VOLATILE** If enabled, the query will not execute until the application begins to iterate over results. Defaulst to `None` (disabled). + lazy_execute (Optional[bool]): **VOLATILE** If enabled, the query will not execute until the application begins to iterate over results. Defaults to `None` (disabled). + max_retries (Optional[int]): **VOLATILE** Set to configure the maximum number of retries for a request. + named_parameters (Optional[Dict[str, :py:type:`~couchbase_analytics.JSONType`]]): Values to use for positional placeholders in query. + positional_parameters (Optional[List[:py:type:`~couchbase_analytics.JSONType`]]):, optional): Values to use for named placeholders in query. + query_context (Optional[str]): Specifies the context within which this query should be executed. + raw (Optional[Dict[str, Any]]): Specifies any additional parameters which should be passed to the Analytics engine when executing the query. + readonly (Optional[bool]): Specifies that this query should be executed in read-only mode, disabling the ability for the query to make any changes to the data. + scan_consistency (Optional[QueryScanConsistency]): Specifies the consistency requirements when executing the query. + timeout (Optional[timedelta]): Set to configure allowed time for operation to complete. Defaults to `None` (75s). + stream_config (Optional[JsonStreamConfig]): **VOLATILE** Configuration for JSON stream processing. Defaults to `None` (default configuration). See :class:`~couchbase_analytics.common.json_parsing.JsonStreamConfig` for details. + """ # noqa: E501 + + +class StartQueryOptions(StartQueryOptionsBase): + """Available options for Analytics asynchronous server query operation. + + Timeout will default to cluster setting if not set for the operation. + + .. note:: + Options marked **VOLATILE** are subject to change at any time. + + Args: + client_context_id (Optional[str]): Set to configure a unique identifier for this query request. Defaults to `None` (autogenerated by client). max_retries (Optional[int]): **VOLATILE** Set to configure the maximum number of retries for a request. named_parameters (Optional[Dict[str, :py:type:`~couchbase_analytics.JSONType`]]): Values to use for positional placeholders in query. positional_parameters (Optional[List[:py:type:`~couchbase_analytics.JSONType`]]):, optional): Values to use for named placeholders in query. @@ -164,7 +198,9 @@ class QueryOptions(QueryOptionsBase): OptionsClass: TypeAlias = Union[ ClusterOptions, + FetchResultsOptions, SecurityOptions, TimeoutOptions, QueryOptions, + StartQueryOptions, ] diff --git a/couchbase_analytics/common/options_base.py b/couchbase_analytics/common/options_base.py index 1fd4811..a02a658 100644 --- a/couchbase_analytics/common/options_base.py +++ b/couchbase_analytics/common/options_base.py @@ -183,3 +183,66 @@ class QueryOptionsBase(Dict[str, object]): def __init__(self, **kwargs: Unpack[QueryOptionsKwargs]) -> None: filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} super().__init__(**filtered_kwargs) + + +class StartQueryOptionsKwargs(TypedDict, total=False): + client_context_id: Optional[str] + max_retries: Optional[int] + named_parameters: Optional[Dict[str, JSONType]] + positional_parameters: Optional[Iterable[JSONType]] + query_context: Optional[str] + raw: Optional[Dict[str, Any]] + readonly: Optional[bool] + scan_consistency: Optional[Union[QueryScanConsistency, str]] + stream_config: Optional[JsonStreamConfig] + timeout: Optional[timedelta] + + +StartQueryOptionsValidKeys: TypeAlias = Literal[ + 'client_context_id', + 'max_retries', + 'named_parameters', + 'positional_parameters', + 'query_context', + 'raw', + 'readonly', + 'scan_consistency', + 'stream_config', + 'timeout', +] + + +class StartQueryOptionsBase(Dict[str, object]): + VALID_OPTION_KEYS: List[StartQueryOptionsValidKeys] = [ + 'client_context_id', + 'max_retries', + 'named_parameters', + 'positional_parameters', + 'query_context', + 'raw', + 'readonly', + 'scan_consistency', + 'stream_config', + 'timeout', + ] + + def __init__(self, **kwargs: Unpack[StartQueryOptionsKwargs]) -> None: + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + super().__init__(**filtered_kwargs) + + +class FetchResultsOptionsKwargs(TypedDict, total=False): + deserializer: Optional[Deserializer] + + +FetchResultsOptionsValidKeys: TypeAlias = Literal['deserializer',] + + +class FetchResultsOptionsBase(Dict[str, object]): + VALID_OPTION_KEYS: List[FetchResultsOptionsValidKeys] = [ + 'deserializer', + ] + + def __init__(self, **kwargs: Unpack[FetchResultsOptionsKwargs]) -> None: + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + super().__init__(**filtered_kwargs) diff --git a/couchbase_analytics/common/query_handle.py b/couchbase_analytics/common/query_handle.py new file mode 100644 index 0000000..0f09404 --- /dev/null +++ b/couchbase_analytics/common/query_handle.py @@ -0,0 +1,121 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from abc import abstractmethod +from typing import Awaitable, Optional + +from couchbase_analytics.common._core.query_handle import QueryHandle, QueryResultHandle +from couchbase_analytics.common.result import AsyncQueryResult, BlockingQueryResult + + +class AsyncQueryHandle(QueryHandle): + @property + @abstractmethod + def handle(self) -> str: + """ """ + raise NotImplementedError + + @property + @abstractmethod + def request_id(self) -> str: + """ """ + raise NotImplementedError + + @abstractmethod + def cancel(self) -> Awaitable[None]: + """ + Cancel the query associated with the QueryHandle. + """ + raise NotImplementedError + + @abstractmethod + def fetch_result_handle(self) -> Awaitable[Optional[AsyncQueryResultHandle]]: + raise NotImplementedError + + +class BlockingQueryHandle(QueryHandle): + @property + @abstractmethod + def handle(self) -> str: + """ """ + raise NotImplementedError + + @property + @abstractmethod + def request_id(self) -> str: + """ """ + raise NotImplementedError + + @abstractmethod + def cancel(self) -> None: + """ + Cancel the query associated with the QueryHandle. + """ + raise NotImplementedError + + @abstractmethod + def fetch_result_handle(self) -> Optional[BlockingQueryResultHandle]: + raise NotImplementedError + + +class AsyncQueryResultHandle(QueryResultHandle): + """Abstract base class for async query result handle.""" + + @property + @abstractmethod + def request_id(self) -> str: + """ """ + raise NotImplementedError + + @abstractmethod + def fetch_results(self) -> Awaitable[AsyncQueryResult]: + """ + Get all the results. + """ + raise NotImplementedError + + @abstractmethod + def discard_results(self) -> Awaitable[None]: + """ + Discard the query results associated with the AsyncQueryResultHandle. + """ + raise NotImplementedError + + +class BlockingQueryResultHandle(QueryResultHandle): + """Abstract base class for query result handle.""" + + @property + @abstractmethod + def request_id(self) -> str: + """ """ + raise NotImplementedError + + @abstractmethod + def fetch_results(self) -> BlockingQueryResult: + """ + Get all the results. + """ + raise NotImplementedError + + @abstractmethod + def discard_results(self) -> None: + """ + Discard the query results associated with the BlockingQueryResultHandle. + """ + raise NotImplementedError diff --git a/couchbase_analytics/options.py b/couchbase_analytics/options.py index bc8f846..47c432b 100644 --- a/couchbase_analytics/options.py +++ b/couchbase_analytics/options.py @@ -16,9 +16,13 @@ from couchbase_analytics.common.options import ClusterOptions as ClusterOptions # noqa: F401 from couchbase_analytics.common.options import ClusterOptionsKwargs as ClusterOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptions as FetchResultsOptions # noqa: F401 +from couchbase_analytics.common.options import FetchResultsOptionsKwargs as FetchResultsOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import QueryOptions as QueryOptions # noqa: F401 from couchbase_analytics.common.options import QueryOptionsKwargs as QueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import SecurityOptions as SecurityOptions # noqa: F401 from couchbase_analytics.common.options import SecurityOptionsKwargs as SecurityOptionsKwargs # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptions as StartQueryOptions # noqa: F401 +from couchbase_analytics.common.options import StartQueryOptionsKwargs as StartQueryOptionsKwargs # noqa: F401 from couchbase_analytics.common.options import TimeoutOptions as TimeoutOptions # noqa: F401 from couchbase_analytics.common.options import TimeoutOptionsKwargs as TimeoutOptionsKwargs # noqa: F401 diff --git a/couchbase_analytics/protocol/_core/client_adapter.py b/couchbase_analytics/protocol/_core/client_adapter.py index e620ac8..d4bf976 100644 --- a/couchbase_analytics/protocol/_core/client_adapter.py +++ b/couchbase_analytics/protocol/_core/client_adapter.py @@ -17,7 +17,7 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Optional, cast +from typing import Optional, Union, cast from uuid import uuid4 from httpx import URL, BasicAuth, Client, Response @@ -25,12 +25,10 @@ from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.deserializer import Deserializer from couchbase_analytics.common.logging import LogLevel, log_message +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import _ConnectionDetails from couchbase_analytics.protocol.options import OptionsBuilder -if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request import QueryRequest - class _ClientAdapter: """ @@ -162,7 +160,9 @@ def create_client(self) -> None: def log_message(self, message: str, log_level: LogLevel) -> None: log_message(logger, f'{self.log_prefix} {message}', log_level) - def send_request(self, request: QueryRequest) -> Response: + def send_request( + self, request: Union[CancelRequest, HttpRequest, QueryRequest, StartQueryRequest], stream: Optional[bool] = True + ) -> Response: """ **INTERNAL** """ @@ -170,8 +170,18 @@ def send_request(self, request: QueryRequest) -> Response: raise RuntimeError('Client not created yet') url = URL(scheme=request.url.scheme, host=request.url.ip, port=request.url.port, path=request.url.path) - req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) - return self._client.send(req, stream=True) + if isinstance(request, (QueryRequest, StartQueryRequest)): + req = self._client.build_request(request.method, url, json=request.body, extensions=request.extensions) + else: + headers = request.headers if request.headers is not None else None + data = request.data if isinstance(request, CancelRequest) else None + req = self._client.build_request( + request.method, url, data=data, headers=headers, extensions=request.extensions + ) + + if stream is None: + stream = True + return self._client.send(req, stream=stream) def reset_client(self) -> None: """ diff --git a/couchbase_analytics/protocol/_core/json_stream.py b/couchbase_analytics/protocol/_core/json_stream.py index f1bc6d5..0fc975a 100644 --- a/couchbase_analytics/protocol/_core/json_stream.py +++ b/couchbase_analytics/protocol/_core/json_stream.py @@ -30,7 +30,7 @@ from couchbase_analytics.protocol._core.json_token_parser import JsonTokenParser if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request_context import RequestContext + from couchbase_analytics.protocol._core.request_context import StreamingRequestContext class JsonStream: @@ -80,7 +80,7 @@ def token_stream_exhausted(self) -> bool: """ return self._token_stream_exhausted - def _continue_processing(self, request_context: Optional[RequestContext] = None) -> bool: + def _continue_processing(self, request_context: Optional[StreamingRequestContext] = None) -> bool: """ **INTERNAL** """ @@ -125,7 +125,7 @@ def _log_message(self, message: str, level: LogLevel) -> None: if self._log_handler is not None: self._log_handler(message, level) - def _process_token_stream(self, request_context: Optional[RequestContext] = None) -> None: + def _process_token_stream(self, request_context: Optional[StreamingRequestContext] = None) -> None: """ **INTERNAL** """ @@ -207,7 +207,7 @@ def get_result(self, timeout: float) -> Optional[ParsedResult]: def start_parsing( self, - request_context: Optional[RequestContext] = None, + request_context: Optional[StreamingRequestContext] = None, notify_on_results_or_error: Optional[Future[ParsedResultType]] = None, ) -> None: if self._json_stream_parser is not None: @@ -218,6 +218,6 @@ def start_parsing( def continue_parsing( self, - request_context: Optional[RequestContext] = None, + request_context: Optional[StreamingRequestContext] = None, ) -> None: self._process_token_stream(request_context=request_context) diff --git a/couchbase_analytics/protocol/_core/request.py b/couchbase_analytics/protocol/_core/request.py index c2cc5f1..9f8476e 100644 --- a/couchbase_analytics/protocol/_core/request.py +++ b/couchbase_analytics/protocol/_core/request.py @@ -18,13 +18,26 @@ from copy import deepcopy from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Dict, Optional, TypedDict, Union, cast +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Dict, + List, + Mapping, + Optional, + TypedDict, + Union, + cast, + overload, +) from uuid import uuid4 from couchbase_analytics.common.deserializer import Deserializer -from couchbase_analytics.common.options import QueryOptions +from couchbase_analytics.common.options import FetchResultsOptions, QueryOptions, StartQueryOptions from couchbase_analytics.common.request import RequestURL -from couchbase_analytics.protocol.options import QueryOptionsTransformedKwargs +from couchbase_analytics.protocol.options import QueryOptionsTransformedKwargs, StartQueryOptionsTransformedKwargs from couchbase_analytics.query import QueryScanConsistency if TYPE_CHECKING: @@ -46,20 +59,17 @@ class RequestExtensions(TypedDict, total=False): @dataclass -class QueryRequest: +class HttpRequest: url: RequestURL - deserializer: Deserializer - body: Dict[str, Union[str, object]] extensions: RequestExtensions + path: str + method: str + headers: Mapping[str, str] max_retries: int - method: str = 'POST' - - options: Optional[QueryOptionsTransformedKwargs] = None - enable_cancel: Optional[bool] = None def add_trace_to_extensions( self, handler: Callable[[str, str], Union[None, Coroutine[Any, Any, None]]] - ) -> QueryRequest: + ) -> HttpRequest: """ **INTERNAL** """ @@ -68,14 +78,6 @@ def add_trace_to_extensions( self.extensions['trace'] = handler return self - def get_request_statement(self) -> Optional[str]: - """ - **INTERNAL** - """ - if 'statement' in self.body: - return cast(str, self.body['statement']) - return None - def get_request_timeouts(self) -> Optional[RequestTimeoutExtensions]: """ **INTERNAL** @@ -84,7 +86,7 @@ def get_request_timeouts(self) -> Optional[RequestTimeoutExtensions]: return {} return self.extensions['timeout'] - def update_url(self, ip: str, path: str) -> QueryRequest: + def update_url(self, ip: str, path: str) -> HttpRequest: """ **INTERNAL** """ @@ -93,6 +95,53 @@ def update_url(self, ip: str, path: str) -> QueryRequest: return self +class CancelRequestData(TypedDict): + request_id: str + + +@dataclass +class CancelRequest(HttpRequest): + data: CancelRequestData + + +@dataclass +class FetchResultsRequest(HttpRequest): + deserializer: Deserializer + should_stream: bool = True + + +@dataclass +class QueryRequest(HttpRequest): + deserializer: Deserializer + body: Dict[str, Union[str, object]] + options: Optional[QueryOptionsTransformedKwargs] = None + enable_cancel: Optional[bool] = None + should_stream: bool = True + + def get_request_statement(self) -> Optional[str]: + """ + **INTERNAL** + """ + if 'statement' in self.body: + return cast(str, self.body['statement']) + return None + + +@dataclass +class StartQueryRequest(HttpRequest): + body: Dict[str, Union[str, object]] + options: Optional[StartQueryOptionsTransformedKwargs] = None + should_stream: bool = False + + def get_request_statement(self) -> Optional[str]: + """ + **INTERNAL** + """ + if 'statement' in self.body: + return cast(str, self.body['statement']) + return None + + class _RequestBuilder: def __init__( self, @@ -113,13 +162,51 @@ def __init__( if self._conn_details.is_secure() and self._conn_details.sni_hostname is not None: self._extensions['sni_hostname'] = self._conn_details.sni_hostname - def build_base_query_request( # noqa: C901 + def build_request_from_handle(self, handle: str, method: Optional[str] = None) -> HttpRequest: + method = method or 'GET' + extensions = deepcopy(self._extensions) + max_retries = self._conn_details.get_max_retries() + return HttpRequest( + self._conn_details.url, extensions, handle, method=method, headers={}, max_retries=max_retries + ) + + def build_cancel_request(self, request_id: str) -> CancelRequest: + extensions = deepcopy(self._extensions) + max_retries = self._conn_details.get_max_retries() + return CancelRequest( + self._conn_details.url, + extensions, + '/api/v1/active_requests', + 'DELETE', + {'Content-Type': 'application/x-www-form-urlencoded'}, + max_retries, + {'request_id': request_id}, + ) + + def build_discard_results_request(self, handle: str) -> HttpRequest: + return self.build_request_from_handle(handle, method='DELETE') + + def build_fetch_results_request(self, handle: str, *args: object, **kwargs: object) -> FetchResultsRequest: + q_opts = self._opts_builder.build_options(FetchResultsOptions, kwargs, args) + base_request = self.build_request_from_handle(handle) + deserializer = q_opts.pop('deserializer', None) or self._conn_details.default_deserializer + max_retries = self._conn_details.get_max_retries() + return FetchResultsRequest( + base_request.url, + base_request.extensions, + base_request.path, + base_request.method, + {}, + max_retries, + deserializer, + ) + + def build_query_request( self, statement: str, *args: object, - is_async: Optional[bool] = False, **kwargs: object, - ) -> QueryRequest: # noqa: C901 + ) -> QueryRequest: enable_cancel: Optional[bool] = None cancel_kwarg_token = kwargs.pop('enable_cancel', None) if isinstance(cancel_kwarg_token, bool): @@ -138,21 +225,102 @@ def build_base_query_request( # noqa: C901 else: parsed_args_list.append(arg) + extensions, body, q_opts = self._get_query_request_details( + QueryOptions, opts, statement, parsed_args_list=parsed_args_list, **kwargs + ) + + # handle deserializer and max_retries + deserializer = q_opts.pop('deserializer', None) or self._conn_details.default_deserializer + max_retries = q_opts.pop('max_retries', None) or self._conn_details.get_max_retries() + + return QueryRequest( + self._conn_details.url, + extensions, + '', + 'POST', + {}, + max_retries, + deserializer, + body, + options=q_opts, + enable_cancel=enable_cancel, + ) + + def build_start_query_request( # noqa: C901 + self, + statement: str, + *args: object, + **kwargs: object, + ) -> StartQueryRequest: # noqa: C901 + # default if no options provided + opts = StartQueryOptions() + args_list = list(args) + parsed_args_list = [] + for arg in args_list: + if isinstance(arg, StartQueryOptions): + # we have options passed in + opts = arg + else: + parsed_args_list.append(arg) + + extensions, body, q_opts = self._get_query_request_details( + StartQueryOptions, opts, statement, parsed_args_list=parsed_args_list, **kwargs + ) + + body['mode'] = 'async' + max_retries = q_opts.pop('max_retries', None) or self._conn_details.get_max_retries() + + return StartQueryRequest( + self._conn_details.url, + extensions, + '', + 'POST', + {}, + max_retries, + body, + options=q_opts, + ) + + @overload + def _get_query_request_details( + self, + option_type: type[QueryOptions], + query_opts: QueryOptions, + statement: str, + parsed_args_list: Optional[List[object]] = None, + **kwargs: object, + ) -> tuple[RequestExtensions, Dict[str, Union[str, object]], QueryOptionsTransformedKwargs]: ... + + @overload + def _get_query_request_details( + self, + option_type: type[StartQueryOptions], + query_opts: StartQueryOptions, + statement: str, + parsed_args_list: Optional[List[object]] = None, + **kwargs: object, + ) -> tuple[RequestExtensions, Dict[str, Union[str, object]], StartQueryOptionsTransformedKwargs]: ... + + def _get_query_request_details( # noqa: C901 + self, + option_type: Union[type[QueryOptions], type[StartQueryOptions]], + query_opts: Union[QueryOptions, StartQueryOptions], + statement: str, + parsed_args_list: Optional[List[object]] = None, + **kwargs: object, + ) -> Any: # noqa: C901 # need to pop out named params prior to sending options to the builder - named_param_keys = list(filter(lambda k: k not in QueryOptions.VALID_OPTION_KEYS, kwargs.keys())) + named_param_keys = list(filter(lambda k: k not in option_type.VALID_OPTION_KEYS, kwargs.keys())) named_params = {} for key in named_param_keys: named_params[key] = kwargs.pop(key) - q_opts = self._opts_builder.build_options(QueryOptions, QueryOptionsTransformedKwargs, kwargs, opts) + q_opts = self._opts_builder.build_options(option_type, kwargs, query_opts) # positional params and named params passed in outside of QueryOptions serve as overrides if parsed_args_list and len(parsed_args_list) > 0: q_opts['positional_parameters'] = parsed_args_list if named_params and len(named_params) > 0: q_opts['named_parameters'] = named_params - # handle deserializer and max_retries - deserializer = q_opts.pop('deserializer', None) or self._conn_details.default_deserializer - max_retries = q_opts.pop('max_retries', None) or self._conn_details.get_max_retries() body: Dict[str, Union[str, object]] = { 'statement': statement, @@ -191,12 +359,4 @@ def build_base_query_request( # noqa: C901 else: body['scan_consistency'] = opt_val - return QueryRequest( - self._conn_details.url, - deserializer, - body, - extensions=extensions, - max_retries=max_retries, - options=q_opts, - enable_cancel=enable_cancel, - ) + return extensions, body, q_opts diff --git a/couchbase_analytics/protocol/_core/request_context.py b/couchbase_analytics/protocol/_core/request_context.py index feab717..506af32 100644 --- a/couchbase_analytics/protocol/_core/request_context.py +++ b/couchbase_analytics/protocol/_core/request_context.py @@ -21,7 +21,7 @@ import time from concurrent.futures import CancelledError, Future, ThreadPoolExecutor from threading import Event -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union, cast from uuid import uuid4 from httpx import Response as HttpCoreResponse @@ -35,12 +35,12 @@ from couchbase_analytics.common.result import BlockingQueryResult from couchbase_analytics.protocol._core.json_stream import JsonStream from couchbase_analytics.protocol._core.net_utils import get_request_ip +from couchbase_analytics.protocol._core.request import FetchResultsRequest, HttpRequest, QueryRequest, StartQueryRequest from couchbase_analytics.protocol.connection import DEFAULT_TIMEOUTS from couchbase_analytics.protocol.errors import ErrorMapper, WrappedError if TYPE_CHECKING: from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter - from couchbase_analytics.protocol._core.request import QueryRequest class BackgroundRequest: @@ -89,84 +89,209 @@ class RequestContext: def __init__( self, client_adapter: _ClientAdapter, - request: QueryRequest, - tp_executor: ThreadPoolExecutor, - stream_config: Optional[JsonStreamConfig] = None, + request: Union[FetchResultsRequest, HttpRequest, QueryRequest, StartQueryRequest], + supports_cancellation: Optional[bool] = None, ) -> None: self._id = str(uuid4()) self._client_adapter = client_adapter self._request = request self._backoff_calc = DefaultBackoffCalculator() - self._error_ctx = ErrorContext(num_attempts=0, method=request.method, statement=request.get_request_statement()) + self._error_context = ErrorContext(num_attempts=0, method=request.method) + self._supports_cancellation = False if supports_cancellation is None else supports_cancellation self._request_state = RequestState.NotStarted - self._stream_config = stream_config or JsonStreamConfig() - self._json_stream: JsonStream - self._cancel_event = Event() - self._tp_executor = tp_executor - self._stage_completed_ft: Optional[Future[Any]] = None - self._stage_notification_ft: Optional[Future[ParsedResultType]] = None + self._cancel_event: Optional[Event] = None self._request_deadline = math.inf self._background_request: Optional[BackgroundRequest] = None self._shutdown = False - - @property - def cancel_enabled(self) -> Optional[bool]: - return self._request.enable_cancel + if self._supports_cancellation: + self._cancel_event = Event() @property def cancelled(self) -> bool: + if not self._supports_cancellation: + return False self._check_cancelled_or_timed_out() return self._request_state in [RequestState.Cancelled, RequestState.SyncCancelledPriorToTimeout] @property def error_context(self) -> ErrorContext: - return self._error_ctx - - @property - def has_stage_completed(self) -> bool: - return self._stage_completed_ft is not None and self._stage_completed_ft.done() + return self._error_context @property def is_shutdown(self) -> bool: return self._shutdown - @property - def okay_to_iterate(self) -> bool: - # NOTE: Called prior to upstream logic attempting to iterate over results from HTTP client - self._check_cancelled_or_timed_out() - return RequestState.okay_to_iterate(self._request_state) - - @property - def okay_to_stream(self) -> bool: - # NOTE: Called prior to upstream logic attempting to send request to HTTP client - self._check_cancelled_or_timed_out() - return RequestState.okay_to_stream(self._request_state) - @property def request_state(self) -> RequestState: return self._request_state @property def retry_limit_exceeded(self) -> bool: - return self.error_context.num_attempts > self._request.max_retries + return self._error_context.num_attempts > self._request.max_retries @property def timed_out(self) -> bool: self._check_cancelled_or_timed_out() return self._request_state == RequestState.Timeout + def calculate_backoff(self) -> float: + return self._backoff_calc.calculate_backoff(self._error_context.num_attempts) / 1000 + + def initialize(self) -> None: + if self._request_state == RequestState.ResetAndNotStarted: + self.log_message( + 'Request is a retry, skipping initialization', + LogLevel.DEBUG, + message_data={'request_deadline': f'{self._request_deadline}'}, + ) + return + self._request_state = RequestState.Started + timeouts = self._request.get_request_timeouts() or {} + current_time = time.monotonic() + self._request_deadline = current_time + (timeouts.get('read', None) or DEFAULT_TIMEOUTS['query_timeout']) + message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} + self.log_message('Request context initialized', LogLevel.DEBUG, message_data=message_data) + + def log_message( + self, + message: str, + log_level: LogLevel, + message_data: Optional[Dict[str, str]] = None, + append_ctx: Optional[bool] = True, + ) -> None: + if append_ctx is True: + message = f'{message}: ctx={self._id}' + if message_data is not None: + message_data_str = ', '.join(f'{k}={v}' for k, v in message_data.items()) + message = f'{message}, {message_data_str}' + self._client_adapter.log_message(message, log_level) + + def okay_to_delay_and_retry(self, delay: float) -> bool: + # calling self.timed_out will call _check_cancelled_or_timed_out, so we don't need to call it again + if self.timed_out: + return False + elif self._supports_cancellation and self._request_state == RequestState.Cancelled: + return False + + current_time = time.monotonic() + delay_time = current_time + delay + will_time_out = self._request_deadline < delay_time + if will_time_out: + self._request_state = RequestState.Timeout + message_data = { + 'current_time': f'{current_time}', + 'delay_time': f'{delay_time}', + 'request_deadline': f'{self._request_deadline}', + } + self.log_message('Request will timeout after delay', LogLevel.DEBUG, message_data=message_data) + return False + elif self.retry_limit_exceeded: + self._request_state = RequestState.Error + message_data = { + 'num_attempts': f'{self.error_context.num_attempts}', + 'max_retries': f'{self._request.max_retries}', + } + self.log_message('Request has exceeded max retries', LogLevel.DEBUG, message_data=message_data) + return False + elif self._supports_cancellation: + # _reset_stream() _should_ exist, but surround w/ try/except just in case + try: + self._reset_stream() # type: ignore[attr-defined] + except AttributeError: + pass # nosec + + return True + + def process_response( + self, + core_response: HttpCoreResponse, + close_handler: Callable[[], None], + handle_context_shutdown: Optional[bool] = False, + ) -> Any: + # we have all the data, close the core response/stream + close_handler() + try: + json_response = core_response.json() + except json.JSONDecodeError: + self._process_error(core_response.text, handle_context_shutdown=handle_context_shutdown) + else: + if 'errors' in json_response: + self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) + return json_response + + def send_request(self, enable_trace_handling: Optional[bool] = False) -> HttpCoreResponse: + self._error_context.update_num_attempts() + ip = get_request_ip(self._request.url.host, self._request.url.port, self.log_message) + + if self._request.path and not self._request.path.isspace(): + req_path = f'{self._request.path}' + else: + req_path = self._client_adapter.analytics_path + + if enable_trace_handling is True and hasattr(self, '_trace_handler'): + self._request.update_url(ip, req_path).add_trace_to_extensions(self._trace_handler) + else: + self._request.update_url(ip, req_path) + + self._error_context.update_request_context(self._request, path=req_path) + message_data = { + 'url': f'{self._request.url.get_formatted_url()}', + 'request_deadline': f'{self._request_deadline}', + } + + if isinstance(self._request, (QueryRequest, StartQueryRequest)): + message_data['body'] = f'{self._request.body}' + + stream = hasattr(self._request, 'should_stream') and self._request.should_stream is True + message_data['streaming'] = str(stream) + self.log_message('HTTP request', LogLevel.DEBUG, message_data=message_data) + response = self._client_adapter.send_request(self._request, stream=stream) + self._error_context.update_response_context(response) + message_data = { + 'status_code': f'{response.status_code}', + 'last_dispatched_to': f'{self._error_context.last_dispatched_to}', + 'last_dispatched_from': f'{self._error_context.last_dispatched_from}', + 'request_deadline': f'{self._request_deadline}', + } + self.log_message('HTTP response', LogLevel.DEBUG, message_data=message_data) + return response + + def shutdown(self, exc_val: Optional[BaseException] = None) -> None: + if self.is_shutdown: + self.log_message('Request context already shutdown', LogLevel.WARNING) + return + if self._supports_cancellation and isinstance(exc_val, CancelledError): + self._request_state = RequestState.Cancelled + elif exc_val is not None: + # calling self.timed_out will call _check_cancelled_or_timed_out, so we don't need to call it again + is_timed_out = self.timed_out + is_cancelled = self._supports_cancellation and self._request_state in ( + RequestState.Cancelled, + RequestState.SyncCancelledPriorToTimeout, + ) + if not is_timed_out and not is_cancelled: + self._request_state = RequestState.Error + + if RequestState.is_okay(self._request_state): + self._request_state = RequestState.Completed + self._shutdown = True + self.log_message('Request context shutdown complete', LogLevel.INFO) + def _check_cancelled_or_timed_out(self) -> None: - if self._request_state in [RequestState.Timeout, RequestState.Cancelled, RequestState.Error]: + if self._request_state in (RequestState.Timeout, RequestState.Error): return - if self._cancel_event.is_set() or ( - self._background_request is not None and self._background_request.user_cancelled - ): + if self._supports_cancellation and self._request_state == RequestState.Cancelled: + return + + if self._supports_cancellation and self._cancel_event and self._cancel_event.is_set(): self._request_state = RequestState.Cancelled if self._cancel_event.is_set(): self.log_message('Request has been cancelled', LogLevel.DEBUG) - elif self._background_request is not None and self._background_request.user_cancelled: - self.log_message('Request has been cancelled via user background request', LogLevel.DEBUG) + # TODO: if we can go w/ this simplified logic, we should move this log message + # (but maybe not b/c of the logger??) + # elif self._background_request is not None and self._background_request.user_cancelled: + # self.log_message('Request has been cancelled via user background request', LogLevel.DEBUG) return current_time = time.monotonic() @@ -174,74 +299,68 @@ def _check_cancelled_or_timed_out(self) -> None: if timed_out: message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} self.log_message('Request has timed out', LogLevel.DEBUG, message_data=message_data) - if self._request_state == RequestState.Cancelled: + if self._supports_cancellation and self._request_state == RequestState.Cancelled: self._request_state = RequestState.SyncCancelledPriorToTimeout else: self._request_state = RequestState.Timeout - def _create_stage_notification_future(self) -> None: - # TODO(PYCO-75): custom ThreadPoolExecutor, to get a "plain" future - if self._stage_notification_ft is not None: - raise RuntimeError('Stage notification future already created for this context.') - self._stage_notification_ft = Future[ParsedResultType]() - def _process_error( self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False ) -> None: self._request_state = RequestState.Error request_error: Union[AnalyticsError, WrappedError] if isinstance(json_data, str): - request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_ctx) + request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_context) elif not isinstance(json_data, list): request_error = AnalyticsError( - message='Cannot parse error response; expected JSON array', context=str(self._error_ctx) + message='Cannot parse error response; expected JSON array', context=str(self._error_context) ) else: - request_error = ErrorMapper.build_error_from_json(json_data, self._error_ctx) + request_error = ErrorMapper.build_error_from_json(json_data, self._error_context) if handle_context_shutdown is True: self.shutdown() raise request_error - def _reset_stream(self) -> None: - if hasattr(self, '_json_stream'): - del self._json_stream - self._request_state = RequestState.ResetAndNotStarted - self._stage_notification_ft = None - self.log_message('Request state has been reset', LogLevel.DEBUG) - def _start_next_stage( +class StreamingRequestContext(RequestContext): + def __init__( self, - fn: Callable[..., Any], - *args: object, - create_notification: Optional[bool] = False, - reset_previous_stage: Optional[bool] = False, + client_adapter: _ClientAdapter, + request: Union[FetchResultsRequest, QueryRequest], + tp_executor: ThreadPoolExecutor, + stream_config: Optional[JsonStreamConfig] = None, ) -> None: - if reset_previous_stage is True: - if self._stage_completed_ft is not None: - self._stage_completed_ft = None - elif self._stage_completed_ft is not None and not self._stage_completed_ft.done(): - raise RuntimeError('Future already running in this context.') - - kwargs: Dict[str, Union[RequestContext, Future[ParsedResultType]]] = {'request_context': self} - if create_notification is True: - self._create_stage_notification_future() - if self._stage_notification_ft is None: - raise RuntimeError('Unable to create stage notification future.') - kwargs['notify_on_results_or_error'] = self._stage_notification_ft + super().__init__(client_adapter, request, supports_cancellation=True) + if isinstance(request, QueryRequest): + self._error_context.set_statement(request.get_request_statement()) + self._stream_config = stream_config or JsonStreamConfig() + self._json_stream: JsonStream + self._tp_executor = tp_executor + self._stage_completed_ft: Optional[Future[Any]] = None + self._stage_notification_ft: Optional[Future[ParsedResultType]] = None + self._deserializer = request.deserializer - self._stage_completed_ft = self._tp_executor.submit(fn, *args, **kwargs) + @property + def cancel_enabled(self) -> Optional[bool]: + if not isinstance(self._request, QueryRequest): + return None + return self._request.enable_cancel - def _trace_handler(self, event_name: str, _: str) -> None: - if event_name == 'connection.connect_tcp.complete': - pass + @property + def has_stage_completed(self) -> bool: + return self._stage_completed_ft is not None and self._stage_completed_ft.done() - def _wait_for_stage_completed(self) -> None: - if self._stage_completed_ft is None: - raise RuntimeError('Stage completed future not created for this context.') - self._stage_completed_ft.result() + @property + def okay_to_iterate(self) -> bool: + # NOTE: Called prior to upstream logic attempting to iterate over results from HTTP client + self._check_cancelled_or_timed_out() + return RequestState.okay_to_iterate(self._request_state) - def calculate_backoff(self) -> float: - return self._backoff_calc.calculate_backoff(self._error_ctx.num_attempts) / 1000 + @property + def okay_to_stream(self) -> bool: + # NOTE: Called prior to upstream logic attempting to send request to HTTP client + self._check_cancelled_or_timed_out() + return RequestState.okay_to_stream(self._request_state) def cancel_request(self) -> None: if self._request_state == RequestState.Timeout: @@ -249,7 +368,9 @@ def cancel_request(self) -> None: self._request_state = RequestState.Cancelled def deserialize_result(self, result: bytes) -> Any: - return self._request.deserializer.deserialize(result) + if not self._deserializer: + raise RuntimeError('No deserializer found for this request context.') + return self._deserializer.deserialize(result) def finish_processing_stream(self) -> None: if not self.has_stage_completed: @@ -264,35 +385,6 @@ def finish_processing_stream(self) -> None: def get_result_from_stream(self) -> Optional[ParsedResult]: return self._json_stream.get_result(self._stream_config.queue_timeout) - def initialize(self) -> None: - if self._request_state == RequestState.ResetAndNotStarted: - self.log_message( - 'Request is a retry, skipping initialization', - LogLevel.DEBUG, - message_data={'request_deadline': f'{self._request_deadline}'}, - ) - return - self._request_state = RequestState.Started - timeouts = self._request.get_request_timeouts() or {} - current_time = time.monotonic() - self._request_deadline = current_time + (timeouts.get('read', None) or DEFAULT_TIMEOUTS['query_timeout']) - message_data = {'current_time': f'{current_time}', 'request_deadline': f'{self._request_deadline}'} - self.log_message('Request context initialized', LogLevel.DEBUG, message_data=message_data) - - def log_message( - self, - message: str, - log_level: LogLevel, - message_data: Optional[Dict[str, str]] = None, - append_ctx: Optional[bool] = True, - ) -> None: - if append_ctx is True: - message = f'{message}: ctx={self._id}' - if message_data is not None: - message_data_str = ', '.join(f'{k}={v}' for k, v in message_data.items()) - message = f'{message}, {message_data_str}' - self._client_adapter.log_message(message, log_level) - def maybe_continue_to_process_stream(self) -> None: if not self.has_stage_completed: return @@ -306,36 +398,7 @@ def maybe_continue_to_process_stream(self) -> None: # NOTE: _start_next_stage injects the request context into args self._start_next_stage(self._json_stream.continue_parsing, reset_previous_stage=True) - def okay_to_delay_and_retry(self, delay: float) -> bool: - self._check_cancelled_or_timed_out() - if self._request_state in [RequestState.Timeout, RequestState.Cancelled]: - return False - - current_time = time.monotonic() - delay_time = current_time + delay - will_time_out = self._request_deadline < delay_time - if will_time_out: - self._request_state = RequestState.Timeout - message_data = { - 'current_time': f'{current_time}', - 'delay_time': f'{delay_time}', - 'request_deadline': f'{self._request_deadline}', - } - self.log_message('Request will timeout after delay', LogLevel.DEBUG, message_data=message_data) - return False - elif self.retry_limit_exceeded: - self._request_state = RequestState.Error - message_data = { - 'num_attempts': f'{self.error_context.num_attempts}', - 'max_retries': f'{self._request.max_retries}', - } - self.log_message('Request has exceeded max retries', LogLevel.DEBUG, message_data=message_data) - return False - else: - self._reset_stream() - return True - - def process_response( + def process_streaming_response( self, close_handler: Callable[[], None], raw_response: Optional[ParsedResult] = None, @@ -346,13 +409,13 @@ def process_response( if raw_response is None: close_handler() raise AnalyticsError( - message='Received unexpected empty result from JsonStream.', context=str(self._error_ctx) + message='Received unexpected empty result from JsonStream.', context=str(self._error_context) ) if raw_response.value is None: close_handler() raise AnalyticsError( - message='Received unexpected empty response value from JsonStream.', context=str(self._error_ctx) + message='Received unexpected empty response value from JsonStream.', context=str(self._error_context) ) # we have all the data, close the core response/stream @@ -366,35 +429,6 @@ def process_response( self._process_error(json_response['errors'], handle_context_shutdown=handle_context_shutdown) return json_response - def send_request(self, enable_trace_handling: Optional[bool] = False) -> HttpCoreResponse: - self._error_ctx.update_num_attempts() - ip = get_request_ip(self._request.url.host, self._request.url.port, self.log_message) - if enable_trace_handling is True: - ( - self._request.update_url(ip, self._client_adapter.analytics_path).add_trace_to_extensions( - self._trace_handler - ) - ) - else: - self._request.update_url(ip, self._client_adapter.analytics_path) - self._error_ctx.update_request_context(self._request) - message_data = { - 'url': f'{self._request.url.get_formatted_url()}', - 'body': f'{self._request.body}', - 'request_deadline': f'{self._request_deadline}', - } - self.log_message('HTTP request', LogLevel.DEBUG, message_data=message_data) - response = self._client_adapter.send_request(self._request) - self._error_ctx.update_response_context(response) - message_data = { - 'status_code': f'{response.status_code}', - 'last_dispatched_to': f'{self._error_ctx.last_dispatched_to}', - 'last_dispatched_from': f'{self._error_ctx.last_dispatched_from}', - 'request_deadline': f'{self._request_deadline}', - } - self.log_message('HTTP response', LogLevel.DEBUG, message_data=message_data) - return response - def send_request_in_background( self, fn: Callable[..., BlockingQueryResult], @@ -405,32 +439,12 @@ def send_request_in_background( # TODO(PYCO-75): custom ThreadPoolExecutor, to get a "plain" future user_ft = Future[BlockingQueryResult]() background_work_ft = self._tp_executor.submit(fn, *args) - self._background_request = BackgroundRequest(background_work_ft, user_ft, self._cancel_event) + self._background_request = BackgroundRequest(background_work_ft, user_ft, cast(Event, self._cancel_event)) return user_ft def set_state_to_streaming(self) -> None: self._request_state = RequestState.StreamingResults - def shutdown(self, exc_val: Optional[BaseException] = None) -> None: - if self.is_shutdown: - self.log_message('Request context already shutdown', LogLevel.WARNING) - return - if isinstance(exc_val, CancelledError): - self._request_state = RequestState.Cancelled - elif exc_val is not None: - self._check_cancelled_or_timed_out() - if self._request_state not in [ - RequestState.Timeout, - RequestState.Cancelled, - RequestState.SyncCancelledPriorToTimeout, - ]: - self._request_state = RequestState.Error - - if RequestState.is_okay(self._request_state): - self._request_state = RequestState.Completed - self._shutdown = True - self.log_message('Request context shutdown complete', LogLevel.INFO) - def start_stream(self, core_response: HttpCoreResponse) -> None: if hasattr(self, '_json_stream'): self.log_message('JSON stream already exists', LogLevel.WARNING) @@ -447,7 +461,9 @@ def wait_for_stage_notification(self) -> None: raise RuntimeError('Stage notification future not created for this context.') deadline = round(self._request_deadline - time.monotonic(), 6) # round to microseconds if deadline <= 0: - raise TimeoutError(message='Request timed out waiting for stage notification', context=str(self._error_ctx)) + raise TimeoutError( + message='Request timed out waiting for stage notification', context=str(self._error_context) + ) result_type = self._stage_notification_ft.result(timeout=deadline) if result_type == ParsedResultType.ROW: self.log_message('Received row, setting status to streaming', LogLevel.DEBUG) @@ -455,3 +471,64 @@ def wait_for_stage_notification(self) -> None: self._request_state = RequestState.StreamingResults else: self.log_message(f'Received result type {result_type.name}', LogLevel.DEBUG) + + def _create_stage_notification_future(self) -> None: + # TODO(PYCO-75): custom ThreadPoolExecutor, to get a "plain" future + if self._stage_notification_ft is not None: + raise RuntimeError('Stage notification future already created for this context.') + self._stage_notification_ft = Future[ParsedResultType]() + + # def _process_error( + # self, json_data: Union[str, List[Dict[str, Any]]], handle_context_shutdown: Optional[bool] = False + # ) -> None: + # self._request_state = RequestState.Error + # request_error: Union[AnalyticsError, WrappedError] + # if isinstance(json_data, str): + # request_error = ErrorMapper.build_error_from_http_status_code(json_data, self._error_context) + # elif not isinstance(json_data, list): + # request_error = AnalyticsError( + # message='Cannot parse error response; expected JSON array', context=str(self._error_context) + # ) + # else: + # request_error = ErrorMapper.build_error_from_json(json_data, self._error_context) + # if handle_context_shutdown is True: + # self.shutdown() + # raise request_error + + def _reset_stream(self) -> None: + if hasattr(self, '_json_stream'): + del self._json_stream + self._request_state = RequestState.ResetAndNotStarted + self._stage_notification_ft = None + self.log_message('Request state has been reset', LogLevel.DEBUG) + + def _start_next_stage( + self, + fn: Callable[..., Any], + *args: object, + create_notification: Optional[bool] = False, + reset_previous_stage: Optional[bool] = False, + ) -> None: + if reset_previous_stage is True: + if self._stage_completed_ft is not None: + self._stage_completed_ft = None + elif self._stage_completed_ft is not None and not self._stage_completed_ft.done(): + raise RuntimeError('Future already running in this context.') + + kwargs: Dict[str, Union[StreamingRequestContext, Future[ParsedResultType]]] = {'request_context': self} + if create_notification is True: + self._create_stage_notification_future() + if self._stage_notification_ft is None: + raise RuntimeError('Unable to create stage notification future.') + kwargs['notify_on_results_or_error'] = self._stage_notification_ft + + self._stage_completed_ft = self._tp_executor.submit(fn, *args, **kwargs) + + def _trace_handler(self, event_name: str, _: str) -> None: + if event_name == 'connection.connect_tcp.complete': + pass + + def _wait_for_stage_completed(self) -> None: + if self._stage_completed_ft is None: + raise RuntimeError('Stage completed future not created for this context.') + self._stage_completed_ft.result() diff --git a/couchbase_analytics/protocol/_core/response.py b/couchbase_analytics/protocol/_core/response.py new file mode 100644 index 0000000..905042f --- /dev/null +++ b/couchbase_analytics/protocol/_core/response.py @@ -0,0 +1,90 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from typing import Any, Optional + +from httpx import Response as HttpCoreResponse + +from couchbase_analytics.common._core.query import build_query_metadata +from couchbase_analytics.common.errors import AnalyticsError, InternalSDKError +from couchbase_analytics.common.logging import LogLevel +from couchbase_analytics.common.query import QueryMetadata +from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.retries import RetryHandler + + +class HttpResponse: + def __init__( + self, + request_context: RequestContext, + skip_process_response: Optional[bool] = None, + request_id: Optional[str] = None, + ) -> None: + self._request_context = request_context + self._metadata: Optional[QueryMetadata] = None + self._core_response: HttpCoreResponse + self._json_response: Optional[Any] = None + self._skip_process_response = skip_process_response + self._request_id = request_id + + @property + def json_response(self) -> Optional[Any]: + return self._json_response + + def close(self) -> None: + """ + **INTERNAL** + """ + if hasattr(self, '_core_response'): + self._core_response.close() + self._request_context.log_message('HTTP core response closed', LogLevel.INFO) + del self._core_response + + def get_metadata(self) -> QueryMetadata: + if self._metadata is None: + raise RuntimeError('Query metadata is only available after response has been processed.') + return self._metadata + + def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> None: + try: + self._metadata = QueryMetadata( + build_query_metadata(json_data=json_data, raw_metadata=raw_metadata, request_id=self._request_id) + ) + self._request_context.shutdown() + except (AnalyticsError, ValueError) as err: + self._request_context.shutdown(err) + raise err + except Exception as ex: + internal_err = InternalSDKError(cause=ex, message=str(ex), context=str(self._request_context.error_context)) + self._request_context.shutdown(internal_err) + finally: + self.close() + + @RetryHandler.with_retries + def send_request(self) -> None: + self._request_context.initialize() + self._core_response = self._request_context.send_request() + if self._skip_process_response is True: + return + self._process_response() + + def _process_response(self) -> None: + self._json_response = self._request_context.process_response( + self._core_response, self.close, handle_context_shutdown=True + ) + self.set_metadata(json_data=self._json_response) diff --git a/couchbase_analytics/protocol/_core/retries.py b/couchbase_analytics/protocol/_core/retries.py index c87fa43..4df1266 100644 --- a/couchbase_analytics/protocol/_core/retries.py +++ b/couchbase_analytics/protocol/_core/retries.py @@ -19,7 +19,7 @@ from concurrent.futures import CancelledError from functools import wraps from time import sleep -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, TypeVar, Union from httpx import ConnectError, ConnectTimeout, CookieConflict, HTTPError, InvalidURL, ReadTimeout, StreamError @@ -29,9 +29,13 @@ from couchbase_analytics.protocol.errors import WrappedError if TYPE_CHECKING: - from couchbase_analytics.protocol._core.request_context import RequestContext + from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext + from couchbase_analytics.protocol._core.response import HttpResponse from couchbase_analytics.protocol.streaming import HttpStreamingResponse +ReqContext = Union['RequestContext', 'StreamingRequestContext'] +T = TypeVar('T', bound=Union['HttpResponse', 'HttpStreamingResponse']) + class RetryHandler: """ @@ -39,7 +43,7 @@ class RetryHandler: """ @staticmethod - def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: RequestContext) -> Optional[Exception]: + def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: ReqContext) -> Optional[Exception]: err_str = str(ex) if 'SSL:' in err_str: message = 'TLS connection error occurred.' @@ -62,7 +66,7 @@ def handle_httpx_retry(ex: Union[ConnectError, ConnectTimeout], ctx: RequestCont return None @staticmethod - def handle_retry(ex: WrappedError, ctx: RequestContext) -> Optional[Union[BaseException, Exception]]: + def handle_retry(ex: WrappedError, ctx: ReqContext) -> Optional[Union[BaseException, Exception]]: if ex.retriable is True: delay = ctx.calculate_backoff() err: Optional[Union[BaseException, Exception]] = None @@ -91,9 +95,11 @@ def handle_retry(ex: WrappedError, ctx: RequestContext) -> Optional[Union[BaseEx return ex.unwrap() @staticmethod - def with_retries(fn: Callable[[HttpStreamingResponse], None]) -> Callable[[HttpStreamingResponse], None]: # noqa: C901 + def with_retries( # noqa: C901 + fn: Callable[[T], None], + ) -> Callable[[T], None]: # noqa: C901 @wraps(fn) - def wrapped_fn(self: HttpStreamingResponse) -> None: # noqa: C901 + def wrapped_fn(self: T) -> None: # noqa: C901 while True: try: fn(self) diff --git a/couchbase_analytics/protocol/cluster.py b/couchbase_analytics/protocol/cluster.py index a0b2053..4e77b62 100644 --- a/couchbase_analytics/protocol/cluster.py +++ b/couchbase_analytics/protocol/cluster.py @@ -17,6 +17,7 @@ from __future__ import annotations import atexit +import sys from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING, Optional, Union from uuid import uuid4 @@ -25,7 +26,9 @@ from couchbase_analytics.common.result import BlockingQueryResult from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter from couchbase_analytics.protocol._core.request import _RequestBuilder -from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext +from couchbase_analytics.protocol._core.response import HttpResponse +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle from couchbase_analytics.protocol.streaming import HttpStreamingResponse if TYPE_CHECKING: @@ -85,6 +88,7 @@ def _shutdown(self) -> None: """ **INTERNAL** """ + atexit.unregister(self._shutdown_executor) self._client_adapter.close_client() self._client_adapter.reset_client() self._shutdown_executor() @@ -97,9 +101,10 @@ def _create_client(self) -> None: def _shutdown_executor(self) -> None: if self._tp_executor_shutdown_called is False: - self._client_adapter.log_message( - f'Shutting down ThreadPoolExecutor({self._tp_executor_prefix})', LogLevel.INFO - ) + if not sys.is_finalizing(): + self._client_adapter.log_message( + f'Shutting down ThreadPoolExecutor({self._tp_executor_prefix})', LogLevel.INFO + ) self._tp_executor.shutdown() self._tp_executor_shutdown_called = True @@ -120,11 +125,11 @@ def shutdown(self) -> None: def execute_query( self, statement: str, *args: object, **kwargs: object ) -> Union[BlockingQueryResult, Future[BlockingQueryResult]]: - base_req = self._request_builder.build_base_query_request(statement, *args, **kwargs) - lazy_execute = base_req.options.pop('lazy_execute', None) - stream_config = base_req.options.pop('stream_config', None) - request_context = RequestContext( - self.client_adapter, base_req, self.threadpool_executor, stream_config=stream_config + req = self._request_builder.build_query_request(statement, *args, **kwargs) + lazy_execute = req.options.pop('lazy_execute', None) + stream_config = req.options.pop('stream_config', None) + request_context = StreamingRequestContext( + self.client_adapter, req, self.threadpool_executor, stream_config=stream_config ) resp = HttpStreamingResponse(request_context, lazy_execute=lazy_execute) @@ -147,6 +152,16 @@ def _execute_query(http_response: HttpStreamingResponse) -> BlockingQueryResult: resp.send_request() return BlockingQueryResult(resp) + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + base_req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = base_req.options.pop('stream_config', None) + request_context = RequestContext(self.client_adapter, base_req) + resp = HttpResponse(request_context) + resp.send_request() + return BlockingQueryHandle( + self._client_adapter, self._request_builder, resp, self._tp_executor, stream_config=stream_config + ) + @classmethod def create_instance( cls, http_endpoint: str, credential: Credential, options: Optional[ClusterOptions], **kwargs: object diff --git a/couchbase_analytics/protocol/cluster.pyi b/couchbase_analytics/protocol/cluster.pyi index dbb950a..206bb55 100644 --- a/couchbase_analytics/protocol/cluster.pyi +++ b/couchbase_analytics/protocol/cluster.pyi @@ -25,8 +25,16 @@ else: from couchbase_analytics import JSONType from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.result import BlockingQueryResult -from couchbase_analytics.options import ClusterOptions, ClusterOptionsKwargs, QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import ( + ClusterOptions, + ClusterOptionsKwargs, + QueryOptions, + QueryOptionsKwargs, + StartQueryOptions, + StartQueryOptionsKwargs, +) from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle class Cluster: @overload @@ -119,6 +127,26 @@ class Cluster: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... def shutdown(self) -> None: ... @overload @classmethod diff --git a/couchbase_analytics/protocol/connection.py b/couchbase_analytics/protocol/connection.py index 3d37496..002b8b7 100644 --- a/couchbase_analytics/protocol/connection.py +++ b/couchbase_analytics/protocol/connection.py @@ -27,7 +27,7 @@ from couchbase_analytics.common._core.utils import is_null_or_empty from couchbase_analytics.common.credential import Credential from couchbase_analytics.common.deserializer import DefaultJsonDeserializer, Deserializer -from couchbase_analytics.common.options import ClusterOptions, SecurityOptions, TimeoutOptions +from couchbase_analytics.common.options import SecurityOptions, TimeoutOptions from couchbase_analytics.common.request import RequestURL from couchbase_analytics.protocol.options import ( ClusterOptionsTransformedKwargs, @@ -255,8 +255,6 @@ def create( logger_name = cast(Optional[str], kwargs.pop('logger_name', None)) cluster_opts = opts_builder.build_cluster_options( - ClusterOptions, - ClusterOptionsTransformedKwargs, kwargs, options, query_str_opts=parse_query_str_options(query_str_opts, logger_name=logger_name), diff --git a/couchbase_analytics/protocol/options.py b/couchbase_analytics/protocol/options.py index d47ee08..4c0f984 100644 --- a/couchbase_analytics/protocol/options.py +++ b/couchbase_analytics/protocol/options.py @@ -17,7 +17,7 @@ from __future__ import annotations from copy import copy -from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypedDict, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypedDict, Union, overload from couchbase_analytics.common._core import JsonStreamConfig from couchbase_analytics.common._core.utils import ( @@ -35,15 +35,19 @@ from couchbase_analytics.common.enums import QueryScanConsistency from couchbase_analytics.common.options import ( ClusterOptions, + FetchResultsOptions, OptionsClass, QueryOptions, SecurityOptions, + StartQueryOptions, TimeoutOptions, ) from couchbase_analytics.common.options_base import ( ClusterOptionsValidKeys, + FetchResultsOptionsValidKeys, QueryOptionsValidKeys, SecurityOptionsValidKeys, + StartQueryOptionsValidKeys, TimeoutOptionsValidKeys, ) @@ -164,25 +168,66 @@ class QueryOptionsTransformedKwargs(TypedDict, total=False): timeout: Optional[float] -TransformedOptionKwargs = TypeVar( - 'TransformedOptionKwargs', - QueryOptionsTransformedKwargs, - ClusterOptionsTransformedKwargs, - SecurityOptionsTransformedKwargs, - TimeoutOptionsTransformedKwargs, -) +class StartQueryOptionsTransforms(TypedDict): + client_context_id: Dict[Literal['client_context_id'], Callable[[Any], str]] + max_retries: Dict[Literal['max_retries'], Callable[[Any], int]] + named_parameters: Dict[Literal['named_parameters'], Callable[[Any], Any]] + positional_parameters: Dict[Literal['positional_parameters'], Callable[[Any], Any]] + query_context: Dict[Literal['query_context'], Callable[[Any], str]] + raw: Dict[Literal['raw'], Callable[[Any], Dict[str, Any]]] + readonly: Dict[Literal['readonly'], Callable[[Any], bool]] + scan_consistency: Dict[Literal['scan_consistency'], Callable[[Any], str]] + stream_config: Dict[Literal['stream_config'], Callable[[Any], JsonStreamConfig]] + timeout: Dict[Literal['timeout'], Callable[[Any], float]] + + +START_QUERY_OPTIONS_TRANSFORMS: StartQueryOptionsTransforms = { + 'client_context_id': {'client_context_id': VALIDATE_STR}, + 'max_retries': {'max_retries': VALIDATE_INT}, + 'named_parameters': {'named_parameters': lambda x: x}, + 'positional_parameters': {'positional_parameters': lambda x: x}, + 'query_context': {'query_context': VALIDATE_STR}, + 'raw': {'raw': validate_raw_dict}, + 'readonly': {'readonly': VALIDATE_BOOL}, + 'scan_consistency': {'scan_consistency': QUERY_CONSISTENCY_TO_STR}, + 'stream_config': {'stream_config': lambda x: x}, + 'timeout': {'timeout': to_seconds}, +} + + +class StartQueryOptionsTransformedKwargs(TypedDict, total=False): + client_context_id: Optional[str] + max_retries: Optional[int] + named_parameters: Optional[Any] + positional_parameters: Optional[Any] + priority: Optional[bool] + query_context: Optional[str] + raw: Optional[Dict[str, Any]] + readonly: Optional[bool] + scan_consistency: Optional[str] + stream_config: Optional[JsonStreamConfig] + timeout: Optional[float] + + +class FetchResultsOptionsTransforms(TypedDict): + deserializer: Dict[Literal['deserializer'], Callable[[Any], Deserializer]] + + +FETCH_RESULTS_OPTIONS_TRANSFORMS: FetchResultsOptionsTransforms = { + 'deserializer': {'deserializer': VALIDATE_DESERIALIZER}, +} + + +class FetchResultsOptionsTransformedKwargs(TypedDict, total=False): + deserializer: Optional[Deserializer] -TransformedClusterOptionKwargs = TypeVar( - 'TransformedClusterOptionKwargs', - ClusterOptionsTransformedKwargs, - SecurityOptionsTransformedKwargs, - TimeoutOptionsTransformedKwargs, -) TransformDetailsPair = Union[ Tuple[List[QueryOptionsValidKeys], QueryOptionsTransforms], Tuple[List[ClusterOptionsValidKeys], ClusterOptionsTransforms], + Tuple[List[FetchResultsOptionsValidKeys], FetchResultsOptionsTransforms], Tuple[List[SecurityOptionsValidKeys], SecurityOptionsTransforms], + Tuple[List[StartQueryOptionsValidKeys], StartQueryOptionsTransforms], Tuple[List[TimeoutOptionsValidKeys], TimeoutOptionsTransforms], ] @@ -216,18 +261,20 @@ def _get_transform_details(self, option_type: str) -> TransformDetailsPair: # n return TimeoutOptions.VALID_OPTION_KEYS, TIMEOUT_OPTIONS_TRANSFORMS elif option_type == 'QueryOptions': return QueryOptions.VALID_OPTION_KEYS, QUERY_OPTIONS_TRANSFORMS + elif option_type == 'StartQueryOptions': + return StartQueryOptions.VALID_OPTION_KEYS, START_QUERY_OPTIONS_TRANSFORMS + elif option_type == 'FetchResultsOptions': + return FetchResultsOptions.VALID_OPTION_KEYS, FETCH_RESULTS_OPTIONS_TRANSFORMS else: raise ValueError('Invalid OptionType.') def build_cluster_options( # noqa: C901 self, - option_type: type[OptionsClass], - output_type: type[TransformedClusterOptionKwargs], orig_kwargs: Dict[str, object], options: Optional[object] = None, query_str_opts: Optional[Dict[str, QueryStrVal]] = None, - ) -> TransformedClusterOptionKwargs: - temp_options = self._get_options_copy(option_type, orig_kwargs, options) + ) -> ClusterOptionsTransformedKwargs: + temp_options = self._get_options_copy(ClusterOptions, orig_kwargs, options) # we flatten all the nested options (timeout_options & security_options) # so that we can combine the nested options w/ potential query string options @@ -254,37 +301,84 @@ def build_cluster_options( # noqa: C901 keys_to_ignore: List[str] = [*ClusterOptions.VALID_OPTION_KEYS, *TimeoutOptions.VALID_OPTION_KEYS] - # not going to be able to make mypy happy w/ keys_to_ignore :/ - transformed_security_opts = self.build_options( - SecurityOptions, SecurityOptionsTransformedKwargs, temp_options, keys_to_ignore=keys_to_ignore - ) + transformed_security_opts = self.build_options(SecurityOptions, temp_options, keys_to_ignore=keys_to_ignore) if transformed_security_opts: temp_options['security_options'] = transformed_security_opts keys_to_ignore = [*ClusterOptions.VALID_OPTION_KEYS, *SecurityOptions.VALID_OPTION_KEYS] - # not going to be able to make mypy happy w/ keys_to_ignore :/ - transformed_timeout_opts = self.build_options( - TimeoutOptions, TimeoutOptionsTransformedKwargs, temp_options, keys_to_ignore=keys_to_ignore - ) + transformed_timeout_opts = self.build_options(TimeoutOptions, temp_options, keys_to_ignore=keys_to_ignore) if transformed_timeout_opts: temp_options['timeout_options'] = transformed_timeout_opts # transform final ClusterOptions - transformed_opts = self.build_options(option_type, output_type, temp_options) + transformed_opts = self.build_options(ClusterOptions, temp_options) return transformed_opts + @overload + def build_options( + self, + option_type: type[ClusterOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> ClusterOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[SecurityOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> SecurityOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[TimeoutOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> TimeoutOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[QueryOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> QueryOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[StartQueryOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> StartQueryOptionsTransformedKwargs: ... + + @overload + def build_options( + self, + option_type: type[FetchResultsOptions], + orig_kwargs: Dict[str, object], + options: Optional[object] = ..., + keys_to_ignore: Optional[List[str]] = ..., + ) -> FetchResultsOptionsTransformedKwargs: ... + def build_options( self, option_type: type[OptionsClass], - output_type: type[TransformedOptionKwargs], orig_kwargs: Dict[str, object], options: Optional[object] = None, keys_to_ignore: Optional[List[str]] = None, - ) -> TransformedOptionKwargs: + ) -> Any: temp_options = self._get_options_copy(option_type, orig_kwargs, options) - transformed_opts: TransformedOptionKwargs = {} + transformed_opts: Any = {} # Option 1 satisfies mypy, but we want temp_options to be the limiting factor for the loop. # Option 2. Also makes providing warnings/exceptions for users not using static type checking easier, # but unfortunately we need to use some type: ignore comments @@ -304,7 +398,7 @@ def build_options( for nk, cfn in transforms.items(): conv = cfn(v) if conv is not None: - transformed_opts[nk] = conv # type: ignore[literal-required] + transformed_opts[nk] = conv elif keys_to_ignore and k not in keys_to_ignore: raise ValueError(f'Invalid key provided (key={k}).') diff --git a/couchbase_analytics/protocol/query_handle.py b/couchbase_analytics/protocol/query_handle.py new file mode 100644 index 0000000..897e699 --- /dev/null +++ b/couchbase_analytics/protocol/query_handle.py @@ -0,0 +1,148 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from typing import TYPE_CHECKING, Optional + +from couchbase_analytics.common._core.query_handle import QueryHandleStatusResponse +from couchbase_analytics.common.errors import AnalyticsError +from couchbase_analytics.common.query_handle import BlockingQueryHandle as _CoreBlockingQueryHandle +from couchbase_analytics.common.query_handle import BlockingQueryResultHandle as _CoreBlockingQueryResultHandle +from couchbase_analytics.common.result import BlockingQueryResult +from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter +from couchbase_analytics.protocol._core.request import _RequestBuilder +from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext +from couchbase_analytics.protocol._core.response import HttpResponse +from couchbase_analytics.protocol.streaming import HttpStreamingResponse + +if TYPE_CHECKING: + from couchbase_analytics.common._core import JsonStreamConfig + + +class BlockingQueryHandle(_CoreBlockingQueryHandle): + def __init__( + self, + client_adapter: _ClientAdapter, + request_builder: _RequestBuilder, + http_response: HttpResponse, + tp_executor: ThreadPoolExecutor, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._http_response = http_response + self._tp_executor = tp_executor + self._stream_config = stream_config + self._request_id: str = '' + self._handle: str = '' + self._get_status_handle() + + @property + def handle(self) -> str: + return self._handle + + @property + def request_id(self) -> str: + return self._request_id + + def fetch_result_handle(self) -> Optional[BlockingQueryResultHandle]: + server_req = self._request_builder.build_request_from_handle(self._handle) + request_context = RequestContext(self._client_adapter, server_req) + resp = HttpResponse(request_context) + resp.send_request() + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + if 'handle' not in resp.json_response: + return None + + status_response = self._get_handle_status_response(resp) + return BlockingQueryResultHandle( + self._client_adapter, + self._request_builder, + self._tp_executor, + status_response, + stream_config=self._stream_config, + ) + + def cancel(self) -> None: + cancel_req = self._request_builder.build_cancel_request(self._request_id) + request_context = RequestContext(self._client_adapter, cancel_req) + resp = HttpResponse(request_context, skip_process_response=True, request_id=self._request_id) + resp.send_request() + + def _get_status_handle(self) -> None: + if self._http_response.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + request_id = self._http_response.json_response.get('requestID', None) + if request_id is None: + raise AnalyticsError(message='Server response is missing "requestID" field.') + handle = self._http_response.json_response.get('handle', None) + if handle is None: + raise AnalyticsError(message='Server response is missing "handle" field.') + + self._request_id = request_id + self._handle = handle + + def _get_handle_status_response(self, resp: HttpResponse) -> QueryHandleStatusResponse: + if resp.json_response is None: + raise AnalyticsError(message='HTTP response does not contain JSON data.') + + handle = resp.json_response.get('handle', None) + if handle is None: + raise AnalyticsError(message='Server response is missing "handle" field.') + + return QueryHandleStatusResponse.from_server(self._request_id, resp.json_response) + + +class BlockingQueryResultHandle(_CoreBlockingQueryResultHandle): + def __init__( + self, + client_adapter: _ClientAdapter, + request_builder: _RequestBuilder, + tp_executor: ThreadPoolExecutor, + status_resp: QueryHandleStatusResponse, + stream_config: Optional[JsonStreamConfig] = None, + ) -> None: + super().__init__() + self._client_adapter = client_adapter + self._request_builder = request_builder + self._tp_executor = tp_executor + self._status_resp = status_resp + self._stream_config = stream_config + + @property + def request_id(self) -> str: + return self._status_resp.request_id + + def fetch_results(self) -> BlockingQueryResult: + server_req = self._request_builder.build_fetch_results_request(self._status_resp.handle) + request_context = StreamingRequestContext( + self._client_adapter, server_req, self._tp_executor, stream_config=self._stream_config + ) + resp = HttpStreamingResponse(request_context, request_id=self._status_resp.request_id) + resp.send_request() + return BlockingQueryResult(resp) + + def discard_results(self) -> None: + req = self._request_builder.build_discard_results_request(self._status_resp.handle) + request_context = RequestContext(self._client_adapter, req) + resp = HttpResponse(request_context, skip_process_response=True, request_id=self._status_resp.request_id) + resp.send_request() diff --git a/couchbase_analytics/protocol/scope.py b/couchbase_analytics/protocol/scope.py index 1e77457..f5aae28 100644 --- a/couchbase_analytics/protocol/scope.py +++ b/couchbase_analytics/protocol/scope.py @@ -22,7 +22,9 @@ from couchbase_analytics.common.result import BlockingQueryResult from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter from couchbase_analytics.protocol._core.request import _RequestBuilder -from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.request_context import RequestContext, StreamingRequestContext +from couchbase_analytics.protocol._core.response import HttpResponse +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle from couchbase_analytics.protocol.streaming import HttpStreamingResponse if TYPE_CHECKING: @@ -59,11 +61,11 @@ def threadpool_executor(self) -> ThreadPoolExecutor: def execute_query( self, statement: str, *args: object, **kwargs: object ) -> Union[BlockingQueryResult, Future[BlockingQueryResult]]: - base_req = self._request_builder.build_base_query_request(statement, *args, **kwargs) - lazy_execute = base_req.options.pop('lazy_execute', None) - stream_config = base_req.options.pop('stream_config', None) - request_context = RequestContext( - self.client_adapter, base_req, self.threadpool_executor, stream_config=stream_config + req = self._request_builder.build_query_request(statement, *args, **kwargs) + lazy_execute = req.options.pop('lazy_execute', None) + stream_config = req.options.pop('stream_config', None) + request_context = StreamingRequestContext( + self.client_adapter, req, self.threadpool_executor, stream_config=stream_config ) resp = HttpStreamingResponse(request_context, lazy_execute=lazy_execute) @@ -84,3 +86,13 @@ def _execute_query(http_response: HttpStreamingResponse) -> BlockingQueryResult: if lazy_execute is not True: resp.send_request() return BlockingQueryResult(resp) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + base_req = self._request_builder.build_start_query_request(statement, *args, **kwargs) + stream_config = base_req.options.pop('stream_config', None) + request_context = RequestContext(self.client_adapter, base_req) + resp = HttpResponse(request_context) + resp.send_request() + return BlockingQueryHandle( + self._client_adapter, self._request_builder, resp, self._tp_executor, stream_config=stream_config + ) diff --git a/couchbase_analytics/protocol/scope.pyi b/couchbase_analytics/protocol/scope.pyi index 9296863..e3d0b59 100644 --- a/couchbase_analytics/protocol/scope.pyi +++ b/couchbase_analytics/protocol/scope.pyi @@ -24,9 +24,10 @@ else: from couchbase_analytics import JSONType from couchbase_analytics.common.result import BlockingQueryResult -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter from couchbase_analytics.protocol.database import Database as Database +from couchbase_analytics.protocol.query_handle import BlockingQueryHandle class Scope: def __init__(self, database: Database, scope_name: str) -> None: ... @@ -106,3 +107,23 @@ class Scope: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... diff --git a/couchbase_analytics/protocol/streaming.py b/couchbase_analytics/protocol/streaming.py index d6372a8..c6981fc 100644 --- a/couchbase_analytics/protocol/streaming.py +++ b/couchbase_analytics/protocol/streaming.py @@ -26,12 +26,17 @@ from couchbase_analytics.common.errors import AnalyticsError, InternalSDKError, TimeoutError from couchbase_analytics.common.logging import LogLevel from couchbase_analytics.common.query import QueryMetadata -from couchbase_analytics.protocol._core.request_context import RequestContext +from couchbase_analytics.protocol._core.request_context import StreamingRequestContext from couchbase_analytics.protocol._core.retries import RetryHandler class HttpStreamingResponse: - def __init__(self, request_context: RequestContext, lazy_execute: Optional[bool] = None) -> None: + def __init__( + self, + request_context: StreamingRequestContext, + lazy_execute: Optional[bool] = None, + request_id: Optional[str] = None, + ) -> None: self._request_context = request_context if lazy_execute is not None: self._lazy_execute = lazy_execute @@ -39,6 +44,7 @@ def __init__(self, request_context: RequestContext, lazy_execute: Optional[bool] self._lazy_execute = False self._metadata: Optional[QueryMetadata] = None self._core_response: HttpCoreResponse + self._request_id = request_id @property def lazy_execute(self) -> bool: @@ -68,7 +74,7 @@ def _handle_iteration_abort(self) -> None: def _process_response( self, raw_response: Optional[ParsedResult] = None, handle_context_shutdown: Optional[bool] = False ) -> None: - json_response = self._request_context.process_response( + json_response = self._request_context.process_streaming_response( self.close, raw_response=raw_response, handle_context_shutdown=handle_context_shutdown ) self.set_metadata(json_data=json_response) @@ -98,7 +104,9 @@ def get_metadata(self) -> QueryMetadata: def set_metadata(self, json_data: Optional[Any] = None, raw_metadata: Optional[bytes] = None) -> None: try: - self._metadata = QueryMetadata(build_query_metadata(json_data=json_data, raw_metadata=raw_metadata)) + self._metadata = QueryMetadata( + build_query_metadata(json_data=json_data, raw_metadata=raw_metadata, request_id=self._request_id) + ) self._request_context.shutdown() except (AnalyticsError, ValueError) as err: self._request_context.shutdown(err) diff --git a/couchbase_analytics/query_handle.py b/couchbase_analytics/query_handle.py new file mode 100644 index 0000000..215bad1 --- /dev/null +++ b/couchbase_analytics/query_handle.py @@ -0,0 +1,17 @@ +# Copyright 2016-2025. Couchbase, Inc. +# All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from couchbase_analytics.common.query_handle import BlockingQueryHandle as BlockingQueryHandle # noqa: F401 +from couchbase_analytics.common.query_handle import BlockingQueryResultHandle as BlockingQueryResultHandle # noqa: F401 diff --git a/couchbase_analytics/result.py b/couchbase_analytics/result.py index 4f0e8e4..4b0cb37 100644 --- a/couchbase_analytics/result.py +++ b/couchbase_analytics/result.py @@ -13,7 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - from couchbase_analytics.common.result import AsyncQueryResult as AsyncQueryResult # noqa: F401 from couchbase_analytics.common.result import BlockingQueryResult as BlockingQueryResult # noqa: F401 from couchbase_analytics.common.result import QueryResult as QueryResult # noqa: F401 diff --git a/couchbase_analytics/scope.py b/couchbase_analytics/scope.py index 702becb..7a805e6 100644 --- a/couchbase_analytics/scope.py +++ b/couchbase_analytics/scope.py @@ -19,6 +19,7 @@ from concurrent.futures import Future from typing import TYPE_CHECKING, Union +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult if TYPE_CHECKING: @@ -114,3 +115,19 @@ def execute_query( """ # noqa: E501 return self._impl.execute_query(statement, *args, **kwargs) + + def start_query(self, statement: str, *args: object, **kwargs: object) -> BlockingQueryHandle: + """Executes a query against an Analytics scope in async mode. + + .. seealso:: + :meth:`couchbase_analytics.Cluster.start_query`: For how to execute cluster-level queries. + + Args: + statement: The SQL++ statement to execute. + options (:class:`~couchbase_analytics.options.StartQueryOptions`): Optional parameters for the query operation. + **kwargs (Dict[str, Any]): keyword arguments that can be used in place or to override provided :class:`~couchbase_analytics.options.StartQueryOptions` + + Returns: + :class:`~couchbase_analytics.query_handle.BlockingQueryHandle`: An instance of a :class:`~couchbase_analytics.query_handle.BlockingQueryHandle` + """ # noqa: E501 + return self._impl.start_query(statement, *args, **kwargs) diff --git a/couchbase_analytics/scope.pyi b/couchbase_analytics/scope.pyi index c5d36d2..fa70bef 100644 --- a/couchbase_analytics/scope.pyi +++ b/couchbase_analytics/scope.pyi @@ -23,8 +23,9 @@ else: from typing import Unpack from couchbase_analytics import JSONType -from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs +from couchbase_analytics.options import QueryOptions, QueryOptionsKwargs, StartQueryOptions, StartQueryOptionsKwargs from couchbase_analytics.protocol.database import Database as Database +from couchbase_analytics.query_handle import BlockingQueryHandle from couchbase_analytics.result import BlockingQueryResult class Scope: @@ -101,3 +102,23 @@ class Scope: def execute_query( self, statement: str, *args: JSONType, enable_cancel: bool, **kwargs: str ) -> Future[BlockingQueryResult]: ... + @overload + def start_query(self, statement: str) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, options: StartQueryOptions) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, **kwargs: Unpack[StartQueryOptionsKwargs]) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: Unpack[StartQueryOptionsKwargs] + ) -> BlockingQueryHandle: ... + @overload + def start_query( + self, statement: str, options: StartQueryOptions, *args: JSONType, **kwargs: str + ) -> BlockingQueryHandle: ... + @overload + def start_query(self, statement: str, *args: JSONType, **kwargs: str) -> BlockingQueryHandle: ... diff --git a/couchbase_analytics/tests/connection_t.py b/couchbase_analytics/tests/connection_t.py index b9b0cf8..709124d 100644 --- a/couchbase_analytics/tests/connection_t.py +++ b/couchbase_analytics/tests/connection_t.py @@ -67,7 +67,7 @@ def test_connstr_options_max_retries(self) -> None: connstr = f'https://localhost?max_retries={max_retries}' client = _ClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') assert req.max_retries == max_retries @pytest.mark.parametrize( @@ -99,7 +99,7 @@ def test_connstr_options_timeout(self, duration: str, expected_seconds: str) -> connstr = f'https://localhost?{to_query_str(opts)}' client = _ClientAdapter(connstr, cred) req_builder = _RequestBuilder(client) - req = req_builder.build_base_query_request('SELECT 1=1') + req = req_builder.build_query_request('SELECT 1=1') expected = float(expected_seconds) returned_timeout_opts = req.get_request_timeouts() assert isinstance(returned_timeout_opts, dict) diff --git a/couchbase_analytics/tests/query_options_t.py b/couchbase_analytics/tests/query_options_t.py index 39d0b67..473c3c0 100644 --- a/couchbase_analytics/tests/query_options_t.py +++ b/couchbase_analytics/tests/query_options_t.py @@ -76,7 +76,7 @@ def test_options_deserializer( deserializer = DefaultJsonDeserializer() q_opts = QueryOptions(deserializer=deserializer) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer @@ -89,7 +89,7 @@ def test_options_deserializer_kwargs( deserializer = DefaultJsonDeserializer() kwargs: QueryOptionsKwargs = {'deserializer': deserializer} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.deserializer == deserializer @@ -101,9 +101,9 @@ def test_options_max_retries( ) -> None: if max_retries is not None: q_opts = QueryOptions(max_retries=max_retries) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) @@ -115,9 +115,9 @@ def test_options_max_retries_kwargs( ) -> None: if max_retries is not None: kwargs: QueryOptionsKwargs = {'max_retries': max_retries} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) else: - req = request_builder.build_base_query_request(query_statment) + req = request_builder.build_query_request(query_statment) exp_opts: QueryOptionsTransformedKwargs = {} assert req.options == exp_opts assert req.max_retries == (max_retries if max_retries is not None else 7) @@ -128,7 +128,7 @@ def test_options_named_parameters( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} q_opts = QueryOptions(named_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -138,7 +138,7 @@ def test_options_named_parameters_kwargs( ) -> None: params: Dict[str, JSONType] = {'foo': 'bar', 'baz': 1, 'quz': False} kwargs: QueryOptionsKwargs = {'named_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'named_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -148,7 +148,7 @@ def test_options_positional_parameters( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] q_opts = QueryOptions(positional_parameters=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -158,7 +158,7 @@ def test_options_positional_parameters_kwargs( ) -> None: params: List[JSONType] = ['foo', 'bar', 1, False] kwargs: QueryOptionsKwargs = {'positional_parameters': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'positional_parameters': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -167,7 +167,7 @@ def test_options_raw(self, query_statment: str, request_builder: _RequestBuilder pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} q_opts = QueryOptions(raw=params) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -178,7 +178,7 @@ def test_options_raw_kwargs( pos_params: List[JSONType] = ['foo', 'bar', 1, False] params: Dict[str, Any] = {'readonly': True, 'positional_params': pos_params} kwargs: QueryOptionsKwargs = {'raw': params} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'raw': params} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -187,7 +187,7 @@ def test_options_readonly( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(readonly=True) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -196,7 +196,7 @@ def test_options_readonly_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'readonly': True} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'readonly': True} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -207,7 +207,7 @@ def test_options_scan_consistency( from couchbase_analytics.query import QueryScanConsistency q_opts = QueryOptions(scan_consistency=QueryScanConsistency.REQUEST_PLUS) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -218,7 +218,7 @@ def test_options_scan_consistency_kwargs( from couchbase_analytics.query import QueryScanConsistency kwargs: QueryOptionsKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'scan_consistency': QueryScanConsistency.REQUEST_PLUS.value} assert req.options == exp_opts query_ctx.validate_query_context(req.body) @@ -227,7 +227,7 @@ def test_options_timeout( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=20)) - req = request_builder.build_base_query_request(query_statment, q_opts) + req = request_builder.build_query_request(query_statment, q_opts) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -238,7 +238,7 @@ def test_options_timeout_kwargs( self, query_statment: str, request_builder: _RequestBuilder, query_ctx: QueryContext ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=20)} - req = request_builder.build_base_query_request(query_statment, **kwargs) + req = request_builder.build_query_request(query_statment, **kwargs) exp_opts: QueryOptionsTransformedKwargs = {'timeout': 20.0} assert req.options == exp_opts # NOTE: we add time to the server timeout to ensure a client side timeout @@ -248,14 +248,14 @@ def test_options_timeout_kwargs( def test_options_timeout_must_be_positive(self, query_statment: str, request_builder: _RequestBuilder) -> None: q_opts = QueryOptions(timeout=timedelta(seconds=-1)) with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, q_opts) + request_builder.build_query_request(query_statment, q_opts) def test_options_timeout_must_be_positive_kwargs( self, query_statment: str, request_builder: _RequestBuilder ) -> None: kwargs: QueryOptionsKwargs = {'timeout': timedelta(seconds=-1)} with pytest.raises(ValueError): - request_builder.build_base_query_request(query_statment, **kwargs) + request_builder.build_query_request(query_statment, **kwargs) class ClusterQueryOptionsTests(QueryOptionsTestSuite): diff --git a/tests/utils/_async_client_adapter.py b/tests/utils/_async_client_adapter.py index 5043edb..7978e3b 100644 --- a/tests/utils/_async_client_adapter.py +++ b/tests/utils/_async_client_adapter.py @@ -14,12 +14,12 @@ # limitations under the License. -from typing import Dict +from typing import Dict, Optional, Union from httpx import URL, Response from acouchbase_analytics.protocol._core.client_adapter import _AsyncClientAdapter -from couchbase_analytics.protocol._core.request import QueryRequest +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] @@ -38,7 +38,11 @@ def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore self._http_transport_cls = adapter._http_transport_cls -async def send_request_override(self: _AsyncClientAdapter, request: QueryRequest) -> Response: +async def send_request_override( + self: _AsyncClientAdapter, + request: Union[CancelRequest, HttpRequest, QueryRequest, StartQueryRequest], + stream: Optional[bool] = True, +) -> Response: if not hasattr(self, '_client'): raise RuntimeError('Client not created yet') diff --git a/tests/utils/_client_adapter.py b/tests/utils/_client_adapter.py index 0acf76d..7e6cbe8 100644 --- a/tests/utils/_client_adapter.py +++ b/tests/utils/_client_adapter.py @@ -14,12 +14,12 @@ # limitations under the License. -from typing import Dict +from typing import Dict, Optional, Union from httpx import URL, Response from couchbase_analytics.protocol._core.client_adapter import _ClientAdapter -from couchbase_analytics.protocol._core.request import QueryRequest +from couchbase_analytics.protocol._core.request import CancelRequest, HttpRequest, QueryRequest, StartQueryRequest def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def] @@ -39,7 +39,11 @@ def client_adapter_init_override(self, *args, **kwargs) -> None: # type: ignore self._http_transport_cls = adapter._http_transport_cls -def send_request_override(self: _ClientAdapter, request: QueryRequest) -> Response: +def send_request_override( + self: _ClientAdapter, + request: Union[CancelRequest, HttpRequest, QueryRequest, StartQueryRequest], + stream: Optional[bool] = True, +) -> Response: if not hasattr(self, '_client'): raise RuntimeError('Client not created yet') @@ -57,7 +61,7 @@ def send_request_override(self: _ClientAdapter, request: QueryRequest) -> Respon url = URL(scheme=request.url.scheme, host=request.url.host, port=request.url.port, path=request.url.path) req = self._client.build_request(request.method, url, json=request_json, extensions=request_extensions) - return self._client.send(req, stream=True) + return self._client.send(req, stream=stream) def set_request_path(self: _ClientAdapter, path: str) -> None: