Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions src/databricks/sqlalchemy/_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,58 @@ def bindparam_string(self, name, **kw):
return self._BIND_TEMPLATE % {"name": name.replace("`", "``")}
return super().bindparam_string(name, **kw)

@staticmethod
def _split_multivalue_bind_name(bind_name):
"""Split SQLAlchemy's ``<col>_m<idx>`` bind names into (column, idx)."""
match = re.match(r"^(?P<col>.+)_m(?P<idx>\d+)$", bind_name)
if not match:
return None
return match.group("col"), int(match.group("idx"))

def _build_multi_value_cast_plan(self, insert_stmt):
"""Return {bind_name: cast_sql_type} for multi-row VALUES insert binds.

This is a deterministic fix for Spark inline-table type reconciliation:
for SQLAlchemy-generated multi-row INSERT binds (``*_mN``), always cast
the marker to the bind's dialect SQL type so each column position in the
VALUES table has an explicit server-side type.
"""
if not getattr(insert_stmt, "_multi_values", None):
return {}

cast_plan = {}
for bind_name, bind_param in self.binds.items():
if self._split_multivalue_bind_name(bind_name) is None:
continue

type_engine = getattr(bind_param, "type", None)
if type_engine is None or isinstance(type_engine, sqltypes.NullType):
continue

dialect_type = type_engine._unwrapped_dialect_impl(self.dialect)
target_type = self.dialect.type_compiler_instance.process(
dialect_type, identifier_preparer=self.preparer
)
cast_plan[bind_name] = target_type

return cast_plan

def _apply_multi_value_casts(self, sql_text, insert_stmt):
"""Wrap selected ``:`name``` markers with ``CAST(... AS <type>)``."""
cast_plan = self._build_multi_value_cast_plan(insert_stmt)
if not cast_plan:
return sql_text

rendered = sql_text
for bind_name, target_type in cast_plan.items():
marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")}
rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})")
return rendered

def visit_insert(self, insert_stmt, **kw):
sql_text = super().visit_insert(insert_stmt, **kw)
return self._apply_multi_value_casts(sql_text, insert_stmt)

def limit_clause(self, select, **kw):
"""Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1,
since Databricks SQL doesn't support the latter.
Expand Down
83 changes: 83 additions & 0 deletions tests/test_local/e2e/test_pandas_multi_mixed_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import uuid

import pandas as pd
import pytest
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine


@pytest.fixture
def db_engine(connection_details) -> Engine:
host = connection_details["host"]
http_path = connection_details["http_path"]
access_token = connection_details["access_token"]
catalog = connection_details["catalog"]
schema = connection_details["schema"]

conn_string = (
f"databricks://token:{access_token}@{host}"
f"?http_path={http_path}&catalog={catalog}&schema={schema}"
)
engine = create_engine(
conn_string, connect_args={"_user_agent_entry": "SQLAlchemy pandas e2e tests"}
)
try:
yield engine
finally:
engine.dispose()


def test_pandas_to_sql_multi_mixed_object_column_succeeds(db_engine: Engine):
table_name = f"pecoblr_2746_e2e_{uuid.uuid4().hex[:8]}"
fq_table_name = f"`main`.`default`.`{table_name}`"
df = pd.DataFrame(
{
"name": ["alice", "bob", None],
"value": [1, 0, "NE"],
"score": [9.5, 8.1, None],
"active": [True, None, False],
}
)

try:
with db_engine.begin() as conn:
conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}"))
conn.execute(
text(
f"CREATE TABLE {fq_table_name} "
"(name STRING, value STRING, score DOUBLE, active BOOLEAN) "
"USING DELTA"
)
)

# This is the failing path from PECOBLR-2746 before the adaptive cast fix.
df.to_sql(
table_name, db_engine, schema="default", if_exists="append", index=False, method="multi"
)

with db_engine.begin() as conn:
rows = conn.execute(
text(
f"SELECT name, value, score, active FROM {fq_table_name} "
"ORDER BY CASE WHEN name IS NULL THEN 1 ELSE 0 END, name"
)
).fetchall()

assert len(rows) == 3
assert rows[0][0] == "alice"
assert rows[0][1] == "1"
assert rows[0][2] == pytest.approx(9.5)
assert rows[0][3] is True

assert rows[1][0] == "bob"
assert rows[1][1] == "0"
assert rows[1][2] == pytest.approx(8.1)
assert rows[1][3] is None

assert rows[2][0] is None
assert rows[2][1] == "NE"
assert rows[2][2] is None
assert rows[2][3] is False
finally:
with db_engine.begin() as conn:
conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}"))
54 changes: 54 additions & 0 deletions tests/test_local/test_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,3 +425,57 @@ def test_in_clause_expansion_renders_backticked_markers(self):
assert ":`col-name_1_1`" in expanded.statement
assert ":`col-name_1_2`" in expanded.statement
assert ":`col-name_1_3`" in expanded.statement


class TestMultiRowInsertCasts(DDLTestBase):
def test_multi_values_casts_mixed_type_column(self):
metadata = MetaData()
table = Table("t", metadata, Column("name", String()), Column("value", String()))
stmt = insert(table).values(
[
{"name": "alice", "value": 1},
{"name": "bob", "value": 0},
{"name": None, "value": "NE"},
]
)

sql = str(stmt.compile(bind=self.engine))

assert "CAST(:`value_m0` AS STRING)" in sql
assert "CAST(:`value_m1` AS STRING)" in sql
assert "CAST(:`value_m2` AS STRING)" in sql
assert "CAST(:`name_m0` AS STRING)" in sql
assert "CAST(:`name_m1` AS STRING)" in sql
assert "CAST(:`name_m2` AS STRING)" in sql

def test_homogeneous_multi_values_are_cast(self):
metadata = MetaData()
table = Table("t", metadata, Column("value", String()))
stmt = insert(table).values(
[{"value": "A"}, {"value": "B"}, {"value": "C"}]
)

sql = str(stmt.compile(bind=self.engine))
assert "CAST(:`value_m0` AS STRING)" in sql
assert "CAST(:`value_m1` AS STRING)" in sql
assert "CAST(:`value_m2` AS STRING)" in sql

def test_numeric_family_multi_values_are_cast(self):
metadata = MetaData()
table = Table("t", metadata, Column("score", Numeric()))
stmt = insert(table).values(
[{"score": 1}, {"score": 2.5}, {"score": 3}]
)

sql = str(stmt.compile(bind=self.engine))
assert "CAST(:`score_m0` AS DECIMAL)" in sql
assert "CAST(:`score_m1` AS DECIMAL)" in sql
assert "CAST(:`score_m2` AS DECIMAL)" in sql

def test_single_row_insert_does_not_render_casts(self):
metadata = MetaData()
table = Table("t", metadata, Column("value", String()))
stmt = insert(table).values({"value": "A"})

sql = str(stmt.compile(bind=self.engine))
assert "CAST(:`value` AS STRING)" not in sql
Loading