Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- id: mixed-line-ending
- id: trailing-whitespace
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: "v0.15.17"
rev: "v0.15.18"
hooks:
- id: ruff
args: ["--fix"]
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ maintainers = [{ name = "Litestar Developers", email = "hello@litestar.dev" }]
name = "sqlspec"
readme = "README.md"
requires-python = ">=3.10, <4.0"
version = "0.50.0"
version = "0.50.1"

[project.urls]
Discord = "https://discord.gg/litestar"
Expand Down Expand Up @@ -201,6 +201,7 @@ packages = []




[tool.hatch.build.targets.wheel.hooks.mypyc]
dependencies = ["hatch-mypyc", "hatch-cython", "mypy>=2.0.0"]
enable-by-default = false
Expand Down Expand Up @@ -302,7 +303,7 @@ opt_level = "3" # Maximum optimization (0-3)
allow_dirty = true
commit = false
commit_args = "--no-verify"
current_version = "0.50.0"
current_version = "0.50.1"
ignore_missing_files = false
ignore_missing_version = false
message = "chore(release): bump to v{new_version}"
Expand Down
9 changes: 9 additions & 0 deletions sqlspec/adapters/duckdb/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ class DuckDBExtensionConfig(TypedDict):
force_install: NotRequired[bool]
"""Force reinstallation of the extension."""

install: NotRequired[bool]
"""Force an explicit install_extension() call even for a name-only config."""

required: NotRequired[bool]
"""When True, install/load failure raises instead of best-effort WARNING."""


class DuckDBSecretConfig(TypedDict):
"""DuckDB secret configuration for AI/API integrations."""
Expand All @@ -134,6 +140,9 @@ class DuckDBSecretConfig(TypedDict):
persistent: NotRequired[bool]
"""Persist the secret to DuckDB's configured secret directory."""

required: NotRequired[bool]
"""When True, secret-creation failure raises (and is verified). Default best-effort."""


class DuckDBDriverFeatures(TypedDict):
"""TypedDict for DuckDB driver features configuration.
Expand Down
100 changes: 79 additions & 21 deletions sqlspec/adapters/duckdb/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


_SQL_IDENTIFIER_RE: Final[re.Pattern[str]] = re.compile(r"^[A-Za-z][A-Za-z0-9_]*$")
_EXPLICIT_INSTALL_KEYS: Final[tuple[str, ...]] = ("version", "repository", "repository_url")


logger = get_logger(POOL_LOGGER_NAME)
Expand Down Expand Up @@ -50,6 +51,7 @@ class DuckDBConnectionPool:
"_extension_flags",
"_extensions",
"_health_check_interval",
"_installed_signatures",
"_is_memory_db",
"_lock",
"_on_connection_create",
Expand Down Expand Up @@ -87,6 +89,7 @@ def __init__(
self._extension_flags = extension_flags or {}
self._secrets = secrets or []
self._on_connection_create = on_connection_create
self._installed_signatures: set[tuple[Any, ...]] = set()
self._thread_local = threading.local()
self._lock = threading.RLock()
self._pool_id = str(uuid.uuid4())[:8]
Expand Down Expand Up @@ -128,21 +131,29 @@ def _create_connection(self) -> DuckDBConnection:
if not ext_name:
continue

install_kwargs = {}
if "version" in ext_config:
install_kwargs["version"] = ext_config["version"]
if "repository" in ext_config:
install_kwargs["repository"] = ext_config["repository"]
if "repository_url" in ext_config:
install_kwargs["repository_url"] = ext_config["repository_url"]
if ext_config.get("force_install", False):
install_kwargs["force_install"] = True

if install_kwargs:
connection.install_extension(ext_name, **install_kwargs)
else:
connection.install_extension(ext_name)
connection.load_extension(ext_name)
required = bool(ext_config.get("required", False))
install_kwargs: dict[str, Any] = {k: ext_config[k] for k in _EXPLICIT_INSTALL_KEYS if k in ext_config}
force_install = bool(ext_config.get("force_install", False))
explicit_install = bool(ext_config.get("install", False)) or force_install or bool(install_kwargs)

if explicit_install:
self._install_extension_once(connection, ext_name, install_kwargs, force_install, required)

try:
connection.load_extension(ext_name)
except Exception as exc:
if required:
raise
log_with_context(
logger,
logging.WARNING,
"pool.extension.load.failed",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
extension=ext_name,
error=str(exc),
)

for secret_config in self._secrets:
_create_secret(connection, secret_config)
Expand All @@ -154,6 +165,47 @@ def _create_connection(self) -> DuckDBConnection:

return connection

def _install_extension_once(
self,
connection: "DuckDBConnection",
ext_name: str,
install_kwargs: "dict[str, Any]",
force_install: bool,
required: bool,
) -> None:
"""Install an extension once per pool per signature, best-effort unless required."""
signature = (
ext_name,
install_kwargs.get("version"),
install_kwargs.get("repository"),
install_kwargs.get("repository_url"),
)
with self._lock:
if not force_install and signature in self._installed_signatures:
return
try:
if force_install:
connection.install_extension(ext_name, force_install=True, **install_kwargs)
elif install_kwargs:
connection.install_extension(ext_name, **install_kwargs)
else:
connection.install_extension(ext_name)
except Exception as exc:
if required:
raise
log_with_context(
logger,
logging.WARNING,
"pool.extension.install.failed",
adapter=_ADAPTER_NAME,
pool_id=self._pool_id,
database=self._database_name,
extension=ext_name,
error=str(exc),
)
return
self._installed_signatures.add(signature)

def _apply_extension_flags(self, connection: DuckDBConnection) -> None:
"""Apply connection-level extension flags via SET statements."""

Expand Down Expand Up @@ -335,12 +387,18 @@ def _create_secret(connection: DuckDBConnection, secret_config: dict[str, Any])
if not (secret_name and secret_type):
return

_validate_sql_identifier(secret_name, "secret_name")
_validate_sql_identifier(secret_type, "secret_type")

sql = _build_secret_sql(secret_config, secret_name, secret_type)
connection.execute(sql)
_verify_secret(connection, secret_config, secret_name, secret_type)
required = bool(secret_config.get("required", False))
try:
_validate_sql_identifier(secret_name, "secret_name")
_validate_sql_identifier(secret_type, "secret_type")
sql = _build_secret_sql(secret_config, secret_name, secret_type)
connection.execute(sql)
if required:
_verify_secret(connection, secret_config, secret_name, secret_type)
except Exception:
if required:
raise
logger.warning("DuckDB secret %r creation failed (best-effort)", secret_name)


def _build_secret_sql(secret_config: dict[str, Any], secret_name: str, secret_type: str) -> str:
Expand Down
183 changes: 183 additions & 0 deletions tests/unit/adapters/test_duckdb/test_extension_lifecycle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
"""Call-count tests for DuckDB extension install/load lifecycle.

These tests pin the install-vs-load contract without touching the network or
relying on timing. A spy connection records every install/load call so we can
assert exact call counts:

* name-only ``{"name": X}`` => LOAD-ONLY (install never called)
* explicit install (``install=True`` / ``force_install`` / version / repository
/ repository_url) => install runs once per pool per signature
* ``load_extension`` runs per physical connection
* best-effort by default; ``required=True`` raises
"""

import logging
import threading
from typing import Any

import pytest

from sqlspec.adapters.duckdb.pool import DuckDBConnectionPool

pytest.importorskip("duckdb", reason="DuckDB adapter requires duckdb package")


class _SpyConnection:
"""Fake DuckDB connection recording install/load calls into shared logs."""

def __init__(
self,
install_log: "list[tuple[str, dict[str, Any]]]",
load_log: "list[str]",
*,
fail_install: bool = False,
fail_load: bool = False,
) -> None:
self._install_log = install_log
self._load_log = load_log
self._fail_install = fail_install
self._fail_load = fail_load

def install_extension(self, extension: str, **kwargs: Any) -> None:
self._install_log.append((extension, kwargs))
if self._fail_install:
msg = "install failed"
raise RuntimeError(msg)

def load_extension(self, extension: str) -> None:
if self._fail_load:
msg = "load failed"
raise RuntimeError(msg)
self._load_log.append(extension)

def execute(self, sql: str, parameters: Any = None) -> "_SpyConnection":
return self

def fetchone(self) -> "tuple[Any, ...] | None":
return None

def cursor(self) -> "_SpyConnection":
return self

def close(self) -> None:
pass


def _spy_connect(
monkeypatch: pytest.MonkeyPatch, *, fail_install: bool = False, fail_load: bool = False
) -> "tuple[list[tuple[str, dict[str, Any]]], list[str]]":
"""Patch ``duckdb.connect`` to return spies sharing install/load logs."""
install_log: list[tuple[str, dict[str, Any]]] = []
load_log: list[str] = []

def fake_connect(**_: Any) -> _SpyConnection:
return _SpyConnection(install_log, load_log, fail_install=fail_install, fail_load=fail_load)

monkeypatch.setattr("sqlspec.adapters.duckdb.pool.duckdb.connect", fake_connect)
return install_log, load_log


def test_name_only_extension_never_installs(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test A: name-only extensions LOAD only, never install."""
install_log, load_log = _spy_connect(monkeypatch)
pool = DuckDBConnectionPool({"database": ":memory:"}, extensions=[{"name": "postgres"}])

pool._create_connection()

assert install_log == []
assert load_log == ["postgres"]


def test_explicit_install_runs_once_across_sessions(monkeypatch: pytest.MonkeyPatch, tmp_path: Any) -> None:
"""Test B: explicit install runs once per pool even across reconnects."""
install_log, load_log = _spy_connect(monkeypatch)
pool = DuckDBConnectionPool(
{"database": str(tmp_path / "assessment.db")}, extensions=[{"name": "postgres", "install": True}]
)

for _ in range(3):
pool._create_connection()

assert len(install_log) == 1
assert load_log == ["postgres", "postgres", "postgres"]


def test_concurrent_sessions_do_not_multiply_installs(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test C: 8 concurrent connection builds install exactly once."""
install_log, load_log = _spy_connect(monkeypatch)
pool = DuckDBConnectionPool({"database": ":memory:"}, extensions=[{"name": "postgres", "install": True}])

threads = [threading.Thread(target=pool._create_connection) for _ in range(8)]
for thread in threads:
thread.start()
for thread in threads:
thread.join()

assert len(install_log) == 1
assert len(load_log) == 8


@pytest.mark.parametrize(
"extension",
[
{"name": "h3", "version": "1.0"},
{"name": "h3", "repository": "community"},
{"name": "h3", "repository_url": "https://ext.example.test"},
],
)
def test_version_repository_imply_install(monkeypatch: pytest.MonkeyPatch, extension: "dict[str, Any]") -> None:
"""Test D: version/repository/repository_url imply an explicit install."""
install_log, load_log = _spy_connect(monkeypatch)
pool = DuckDBConnectionPool({"database": ":memory:"}, extensions=[extension])

pool._create_connection()

assert len(install_log) == 1
assert install_log[0][0] == "h3"
assert load_log == ["h3"]


def test_load_failure_is_best_effort_by_default(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
"""Test E (load): a failing LOAD is swallowed with a WARNING by default."""
_spy_connect(monkeypatch, fail_load=True)
pool = DuckDBConnectionPool({"database": ":memory:"}, extensions=[{"name": "postgres"}])

with caplog.at_level(logging.WARNING):
pool._create_connection()

assert any("load" in record.message and record.levelno == logging.WARNING for record in caplog.records)


def test_load_failure_raises_when_required(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test E (load): a failing LOAD raises when required=True."""
_spy_connect(monkeypatch, fail_load=True)
pool = DuckDBConnectionPool({"database": ":memory:"}, extensions=[{"name": "postgres", "required": True}])

with pytest.raises(RuntimeError, match="load failed"):
pool._create_connection()


def test_install_failure_is_best_effort_by_default(
monkeypatch: pytest.MonkeyPatch, caplog: pytest.LogCaptureFixture
) -> None:
"""Test E (install): a failing INSTALL is swallowed with a WARNING by default."""
_spy_connect(monkeypatch, fail_install=True)
pool = DuckDBConnectionPool({"database": ":memory:"}, extensions=[{"name": "postgres", "install": True}])

with caplog.at_level(logging.WARNING):
pool._create_connection()

assert any("install" in record.message and record.levelno == logging.WARNING for record in caplog.records)


def test_install_failure_raises_when_required(monkeypatch: pytest.MonkeyPatch) -> None:
"""Test E (install): a failing INSTALL raises when required=True."""
_spy_connect(monkeypatch, fail_install=True)
pool = DuckDBConnectionPool(
{"database": ":memory:"}, extensions=[{"name": "postgres", "install": True, "required": True}]
)

with pytest.raises(RuntimeError, match="install failed"):
pool._create_connection()
Loading
Loading