From 734ba60d8c40e55b540a4613a9369b906f0d866a Mon Sep 17 00:00:00 2001 From: Axell Padilla <68310020+axellpadilla@users.noreply.github.com> Date: Mon, 15 Jun 2026 06:34:18 +0000 Subject: [PATCH] chore: update stale comment in table_dml_refresh to reflect dbt_sqlserver_use_dbt_transactions flag --- README.md | 16 ++ dbt/adapters/sqlserver/sqlserver_adapter.py | 14 ++ .../sqlserver/sqlserver_connections.py | 28 ++- .../macros/materializations/hooks.sql | 6 +- .../models/table/table_dml_refresh.sql | 7 +- .../adapter/dbt/test_transactions.py | 176 ++++++++++++++++++ .../test_sqlserver_connection_manager.py | 56 ++++++ 7 files changed, 298 insertions(+), 5 deletions(-) create mode 100644 tests/functional/adapter/dbt/test_transactions.py diff --git a/README.md b/README.md index 3c431ec0a..3bbd65729 100644 --- a/README.md +++ b/README.md @@ -148,6 +148,22 @@ vars: *(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required backend package (Python dependency), such as `pyodbc` or `mssql-python`, is not installed. +### `dbt_sqlserver_use_dbt_transactions` + +_(default: `false`)_ When enabled, makes dbt's transaction hooks real at the SQL Server level by emitting `BEGIN TRANSACTION` / `COMMIT TRANSACTION` through the adapter's `add_begin_query` and `add_commit_query` methods. The driver connection remains in autocommit mode (`autocommit=true`). + +The default is `false`, preserving existing behavior where `begin`/`commit` hooks are logical no-ops and the ODBC driver auto-commits each statement. When `dbt_sqlserver_use_dbt_transactions: true`, the adapter emits real T-SQL transaction statements, and rollback uses `IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION`. + +This mode is opt-in and should be tested carefully with project-specific materializations and hooks. + +```yaml +# dbt_project.yml +flags: + dbt_sqlserver_use_dbt_transactions: true # <-- opt-in; default is false +``` + +**Compatibility notes:** Enabling `dbt_sqlserver_use_dbt_transactions: true` may expose transaction-state assumptions that were hidden by autocommit without transaction management. Explicit transaction macros may interact with dbt-managed transactions, and cleanup after failed DDL/DML may differ from current behavior. + ## Contributing [![Unit tests](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml) diff --git a/dbt/adapters/sqlserver/sqlserver_adapter.py b/dbt/adapters/sqlserver/sqlserver_adapter.py index a575cf42c..5f9843f3f 100644 --- a/dbt/adapters/sqlserver/sqlserver_adapter.py +++ b/dbt/adapters/sqlserver/sqlserver_adapter.py @@ -55,6 +55,9 @@ def __init__(self, config, mp_context=None): ) if self.behavior.dbt_sqlserver_use_native_string_types: self.Column = SQLServerColumnNative + SQLServerConnectionManager._dbt_sqlserver_use_dbt_transactions = bool( + self.behavior.dbt_sqlserver_use_dbt_transactions + ) @property def _behavior_flags(self) -> List[BehaviorFlag]: @@ -99,6 +102,17 @@ def _behavior_flags(self) -> List[BehaviorFlag]: "The new behaviour is intended to become the default in a future release." ), }, + { + "name": "dbt_sqlserver_use_dbt_transactions", + "default": False, + "description": ( + "When True, dbt transaction hooks (begin/commit) emit real T-SQL " + "BEGIN TRANSACTION / COMMIT TRANSACTION statements. " + "When False (default), begin/commit are no-ops and each statement " + "is auto-committed by the driver." + "True will be the default in v1.11.0" + ), + }, ] @available.parse(lambda *a, **k: []) diff --git a/dbt/adapters/sqlserver/sqlserver_connections.py b/dbt/adapters/sqlserver/sqlserver_connections.py index aa0beede0..a6c47b1e7 100644 --- a/dbt/adapters/sqlserver/sqlserver_connections.py +++ b/dbt/adapters/sqlserver/sqlserver_connections.py @@ -1,5 +1,6 @@ import datetime as dt import time +import traceback from contextlib import contextmanager from typing import ( Any, @@ -25,6 +26,7 @@ from dbt.adapters.events.types import ( AdapterEventDebug, ConnectionUsed, + RollbackFailed, SQLQuery, SQLQueryStatus, ) @@ -63,6 +65,8 @@ class SQLServerConnectionManager(SQLConnectionManager): TYPE = "sqlserver" + _dbt_sqlserver_use_dbt_transactions: bool = False + @contextmanager def exception_handler(self, sql): """Translate backend database errors and re-raise everything else. @@ -142,10 +146,30 @@ def cancel(self, connection: Connection): logger.debug("Cancel query") def add_begin_query(self): - pass + if self._dbt_sqlserver_use_dbt_transactions: + return self.add_query("BEGIN TRANSACTION", auto_begin=False) def add_commit_query(self): - pass + if self._dbt_sqlserver_use_dbt_transactions: + return self.add_query("IF @@TRANCOUNT > 0 COMMIT TRANSACTION", auto_begin=False) + + @classmethod + def _rollback_handle(cls, connection: Connection) -> None: + if cls._dbt_sqlserver_use_dbt_transactions: + try: + cursor = connection.handle.cursor() + cursor.execute("IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION") + cursor.close() + except Exception: + fire_event( + RollbackFailed( + conn_name=cast_to_str(connection.name), + exc_info=traceback.format_exc(), + node_info=get_node_info(), + ) + ) + else: + super()._rollback_handle(connection) def add_query( self, diff --git a/dbt/include/sqlserver/macros/materializations/hooks.sql b/dbt/include/sqlserver/macros/materializations/hooks.sql index 51b01fab2..a47895bca 100644 --- a/dbt/include/sqlserver/macros/materializations/hooks.sql +++ b/dbt/include/sqlserver/macros/materializations/hooks.sql @@ -2,7 +2,11 @@ {% for hook in hooks | selectattr('transaction', 'equalto', inside_transaction) %} {% if not inside_transaction and loop.first %} {% call statement(auto_begin=inside_transaction) %} - if @@trancount > 0 commit; + {% if not flags.dbt_sqlserver_use_dbt_transactions %} + if @@trancount > 0 commit; -- post hooks after fictitious transaction work as expected + {% else %} + commit; --proves real transactions are correctly managed now like core. + {% endif %} {% endcall %} {% endif %} {% set rendered = render(hook.get('sql')) | trim %} diff --git a/dbt/include/sqlserver/macros/materializations/models/table/table_dml_refresh.sql b/dbt/include/sqlserver/macros/materializations/models/table/table_dml_refresh.sql index 4a5d095c4..991cf31a0 100644 --- a/dbt/include/sqlserver/macros/materializations/models/table/table_dml_refresh.sql +++ b/dbt/include/sqlserver/macros/materializations/models/table/table_dml_refresh.sql @@ -54,8 +54,11 @@ {%- set column_list = target_columns | map(attribute='quoted') | join(', ') -%} {# Atomic DML swap — RCSI protects concurrent readers #} - {# dbt-sqlserver uses autocommit=True and add_begin_query/add_commit_query #} - {# are no-ops, so this creates a simple (non-nested) transaction. #} + {# When dbt_sqlserver_use_dbt_transactions is off (default), autocommit #} + {# ensures only this explicit transaction exists. When the flag is on, #} + {# the statement call auto-begins an outer transaction first, but SQL #} + {# Server handles the nesting correctly — only the outermost COMMIT #} + {# actually commits. #} {% call statement('dml_refresh_swap') -%} BEGIN TRANSACTION; DELETE FROM {{ target_relation }}; diff --git a/tests/functional/adapter/dbt/test_transactions.py b/tests/functional/adapter/dbt/test_transactions.py new file mode 100644 index 000000000..976ece10c --- /dev/null +++ b/tests/functional/adapter/dbt/test_transactions.py @@ -0,0 +1,176 @@ +import pytest + +from dbt.tests.util import run_dbt + + +class BaseTransactionsEnabled: + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_dbt_transactions": True}} + + +class TestTableMaterializationTransactionsOn(BaseTransactionsEnabled): + @pytest.fixture(scope="class") + def models(self): + return { + "table_model.sql": """ +{{ config(materialized='table') }} +select 1 as id, 'hello' as name +""", + } + + def test_table_materialization(self, project): + results = run_dbt(["run"]) + assert len(results) == 1 + + rows = project.run_sql("select id, name from {schema}.table_model", fetch="all") + assert len(rows) == 1 + assert rows[0][0] == 1 + assert rows[0][1] == "hello" + + +class TestViewMaterializationTransactionsOn(BaseTransactionsEnabled): + @pytest.fixture(scope="class") + def models(self): + return { + "view_model.sql": """ +{{ config(materialized='view') }} +select 42 as answer +""", + } + + def test_view_materialization(self, project): + results = run_dbt(["run"]) + assert len(results) == 1 + + rows = project.run_sql("select answer from {schema}.view_model", fetch="all") + assert len(rows) == 1 + assert rows[0][0] == 42 + + +class TestIncrementalMaterializationTransactionsOn(BaseTransactionsEnabled): + @pytest.fixture(scope="class") + def models(self): + return { + "incremental_model.sql": """ +{{ config(materialized='incremental', unique_key='id') }} +select 1 as id, 'first' as value +{% if is_incremental() %} +union all +select 2 as id, 'second' as value +{% endif %} +""", + } + + def test_incremental_materialization(self, project): + results = run_dbt(["run"]) + assert len(results) == 1 + + rows = project.run_sql( + "select count(*) as cnt from {schema}.incremental_model", fetch="one" + ) + assert rows[0] == 1 + + results = run_dbt(["run"]) + assert len(results) == 1 + + rows = project.run_sql( + "select count(*) as cnt from {schema}.incremental_model", fetch="one" + ) + assert rows[0] == 2 + + +class BaseFailingModelWithSideEffect: + @pytest.fixture(scope="class") + def models(self): + return { + "failing_model.sql": """ +{{ config( + materialized='table', + pre_hook=[ + "INSERT INTO {{ this.schema }}.audit_log " + "(msg, created_at) VALUES ('from_model', getdate())" + ] +) }} +select 1/0 as boom +""", + } + + +class TestRollbackWithoutFlag(BaseFailingModelWithSideEffect): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_dbt_transactions": False}} + + @pytest.mark.xfail( + strict=True, + reason="Without transactions flag, DML in pre-hooks is auto-committed and not rolled back," + " remove after migration to always use transactions.", + ) + def test_side_effect_rolled_back(self, project): + project.run_sql("CREATE TABLE {schema}.audit_log (msg varchar(100), created_at datetime)") + run_dbt(["run", "-m", "failing_model"], expect_pass=False) + rows = project.run_sql("SELECT COUNT(*) FROM {schema}.audit_log", fetch="one") + assert rows[0] == 0 + + +class TestRollbackWithFlag(BaseFailingModelWithSideEffect): + @pytest.fixture(scope="class") + def project_config_update(self): + return {"flags": {"dbt_sqlserver_use_dbt_transactions": True}} + + def test_side_effect_rolled_back(self, project): + project.run_sql("CREATE TABLE {schema}.audit_log (msg varchar(100), created_at datetime)") + run_dbt(["run", "-m", "failing_model"], expect_pass=False) + rows = project.run_sql("SELECT COUNT(*) FROM {schema}.audit_log", fetch="one") + assert rows[0] == 0 + + +class TestAfterCommitModelHookTransactionsOn(BaseTransactionsEnabled): + @pytest.fixture(scope="class") + def project_config_update(self): + return { + "flags": {"dbt_sqlserver_use_dbt_transactions": True}, + "models": { + "test": { + "post-hook": [ + {"sql": "select 1", "transaction": False}, + ], + } + }, + } + + @pytest.fixture(scope="class") + def models(self): + return {"after_commit_hook_model.sql": "select 1 as id"} + + def test_after_commit_post_hook_does_not_double_commit(self, project): + run_dbt() + + +class TestFailedModelThenSuccessTransactionsOn(BaseTransactionsEnabled): + @pytest.fixture(scope="class") + def models(self): + return { + "good_model.sql": """ +{{ config(materialized='table') }} +select 1 as id +""", + "bad_model.sql": """ +{{ config(materialized='table') }} +select 1/0 as boom +""", + } + + def test_failed_then_successful_run(self, project): + results = run_dbt(["run", "-m", "bad_model"], expect_pass=False) + assert len(results) == 1 + assert results[0].status == "error" + + results = run_dbt(["run", "-m", "good_model"]) + assert len(results) == 1 + assert results[0].status == "success" + + rows = project.run_sql("select id from {schema}.good_model", fetch="all") + assert len(rows) == 1 + assert rows[0][0] == 1 diff --git a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py index e1e315dd4..848e00bff 100644 --- a/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py +++ b/tests/unit/adapters/mssql/test_sqlserver_connection_manager.py @@ -1460,3 +1460,59 @@ def fake_connect(connection_string, attrs_before, autocommit, timeout): assert captured["autocommit"] is True assert captured["timeout"] == credentials.login_timeout assert "Pooling=true" in captured["connection_string"] + + +@pytest.mark.parametrize("flag_value", [True, False]) +def test_add_begin_query_respects_dbt_sqlserver_use_dbt_transactions( + monkeypatch: pytest.MonkeyPatch, + flag_value: bool, +) -> None: + manager = object.__new__(SQLServerConnectionManager) + monkeypatch.setattr( + SQLServerConnectionManager, "_dbt_sqlserver_use_dbt_transactions", flag_value + ) + + add_query_calls: list[tuple[str, bool]] = [] + + def fake_add_query(sql, auto_begin=True): + add_query_calls.append((sql, auto_begin)) + return None, None + + monkeypatch.setattr(manager, "add_query", fake_add_query) + + result = manager.add_begin_query() + + if flag_value: + assert add_query_calls == [("BEGIN TRANSACTION", False)] + assert result == (None, None) + else: + assert result is None + assert add_query_calls == [] + + +@pytest.mark.parametrize("flag_value", [True, False]) +def test_add_commit_query_respects_dbt_sqlserver_use_dbt_transactions( + monkeypatch: pytest.MonkeyPatch, + flag_value: bool, +) -> None: + manager = object.__new__(SQLServerConnectionManager) + monkeypatch.setattr( + SQLServerConnectionManager, "_dbt_sqlserver_use_dbt_transactions", flag_value + ) + + add_query_calls: list[tuple[str, bool]] = [] + + def fake_add_query(sql, auto_begin=True): + add_query_calls.append((sql, auto_begin)) + return None, None + + monkeypatch.setattr(manager, "add_query", fake_add_query) + + result = manager.add_commit_query() + + if flag_value: + assert add_query_calls == [("IF @@TRANCOUNT > 0 COMMIT TRANSACTION", False)] + assert result == (None, None) + else: + assert result is None + assert add_query_calls == []