diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 347003c..5ccd500 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -15,16 +15,24 @@ from hotdata import ApiClient, Configuration from hotdata.api.connections_api import ConnectionsApi +from hotdata.api.databases_api import DatabasesApi from hotdata.api.datasets_api import DatasetsApi from hotdata.api.indexes_api import IndexesApi from hotdata.api.saved_queries_api import SavedQueriesApi from hotdata.api.secrets_api import SecretsApi from hotdata.api.workspaces_api import WorkspacesApi +from hotdata.models.create_database_request import CreateDatabaseRequest REQUIRED_ENV = ("HOTDATA_SDK_TEST_API_KEY", "HOTDATA_SDK_TEST_WORKSPACE_ID") DEFAULT_API_URL = "https://api.hotdata.dev" +# Queries are scoped to a database via the X-Database-Id header. Databases no +# longer auto-expire, so rather than create one per run (which would leak), +# tests find-or-create a single stable database by name and reuse it — same +# pattern as the e2e suite. Name is not unique server-side; we key off it. +SHARED_DATABASE_NAME = "sdkci-shared" + @dataclass(frozen=True) class TestEnv: @@ -79,6 +87,23 @@ def connection_id(env: TestEnv) -> str: return env.connection_id +@pytest.fixture(scope="session") +def database_id(api_client: ApiClient) -> str: + """Id of the shared `sdkci-shared` database, creating it if absent. + + Queries require an `X-Database-Id` scope; databases persist (no auto-expiry) + so we reuse one across runs instead of creating-and-deleting per session. + """ + databases_api = DatabasesApi(api_client) + for db in databases_api.list_databases().databases: + if db.name == SHARED_DATABASE_NAME: + return db.id + created = databases_api.create_database( + CreateDatabaseRequest(name=SHARED_DATABASE_NAME) + ) + return created.id + + @pytest.fixture def sdkci_name() -> "callable[[str], str]": """Returns `sdkci--` so orphans are identifiable. @@ -109,6 +134,11 @@ def connections_api(api_client: ApiClient) -> ConnectionsApi: return ConnectionsApi(api_client) +@pytest.fixture +def databases_api(api_client: ApiClient) -> DatabasesApi: + return DatabasesApi(api_client) + + @pytest.fixture def indexes_api(api_client: ApiClient) -> IndexesApi: return IndexesApi(api_client) diff --git a/tests/integration/test_query_async_polling.py b/tests/integration/test_query_async_polling.py index fbf46df..a4d797b 100644 --- a/tests/integration/test_query_async_polling.py +++ b/tests/integration/test_query_async_polling.py @@ -40,13 +40,15 @@ def test_query_async_polling( query_api: QueryApi, query_runs_api: QueryRunsApi, results_api: ResultsApi, + database_id: str, ) -> None: # async=True with a small async_after_ms forces the run to come back as # AsyncQueryResponse rather than synchronous. The QueryResponse / async # response variants are union-shaped on the client; we treat anything with # query_run_id as the start of the polling loop. submitted = query_api.query( - QueryRequest(var_async=True, async_after_ms=1000, sql="SELECT 1 AS x") + QueryRequest(var_async=True, async_after_ms=1000, sql="SELECT 1 AS x"), + x_database_id=database_id, ) query_run_id = submitted.query_run_id assert query_run_id diff --git a/tests/integration/test_results_arrow.py b/tests/integration/test_results_arrow.py index 82d3163..e1a2311 100644 --- a/tests/integration/test_results_arrow.py +++ b/tests/integration/test_results_arrow.py @@ -46,13 +46,15 @@ def test_results_arrow( query_api: QueryApi, query_runs_api: QueryRunsApi, results_api: ResultsApi, + database_id: str, ) -> None: submitted = query_api.query( QueryRequest( var_async=True, async_after_ms=1000, sql="SELECT 1 AS x, 'hello' AS msg UNION ALL SELECT 2, 'world'", - ) + ), + x_database_id=database_id, ) query_run_id = submitted.query_run_id assert query_run_id diff --git a/tests/integration/test_saved_query_versioning.py b/tests/integration/test_saved_query_versioning.py index 61a875a..aacdbf4 100644 --- a/tests/integration/test_saved_query_versioning.py +++ b/tests/integration/test_saved_query_versioning.py @@ -13,7 +13,7 @@ def test_saved_query_versioning( - saved_queries_api: SavedQueriesApi, sdkci_name + saved_queries_api: SavedQueriesApi, sdkci_name, database_id: str ) -> None: name = sdkci_name("savedq-versioning") created_id: str | None = None @@ -50,7 +50,12 @@ def test_saved_query_versioning( f"expected versions 1,2,3 in {sorted(version_numbers)}" ) - executed = saved_queries_api.execute_saved_query(created.id) + # execute_saved_query runs SQL, so it also needs the database scope. + # The endpoint has no typed x_database_id param yet, so set the header + # directly via the _headers override. + executed = saved_queries_api.execute_saved_query( + created.id, _headers={"X-Database-Id": database_id} + ) assert executed.row_count == 1 assert executed.rows == [[3]] finally: