From b1d7ab27a50d7590029ed16a9147daf49efcec1e Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 2 Jun 2026 02:03:29 +0530 Subject: [PATCH 1/3] Fix mixed-type pandas multi-row inserts with targeted casts. Cast only mixed scalar bind groups in SQLAlchemy-generated multi-row INSERT VALUES statements so Spark resolves inline table columns consistently. Keep homogeneous, complex, and custom bind-expression types unchanged, and add regression coverage for PECOBLR-2746 plus advertised SQLAlchemy scalar and complex types. Signed-off-by: Madhavendra Rathore --- src/databricks/sqlalchemy/_ddl.py | 121 ++++++++ tests/test_local/e2e/test_complex_types.py | 128 +++++++- .../e2e/test_pandas_multi_mixed_types.py | 279 ++++++++++++++++++ tests/test_local/test_ddl.py | 72 ++++- 4 files changed, 582 insertions(+), 18 deletions(-) create mode 100644 tests/test_local/e2e/test_pandas_multi_mixed_types.py diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index f61673b..f3dd0cf 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -1,4 +1,7 @@ import re +from datetime import date, datetime, time +from numbers import Number +from uuid import UUID from sqlalchemy.sql import compiler, sqltypes import logging @@ -165,6 +168,124 @@ def bindparam_string(self, name, **kw): return self._BIND_TEMPLATE % {"name": name.replace("`", "``")} return super().bindparam_string(name, **kw) + @staticmethod + def _split_multivalue_bind_name(bind_name): + """Split SQLAlchemy's ``_m`` bind names into (column, idx).""" + match = re.match(r"^(?P.+)_m(?P\d+)$", bind_name) + if not match: + return None + return match.group("col"), int(match.group("idx")) + + @staticmethod + def _value_family(value): + """Return scalar value family; ``None`` means non-scalar/unsupported.""" + if value is None: + return "null" + if isinstance(value, bool): + return "bool" + if isinstance(value, Number): + return "number" + if isinstance(value, str): + return "string" + if isinstance(value, (bytes, bytearray, memoryview)): + return "binary" + if isinstance(value, (date, time, datetime)): + return "temporal" + if isinstance(value, UUID): + return "uuid" + return None + + @staticmethod + def _has_custom_bind_expression(type_engine): + """True if the type (or its impl) customizes bind-expression rendering.""" + type_cls = type(type_engine) + if ( + getattr(type_cls, "bind_expression", None) + is not sqltypes.TypeEngine.bind_expression + ): + return True + + impl = getattr(type_engine, "impl", None) + if impl is not None: + impl_cls = type(impl) + if ( + getattr(impl_cls, "bind_expression", None) + is not sqltypes.TypeEngine.bind_expression + ): + return True + return False + + def _build_multi_value_cast_plan(self, insert_stmt): + """Return {bind_name: cast_sql_type} for multi-row VALUES insert binds. + + Cast only *mixed scalar* multi-row bind groups. This avoids breaking + complex/custom bind types (e.g. ARRAY/MAP/VARIANT) while still fixing + Spark inline-table incompatibility for object columns that mix + primitive families (e.g. INT + STRING). + """ + if not getattr(insert_stmt, "_multi_values", None): + return {} + + grouped_binds = {} + for bind_name, bind_param in self.binds.items(): + split = self._split_multivalue_bind_name(bind_name) + if split is None: + continue + column_name, _ = split + grouped_binds.setdefault(column_name, []).append((bind_name, bind_param)) + + cast_plan = {} + for bind_entries in grouped_binds.values(): + families = set() + has_non_scalar = False + has_custom_bind_expression = False + + for _, bind_param in bind_entries: + value_family = self._value_family(getattr(bind_param, "value", None)) + if value_family is None: + has_non_scalar = True + break + if value_family != "null": + families.add(value_family) + + type_engine = getattr(bind_param, "type", None) + if type_engine is not None and self._has_custom_bind_expression( + type_engine + ): + has_custom_bind_expression = True + + if has_non_scalar or has_custom_bind_expression or len(families) <= 1: + continue + + for bind_name, bind_param in bind_entries: + type_engine = getattr(bind_param, "type", None) + if type_engine is None or isinstance(type_engine, sqltypes.NullType): + continue + + dialect_type = type_engine._unwrapped_dialect_impl(self.dialect) + target_type = self.dialect.type_compiler_instance.process( + dialect_type, identifier_preparer=self.preparer + ) + cast_plan[bind_name] = target_type + + return cast_plan + + def _apply_multi_value_casts(self, sql_text, insert_stmt): + """Wrap selected ``:`name``` markers with ``CAST(... AS )``.""" + cast_plan = self._build_multi_value_cast_plan(insert_stmt) + if not cast_plan: + return sql_text + + rendered = sql_text + for bind_name, target_type in cast_plan.items(): + marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")} + rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})") + return rendered + + def visit_insert(self, insert_stmt, **kw): + sql_text = super().visit_insert(insert_stmt, **kw) + return self._apply_multi_value_casts(sql_text, insert_stmt) + def limit_clause(self, select, **kw): """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1, since Databricks SQL doesn't support the latter. diff --git a/tests/test_local/e2e/test_complex_types.py b/tests/test_local/e2e/test_complex_types.py index 07cd637..db77493 100644 --- a/tests/test_local/e2e/test_complex_types.py +++ b/tests/test_local/e2e/test_complex_types.py @@ -11,7 +11,13 @@ DateTime, ) from collections.abc import Sequence -from databricks.sqlalchemy import TIMESTAMP, TINYINT, DatabricksArray, DatabricksMap, DatabricksVariant +from databricks.sqlalchemy import ( + TIMESTAMP, + TINYINT, + DatabricksArray, + DatabricksMap, + DatabricksVariant, +) from sqlalchemy.orm import DeclarativeBase, Session from sqlalchemy import select from datetime import date, datetime, time, timedelta, timezone @@ -20,6 +26,7 @@ import decimal import json + class TestComplexTypes(TestSetup): def _parse_to_common_type(self, value): """ @@ -175,8 +182,8 @@ class VariantTable(Base): "number": 123, "boolean": True, "array": [1, 2, 3], - "object": {"nested": "value"} - } + "object": {"nested": "value"}, + }, } return VariantTable, sample_data @@ -239,6 +246,44 @@ def test_map_table_creation_pandas(self): df_result = pd.read_sql(stmt, engine) assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) + def test_array_table_creation_pandas_multi(self): + table, sample_data = self.sample_array_table() + + with self.table_context(table) as engine: + df = pd.DataFrame([sample_data, sample_data | {"int_col": 2}]) + df.to_sql( + table.__tablename__, + engine, + if_exists="append", + index=False, + method="multi", + ) + + stmt = select(table).order_by(table.int_col) + df_result = pd.read_sql(stmt, engine) + assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) + expected_second = sample_data | {"int_col": 2} + assert self._recursive_compare(df_result.iloc[1].to_dict(), expected_second) + + def test_map_table_creation_pandas_multi(self): + table, sample_data = self.sample_map_table() + + with self.table_context(table) as engine: + df = pd.DataFrame([sample_data, sample_data | {"int_col": 2}]) + df.to_sql( + table.__tablename__, + engine, + if_exists="append", + index=False, + method="multi", + ) + + stmt = select(table).order_by(table.int_col) + df_result = pd.read_sql(stmt, engine) + assert self._recursive_compare(df_result.iloc[0].to_dict(), sample_data) + expected_second = sample_data | {"int_col": 2} + assert self._recursive_compare(df_result.iloc[1].to_dict(), expected_second) + def test_insert_variant_table_sqlalchemy(self): table, sample_data = self.sample_variant_table() @@ -253,7 +298,12 @@ def test_insert_variant_table_sqlalchemy(self): result = session.scalar(stmt) compare = {key: getattr(result, key) for key in sample_data.keys()} # Parse JSON values back to original format for comparison - for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + for key in [ + "variant_simple_col", + "variant_nested_col", + "variant_array_col", + "variant_mixed_col", + ]: if compare[key] is not None: compare[key] = json.loads(compare[key]) @@ -263,26 +313,76 @@ def test_variant_table_creation_pandas(self): table, sample_data = self.sample_variant_table() with self.table_context(table) as engine: - + df = pd.DataFrame([sample_data]) dtype_mapping = { "variant_simple_col": DatabricksVariant, "variant_nested_col": DatabricksVariant, "variant_array_col": DatabricksVariant, - "variant_mixed_col": DatabricksVariant + "variant_mixed_col": DatabricksVariant, } - df.to_sql(table.__tablename__, engine, if_exists="append", index=False, dtype=dtype_mapping) - + df.to_sql( + table.__tablename__, + engine, + if_exists="append", + index=False, + dtype=dtype_mapping, + ) + stmt = select(table) df_result = pd.read_sql(stmt, engine) result_dict = df_result.iloc[0].to_dict() # Parse JSON values back to original format for comparison - for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + for key in [ + "variant_simple_col", + "variant_nested_col", + "variant_array_col", + "variant_mixed_col", + ]: if result_dict[key] is not None: result_dict[key] = json.loads(result_dict[key]) assert result_dict == sample_data + def test_variant_table_creation_pandas_multi(self): + table, sample_data = self.sample_variant_table() + + with self.table_context(table) as engine: + second = sample_data | {"int_col": 2} + df = pd.DataFrame([sample_data, second]) + dtype_mapping = { + "variant_simple_col": DatabricksVariant, + "variant_nested_col": DatabricksVariant, + "variant_array_col": DatabricksVariant, + "variant_mixed_col": DatabricksVariant, + } + df.to_sql( + table.__tablename__, + engine, + if_exists="append", + index=False, + dtype=dtype_mapping, + method="multi", + ) + + stmt = select(table).order_by(table.int_col) + df_result = pd.read_sql(stmt, engine) + first_row = df_result.iloc[0].to_dict() + second_row = df_result.iloc[1].to_dict() + for key in [ + "variant_simple_col", + "variant_nested_col", + "variant_array_col", + "variant_mixed_col", + ]: + if first_row[key] is not None: + first_row[key] = json.loads(first_row[key]) + if second_row[key] is not None: + second_row[key] = json.loads(second_row[key]) + + assert first_row == sample_data + assert second_row == second + def test_variant_literal_processor(self): table, sample_data = self.sample_variant_table() @@ -291,8 +391,7 @@ def test_variant_literal_processor(self): try: compiled = stmt.compile( - dialect=engine.dialect, - compile_kwargs={"literal_binds": True} + dialect=engine.dialect, compile_kwargs={"literal_binds": True} ) sql_str = str(compiled) @@ -311,7 +410,12 @@ def test_variant_literal_processor(self): compare = {key: getattr(result, key) for key in sample_data.keys()} # Parse JSON values back to original Python objects - for key in ['variant_simple_col', 'variant_nested_col', 'variant_array_col', 'variant_mixed_col']: + for key in [ + "variant_simple_col", + "variant_nested_col", + "variant_array_col", + "variant_mixed_col", + ]: if compare[key] is not None: compare[key] = json.loads(compare[key]) diff --git a/tests/test_local/e2e/test_pandas_multi_mixed_types.py b/tests/test_local/e2e/test_pandas_multi_mixed_types.py new file mode 100644 index 0000000..9d38576 --- /dev/null +++ b/tests/test_local/e2e/test_pandas_multi_mixed_types.py @@ -0,0 +1,279 @@ +import uuid +import json +from datetime import date, datetime, time, timedelta, timezone +from decimal import Decimal +from uuid import UUID + +import pandas as pd +import pytest +from sqlalchemy import Uuid, create_engine, text +from sqlalchemy.engine import Engine + +from databricks.sqlalchemy import DatabricksVariant + + +@pytest.fixture +def db_engine(connection_details) -> Engine: + host = connection_details["host"] + http_path = connection_details["http_path"] + access_token = connection_details["access_token"] + catalog = connection_details["catalog"] + schema = connection_details["schema"] + + conn_string = ( + f"databricks://token:{access_token}@{host}" + f"?http_path={http_path}&catalog={catalog}&schema={schema}" + ) + engine = create_engine( + conn_string, connect_args={"_user_agent_entry": "SQLAlchemy pandas e2e tests"} + ) + try: + yield engine + finally: + engine.dispose() + + +def test_pandas_to_sql_multi_mixed_object_column_succeeds(db_engine: Engine): + table_name = f"pecoblr_2746_e2e_{uuid.uuid4().hex[:8]}" + fq_table_name = f"`main`.`default`.`{table_name}`" + df = pd.DataFrame( + { + "name": ["alice", "bob", None], + "value": [1, 0, "NE"], + "score": [9.5, 8.1, None], + "active": [True, None, False], + } + ) + + try: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) + conn.execute( + text( + f"CREATE TABLE {fq_table_name} " + "(name STRING, value STRING, score DOUBLE, active BOOLEAN) " + "USING DELTA" + ) + ) + + # This is the failing path from PECOBLR-2746 before the adaptive cast fix. + df.to_sql( + table_name, + db_engine, + schema="default", + if_exists="append", + index=False, + method="multi", + ) + + with db_engine.begin() as conn: + rows = conn.execute( + text( + f"SELECT name, value, score, active FROM {fq_table_name} " + "ORDER BY CASE WHEN name IS NULL THEN 1 ELSE 0 END, name" + ) + ).fetchall() + + assert len(rows) == 3 + assert rows[0][0] == "alice" + assert rows[0][1] == "1" + assert rows[0][2] == pytest.approx(9.5) + assert rows[0][3] is True + + assert rows[1][0] == "bob" + assert rows[1][1] == "0" + assert rows[1][2] == pytest.approx(8.1) + assert rows[1][3] is None + + assert rows[2][0] is None + assert rows[2][1] == "NE" + assert rows[2][2] is None + assert rows[2][3] is False + finally: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) + + +def test_pandas_to_sql_multi_example_types_succeeds(db_engine: Engine): + table_name = f"pecoblr_2746_example_types_{uuid.uuid4().hex[:8]}" + fq_table_name = f"`main`.`default`.`{table_name}`" + base_variant = { + "name": "John Doe", + "age": 30, + "address": {"city": "San Francisco", "state": "CA"}, + "hobbies": ["reading", "hiking"], + "is_active": True, + } + rows = [ + { + "bigint_col": 1234567890123456789, + "string_col": "foo", + "tinyint_col": -100, + "int_col": 5280, + "numeric_col": Decimal("525600.01"), + "boolean_col": True, + "date_col": date(2020, 12, 25), + "datetime_col": datetime( + 1991, 8, 3, 21, 30, 5, tzinfo=timezone(timedelta(hours=-8)) + ), + "datetime_col_ntz": datetime(1990, 12, 4, 6, 33, 41), + "time_col": time(23, 59, 59), + "uuid_col": UUID(int=255), + "variant_col": base_variant, + }, + { + "bigint_col": 2234567890123456789, + "string_col": "bar", + "tinyint_col": 100, + "int_col": 42, + "numeric_col": Decimal("123.45"), + "boolean_col": False, + "date_col": date(2021, 1, 2), + "datetime_col": datetime( + 1992, 9, 4, 22, 31, 6, tzinfo=timezone(timedelta(hours=-7)) + ), + "datetime_col_ntz": datetime(1991, 1, 5, 7, 34, 42), + "time_col": time(1, 2, 3), + "uuid_col": UUID(int=256), + "variant_col": base_variant | {"name": "Jane Doe"}, + }, + ] + df = pd.DataFrame(rows) + + try: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) + conn.execute( + text( + f"CREATE TABLE {fq_table_name} (" + "bigint_col BIGINT, " + "string_col STRING, " + "tinyint_col TINYINT, " + "int_col INT, " + "numeric_col DECIMAL(10, 2), " + "boolean_col BOOLEAN, " + "date_col DATE, " + "datetime_col TIMESTAMP, " + "datetime_col_ntz TIMESTAMP_NTZ, " + "time_col STRING, " + "uuid_col STRING, " + "variant_col VARIANT" + ") USING DELTA" + ) + ) + + df.to_sql( + table_name, + db_engine, + schema="default", + if_exists="append", + index=False, + method="multi", + dtype={"uuid_col": Uuid(), "variant_col": DatabricksVariant()}, + ) + + with db_engine.begin() as conn: + result = conn.execute( + text( + f"SELECT bigint_col, string_col, tinyint_col, int_col, " + f"numeric_col, boolean_col, date_col, datetime_col, " + f"datetime_col_ntz, time_col, uuid_col, TO_JSON(variant_col) " + f"FROM {fq_table_name} ORDER BY bigint_col" + ) + ).fetchall() + + assert len(result) == 2 + assert result[0][0] == rows[0]["bigint_col"] + assert result[0][1] == rows[0]["string_col"] + assert result[0][2] == rows[0]["tinyint_col"] + assert result[0][3] == rows[0]["int_col"] + assert result[0][4] == rows[0]["numeric_col"] + assert result[0][5] is rows[0]["boolean_col"] + assert result[0][6] == rows[0]["date_col"] + assert result[0][8] == rows[0]["datetime_col_ntz"] + assert result[0][9] == "23:59:59" + assert result[0][10] == str(rows[0]["uuid_col"]) + assert json.loads(result[0][11]) == rows[0]["variant_col"] + + assert result[1][0] == rows[1]["bigint_col"] + assert result[1][1] == rows[1]["string_col"] + assert result[1][2] == rows[1]["tinyint_col"] + assert result[1][3] == rows[1]["int_col"] + assert result[1][4] == rows[1]["numeric_col"] + assert result[1][5] is rows[1]["boolean_col"] + assert result[1][6] == rows[1]["date_col"] + assert result[1][8] == rows[1]["datetime_col_ntz"] + assert result[1][9] == "01:02:03" + assert result[1][10] == str(rows[1]["uuid_col"]) + assert json.loads(result[1][11]) == rows[1]["variant_col"] + finally: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) + + +def test_pandas_to_sql_multi_mixed_scalar_families_cast_to_string(db_engine: Engine): + table_name = f"pecoblr_2746_scalar_families_{uuid.uuid4().hex[:8]}" + fq_table_name = f"`main`.`default`.`{table_name}`" + df = pd.DataFrame( + { + "number_value": [1, "one"], + "decimal_value": [Decimal("1.25"), "one point two five"], + "boolean_value": [True, "true"], + "date_value": [date(2020, 12, 25), "christmas"], + "datetime_value": [datetime(1990, 12, 4, 6, 33, 41), "datetime"], + "uuid_value": [UUID(int=255), str(UUID(int=256))], + } + ) + + try: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) + conn.execute( + text( + f"CREATE TABLE {fq_table_name} (" + "number_value STRING, " + "decimal_value STRING, " + "boolean_value STRING, " + "date_value STRING, " + "datetime_value STRING, " + "uuid_value STRING" + ") USING DELTA" + ) + ) + + df.to_sql( + table_name, + db_engine, + schema="default", + if_exists="append", + index=False, + method="multi", + dtype={"uuid_value": Uuid()}, + ) + + with db_engine.begin() as conn: + rows = conn.execute( + text( + f"SELECT number_value, decimal_value, boolean_value, date_value, " + f"datetime_value, uuid_value FROM {fq_table_name} " + "ORDER BY number_value" + ) + ).fetchall() + + assert len(rows) == 2 + assert rows[0][0] == "1" + assert rows[0][1] == "1.25" + assert rows[0][2].lower() == "true" + assert rows[0][3] == "2020-12-25" + assert rows[0][5] == str(UUID(int=255)) + assert rows[1] == ( + "one", + "one point two five", + "true", + "christmas", + "datetime", + str(UUID(int=256)), + ) + finally: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index c3fae18..9cc81c7 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -1,5 +1,14 @@ import pytest -from sqlalchemy import Column, MetaData, String, Table, Numeric, Integer, create_engine, insert +from sqlalchemy import ( + Column, + MetaData, + String, + Table, + Numeric, + Integer, + create_engine, + insert, +) from sqlalchemy.schema import ( CreateTable, DropColumnComment, @@ -203,7 +212,6 @@ def test_plain_identifier_bind_names_are_also_backticked(self): assert ":`id`" in sql assert ":`name`" in sql - def test_leading_digit_column_is_backticked(self): """Databricks bind names cannot start with a digit bare.""" metadata = MetaData() @@ -345,7 +353,9 @@ def test_unicode_column_names(self): def test_sql_reserved_word_as_column_name(self): """Reserved words used as column names must work as bind params too.""" metadata = MetaData() - table = Table("t", metadata, Column("select", String()), Column("from", String())) + table = Table( + "t", metadata, Column("select", String()), Column("from", String()) + ) compiled = self._compile_insert(table, {"select": "s", "from": "f"}) sql = str(compiled) assert ":`select`" in sql @@ -419,9 +429,59 @@ def test_in_clause_expansion_renders_backticked_markers(self): # (2) construct_expanded_state at execute time compiled = stmt.compile(bind=self.engine) - expanded = compiled.construct_expanded_state( - {"col-name_1": ["a", "b", "c"]} - ) + expanded = compiled.construct_expanded_state({"col-name_1": ["a", "b", "c"]}) assert ":`col-name_1_1`" in expanded.statement assert ":`col-name_1_2`" in expanded.statement assert ":`col-name_1_3`" in expanded.statement + + +class TestMultiRowInsertCasts(DDLTestBase): + def test_multi_values_casts_mixed_type_column(self): + metadata = MetaData() + table = Table( + "t", metadata, Column("name", String()), Column("value", String()) + ) + stmt = insert(table).values( + [ + {"name": "alice", "value": 1}, + {"name": "bob", "value": 0}, + {"name": None, "value": "NE"}, + ] + ) + + sql = str(stmt.compile(bind=self.engine)) + + assert "CAST(:`value_m0` AS STRING)" in sql + assert "CAST(:`value_m1` AS STRING)" in sql + assert "CAST(:`value_m2` AS STRING)" in sql + assert "CAST(:`name_m0` AS STRING)" not in sql + assert "CAST(:`name_m1` AS STRING)" not in sql + assert "CAST(:`name_m2` AS STRING)" not in sql + + def test_homogeneous_multi_values_are_not_cast(self): + metadata = MetaData() + table = Table("t", metadata, Column("value", String())) + stmt = insert(table).values([{"value": "A"}, {"value": "B"}, {"value": "C"}]) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`value_m0` AS STRING)" not in sql + assert "CAST(:`value_m1` AS STRING)" not in sql + assert "CAST(:`value_m2` AS STRING)" not in sql + + def test_numeric_family_multi_values_are_not_cast(self): + metadata = MetaData() + table = Table("t", metadata, Column("score", Numeric())) + stmt = insert(table).values([{"score": 1}, {"score": 2.5}, {"score": 3}]) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`score_m0` AS DECIMAL)" not in sql + assert "CAST(:`score_m1` AS DECIMAL)" not in sql + assert "CAST(:`score_m2` AS DECIMAL)" not in sql + + def test_single_row_insert_does_not_render_casts(self): + metadata = MetaData() + table = Table("t", metadata, Column("value", String())) + stmt = insert(table).values({"value": "A"}) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`value` AS STRING)" not in sql From aa0d69b061f7e307f18cb66e40e2a24722b602e6 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 2 Jun 2026 02:11:30 +0530 Subject: [PATCH 2/3] Add opt-out gate for multi-row insert casts. Allow users to disable targeted multi-row insert cast rendering with enable_multirow_insert_casts=false in the Databricks SQLAlchemy engine URL while keeping the PECOBLR-2746 fix enabled by default. Signed-off-by: Madhavendra Rathore --- src/databricks/sqlalchemy/_ddl.py | 3 +++ src/databricks/sqlalchemy/base.py | 10 ++++++++++ tests/test_local/test_ddl.py | 16 ++++++++++++++++ 3 files changed, 29 insertions(+) diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index f3dd0cf..f8f90f3 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -223,6 +223,9 @@ def _build_multi_value_cast_plan(self, insert_stmt): Spark inline-table incompatibility for object columns that mix primitive families (e.g. INT + STRING). """ + if not self.dialect.enable_multirow_insert_casts: + return {} + if not getattr(insert_stmt, "_multi_values", None): return {} diff --git a/src/databricks/sqlalchemy/base.py b/src/databricks/sqlalchemy/base.py index 0de2326..e608996 100644 --- a/src/databricks/sqlalchemy/base.py +++ b/src/databricks/sqlalchemy/base.py @@ -42,6 +42,12 @@ class DatabricksImpl(DefaultImpl): logger = logging.getLogger(__name__) +def _parse_bool_url_param(value: Optional[str], default: bool) -> bool: + if value is None: + return default + return value.lower() not in ("0", "false", "no", "off") + + class DatabricksDialect(default.DefaultDialect): """This dialect implements only those methods required to pass our e2e tests""" @@ -65,6 +71,7 @@ class DatabricksDialect(default.DefaultDialect): supports_server_side_cursors: bool = False supports_sequences: bool = False supports_native_boolean: bool = True + enable_multirow_insert_casts: bool = True colspecs = { sqlalchemy.types.DateTime: dialect_type_impl.TIMESTAMP_NTZ, @@ -117,6 +124,9 @@ def create_connect_args(self, url): self.schema = kwargs["schema"] self.catalog = kwargs["catalog"] + self.enable_multirow_insert_casts = _parse_bool_url_param( + url.query.get("enable_multirow_insert_casts"), True + ) self._force_paramstyle_to_native_mode() diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index 9cc81c7..d648504 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -458,6 +458,22 @@ def test_multi_values_casts_mixed_type_column(self): assert "CAST(:`name_m1` AS STRING)" not in sql assert "CAST(:`name_m2` AS STRING)" not in sql + def test_multi_value_casts_can_be_disabled_by_url_param(self): + engine = create_engine( + "databricks://token:****@****" + "?http_path=****&catalog=****&schema=****" + "&enable_multirow_insert_casts=false" + ) + metadata = MetaData() + table = Table("t", metadata, Column("value", String())) + stmt = insert(table).values([{"value": 1}, {"value": 0}, {"value": "NE"}]) + + sql = str(stmt.compile(bind=engine)) + assert "CAST(:`value_m0` AS STRING)" not in sql + assert ":`value_m0`" in sql + assert ":`value_m1`" in sql + assert ":`value_m2`" in sql + def test_homogeneous_multi_values_are_not_cast(self): metadata = MetaData() table = Table("t", metadata, Column("value", String())) From c022318ed61c6178ec138cc8bb7517dc711bc950 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Tue, 2 Jun 2026 02:17:07 +0530 Subject: [PATCH 3/3] Document SQLAlchemy connection URL parameters. Add README guidance for dialect URL parameters, DBAPI connect_args, and the enable_multirow_insert_casts opt-out flag. Signed-off-by: Madhavendra Rathore --- README.md | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/README.md b/README.md index 4c442ad..ca563f9 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,42 @@ engine = create_engine( ) ``` +### Connection URL parameters and `connect_args` + +The Databricks SQLAlchemy dialect accepts dialect-specific options in the +SQLAlchemy connection URL query string: + +| Parameter | Required | Default | Description | +|-|-|-|-| +| `http_path` | Yes | | HTTP path for the Databricks SQL warehouse or compute resource. | +| `catalog` | Yes | | Initial catalog for the connection. | +| `schema` | Yes | | Initial schema for the connection. | +| `enable_multirow_insert_casts` | No | `true` | Enables targeted casts for mixed scalar values in SQLAlchemy-generated multi-row `INSERT ... VALUES` statements. This avoids Spark inline-table type errors for pandas `to_sql(method="multi")` with mixed scalar/object columns. Set to `false` to disable this rewrite. | + +For example, to disable targeted multi-row insert casts: + +```python +engine = create_engine( + "databricks://token:dapi***@***.cloud.databricks.com" + "?http_path=***&catalog=main&schema=test" + "&enable_multirow_insert_casts=false" +) +``` + +Use SQLAlchemy's `connect_args` for DBAPI connection options that should be +passed through to `databricks-sql-connector`, such as user-agent settings: + +```python +engine = create_engine( + "databricks://token:dapi***@***.cloud.databricks.com" + "?http_path=***&catalog=main&schema=test", + connect_args={"user_agent_entry": "My SQLAlchemy App"}, +) +``` + +Dialect URL parameters control SQLAlchemy compilation behavior and are not +forwarded to the DBAPI connector. + ## Types The [SQLAlchemy type hierarchy](https://docs.sqlalchemy.org/en/20/core/type_basics.html) contains backend-agnostic type implementations (represented in CamelCase) and backend-specific types (represented in UPPERCASE). The majority of SQLAlchemy's [CamelCase](https://docs.sqlalchemy.org/en/20/core/type_basics.html#the-camelcase-datatypes) types are supported. This means that a SQLAlchemy application using these types should "just work" with Databricks.