Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,22 @@ vars:

*(default: `pyodbc`)* Set to `mssql-python` in a profile target to use the `mssql-python` backend instead of `pyodbc`. The adapter fails if the required backend package (Python dependency), such as `pyodbc` or `mssql-python`, is not installed.

### `dbt_sqlserver_use_dbt_transactions`

_(default: `false`)_ When enabled, makes dbt's transaction hooks real at the SQL Server level by emitting `BEGIN TRANSACTION` / `COMMIT TRANSACTION` through the adapter's `add_begin_query` and `add_commit_query` methods. The driver connection remains in autocommit mode (`autocommit=true`).

The default is `false`, preserving existing behavior where `begin`/`commit` hooks are logical no-ops and the ODBC driver auto-commits each statement. When `dbt_sqlserver_use_dbt_transactions: true`, the adapter emits real T-SQL transaction statements, and rollback uses `IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION`.

This mode is opt-in and should be tested carefully with project-specific materializations and hooks.

```yaml
# dbt_project.yml
flags:
dbt_sqlserver_use_dbt_transactions: true # <-- opt-in; default is false
```

**Compatibility notes:** Enabling `dbt_sqlserver_use_dbt_transactions: true` may expose transaction-state assumptions that were hidden by autocommit without transaction management. Explicit transaction macros may interact with dbt-managed transactions, and cleanup after failed DDL/DML may differ from current behavior.

## Contributing

[![Unit tests](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml/badge.svg)](https://github.com/dbt-msft/dbt-sqlserver/actions/workflows/unit-tests.yml)
Expand Down
14 changes: 14 additions & 0 deletions dbt/adapters/sqlserver/sqlserver_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def __init__(self, config, mp_context=None):
)
if self.behavior.dbt_sqlserver_use_native_string_types:
self.Column = SQLServerColumnNative
SQLServerConnectionManager._dbt_sqlserver_use_dbt_transactions = bool(
self.behavior.dbt_sqlserver_use_dbt_transactions
)

@property
def _behavior_flags(self) -> List[BehaviorFlag]:
Expand Down Expand Up @@ -99,6 +102,17 @@ def _behavior_flags(self) -> List[BehaviorFlag]:
"The new behaviour is intended to become the default in a future release."
),
},
{
"name": "dbt_sqlserver_use_dbt_transactions",
"default": False,
"description": (
"When True, dbt transaction hooks (begin/commit) emit real T-SQL "
"BEGIN TRANSACTION / COMMIT TRANSACTION statements. "
"When False (default), begin/commit are no-ops and each statement "
"is auto-committed by the driver."
"True will be the default in v1.11.0"
),
},
]

@available.parse(lambda *a, **k: [])
Expand Down
28 changes: 26 additions & 2 deletions dbt/adapters/sqlserver/sqlserver_connections.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import datetime as dt
import time
import traceback
from contextlib import contextmanager
from typing import (
Any,
Expand All @@ -25,6 +26,7 @@
from dbt.adapters.events.types import (
AdapterEventDebug,
ConnectionUsed,
RollbackFailed,
SQLQuery,
SQLQueryStatus,
)
Expand Down Expand Up @@ -63,6 +65,8 @@
class SQLServerConnectionManager(SQLConnectionManager):
TYPE = "sqlserver"

_dbt_sqlserver_use_dbt_transactions: bool = False

@contextmanager
def exception_handler(self, sql):
"""Translate backend database errors and re-raise everything else.
Expand Down Expand Up @@ -142,10 +146,30 @@ def cancel(self, connection: Connection):
logger.debug("Cancel query")

def add_begin_query(self):
pass
if self._dbt_sqlserver_use_dbt_transactions:
return self.add_query("BEGIN TRANSACTION", auto_begin=False)

def add_commit_query(self):
pass
if self._dbt_sqlserver_use_dbt_transactions:
return self.add_query("IF @@TRANCOUNT > 0 COMMIT TRANSACTION", auto_begin=False)

@classmethod
def _rollback_handle(cls, connection: Connection) -> None:
if cls._dbt_sqlserver_use_dbt_transactions:
try:
cursor = connection.handle.cursor()
cursor.execute("IF @@TRANCOUNT > 0 ROLLBACK TRANSACTION")
cursor.close()
except Exception:
fire_event(
RollbackFailed(
conn_name=cast_to_str(connection.name),
exc_info=traceback.format_exc(),
node_info=get_node_info(),
)
)
else:
super()._rollback_handle(connection)

def add_query(
self,
Expand Down
6 changes: 5 additions & 1 deletion dbt/include/sqlserver/macros/materializations/hooks.sql

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hooks.sql reads the flag from a different source (flags) than the connection manager (adapter.behavior), and the two agree only while the default is False

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
{% for hook in hooks | selectattr('transaction', 'equalto', inside_transaction) %}
{% if not inside_transaction and loop.first %}
{% call statement(auto_begin=inside_transaction) %}
if @@trancount > 0 commit;
{% if not flags.dbt_sqlserver_use_dbt_transactions %}
if @@trancount > 0 commit; -- post hooks after fictitious transaction work as expected
{% else %}
commit; --proves real transactions are correctly managed now like core.
{% endif %}
{% endcall %}
{% endif %}
{% set rendered = render(hook.get('sql')) | trim %}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,11 @@
{%- set column_list = target_columns | map(attribute='quoted') | join(', ') -%}

{# Atomic DML swap — RCSI protects concurrent readers #}
{# dbt-sqlserver uses autocommit=True and add_begin_query/add_commit_query #}
{# are no-ops, so this creates a simple (non-nested) transaction. #}
{# When dbt_sqlserver_use_dbt_transactions is off (default), autocommit #}
{# ensures only this explicit transaction exists. When the flag is on, #}
{# the statement call auto-begins an outer transaction first, but SQL #}
{# Server handles the nesting correctly — only the outermost COMMIT #}
{# actually commits. #}
{% call statement('dml_refresh_swap') -%}
BEGIN TRANSACTION;
DELETE FROM {{ target_relation }};
Expand Down
176 changes: 176 additions & 0 deletions tests/functional/adapter/dbt/test_transactions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import pytest

from dbt.tests.util import run_dbt


class BaseTransactionsEnabled:
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {"dbt_sqlserver_use_dbt_transactions": True}}


class TestTableMaterializationTransactionsOn(BaseTransactionsEnabled):
@pytest.fixture(scope="class")
def models(self):
return {
"table_model.sql": """
{{ config(materialized='table') }}
select 1 as id, 'hello' as name
""",
}

def test_table_materialization(self, project):
results = run_dbt(["run"])
assert len(results) == 1

rows = project.run_sql("select id, name from {schema}.table_model", fetch="all")
assert len(rows) == 1
assert rows[0][0] == 1
assert rows[0][1] == "hello"


class TestViewMaterializationTransactionsOn(BaseTransactionsEnabled):
@pytest.fixture(scope="class")
def models(self):
return {
"view_model.sql": """
{{ config(materialized='view') }}
select 42 as answer
""",
}

def test_view_materialization(self, project):
results = run_dbt(["run"])
assert len(results) == 1

rows = project.run_sql("select answer from {schema}.view_model", fetch="all")
assert len(rows) == 1
assert rows[0][0] == 42


class TestIncrementalMaterializationTransactionsOn(BaseTransactionsEnabled):
@pytest.fixture(scope="class")
def models(self):
return {
"incremental_model.sql": """
{{ config(materialized='incremental', unique_key='id') }}
select 1 as id, 'first' as value
{% if is_incremental() %}
union all
select 2 as id, 'second' as value
{% endif %}
""",
}

def test_incremental_materialization(self, project):
results = run_dbt(["run"])
assert len(results) == 1

rows = project.run_sql(
"select count(*) as cnt from {schema}.incremental_model", fetch="one"
)
assert rows[0] == 1

results = run_dbt(["run"])
assert len(results) == 1

rows = project.run_sql(
"select count(*) as cnt from {schema}.incremental_model", fetch="one"
)
assert rows[0] == 2


class BaseFailingModelWithSideEffect:
@pytest.fixture(scope="class")
def models(self):
return {
"failing_model.sql": """
{{ config(
materialized='table',
pre_hook=[
"INSERT INTO {{ this.schema }}.audit_log "
"(msg, created_at) VALUES ('from_model', getdate())"
]
) }}
select 1/0 as boom
""",
}


class TestRollbackWithoutFlag(BaseFailingModelWithSideEffect):
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {"dbt_sqlserver_use_dbt_transactions": False}}

@pytest.mark.xfail(
strict=True,
reason="Without transactions flag, DML in pre-hooks is auto-committed and not rolled back,"
" remove after migration to always use transactions.",
)
def test_side_effect_rolled_back(self, project):
project.run_sql("CREATE TABLE {schema}.audit_log (msg varchar(100), created_at datetime)")
run_dbt(["run", "-m", "failing_model"], expect_pass=False)
rows = project.run_sql("SELECT COUNT(*) FROM {schema}.audit_log", fetch="one")
assert rows[0] == 0


class TestRollbackWithFlag(BaseFailingModelWithSideEffect):
@pytest.fixture(scope="class")
def project_config_update(self):
return {"flags": {"dbt_sqlserver_use_dbt_transactions": True}}

def test_side_effect_rolled_back(self, project):
project.run_sql("CREATE TABLE {schema}.audit_log (msg varchar(100), created_at datetime)")
run_dbt(["run", "-m", "failing_model"], expect_pass=False)
rows = project.run_sql("SELECT COUNT(*) FROM {schema}.audit_log", fetch="one")
assert rows[0] == 0


class TestAfterCommitModelHookTransactionsOn(BaseTransactionsEnabled):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"flags": {"dbt_sqlserver_use_dbt_transactions": True},
"models": {
"test": {
"post-hook": [
{"sql": "select 1", "transaction": False},
],
}
},
}

@pytest.fixture(scope="class")
def models(self):
return {"after_commit_hook_model.sql": "select 1 as id"}

def test_after_commit_post_hook_does_not_double_commit(self, project):
run_dbt()


class TestFailedModelThenSuccessTransactionsOn(BaseTransactionsEnabled):
@pytest.fixture(scope="class")
def models(self):
return {
"good_model.sql": """
{{ config(materialized='table') }}
select 1 as id
""",
"bad_model.sql": """
{{ config(materialized='table') }}
select 1/0 as boom
""",
}

def test_failed_then_successful_run(self, project):
results = run_dbt(["run", "-m", "bad_model"], expect_pass=False)
assert len(results) == 1
assert results[0].status == "error"

results = run_dbt(["run", "-m", "good_model"])
assert len(results) == 1
assert results[0].status == "success"

rows = project.run_sql("select id from {schema}.good_model", fetch="all")
assert len(rows) == 1
assert rows[0][0] == 1
56 changes: 56 additions & 0 deletions tests/unit/adapters/mssql/test_sqlserver_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,3 +1460,59 @@ def fake_connect(connection_string, attrs_before, autocommit, timeout):
assert captured["autocommit"] is True
assert captured["timeout"] == credentials.login_timeout
assert "Pooling=true" in captured["connection_string"]


@pytest.mark.parametrize("flag_value", [True, False])
def test_add_begin_query_respects_dbt_sqlserver_use_dbt_transactions(
monkeypatch: pytest.MonkeyPatch,
flag_value: bool,
) -> None:
manager = object.__new__(SQLServerConnectionManager)
monkeypatch.setattr(
SQLServerConnectionManager, "_dbt_sqlserver_use_dbt_transactions", flag_value
)

add_query_calls: list[tuple[str, bool]] = []

def fake_add_query(sql, auto_begin=True):
add_query_calls.append((sql, auto_begin))
return None, None

monkeypatch.setattr(manager, "add_query", fake_add_query)

result = manager.add_begin_query()

if flag_value:
assert add_query_calls == [("BEGIN TRANSACTION", False)]
assert result == (None, None)
else:
assert result is None
assert add_query_calls == []


@pytest.mark.parametrize("flag_value", [True, False])
def test_add_commit_query_respects_dbt_sqlserver_use_dbt_transactions(
monkeypatch: pytest.MonkeyPatch,
flag_value: bool,
) -> None:
manager = object.__new__(SQLServerConnectionManager)
monkeypatch.setattr(
SQLServerConnectionManager, "_dbt_sqlserver_use_dbt_transactions", flag_value
)

add_query_calls: list[tuple[str, bool]] = []

def fake_add_query(sql, auto_begin=True):
add_query_calls.append((sql, auto_begin))
return None, None

monkeypatch.setattr(manager, "add_query", fake_add_query)

result = manager.add_commit_query()

if flag_value:
assert add_query_calls == [("IF @@TRANCOUNT > 0 COMMIT TRANSACTION", False)]
assert result == (None, None)
else:
assert result is None
assert add_query_calls == []