diff --git a/src/databricks/sqlalchemy/_ddl.py b/src/databricks/sqlalchemy/_ddl.py index f61673b..9d88d7a 100644 --- a/src/databricks/sqlalchemy/_ddl.py +++ b/src/databricks/sqlalchemy/_ddl.py @@ -165,6 +165,58 @@ def bindparam_string(self, name, **kw): return self._BIND_TEMPLATE % {"name": name.replace("`", "``")} return super().bindparam_string(name, **kw) + @staticmethod + def _split_multivalue_bind_name(bind_name): + """Split SQLAlchemy's ``_m`` bind names into (column, idx).""" + match = re.match(r"^(?P.+)_m(?P\d+)$", bind_name) + if not match: + return None + return match.group("col"), int(match.group("idx")) + + def _build_multi_value_cast_plan(self, insert_stmt): + """Return {bind_name: cast_sql_type} for multi-row VALUES insert binds. + + This is a deterministic fix for Spark inline-table type reconciliation: + for SQLAlchemy-generated multi-row INSERT binds (``*_mN``), always cast + the marker to the bind's dialect SQL type so each column position in the + VALUES table has an explicit server-side type. + """ + if not getattr(insert_stmt, "_multi_values", None): + return {} + + cast_plan = {} + for bind_name, bind_param in self.binds.items(): + if self._split_multivalue_bind_name(bind_name) is None: + continue + + type_engine = getattr(bind_param, "type", None) + if type_engine is None or isinstance(type_engine, sqltypes.NullType): + continue + + dialect_type = type_engine._unwrapped_dialect_impl(self.dialect) + target_type = self.dialect.type_compiler_instance.process( + dialect_type, identifier_preparer=self.preparer + ) + cast_plan[bind_name] = target_type + + return cast_plan + + def _apply_multi_value_casts(self, sql_text, insert_stmt): + """Wrap selected ``:`name``` markers with ``CAST(... AS )``.""" + cast_plan = self._build_multi_value_cast_plan(insert_stmt) + if not cast_plan: + return sql_text + + rendered = sql_text + for bind_name, target_type in cast_plan.items(): + marker = self._BIND_TEMPLATE % {"name": bind_name.replace("`", "``")} + rendered = rendered.replace(marker, f"CAST({marker} AS {target_type})") + return rendered + + def visit_insert(self, insert_stmt, **kw): + sql_text = super().visit_insert(insert_stmt, **kw) + return self._apply_multi_value_casts(sql_text, insert_stmt) + def limit_clause(self, select, **kw): """Identical to the default implementation of SQLCompiler.limit_clause except it writes LIMIT ALL instead of LIMIT -1, since Databricks SQL doesn't support the latter. diff --git a/tests/test_local/e2e/test_pandas_multi_mixed_types.py b/tests/test_local/e2e/test_pandas_multi_mixed_types.py new file mode 100644 index 0000000..4d01227 --- /dev/null +++ b/tests/test_local/e2e/test_pandas_multi_mixed_types.py @@ -0,0 +1,83 @@ +import uuid + +import pandas as pd +import pytest +from sqlalchemy import create_engine, text +from sqlalchemy.engine import Engine + + +@pytest.fixture +def db_engine(connection_details) -> Engine: + host = connection_details["host"] + http_path = connection_details["http_path"] + access_token = connection_details["access_token"] + catalog = connection_details["catalog"] + schema = connection_details["schema"] + + conn_string = ( + f"databricks://token:{access_token}@{host}" + f"?http_path={http_path}&catalog={catalog}&schema={schema}" + ) + engine = create_engine( + conn_string, connect_args={"_user_agent_entry": "SQLAlchemy pandas e2e tests"} + ) + try: + yield engine + finally: + engine.dispose() + + +def test_pandas_to_sql_multi_mixed_object_column_succeeds(db_engine: Engine): + table_name = f"pecoblr_2746_e2e_{uuid.uuid4().hex[:8]}" + fq_table_name = f"`main`.`default`.`{table_name}`" + df = pd.DataFrame( + { + "name": ["alice", "bob", None], + "value": [1, 0, "NE"], + "score": [9.5, 8.1, None], + "active": [True, None, False], + } + ) + + try: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) + conn.execute( + text( + f"CREATE TABLE {fq_table_name} " + "(name STRING, value STRING, score DOUBLE, active BOOLEAN) " + "USING DELTA" + ) + ) + + # This is the failing path from PECOBLR-2746 before the adaptive cast fix. + df.to_sql( + table_name, db_engine, schema="default", if_exists="append", index=False, method="multi" + ) + + with db_engine.begin() as conn: + rows = conn.execute( + text( + f"SELECT name, value, score, active FROM {fq_table_name} " + "ORDER BY CASE WHEN name IS NULL THEN 1 ELSE 0 END, name" + ) + ).fetchall() + + assert len(rows) == 3 + assert rows[0][0] == "alice" + assert rows[0][1] == "1" + assert rows[0][2] == pytest.approx(9.5) + assert rows[0][3] is True + + assert rows[1][0] == "bob" + assert rows[1][1] == "0" + assert rows[1][2] == pytest.approx(8.1) + assert rows[1][3] is None + + assert rows[2][0] is None + assert rows[2][1] == "NE" + assert rows[2][2] is None + assert rows[2][3] is False + finally: + with db_engine.begin() as conn: + conn.execute(text(f"DROP TABLE IF EXISTS {fq_table_name}")) diff --git a/tests/test_local/test_ddl.py b/tests/test_local/test_ddl.py index c3fae18..dcd2526 100644 --- a/tests/test_local/test_ddl.py +++ b/tests/test_local/test_ddl.py @@ -425,3 +425,57 @@ def test_in_clause_expansion_renders_backticked_markers(self): assert ":`col-name_1_1`" in expanded.statement assert ":`col-name_1_2`" in expanded.statement assert ":`col-name_1_3`" in expanded.statement + + +class TestMultiRowInsertCasts(DDLTestBase): + def test_multi_values_casts_mixed_type_column(self): + metadata = MetaData() + table = Table("t", metadata, Column("name", String()), Column("value", String())) + stmt = insert(table).values( + [ + {"name": "alice", "value": 1}, + {"name": "bob", "value": 0}, + {"name": None, "value": "NE"}, + ] + ) + + sql = str(stmt.compile(bind=self.engine)) + + assert "CAST(:`value_m0` AS STRING)" in sql + assert "CAST(:`value_m1` AS STRING)" in sql + assert "CAST(:`value_m2` AS STRING)" in sql + assert "CAST(:`name_m0` AS STRING)" in sql + assert "CAST(:`name_m1` AS STRING)" in sql + assert "CAST(:`name_m2` AS STRING)" in sql + + def test_homogeneous_multi_values_are_cast(self): + metadata = MetaData() + table = Table("t", metadata, Column("value", String())) + stmt = insert(table).values( + [{"value": "A"}, {"value": "B"}, {"value": "C"}] + ) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`value_m0` AS STRING)" in sql + assert "CAST(:`value_m1` AS STRING)" in sql + assert "CAST(:`value_m2` AS STRING)" in sql + + def test_numeric_family_multi_values_are_cast(self): + metadata = MetaData() + table = Table("t", metadata, Column("score", Numeric())) + stmt = insert(table).values( + [{"score": 1}, {"score": 2.5}, {"score": 3}] + ) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`score_m0` AS DECIMAL)" in sql + assert "CAST(:`score_m1` AS DECIMAL)" in sql + assert "CAST(:`score_m2` AS DECIMAL)" in sql + + def test_single_row_insert_does_not_render_casts(self): + metadata = MetaData() + table = Table("t", metadata, Column("value", String())) + stmt = insert(table).values({"value": "A"}) + + sql = str(stmt.compile(bind=self.engine)) + assert "CAST(:`value` AS STRING)" not in sql