From ffe0ddd99b44e67986d694eb2199edf724bbb240 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Fri, 29 May 2026 16:55:01 +0530 Subject: [PATCH 1/2] Fix mixed-type pandas multi-row inserts with adaptive bind casts. Target cast rendering only for heterogeneous INSERT ... VALUES multi-row bind groups to prevent Spark inline-table type incompatibility, and add compile-time plus e2e tests to validate behavior in main.default. --- src/databricks/sqlalchemy/_ddl.py | 86 +++++++++++++++++++ .../e2e/test_pandas_multi_mixed_types.py | 83 ++++++++++++++++++ tests/test_local/test_ddl.py | 47 ++++++++++ 3 files changed, 216 insertions(+) create mode 100644 tests/test_local/e2e/test_pandas_multi_mixed_types.py diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index f61673b..659415f 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -1,4 +1,5 @@ import re +from numbers import Number from sqlalchemy.sql import compiler, sqltypes import logging @@ -165,6 +166,91 @@ def bindparam_string(self, name, **kw): return self._BIND_TEMPLATE % {"name": name.replace("`", "``")} return super().bindparam_string(name, **kw) + @staticmethod + def _value_family(value): + """Return a coarse runtime family for adaptive multi-row cast decisions.""" + if value is None: + return "null" + if isinstance(value, bool): + return "bool" + if isinstance(value, str): + return "string" + if isinstance(value, (bytes, bytearray, memoryview)): + return "binary" + if isinstance(value, Number): + return "number" + return "other" + + @staticmethod + def _split_multivalue_bind_name(bind_name): + """Split SQLAlchemy's ``_m`` bind names into (column, idx).""" + match = re.match(r"^(?P.+)_m(?P\d+)$", bind_name) + if not match: + return None + return match.group("col"), int(match.group("idx")) + + def _build_adaptive_cast_plan(self): + """Return {bind_name: cast_sql_type} for risky multi-row value groups. + + We only target SQLAlchemy-generated multi-row binds (``*_mN``). For + each logical column we inspect row values available at compile time and + cast only when families are heterogeneous in a way that commonly causes + Spark inline-table incompatibility (e.g., number + string). + """ + column_bind_names = {} + for bind_name, bind_param in self.binds.items(): + split = self._split_multivalue_bind_name(bind_name) + if split is None: + continue + column_name, _ = split + column_bind_names.setdefault(column_name, []).append((bind_name, bind_param)) + + cast_plan = {} + for bind_entries in column_bind_names.values(): + families = set() + for _, bind_param in bind_entries: + value = getattr(bind_param, "value", None) + family = self._value_family(value) + if family != "null": + families.add(family) + + if len(families) <= 1: + continue + + # Numeric + numeric is safe for Spark inline tables and does not + # need explicit casting. + if families == {"number"}: + continue + + for bind_name, bind_param in bind_entries: + type_engine = getattr(bind_param, "type", None) + if type_engine is None or isinstance(type_engine, sqltypes.NullType): + continue + + dialect_type = type_engine._unwrapped_dialect_impl(self.dialect) + target_type = self.dialect.type_compiler_instance.process( + dialect_type, identifier_preparer=self.preparer + ) + cast_plan[bind_name] = target_type + + return cast_plan + + def _apply_adaptive_multi_value_casts(self, sql_text): + """Wrap selected ``:`name``` markers with ``CAST(... AS )``.""" + cast_plan = self._build_adaptive_cast_plan() + if not cast_plan: + return sql_text + + rendered = sql_text + for bind_name, target_type in cast_plan.items(): + marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")} + rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})") + return rendered + + def visit_insert(self, insert_stmt, **kw): + sql_text = super().visit_insert(insert_stmt, **kw) + return self._apply_adaptive_multi_value_casts(sql_text) + def limit_clause(self, select, **kw): """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1, since Databricks SQL doesn't support the latter. diff --git a/tests/test_local/e2e/test_pandas_multi_mixed_types.py b/tests/test_local/e2e/test_pandas_multi_mixed_types.py new file mode 100644 index 0000000..4d01227 --- /dev/null +++ b/tests/test_local/e2e/test_pandas_multi_mixed_types.py @@ -0,0 +1,83 @@ +import uuid + +import pandas as pd +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine + + +@pytest.fixture +def db_engine(connection_details) -> Engine: + host = connection_details["host"] + http_path = connection_details["http_path"] + access_token = connection_details["access_token"] + catalog = connection_details["catalog"] + schema = connection_details["schema"] + + conn_string = ( + f"databricks://token:{access_token}@{host}" + f"?http_path={http_path}&catalog={catalog}&schema={schema}" + ) + engine = create_engine( + conn_string, connect_args={"_user_agent_entry": "SQLAlchemy pandas e2e tests"} + ) + try: + yield engine + finally: + engine.dispose() + + +def test_pandas_to_sql_multi_mixed_object_column_succeeds(db_engine: Engine): + table_name = f"pecoblr_2746_e2e_{uuid.uuid4().hex[:8]}" + fq_table_name = f"`main`.`default`.`{table_name}`" + df = pd.DataFrame( + { + "name": ["alice", "bob", None], + "value": [1, 0, "NE"], + "score": [9.5, 8.1, None], + "active": [True, None, False], + } + ) + + try: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) + conn.execute( + text( + f"CREATE TABLE {fq_table_name} " + "(name STRING, value STRING, score DOUBLE, active BOOLEAN) " + "USING DELTA" + ) + ) + + # This is the failing path from PECOBLR-2746 before the adaptive cast fix. + df.to_sql( + table_name, db_engine, schema="default", if_exists="append", index=False, method="multi" + ) + + with db_engine.begin() as conn: + rows = conn.execute( + text( + f"SELECT name, value, score, active FROM {fq_table_name} " + "ORDER BY CASE WHEN name IS NULL THEN 1 ELSE 0 END, name" + ) + ).fetchall() + + assert len(rows) == 3 + assert rows[0][0] == "alice" + assert rows[0][1] == "1" + assert rows[0][2] == pytest.approx(9.5) + assert rows[0][3] is True + + assert rows[1][0] == "bob" + assert rows[1][1] == "0" + assert rows[1][2] == pytest.approx(8.1) + assert rows[1][3] is None + + assert rows[2][0] is None + assert rows[2][1] == "NE" + assert rows[2][2] is None + assert rows[2][3] is False + finally: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index c3fae18..1231955 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -425,3 +425,50 @@ def test_in_clause_expansion_renders_backticked_markers(self): assert ":`col-name_1_1`" in expanded.statement assert ":`col-name_1_2`" in expanded.statement assert ":`col-name_1_3`" in expanded.statement + + +class TestAdaptiveMultiRowInsertCasts(DDLTestBase): + def test_mixed_runtime_families_in_multi_values_are_cast(self): + metadata = MetaData() + table = Table("t", metadata, Column("name", String()), Column("value", String())) + stmt = insert(table).values( + [ + {"name": "alice", "value": 1}, + {"name": "bob", "value": 0}, + {"name": None, "value": "NE"}, + ] + ) + + sql = str(stmt.compile(bind=self.engine)) + + assert "CAST(:`value_m0` AS STRING)" in sql + assert "CAST(:`value_m1` AS STRING)" in sql + assert "CAST(:`value_m2` AS STRING)" in sql + # Name values are already all string/null and should remain untouched. + assert "CAST(:`name_m0` AS STRING)" not in sql + assert "CAST(:`name_m1` AS STRING)" not in sql + assert "CAST(:`name_m2` AS STRING)" not in sql + + def test_homogeneous_multi_values_are_not_cast(self): + metadata = MetaData() + table = Table("t", metadata, Column("value", String())) + stmt = insert(table).values( + [{"value": "A"}, {"value": "B"}, {"value": "C"}] + ) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`value_m0` AS STRING)" not in sql + assert "CAST(:`value_m1` AS STRING)" not in sql + assert "CAST(:`value_m2` AS STRING)" not in sql + + def test_numeric_family_multi_values_are_not_cast(self): + metadata = MetaData() + table = Table("t", metadata, Column("score", Numeric())) + stmt = insert(table).values( + [{"score": 1}, {"score": 2.5}, {"score": 3}] + ) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`score_m0` AS DECIMAL)" not in sql + assert "CAST(:`score_m1` AS DECIMAL)" not in sql + assert "CAST(:`score_m2` AS DECIMAL)" not in sql From 443ea63a3aeb5f04e2ce4d591ddecd54f92a5fb8 Mon Sep 17 00:00:00 2001 From: Madhavendra Rathore Date: Fri, 29 May 2026 17:36:06 +0530 Subject: [PATCH 2/2] Harden multi-row insert casting to deterministic typed behavior. Replace runtime-value adaptive logic with deterministic casting for SQLAlchemy multi-row VALUES bind markers only, and update compiler tests to verify multi-row casts and single-row non-cast behavior. --- src/databricks/sqlalchemy/_ddl.py | 74 +++++++++---------------------- tests/test_local/test_ddl.py | 35 +++++++++------ 2 files changed, 41 insertions(+), 68 deletions(-) diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index 659415f..9d88d7a 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -1,5 +1,4 @@ import re -from numbers import Number from sqlalchemy.sql import compiler, sqltypes import logging @@ -166,21 +165,6 @@ def bindparam_string(self, name, **kw): return self._BIND_TEMPLATE % {"name": name.replace("`", "``")} return super().bindparam_string(name, **kw) - @staticmethod - def _value_family(value): - """Return a coarse runtime family for adaptive multi-row cast decisions.""" - if value is None: - return "null" - if isinstance(value, bool): - return "bool" - if isinstance(value, str): - return "string" - if isinstance(value, (bytes, bytearray, memoryview)): - return "binary" - if isinstance(value, Number): - return "number" - return "other" - @staticmethod def _split_multivalue_bind_name(bind_name): """Split SQLAlchemy's ``_m`` bind names into (column, idx).""" @@ -189,55 +173,37 @@ def _split_multivalue_bind_name(bind_name): return None return match.group("col"), int(match.group("idx")) - def _build_adaptive_cast_plan(self): - """Return {bind_name: cast_sql_type} for risky multi-row value groups. + def _build_multi_value_cast_plan(self, insert_stmt): + """Return {bind_name: cast_sql_type} for multi-row VALUES insert binds. - We only target SQLAlchemy-generated multi-row binds (``*_mN``). For - each logical column we inspect row values available at compile time and - cast only when families are heterogeneous in a way that commonly causes - Spark inline-table incompatibility (e.g., number + string). + This is a deterministic fix for Spark inline-table type reconciliation: + for SQLAlchemy-generated multi-row INSERT binds (``*_mN``), always cast + the marker to the bind's dialect SQL type so each column position in the + VALUES table has an explicit server-side type. """ - column_bind_names = {} - for bind_name, bind_param in self.binds.items(): - split = self._split_multivalue_bind_name(bind_name) - if split is None: - continue - column_name, _ = split - column_bind_names.setdefault(column_name, []).append((bind_name, bind_param)) + if not getattr(insert_stmt, "_multi_values", None): + return {} cast_plan = {} - for bind_entries in column_bind_names.values(): - families = set() - for _, bind_param in bind_entries: - value = getattr(bind_param, "value", None) - family = self._value_family(value) - if family != "null": - families.add(family) - - if len(families) <= 1: + for bind_name, bind_param in self.binds.items(): + if self._split_multivalue_bind_name(bind_name) is None: continue - # Numeric + numeric is safe for Spark inline tables and does not - # need explicit casting. - if families == {"number"}: + type_engine = getattr(bind_param, "type", None) + if type_engine is None or isinstance(type_engine, sqltypes.NullType): continue - for bind_name, bind_param in bind_entries: - type_engine = getattr(bind_param, "type", None) - if type_engine is None or isinstance(type_engine, sqltypes.NullType): - continue - - dialect_type = type_engine._unwrapped_dialect_impl(self.dialect) - target_type = self.dialect.type_compiler_instance.process( - dialect_type, identifier_preparer=self.preparer - ) - cast_plan[bind_name] = target_type + dialect_type = type_engine._unwrapped_dialect_impl(self.dialect) + target_type = self.dialect.type_compiler_instance.process( + dialect_type, identifier_preparer=self.preparer + ) + cast_plan[bind_name] = target_type return cast_plan - def _apply_adaptive_multi_value_casts(self, sql_text): + def _apply_multi_value_casts(self, sql_text, insert_stmt): """Wrap selected ``:`name``` markers with ``CAST(... AS )``.""" - cast_plan = self._build_adaptive_cast_plan() + cast_plan = self._build_multi_value_cast_plan(insert_stmt) if not cast_plan: return sql_text @@ -249,7 +215,7 @@ def _apply_adaptive_multi_value_casts(self, sql_text): def visit_insert(self, insert_stmt, **kw): sql_text = super().visit_insert(insert_stmt, **kw) - return self._apply_adaptive_multi_value_casts(sql_text) + return self._apply_multi_value_casts(sql_text, insert_stmt) def limit_clause(self, select, **kw): """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1, diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index 1231955..dcd2526 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -427,8 +427,8 @@ def test_in_clause_expansion_renders_backticked_markers(self): assert ":`col-name_1_3`" in expanded.statement -class TestAdaptiveMultiRowInsertCasts(DDLTestBase): - def test_mixed_runtime_families_in_multi_values_are_cast(self): +class TestMultiRowInsertCasts(DDLTestBase): + def test_multi_values_casts_mixed_type_column(self): metadata = MetaData() table = Table("t", metadata, Column("name", String()), Column("value", String())) stmt = insert(table).values( @@ -444,12 +444,11 @@ def test_mixed_runtime_families_in_multi_values_are_cast(self): assert "CAST(:`value_m0` AS STRING)" in sql assert "CAST(:`value_m1` AS STRING)" in sql assert "CAST(:`value_m2` AS STRING)" in sql - # Name values are already all string/null and should remain untouched. - assert "CAST(:`name_m0` AS STRING)" not in sql - assert "CAST(:`name_m1` AS STRING)" not in sql - assert "CAST(:`name_m2` AS STRING)" not in sql + assert "CAST(:`name_m0` AS STRING)" in sql + assert "CAST(:`name_m1` AS STRING)" in sql + assert "CAST(:`name_m2` AS STRING)" in sql - def test_homogeneous_multi_values_are_not_cast(self): + def test_homogeneous_multi_values_are_cast(self): metadata = MetaData() table = Table("t", metadata, Column("value", String())) stmt = insert(table).values( @@ -457,11 +456,11 @@ def test_homogeneous_multi_values_are_not_cast(self): ) sql = str(stmt.compile(bind=self.engine)) - assert "CAST(:`value_m0` AS STRING)" not in sql - assert "CAST(:`value_m1` AS STRING)" not in sql - assert "CAST(:`value_m2` AS STRING)" not in sql + assert "CAST(:`value_m0` AS STRING)" in sql + assert "CAST(:`value_m1` AS STRING)" in sql + assert "CAST(:`value_m2` AS STRING)" in sql - def test_numeric_family_multi_values_are_not_cast(self): + def test_numeric_family_multi_values_are_cast(self): metadata = MetaData() table = Table("t", metadata, Column("score", Numeric())) stmt = insert(table).values( @@ -469,6 +468,14 @@ def test_numeric_family_multi_values_are_not_cast(self): ) sql = str(stmt.compile(bind=self.engine)) - assert "CAST(:`score_m0` AS DECIMAL)" not in sql - assert "CAST(:`score_m1` AS DECIMAL)" not in sql - assert "CAST(:`score_m2` AS DECIMAL)" not in sql + assert "CAST(:`score_m0` AS DECIMAL)" in sql + assert "CAST(:`score_m1` AS DECIMAL)" in sql + assert "CAST(:`score_m2` AS DECIMAL)" in sql + + def test_single_row_insert_does_not_render_casts(self): + metadata = MetaData() + table = Table("t", metadata, Column("value", String())) + stmt = insert(table).values({"value": "A"}) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`value` AS STRING)" not in sql