From 0c517af3c6a255168b5a0ab914cbbffb6a2efe22 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Mon, 10 Nov 2025 19:39:28 +0100 Subject: [PATCH 01/36] use debug config --- mssql_python/pybind/build.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/build.sh b/mssql_python/pybind/build.sh index 811777285..38f141cff 100755 --- a/mssql_python/pybind/build.sh +++ b/mssql_python/pybind/build.sh @@ -87,8 +87,8 @@ if [ $? -ne 0 ]; then fi # Build the project -echo "[DIAGNOSTIC] Running CMake build with: cmake --build . --config Release" -cmake --build . --config Release +echo "[DIAGNOSTIC] Running CMake build with: cmake --build . --config Debug" +cmake --build . --config Debug # Check if build succeeded if [ $? -ne 0 ]; then From b5f3057f1d7316df6a6fdeb19ed994775ed770a6 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Mon, 10 Nov 2025 19:51:49 +0100 Subject: [PATCH 02/36] Add dummy fetch_arrow_batch function & generate dummy schema --- mssql_python/cursor.py | 12 +++ mssql_python/pybind/ddbc_bindings.cpp | 107 ++++++++++++++++++++++++++ 2 files changed, 119 insertions(+) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 2889f2ca8..03fd25633 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2195,6 +2195,17 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e + def fetch_arrow_batch(self) -> Any: + self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() + + capsules = [] + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules) + print(ret) + # assert ret is None, (ret, type(ret)) + return capsules + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. @@ -2374,6 +2385,7 @@ def __del__(self): """ if "closed" not in self.__dict__ or not self.closed: try: + assert self is not None self.close() except Exception as e: # pylint: disable=broad-exception-caught # Don't raise an exception in __del__, just log it diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 9a8280117..ccc71ebb3 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -147,6 +147,48 @@ struct NumericData { } }; +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + //------------------------------------------------------------------------------------------------- // Function pointer initialization //------------------------------------------------------------------------------------------------- @@ -3916,6 +3958,70 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } +void ArrowSchema_release(struct ArrowSchema* schema) { + if (schema->release) { + schema->release = nullptr; + } +} + +SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { + SQLRETURN ret; + SQLHSTMT hStmt = StatementHandle->get(); + // Retrieve column count + SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); + + // Retrieve column metadata + py::list columnNames; + ret = SQLDescribeCol_wrap(StatementHandle, columnNames); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to get column descriptions"); + return ret; + } + + std::vector lobColumns; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto colMeta = columnNames[i].cast(); + SQLSMALLINT dataType = colMeta["DataType"].cast(); + SQLULEN columnSize = colMeta["ColumnSize"].cast(); + + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { + lobColumns.push_back(i + 1); // 1-based + } + } + + assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); + + capsules.append(py::none()); + capsules.append(py::str("")); + + // vector of arrowschema children + auto children = new ArrowSchema* [1]; + auto arrow_schema = new ArrowSchema({ + .format = "l", + .name = "test_column", + .release = ArrowSchema_release, + }); + children[0] = arrow_schema; + auto arrow_schema_batch = new ArrowSchema({ + .format = "+s", + .name = "test_batch", + .n_children = 1, + .children = children, + .release = ArrowSchema_release, + }); + // auto caps = py::capsule((void*)arrow_schema, "arrow_schema", nullptr); + auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { + delete static_cast(ptr); + }); + capsules.append(caps); + + return 0; +} + + // FetchAll_wrap - Fetches all rows of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be @@ -4222,6 +4328,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), py::arg("fetchSize") = 1, "Fetch many rows from the result set"); m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); + m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, "Fetch an arrow batch of given length from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, From 5a6a09eb5a92f691d8494eb5dc3b224ea892f344 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 11 Nov 2025 19:04:31 +0100 Subject: [PATCH 03/36] schema names --- mssql_python/cursor.py | 11 +++++++++- mssql_python/pybind/ddbc_bindings.cpp | 30 ++++++++++++++------------- 2 files changed, 26 insertions(+), 15 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 03fd25633..d5e081edf 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2200,11 +2200,20 @@ def fetch_arrow_batch(self) -> Any: if not self._has_result_set and self.description: self._reset_rownumber() + try: + import pyarrow as pa + except ImportError as e: + raise ImportError( + "pyarrow is required for fetch_arrow_batch(). Please install pyarrow." + ) from e capsules = [] ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules) print(ret) + schema_capsule = capsules[0] + schema = pa.Schema._import_from_c_capsule(schema_capsule) + # assert ret is None, (ret, type(ret)) - return capsules + return schema, capsules def nextset(self) -> Union[bool, None]: """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index ccc71ebb3..090227f90 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3978,6 +3978,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) return ret; } + auto batch_children = new ArrowSchema* [numCols]; std::vector lobColumns; for (SQLSMALLINT i = 0; i < numCols; i++) { auto colMeta = columnNames[i].cast(); @@ -3990,26 +3991,27 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { lobColumns.push_back(i + 1); // 1-based } + + assert(dataType == SQL_INTEGER && "Only INTEGER type is supported in Arrow batch fetch for now"); + + std::string columnName = colMeta["ColumnName"].cast(); + char* name_copy = strdup(columnName.c_str()); + + auto arrow_schema = new ArrowSchema({ + .format = "i", + .name = name_copy, + .release = ArrowSchema_release, + }); + batch_children[i] = arrow_schema; } assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); - capsules.append(py::none()); - capsules.append(py::str("")); - - // vector of arrowschema children - auto children = new ArrowSchema* [1]; - auto arrow_schema = new ArrowSchema({ - .format = "l", - .name = "test_column", - .release = ArrowSchema_release, - }); - children[0] = arrow_schema; auto arrow_schema_batch = new ArrowSchema({ .format = "+s", - .name = "test_batch", - .n_children = 1, - .children = children, + .name = "", + .n_children = numCols, + .children = batch_children, .release = ArrowSchema_release, }); // auto caps = py::capsule((void*)arrow_schema, "arrow_schema", nullptr); From f4210933c3a9541dac1ec08c21d52a027d3c8ace Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 11 Nov 2025 21:21:43 +0100 Subject: [PATCH 04/36] inline py fetch --- mssql_python/pybind/ddbc_bindings.cpp | 295 +++++++++++++++++++++++++- 1 file changed, 294 insertions(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 090227f90..1dc2a3914 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3965,6 +3965,7 @@ void ArrowSchema_release(struct ArrowSchema* schema) { } SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { + ssize_t fetchSize = 500; SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -4020,7 +4021,299 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) }); capsules.append(caps); - return 0; + // Initialize column buffers + ColumnBuffers buffers(numCols, fetchSize); + + // Bind columns + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error when binding columns"); + return ret; + } + + SQLULEN numRowsFetched; + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); + + + + + ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); + if (ret == SQL_NO_DATA) { + LOG("No data to fetch"); + return ret; + } + if (!SQL_SUCCEEDED(ret)) { + LOG("Error while fetching rows in batches"); + return ret; + } + // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be populated by + // SQLFetchScroll + for (SQLULEN i = 0; i < numRowsFetched; i++) { + py::list row; + for (SQLUSMALLINT col = 1; col <= numCols; col++) { + auto columnMeta = columnNames[col - 1].cast(); + SQLSMALLINT dataType = columnMeta["DataType"].cast(); + SQLLEN dataLen = buffers.indicators[col - 1][i]; + + if (dataLen == SQL_NULL_DATA) { + row.append(py::none()); + continue; + } + // TODO: variable length data needs special handling, this logic wont suffice + // This value indicates that the driver cannot determine the length of the data + if (dataLen == SQL_NO_TOTAL) { + LOG("Cannot determine the length of the data. Returning NULL value instead." + "Column ID - {}", col); + row.append(py::none()); + continue; + } else if (dataLen == SQL_NULL_DATA) { + LOG("Column data is NULL. Appending None to the result row. Column ID - {}", col); + row.append(py::none()); + continue; + } else if (dataLen == 0) { + // Handle zero-length (non-NULL) data + if (dataType == SQL_CHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR) { + row.append(std::string("")); + } else if (dataType == SQL_WCHAR || dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR) { + row.append(std::wstring(L"")); + } else if (dataType == SQL_BINARY || dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) { + row.append(py::bytes("")); + } else { + // For other datatypes, 0 length is unexpected. Log & append None + LOG("Column data length is 0 for non-string/binary datatype. Appending None to the result row. Column ID - {}", col); + row.append(py::none()); + } + continue; + } else if (dataLen < 0) { + // Negative value is unexpected, log column index, SQL type & raise exception + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); + ThrowStdException("Unexpected negative data length, check logs for details"); + } + assert(dataLen > 0 && "Data length must be > 0"); + + switch (dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + HandleZeroColumnSizeAtFetch(columnSize); + uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; + uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' + if (!isLob && numCharsInData < fetchBufferSize) { + // SQLFetch will nullterminate the data + row.append(std::string( + reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), + numCharsInData)); + } else { + row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false)); + } + break; + } + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + // TODO: variable length data needs special handling, this logic wont suffice + SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + HandleZeroColumnSizeAtFetch(columnSize); + uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; + uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' + if (!isLob && numCharsInData < fetchBufferSize) { + // SQLFetch will nullterminate the data +#if defined(__APPLE__) || defined(__linux__) + // Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference + SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; + std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); + row.append(wstr); +#else + // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works + row.append(std::wstring( + reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), + numCharsInData)); +#endif + } else { + row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false)); + } + break; + } + case SQL_INTEGER: { + row.append(buffers.intBuffers[col - 1][i]); + break; + } + case SQL_SMALLINT: { + row.append(buffers.smallIntBuffers[col - 1][i]); + break; + } + case SQL_TINYINT: { + row.append(buffers.charBuffers[col - 1][i]); + break; + } + case SQL_BIT: { + row.append(static_cast(buffers.charBuffers[col - 1][i])); + break; + } + case SQL_REAL: { + row.append(buffers.realBuffers[col - 1][i]); + break; + } + case SQL_DECIMAL: + case SQL_NUMERIC: { + try { + // Convert the string to use the current decimal separator + std::string numStr(reinterpret_cast( + &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), + buffers.indicators[col - 1][i]); + + // Get the current separator in a thread-safe way + std::string separator = GetDecimalSeparator(); + + if (separator != ".") { + // Replace the driver's decimal point with our configured separator + size_t pos = numStr.find('.'); + if (pos != std::string::npos) { + numStr.replace(pos, 1, separator); + } + } + + // Convert to Python decimal + row.append(py::module_::import("decimal").attr("Decimal")(numStr)); + } catch (const py::error_already_set& e) { + // Handle the exception, e.g., log the error and append py::none() + LOG("Error converting to decimal: {}", e.what()); + row.append(py::none()); + } + break; + } + case SQL_DOUBLE: + case SQL_FLOAT: { + row.append(buffers.doubleBuffers[col - 1][i]); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + row.append(py::module_::import("datetime") + .attr("datetime")(buffers.timestampBuffers[col - 1][i].year, + buffers.timestampBuffers[col - 1][i].month, + buffers.timestampBuffers[col - 1][i].day, + buffers.timestampBuffers[col - 1][i].hour, + buffers.timestampBuffers[col - 1][i].minute, + buffers.timestampBuffers[col - 1][i].second, + buffers.timestampBuffers[col - 1][i].fraction / 1000 /* Convert back ns to µs */)); + break; + } + case SQL_BIGINT: { + row.append(buffers.bigIntBuffers[col - 1][i]); + break; + } + case SQL_TYPE_DATE: { + row.append(py::module_::import("datetime") + .attr("date")(buffers.dateBuffers[col - 1][i].year, + buffers.dateBuffers[col - 1][i].month, + buffers.dateBuffers[col - 1][i].day)); + break; + } + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + row.append(py::module_::import("datetime") + .attr("time")(buffers.timeBuffers[col - 1][i].hour, + buffers.timeBuffers[col - 1][i].minute, + buffers.timeBuffers[col - 1][i].second)); + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + SQLULEN rowIdx = i; + const DateTimeOffset& dtoValue = buffers.datetimeoffsetBuffers[col - 1][rowIdx]; + SQLLEN indicator = buffers.indicators[col - 1][rowIdx]; + if (indicator != SQL_NULL_DATA) { + int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute; + py::object datetime = py::module_::import("datetime"); + py::object tzinfo = datetime.attr("timezone")( + datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) + ); + py::object py_dt = datetime.attr("datetime")( + dtoValue.year, + dtoValue.month, + dtoValue.day, + dtoValue.hour, + dtoValue.minute, + dtoValue.second, + dtoValue.fraction / 1000, // ns → µs + tzinfo + ); + row.append(py_dt); + } else { + row.append(py::none()); + } + break; + } + case SQL_GUID: { + SQLLEN indicator = buffers.indicators[col - 1][i]; + if (indicator == SQL_NULL_DATA) { + row.append(py::none()); + break; + } + SQLGUID* guidValue = &buffers.guidBuffers[col - 1][i]; + uint8_t reordered[16]; + reordered[0] = ((char*)&guidValue->Data1)[3]; + reordered[1] = ((char*)&guidValue->Data1)[2]; + reordered[2] = ((char*)&guidValue->Data1)[1]; + reordered[3] = ((char*)&guidValue->Data1)[0]; + reordered[4] = ((char*)&guidValue->Data2)[1]; + reordered[5] = ((char*)&guidValue->Data2)[0]; + reordered[6] = ((char*)&guidValue->Data3)[1]; + reordered[7] = ((char*)&guidValue->Data3)[0]; + std::memcpy(reordered + 8, guidValue->Data4, 8); + + py::bytes py_guid_bytes(reinterpret_cast(reordered), 16); + py::dict kwargs; + kwargs["bytes"] = py_guid_bytes; + py::object uuid_obj = py::module_::import("uuid").attr("UUID")(**kwargs); + row.append(uuid_obj); + break; + } + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + HandleZeroColumnSizeAtFetch(columnSize); + bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); + if (!isLob && static_cast(dataLen) <= columnSize) { + row.append(py::bytes(reinterpret_cast( + &buffers.charBuffers[col - 1][i * columnSize]), + dataLen)); + } else { + row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true)); + } + break; + } + default: { + std::wstring columnName = columnMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << col; + LOG(errorString.str()); + ThrowStdException(errorString.str()); + break; + } + } + } + capsules.append(row); + } + + + + + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + + return ret; } From 2a8535d3ba5ad77fb272f2cb46a3eeefd8452612 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 11 Nov 2025 22:37:56 +0100 Subject: [PATCH 05/36] arrow int batch which returns bug nulls --- mssql_python/cursor.py | 7 +- mssql_python/pybind/ddbc_bindings.cpp | 318 ++++++++------------------ 2 files changed, 98 insertions(+), 227 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index d5e081edf..e5a1da7fa 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2210,10 +2210,9 @@ def fetch_arrow_batch(self) -> Any: ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules) print(ret) schema_capsule = capsules[0] - schema = pa.Schema._import_from_c_capsule(schema_capsule) - - # assert ret is None, (ret, type(ret)) - return schema, capsules + array_capsule = capsules[1] + batch = pa.RecordBatch._import_from_c_capsule(schema_capsule, array_capsule) + return batch, capsules def nextset(self) -> Union[bool, None]: """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 1dc2a3914..3f709de7a 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -147,6 +147,39 @@ struct NumericData { } }; +// Struct to hold data buffers and indicators for each column +struct ColumnBuffersArrow { + // std::vector> charBuffers; + // std::vector> wcharBuffers; + std::vector> intBuffers; + // std::vector> smallIntBuffers; + // std::vector> realBuffers; + // std::vector> doubleBuffers; + // std::vector> timestampBuffers; + // std::vector> bigIntBuffers; + // std::vector> dateBuffers; + // std::vector> timeBuffers; + // std::vector> guidBuffers; + std::vector> indicators; + // std::vector> datetimeoffsetBuffers; + + ColumnBuffersArrow(SQLSMALLINT numCols, int fetchSize) + : + // : charBuffers(numCols), + // wcharBuffers(numCols), + intBuffers(numCols), + // smallIntBuffers(numCols), + // realBuffers(numCols), + // doubleBuffers(numCols), + // timestampBuffers(numCols), + // bigIntBuffers(numCols), + // dateBuffers(numCols), + // timeBuffers(numCols), + // guidBuffers(numCols), + // datetimeoffsetBuffers(numCols), + indicators(numCols, std::vector(fetchSize)) {} +}; + #ifndef ARROW_C_DATA_INTERFACE #define ARROW_C_DATA_INTERFACE @@ -3964,6 +3997,12 @@ void ArrowSchema_release(struct ArrowSchema* schema) { } } +void ArrowArray_release(struct ArrowArray* array) { + if (array->release) { + array->release = nullptr; + } +} + SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { ssize_t fetchSize = 500; SQLRETURN ret; @@ -3981,6 +4020,8 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) auto batch_children = new ArrowSchema* [numCols]; std::vector lobColumns; + + ColumnBuffersArrow buffersArrow(numCols, fetchSize); for (SQLSMALLINT i = 0; i < numCols; i++) { auto colMeta = columnNames[i].cast(); SQLSMALLINT dataType = colMeta["DataType"].cast(); @@ -4004,6 +4045,20 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) .release = ArrowSchema_release, }); batch_children[i] = arrow_schema; + + switch(dataType) { + case SQL_INTEGER: + buffersArrow.intBuffers[i].resize(fetchSize); + break; + default: + std::wstring columnName = colMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << (i + 1); + LOG(errorString.str()); + ThrowStdException(errorString.str()); + break; + } } assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); @@ -4050,41 +4105,17 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be populated by // SQLFetchScroll for (SQLULEN i = 0; i < numRowsFetched; i++) { - py::list row; for (SQLUSMALLINT col = 1; col <= numCols; col++) { auto columnMeta = columnNames[col - 1].cast(); SQLSMALLINT dataType = columnMeta["DataType"].cast(); SQLLEN dataLen = buffers.indicators[col - 1][i]; - if (dataLen == SQL_NULL_DATA) { - row.append(py::none()); - continue; - } // TODO: variable length data needs special handling, this logic wont suffice // This value indicates that the driver cannot determine the length of the data if (dataLen == SQL_NO_TOTAL) { - LOG("Cannot determine the length of the data. Returning NULL value instead." - "Column ID - {}", col); - row.append(py::none()); - continue; + assert(false && "Is this actually possible?"); } else if (dataLen == SQL_NULL_DATA) { - LOG("Column data is NULL. Appending None to the result row. Column ID - {}", col); - row.append(py::none()); - continue; - } else if (dataLen == 0) { - // Handle zero-length (non-NULL) data - if (dataType == SQL_CHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR) { - row.append(std::string("")); - } else if (dataType == SQL_WCHAR || dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR) { - row.append(std::wstring(L"")); - } else if (dataType == SQL_BINARY || dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY) { - row.append(py::bytes("")); - } else { - // For other datatypes, 0 length is unexpected. Log & append None - LOG("Column data length is 0 for non-string/binary datatype. Appending None to the result row. Column ID - {}", col); - row.append(py::none()); - } - continue; + assert(false && "TODO"); } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); @@ -4093,203 +4124,8 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) assert(dataLen > 0 && "Data length must be > 0"); switch (dataType) { - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - uint64_t numCharsInData = dataLen / sizeof(SQLCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); - // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' - if (!isLob && numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data - row.append(std::string( - reinterpret_cast(&buffers.charBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); - } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_CHAR, false, false)); - } - break; - } - case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: { - // TODO: variable length data needs special handling, this logic wont suffice - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - HandleZeroColumnSizeAtFetch(columnSize); - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; - uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); - // fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<' - if (!isLob && numCharsInData < fetchBufferSize) { - // SQLFetch will nullterminate the data -#if defined(__APPLE__) || defined(__linux__) - // Use unix-specific conversion to handle the wchar_t/SQLWCHAR size difference - SQLWCHAR* wcharData = &buffers.wcharBuffers[col - 1][i * fetchBufferSize]; - std::wstring wstr = SQLWCHARToWString(wcharData, numCharsInData); - row.append(wstr); -#else - // On Windows, wchar_t and SQLWCHAR are both 2 bytes, so direct cast works - row.append(std::wstring( - reinterpret_cast(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]), - numCharsInData)); -#endif - } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_WCHAR, true, false)); - } - break; - } case SQL_INTEGER: { - row.append(buffers.intBuffers[col - 1][i]); - break; - } - case SQL_SMALLINT: { - row.append(buffers.smallIntBuffers[col - 1][i]); - break; - } - case SQL_TINYINT: { - row.append(buffers.charBuffers[col - 1][i]); - break; - } - case SQL_BIT: { - row.append(static_cast(buffers.charBuffers[col - 1][i])); - break; - } - case SQL_REAL: { - row.append(buffers.realBuffers[col - 1][i]); - break; - } - case SQL_DECIMAL: - case SQL_NUMERIC: { - try { - // Convert the string to use the current decimal separator - std::string numStr(reinterpret_cast( - &buffers.charBuffers[col - 1][i * MAX_DIGITS_IN_NUMERIC]), - buffers.indicators[col - 1][i]); - - // Get the current separator in a thread-safe way - std::string separator = GetDecimalSeparator(); - - if (separator != ".") { - // Replace the driver's decimal point with our configured separator - size_t pos = numStr.find('.'); - if (pos != std::string::npos) { - numStr.replace(pos, 1, separator); - } - } - - // Convert to Python decimal - row.append(py::module_::import("decimal").attr("Decimal")(numStr)); - } catch (const py::error_already_set& e) { - // Handle the exception, e.g., log the error and append py::none() - LOG("Error converting to decimal: {}", e.what()); - row.append(py::none()); - } - break; - } - case SQL_DOUBLE: - case SQL_FLOAT: { - row.append(buffers.doubleBuffers[col - 1][i]); - break; - } - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: { - row.append(py::module_::import("datetime") - .attr("datetime")(buffers.timestampBuffers[col - 1][i].year, - buffers.timestampBuffers[col - 1][i].month, - buffers.timestampBuffers[col - 1][i].day, - buffers.timestampBuffers[col - 1][i].hour, - buffers.timestampBuffers[col - 1][i].minute, - buffers.timestampBuffers[col - 1][i].second, - buffers.timestampBuffers[col - 1][i].fraction / 1000 /* Convert back ns to µs */)); - break; - } - case SQL_BIGINT: { - row.append(buffers.bigIntBuffers[col - 1][i]); - break; - } - case SQL_TYPE_DATE: { - row.append(py::module_::import("datetime") - .attr("date")(buffers.dateBuffers[col - 1][i].year, - buffers.dateBuffers[col - 1][i].month, - buffers.dateBuffers[col - 1][i].day)); - break; - } - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: { - row.append(py::module_::import("datetime") - .attr("time")(buffers.timeBuffers[col - 1][i].hour, - buffers.timeBuffers[col - 1][i].minute, - buffers.timeBuffers[col - 1][i].second)); - break; - } - case SQL_SS_TIMESTAMPOFFSET: { - SQLULEN rowIdx = i; - const DateTimeOffset& dtoValue = buffers.datetimeoffsetBuffers[col - 1][rowIdx]; - SQLLEN indicator = buffers.indicators[col - 1][rowIdx]; - if (indicator != SQL_NULL_DATA) { - int totalMinutes = dtoValue.timezone_hour * 60 + dtoValue.timezone_minute; - py::object datetime = py::module_::import("datetime"); - py::object tzinfo = datetime.attr("timezone")( - datetime.attr("timedelta")(py::arg("minutes") = totalMinutes) - ); - py::object py_dt = datetime.attr("datetime")( - dtoValue.year, - dtoValue.month, - dtoValue.day, - dtoValue.hour, - dtoValue.minute, - dtoValue.second, - dtoValue.fraction / 1000, // ns → µs - tzinfo - ); - row.append(py_dt); - } else { - row.append(py::none()); - } - break; - } - case SQL_GUID: { - SQLLEN indicator = buffers.indicators[col - 1][i]; - if (indicator == SQL_NULL_DATA) { - row.append(py::none()); - break; - } - SQLGUID* guidValue = &buffers.guidBuffers[col - 1][i]; - uint8_t reordered[16]; - reordered[0] = ((char*)&guidValue->Data1)[3]; - reordered[1] = ((char*)&guidValue->Data1)[2]; - reordered[2] = ((char*)&guidValue->Data1)[1]; - reordered[3] = ((char*)&guidValue->Data1)[0]; - reordered[4] = ((char*)&guidValue->Data2)[1]; - reordered[5] = ((char*)&guidValue->Data2)[0]; - reordered[6] = ((char*)&guidValue->Data3)[1]; - reordered[7] = ((char*)&guidValue->Data3)[0]; - std::memcpy(reordered + 8, guidValue->Data4, 8); - - py::bytes py_guid_bytes(reinterpret_cast(reordered), 16); - py::dict kwargs; - kwargs["bytes"] = py_guid_bytes; - py::object uuid_obj = py::module_::import("uuid").attr("UUID")(**kwargs); - row.append(uuid_obj); - break; - } - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: { - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - HandleZeroColumnSizeAtFetch(columnSize); - bool isLob = std::find(lobColumns.begin(), lobColumns.end(), col) != lobColumns.end(); - if (!isLob && static_cast(dataLen) <= columnSize) { - row.append(py::bytes(reinterpret_cast( - &buffers.charBuffers[col - 1][i * columnSize]), - dataLen)); - } else { - row.append(FetchLobColumnData(hStmt, col, SQL_C_BINARY, false, true)); - } + buffersArrow.intBuffers[col - 1][i] = buffers.intBuffers[col - 1][i]; break; } default: { @@ -4303,12 +4139,48 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) } } } - capsules.append(row); } + auto arrow_array_batch = new ArrowArray({ + .length = static_cast(numRowsFetched), + .n_buffers = 1, + .n_children = numCols, + .buffers = new const void* [3], + .children = new ArrowArray* [numCols], + .release = ArrowArray_release, + }); + // dummy buffer + arrow_array_batch->buffers[1] = new int[1]; + + for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto arrow_array_col = new ArrowArray({ + .length = static_cast(numRowsFetched), + .null_count = 0, + .offset = 0, + .n_buffers = 2, + .n_children = 0, + .buffers = new const void* [3], + .children = nullptr, + .release = ArrowArray_release, + }); + // Allocate new memory and copy the data + int* data_copy = new int[numRowsFetched]; + std::memcpy(data_copy, buffersArrow.intBuffers[col].data(), + numRowsFetched * sizeof(int)); + arrow_array_col->buffers[1] = data_copy; + + // TODO Make sure to free in release callback! + arrow_array_batch->children[col] = arrow_array_col; + } + + capsules.append(py::capsule((void*)arrow_array_batch, "arrow_array", [](void* ptr) { + delete static_cast(ptr); + })); + + // Reset attributes before returning to avoid using stack pointers later SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); From 7c5b22d6692654ebae00bba4fc48340d5b782442 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 11 Nov 2025 22:58:33 +0100 Subject: [PATCH 06/36] working nulls --- mssql_python/pybind/ddbc_bindings.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 3f709de7a..c4c447ed4 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -160,7 +160,7 @@ struct ColumnBuffersArrow { // std::vector> dateBuffers; // std::vector> timeBuffers; // std::vector> guidBuffers; - std::vector> indicators; + std::vector> valid; // std::vector> datetimeoffsetBuffers; ColumnBuffersArrow(SQLSMALLINT numCols, int fetchSize) @@ -177,7 +177,7 @@ struct ColumnBuffersArrow { // timeBuffers(numCols), // guidBuffers(numCols), // datetimeoffsetBuffers(numCols), - indicators(numCols, std::vector(fetchSize)) {} + valid(numCols, std::vector(fetchSize)) {} }; #ifndef ARROW_C_DATA_INTERFACE @@ -4059,6 +4059,9 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) ThrowStdException(errorString.str()); break; } + buffersArrow.valid[i].resize(fetchSize / 64 + 1); + // Initialize validity bitmap to all valid + std::memset(buffersArrow.valid[i].data(), 0xFF, buffersArrow.valid[i].size()); } assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); @@ -4115,7 +4118,11 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) if (dataLen == SQL_NO_TOTAL) { assert(false && "Is this actually possible?"); } else if (dataLen == SQL_NULL_DATA) { - assert(false && "TODO"); + // Mark as null in validity bitmap + size_t bytePos = i / 8; + size_t bitPos = i % 8; + buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + continue; } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); @@ -4172,6 +4179,12 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) numRowsFetched * sizeof(int)); arrow_array_col->buffers[1] = data_copy; + // Allocate & copy validity bitmap + size_t validityBitmapSize = buffersArrow.valid[col].size(); + uint8_t* validity_copy = new uint8_t[validityBitmapSize]; + std::memcpy(validity_copy, buffersArrow.valid[col].data(), validityBitmapSize); + arrow_array_col->buffers[0] = validity_copy; + // TODO Make sure to free in release callback! arrow_array_batch->children[col] = arrow_array_col; } From 15fa431908d9941552e2b29bb4aa3db5a403e8e7 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Thu, 13 Nov 2025 00:16:13 +0100 Subject: [PATCH 07/36] Free arrow memory --- mssql_python/cursor.py | 4 +- mssql_python/pybind/ddbc_bindings.cpp | 70 +++++++++++++++++++++------ 2 files changed, 57 insertions(+), 17 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index e5a1da7fa..90d26681d 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2208,11 +2208,11 @@ def fetch_arrow_batch(self) -> Any: ) from e capsules = [] ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules) - print(ret) + assert ret == 0 schema_capsule = capsules[0] array_capsule = capsules[1] batch = pa.RecordBatch._import_from_c_capsule(schema_capsule, array_capsule) - return batch, capsules + return batch def nextset(self) -> Union[bool, None]: """ diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c4c447ed4..4d108c0b8 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3992,15 +3992,46 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch } void ArrowSchema_release(struct ArrowSchema* schema) { - if (schema->release) { - schema->release = nullptr; + assert (schema != nullptr); + assert (schema->release != nullptr); + schema->release = nullptr; + delete[] schema->name; + for (int i = 0; i < schema->n_children; i++) { + assert (schema->children != nullptr); + if (schema->children[i]) { + schema->children[i]->release(schema->children[i]); + delete schema->children[i]; + } } + delete[] schema->children; + delete[] schema->format; } void ArrowArray_release(struct ArrowArray* array) { - if (array->release) { - array->release = nullptr; + assert (array != nullptr); + assert (array->release != nullptr); + array->release = nullptr; + + uint32_t buffers_freed = 0; + uint32_t current_buffer = 0; + while (buffers_freed < array->n_buffers) { + if (array->buffers[current_buffer]) { + delete[] array->buffers[current_buffer]; + buffers_freed++; + } + current_buffer++; + assert (current_buffer <= 3); + } + delete[] array->buffers; + + for (int i = 0; i < array->n_children; i++) { + assert (array->children != nullptr); + assert (array->children[i] != nullptr); + array->children[i]->release(array->children[i]); + delete array->children[i]; } + delete[] array->children; + } SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { @@ -4037,11 +4068,10 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) assert(dataType == SQL_INTEGER && "Only INTEGER type is supported in Arrow batch fetch for now"); std::string columnName = colMeta["ColumnName"].cast(); - char* name_copy = strdup(columnName.c_str()); auto arrow_schema = new ArrowSchema({ - .format = "i", - .name = name_copy, + .format = strdup("i"), + .name = strdup(columnName.c_str()), .release = ArrowSchema_release, }); batch_children[i] = arrow_schema; @@ -4067,15 +4097,18 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); auto arrow_schema_batch = new ArrowSchema({ - .format = "+s", - .name = "", + .format = strdup("+s"), + .name = strdup(""), .n_children = numCols, .children = batch_children, .release = ArrowSchema_release, }); - // auto caps = py::capsule((void*)arrow_schema, "arrow_schema", nullptr); auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { - delete static_cast(ptr); + auto arrow_schema = static_cast(ptr); + if (arrow_schema->release) { + arrow_schema->release(arrow_schema); + } + delete arrow_schema; }); capsules.append(caps); @@ -4150,12 +4183,13 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) - + auto arrow_array_batch_buffers = new const void* [3]; + memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); auto arrow_array_batch = new ArrowArray({ .length = static_cast(numRowsFetched), .n_buffers = 1, .n_children = numCols, - .buffers = new const void* [3], + .buffers = arrow_array_batch_buffers, .children = new ArrowArray* [numCols], .release = ArrowArray_release, }); @@ -4163,13 +4197,15 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) arrow_array_batch->buffers[1] = new int[1]; for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto arrow_array_col_buffers = new const void* [3]; + memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); auto arrow_array_col = new ArrowArray({ .length = static_cast(numRowsFetched), .null_count = 0, .offset = 0, .n_buffers = 2, .n_children = 0, - .buffers = new const void* [3], + .buffers = arrow_array_col_buffers, .children = nullptr, .release = ArrowArray_release, }); @@ -4190,7 +4226,11 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) } capsules.append(py::capsule((void*)arrow_array_batch, "arrow_array", [](void* ptr) { - delete static_cast(ptr); + auto arrow_array = static_cast(ptr); + if (arrow_array->release) { + arrow_array->release(arrow_array); + } + delete arrow_array; })); From 31b418d2dfc0a27ce9ed183b8055131818e2e3fd Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 14 Nov 2025 22:12:30 +0100 Subject: [PATCH 08/36] try adding more datatypes --- mssql_python/pybind/ddbc_bindings.cpp | 88 ++++++++++++++++++++------- 1 file changed, 67 insertions(+), 21 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 4d108c0b8..795078b0c 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -151,13 +151,13 @@ struct NumericData { struct ColumnBuffersArrow { // std::vector> charBuffers; // std::vector> wcharBuffers; - std::vector> intBuffers; + std::vector> intBuffers; // std::vector> smallIntBuffers; // std::vector> realBuffers; - // std::vector> doubleBuffers; + std::vector> doubleBuffers; // std::vector> timestampBuffers; - // std::vector> bigIntBuffers; - // std::vector> dateBuffers; + std::vector> bigIntBuffers; + // std::vector> dateBuffers; // std::vector> timeBuffers; // std::vector> guidBuffers; std::vector> valid; @@ -170,9 +170,9 @@ struct ColumnBuffersArrow { intBuffers(numCols), // smallIntBuffers(numCols), // realBuffers(numCols), - // doubleBuffers(numCols), + doubleBuffers(numCols), // timestampBuffers(numCols), - // bigIntBuffers(numCols), + bigIntBuffers(numCols), // dateBuffers(numCols), // timeBuffers(numCols), // guidBuffers(numCols), @@ -4016,7 +4016,7 @@ void ArrowArray_release(struct ArrowArray* array) { uint32_t current_buffer = 0; while (buffers_freed < array->n_buffers) { if (array->buffers[current_buffer]) { - delete[] array->buffers[current_buffer]; + free((void*)array->buffers[current_buffer]); buffers_freed++; } current_buffer++; @@ -4065,21 +4065,22 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) lobColumns.push_back(i + 1); // 1-based } - assert(dataType == SQL_INTEGER && "Only INTEGER type is supported in Arrow batch fetch for now"); - std::string columnName = colMeta["ColumnName"].cast(); - auto arrow_schema = new ArrowSchema({ - .format = strdup("i"), - .name = strdup(columnName.c_str()), - .release = ArrowSchema_release, - }); - batch_children[i] = arrow_schema; - + char* format = nullptr; switch(dataType) { case SQL_INTEGER: + format = strdup("i"); buffersArrow.intBuffers[i].resize(fetchSize); break; + case SQL_DOUBLE: + format = strdup("g"); + buffersArrow.doubleBuffers[i].resize(fetchSize); + break; + case SQL_BIGINT: + format = strdup("l"); + buffersArrow.bigIntBuffers[i].resize(fetchSize); + break; default: std::wstring columnName = colMeta["ColumnName"].cast(); std::ostringstream errorString; @@ -4089,6 +4090,14 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) ThrowStdException(errorString.str()); break; } + + auto arrow_schema = new ArrowSchema({ + .format = format, + .name = strdup(columnName.c_str()), + .release = ArrowSchema_release, + }); + batch_children[i] = arrow_schema; + buffersArrow.valid[i].resize(fetchSize / 64 + 1); // Initialize validity bitmap to all valid std::memset(buffersArrow.valid[i].data(), 0xFF, buffersArrow.valid[i].size()); @@ -4168,6 +4177,14 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) buffersArrow.intBuffers[col - 1][i] = buffers.intBuffers[col - 1][i]; break; } + case SQL_DOUBLE: { + buffersArrow.doubleBuffers[col - 1][i] = buffers.doubleBuffers[col - 1][i]; + break; + } + case SQL_BIGINT: { + buffersArrow.bigIntBuffers[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; + break; + } default: { std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; @@ -4197,6 +4214,8 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) arrow_array_batch->buffers[1] = new int[1]; for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto columnMeta = columnNames[col].cast(); + SQLSMALLINT dataType = columnMeta["DataType"].cast(); auto arrow_array_col_buffers = new const void* [3]; memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); auto arrow_array_col = new ArrowArray({ @@ -4210,10 +4229,38 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) .release = ArrowArray_release, }); // Allocate new memory and copy the data - int* data_copy = new int[numRowsFetched]; - std::memcpy(data_copy, buffersArrow.intBuffers[col].data(), - numRowsFetched * sizeof(int)); - arrow_array_col->buffers[1] = data_copy; + switch (dataType) { + case SQL_INTEGER: { + int* data_copy = new int[numRowsFetched]; + std::memcpy(data_copy, buffersArrow.intBuffers[col].data(), + numRowsFetched * sizeof(int)); + arrow_array_col->buffers[1] = data_copy; + break; + } + case SQL_DOUBLE: { + double* data_copy = new double[numRowsFetched]; + std::memcpy(data_copy, buffersArrow.doubleBuffers[col].data(), + numRowsFetched * sizeof(double)); + arrow_array_col->buffers[1] = data_copy; + break; + } + case SQL_BIGINT: { + int64_t* data_copy = new int64_t[numRowsFetched]; + std::memcpy(data_copy, buffersArrow.bigIntBuffers[col].data(), + numRowsFetched * sizeof(int64_t)); + arrow_array_col->buffers[1] = data_copy; + break; + } + default: { + std::wstring columnName = columnMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << (col + 1); + LOG(errorString.str()); + ThrowStdException(errorString.str()); + break; + } + } // Allocate & copy validity bitmap size_t validityBitmapSize = buffersArrow.valid[col].size(); @@ -4221,7 +4268,6 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) std::memcpy(validity_copy, buffersArrow.valid[col].data(), validityBitmapSize); arrow_array_col->buffers[0] = validity_copy; - // TODO Make sure to free in release callback! arrow_array_batch->children[col] = arrow_array_col; } From 3f5b335c328553119fd0dfc2e5e5f0a9e404a3db Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 14 Nov 2025 22:53:08 +0100 Subject: [PATCH 09/36] unique pointers for arrow array --- mssql_python/pybind/ddbc_bindings.cpp | 60 ++++++++++----------------- 1 file changed, 23 insertions(+), 37 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 795078b0c..cfdde2203 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -149,19 +149,19 @@ struct NumericData { // Struct to hold data buffers and indicators for each column struct ColumnBuffersArrow { - // std::vector> charBuffers; - // std::vector> wcharBuffers; - std::vector> intBuffers; - // std::vector> smallIntBuffers; - // std::vector> realBuffers; - std::vector> doubleBuffers; - // std::vector> timestampBuffers; - std::vector> bigIntBuffers; - // std::vector> dateBuffers; - // std::vector> timeBuffers; - // std::vector> guidBuffers; - std::vector> valid; - // std::vector> datetimeoffsetBuffers; + // std::vector> charBuffers; + // std::vector> wcharBuffers; + std::vector> intBuffers; + // std::vector> smallIntBuffers; + // std::vector> realBuffers; + std::vector> doubleBuffers; + // std::vector> timestampBuffers; + std::vector> bigIntBuffers; + // std::vector> dateBuffers; + // std::vector> timeBuffers; + // std::vector> guidBuffers; + std::vector> valid; + // std::vector> datetimeoffsetBuffers; ColumnBuffersArrow(SQLSMALLINT numCols, int fetchSize) : @@ -177,7 +177,7 @@ struct ColumnBuffersArrow { // timeBuffers(numCols), // guidBuffers(numCols), // datetimeoffsetBuffers(numCols), - valid(numCols, std::vector(fetchSize)) {} + valid(numCols) {} }; #ifndef ARROW_C_DATA_INTERFACE @@ -4071,15 +4071,15 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) switch(dataType) { case SQL_INTEGER: format = strdup("i"); - buffersArrow.intBuffers[i].resize(fetchSize); + buffersArrow.intBuffers[i] = std::make_unique(fetchSize); break; case SQL_DOUBLE: format = strdup("g"); - buffersArrow.doubleBuffers[i].resize(fetchSize); + buffersArrow.doubleBuffers[i] = std::make_unique(fetchSize); break; case SQL_BIGINT: format = strdup("l"); - buffersArrow.bigIntBuffers[i].resize(fetchSize); + buffersArrow.bigIntBuffers[i] = std::make_unique(fetchSize); break; default: std::wstring columnName = colMeta["ColumnName"].cast(); @@ -4098,9 +4098,9 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) }); batch_children[i] = arrow_schema; - buffersArrow.valid[i].resize(fetchSize / 64 + 1); + buffersArrow.valid[i] = std::make_unique((fetchSize + 7) / 8); // Initialize validity bitmap to all valid - std::memset(buffersArrow.valid[i].data(), 0xFF, buffersArrow.valid[i].size()); + std::memset(buffersArrow.valid[i].get(), 0xFF, (fetchSize + 7) / 8); } assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); @@ -4231,24 +4231,15 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) // Allocate new memory and copy the data switch (dataType) { case SQL_INTEGER: { - int* data_copy = new int[numRowsFetched]; - std::memcpy(data_copy, buffersArrow.intBuffers[col].data(), - numRowsFetched * sizeof(int)); - arrow_array_col->buffers[1] = data_copy; + arrow_array_col->buffers[1] = buffersArrow.intBuffers[col].release(); break; } case SQL_DOUBLE: { - double* data_copy = new double[numRowsFetched]; - std::memcpy(data_copy, buffersArrow.doubleBuffers[col].data(), - numRowsFetched * sizeof(double)); - arrow_array_col->buffers[1] = data_copy; + arrow_array_col->buffers[1] = buffersArrow.doubleBuffers[col].release(); break; } case SQL_BIGINT: { - int64_t* data_copy = new int64_t[numRowsFetched]; - std::memcpy(data_copy, buffersArrow.bigIntBuffers[col].data(), - numRowsFetched * sizeof(int64_t)); - arrow_array_col->buffers[1] = data_copy; + arrow_array_col->buffers[1] = buffersArrow.bigIntBuffers[col].release(); break; } default: { @@ -4262,12 +4253,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) } } - // Allocate & copy validity bitmap - size_t validityBitmapSize = buffersArrow.valid[col].size(); - uint8_t* validity_copy = new uint8_t[validityBitmapSize]; - std::memcpy(validity_copy, buffersArrow.valid[col].data(), validityBitmapSize); - arrow_array_col->buffers[0] = validity_copy; - + arrow_array_col->buffers[0] = buffersArrow.valid[col].release(); arrow_array_batch->children[col] = arrow_array_col; } From 724eab96113e766a916e500b912d287ecfcf3b98 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 14 Nov 2025 23:12:23 +0100 Subject: [PATCH 10/36] more arrow like buffersArrow names --- mssql_python/pybind/ddbc_bindings.cpp | 78 ++++++++++++++------------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index cfdde2203..f42630e30 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -149,35 +149,37 @@ struct NumericData { // Struct to hold data buffers and indicators for each column struct ColumnBuffersArrow { - // std::vector> charBuffers; - // std::vector> wcharBuffers; - std::vector> intBuffers; - // std::vector> smallIntBuffers; - // std::vector> realBuffers; - std::vector> doubleBuffers; - // std::vector> timestampBuffers; - std::vector> bigIntBuffers; - // std::vector> dateBuffers; - // std::vector> timeBuffers; - // std::vector> guidBuffers; + std::vector> uint8; + std::vector> int16; + std::vector> int32; + std::vector> int64; + std::vector> float64; + std::vector> bit; + std::vector> varlen; + std::vector> date; + std::vector> ts_micro; + std::vector> time_nano; + std::vector> decimal; + std::vector> valid; - // std::vector> datetimeoffsetBuffers; + std::vector> var_data; - ColumnBuffersArrow(SQLSMALLINT numCols, int fetchSize) + ColumnBuffersArrow(SQLSMALLINT numCols) : - // : charBuffers(numCols), - // wcharBuffers(numCols), - intBuffers(numCols), - // smallIntBuffers(numCols), - // realBuffers(numCols), - doubleBuffers(numCols), - // timestampBuffers(numCols), - bigIntBuffers(numCols), - // dateBuffers(numCols), - // timeBuffers(numCols), - // guidBuffers(numCols), - // datetimeoffsetBuffers(numCols), - valid(numCols) {} + uint8(numCols), + int16(numCols), + int32(numCols), + int64(numCols), + float64(numCols), + bit(numCols), + varlen(numCols), + date(numCols), + ts_micro(numCols), + time_nano(numCols), + decimal(numCols), + + valid(numCols), + var_data(numCols){} }; #ifndef ARROW_C_DATA_INTERFACE @@ -4052,7 +4054,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) auto batch_children = new ArrowSchema* [numCols]; std::vector lobColumns; - ColumnBuffersArrow buffersArrow(numCols, fetchSize); + ColumnBuffersArrow buffersArrow(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { auto colMeta = columnNames[i].cast(); SQLSMALLINT dataType = colMeta["DataType"].cast(); @@ -4071,15 +4073,19 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) switch(dataType) { case SQL_INTEGER: format = strdup("i"); - buffersArrow.intBuffers[i] = std::make_unique(fetchSize); + buffersArrow.int32[i] = std::make_unique(fetchSize); break; case SQL_DOUBLE: format = strdup("g"); - buffersArrow.doubleBuffers[i] = std::make_unique(fetchSize); + buffersArrow.float64[i] = std::make_unique(fetchSize); break; case SQL_BIGINT: format = strdup("l"); - buffersArrow.bigIntBuffers[i] = std::make_unique(fetchSize); + buffersArrow.int64[i] = std::make_unique(fetchSize); + break; + case SQL_DATE: + format = strdup("tdD"); + buffersArrow.date[i] = std::make_unique(fetchSize); break; default: std::wstring columnName = colMeta["ColumnName"].cast(); @@ -4174,15 +4180,15 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) switch (dataType) { case SQL_INTEGER: { - buffersArrow.intBuffers[col - 1][i] = buffers.intBuffers[col - 1][i]; + buffersArrow.int32[col - 1][i] = buffers.intBuffers[col - 1][i]; break; } case SQL_DOUBLE: { - buffersArrow.doubleBuffers[col - 1][i] = buffers.doubleBuffers[col - 1][i]; + buffersArrow.float64[col - 1][i] = buffers.doubleBuffers[col - 1][i]; break; } case SQL_BIGINT: { - buffersArrow.bigIntBuffers[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; + buffersArrow.int64[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; break; } default: { @@ -4231,15 +4237,15 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) // Allocate new memory and copy the data switch (dataType) { case SQL_INTEGER: { - arrow_array_col->buffers[1] = buffersArrow.intBuffers[col].release(); + arrow_array_col->buffers[1] = buffersArrow.int32[col].release(); break; } case SQL_DOUBLE: { - arrow_array_col->buffers[1] = buffersArrow.doubleBuffers[col].release(); + arrow_array_col->buffers[1] = buffersArrow.float64[col].release(); break; } case SQL_BIGINT: { - arrow_array_col->buffers[1] = buffersArrow.bigIntBuffers[col].release(); + arrow_array_col->buffers[1] = buffersArrow.int64[col].release(); break; } default: { From 24c2c5e9f8a8da48a891c2b7d04521da7799d2c4 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 14 Nov 2025 23:41:28 +0100 Subject: [PATCH 11/36] Add all formats/buffer allocs --- mssql_python/pybind/ddbc_bindings.cpp | 94 +++++++++++++++++++++++++-- 1 file changed, 87 insertions(+), 7 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f42630e30..167de0a3b 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -155,7 +155,7 @@ struct ColumnBuffersArrow { std::vector> int64; std::vector> float64; std::vector> bit; - std::vector> varlen; + std::vector> var; std::vector> date; std::vector> ts_micro; std::vector> time_nano; @@ -163,6 +163,7 @@ struct ColumnBuffersArrow { std::vector> valid; std::vector> var_data; + std::vector var_data_len; ColumnBuffersArrow(SQLSMALLINT numCols) : @@ -172,14 +173,16 @@ struct ColumnBuffersArrow { int64(numCols), float64(numCols), bit(numCols), - varlen(numCols), + var(numCols), date(numCols), ts_micro(numCols), time_nano(numCols), decimal(numCols), valid(numCols), - var_data(numCols){} + var_data(numCols), + // initialize lengths to 0 + var_data_len(numCols, 0){} }; #ifndef ARROW_C_DATA_INTERFACE @@ -4071,22 +4074,81 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) char* format = nullptr; switch(dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + format = strdup("u"); + buffersArrow.var[i] = std::make_unique(fetchSize); + buffersArrow.var_data[i] = std::make_unique(fetchSize * 42); + buffersArrow.var_data_len[i] = 42; + break; + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + format = strdup("z"); + buffersArrow.var[i] = std::make_unique(fetchSize); + buffersArrow.var_data[i] = std::make_unique(fetchSize * 42); + buffersArrow.var_data_len[i] = 42; + case SQL_TINYINT: + format = strdup("C"); + buffersArrow.uint8[i] = std::make_unique(fetchSize); + break; + case SQL_SMALLINT: + format = strdup("s"); + buffersArrow.int16[i] = std::make_unique(fetchSize); + break; case SQL_INTEGER: format = strdup("i"); buffersArrow.int32[i] = std::make_unique(fetchSize); break; + case SQL_BIGINT: + format = strdup("l"); + buffersArrow.int64[i] = std::make_unique(fetchSize); + break; + case SQL_REAL: + case SQL_FLOAT: case SQL_DOUBLE: format = strdup("g"); buffersArrow.float64[i] = std::make_unique(fetchSize); break; - case SQL_BIGINT: - format = strdup("l"); - buffersArrow.int64[i] = std::make_unique(fetchSize); + case SQL_DECIMAL: + case SQL_NUMERIC: { + std::ostringstream formatStream; + formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); + std::string formatStr = formatStream.str(); + format = strdup(formatStr.c_str()); + buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(fetchSize); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + format = strdup("tsu"); + buffersArrow.ts_micro[i] = std::make_unique(fetchSize); + break; + case SQL_SS_TIMESTAMPOFFSET: + format = strdup("tsu:+00:00"); + buffersArrow.ts_micro[i] = std::make_unique(fetchSize); break; - case SQL_DATE: + case SQL_TYPE_DATE: format = strdup("tdD"); buffersArrow.date[i] = std::make_unique(fetchSize); break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + format = strdup("ttu"); + buffersArrow.time_nano[i] = std::make_unique(fetchSize); + break; + case SQL_BIT: + format = strdup("b"); + buffersArrow.bit[i] = std::make_unique((fetchSize + 7) / 8); + break; default: std::wstring columnName = colMeta["ColumnName"].cast(); std::ostringstream errorString; @@ -4191,6 +4253,24 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) buffersArrow.int64[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; break; } + // case SQL_DATE: { + // // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) + // SQL_DATE_STRUCT sqlDate = buffers.dateBuffers[col - 1][i]; + // std::tm tm_date = {}; + // tm_date.tm_year = sqlDate.year - 1900; // tm_year is years since 1900 + // tm_date.tm_mon = sqlDate.month - 1; // tm_mon is 0-11 + // tm_date.tm_mday = sqlDate.day; + + // std::time_t time_since_epoch = std::mktime(&tm_date); + // if (time_since_epoch == -1) { + // LOG("Failed to convert SQL_DATE_STRUCT to time_t for Column ID - {}", col); + // ThrowStdException("Date conversion error, check logs for details"); + // } + // // Calculate days since epoch + // int32_t days_since_epoch = static_cast(time_since_epoch / 86400); + // buffersArrow.date32[col - 1][i] = days_since_epoch; + // break; + // } default: { std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; From 84b4b787acc7257bf618751617a9d88b2ad12977 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 14 Nov 2025 23:48:07 +0100 Subject: [PATCH 12/36] Add ownership -> arrow transfer for all --- mssql_python/pybind/ddbc_bindings.cpp | 55 ++++++++++++++++++++++++--- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 167de0a3b..c82bfdf77 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4316,18 +4316,61 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) }); // Allocate new memory and copy the data switch (dataType) { - case SQL_INTEGER: { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + arrow_array_col->buffers[1] = buffersArrow.var[col].release(); + arrow_array_col->buffers[2] = buffersArrow.var_data[col].release(); + break; + case SQL_TINYINT: + arrow_array_col->buffers[1] = buffersArrow.uint8[col].release(); + break; + case SQL_SMALLINT: + arrow_array_col->buffers[1] = buffersArrow.int16[col].release(); + break; + case SQL_INTEGER: arrow_array_col->buffers[1] = buffersArrow.int32[col].release(); break; - } - case SQL_DOUBLE: { + case SQL_BIGINT: + arrow_array_col->buffers[1] = buffersArrow.int64[col].release(); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: arrow_array_col->buffers[1] = buffersArrow.float64[col].release(); break; - } - case SQL_BIGINT: { - arrow_array_col->buffers[1] = buffersArrow.int64[col].release(); + case SQL_DECIMAL: + case SQL_NUMERIC: { + arrow_array_col->buffers[1] = buffersArrow.decimal[col].release(); break; } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + arrow_array_col->buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_SS_TIMESTAMPOFFSET: + arrow_array_col->buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_TYPE_DATE: + arrow_array_col->buffers[1] = buffersArrow.date[col].release(); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + arrow_array_col->buffers[1] = buffersArrow.time_nano[col].release(); + break; + case SQL_BIT: + arrow_array_col->buffers[1] = buffersArrow.bit[col].release(); + break; default: { std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; From a78f66d6df12fd9c105343d31e14aa513601add5 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 14 Nov 2025 23:54:05 +0100 Subject: [PATCH 13/36] working date --- mssql_python/pybind/ddbc_bindings.cpp | 36 +++++++++++++-------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c82bfdf77..f721967c3 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4253,24 +4253,24 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) buffersArrow.int64[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; break; } - // case SQL_DATE: { - // // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) - // SQL_DATE_STRUCT sqlDate = buffers.dateBuffers[col - 1][i]; - // std::tm tm_date = {}; - // tm_date.tm_year = sqlDate.year - 1900; // tm_year is years since 1900 - // tm_date.tm_mon = sqlDate.month - 1; // tm_mon is 0-11 - // tm_date.tm_mday = sqlDate.day; - - // std::time_t time_since_epoch = std::mktime(&tm_date); - // if (time_since_epoch == -1) { - // LOG("Failed to convert SQL_DATE_STRUCT to time_t for Column ID - {}", col); - // ThrowStdException("Date conversion error, check logs for details"); - // } - // // Calculate days since epoch - // int32_t days_since_epoch = static_cast(time_since_epoch / 86400); - // buffersArrow.date32[col - 1][i] = days_since_epoch; - // break; - // } + case SQL_TYPE_DATE: { + // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) + SQL_DATE_STRUCT sqlDate = buffers.dateBuffers[col - 1][i]; + std::tm tm_date = {}; + tm_date.tm_year = sqlDate.year - 1900; // tm_year is years since 1900 + tm_date.tm_mon = sqlDate.month - 1; // tm_mon is 0-11 + tm_date.tm_mday = sqlDate.day; + + std::time_t time_since_epoch = std::mktime(&tm_date); + if (time_since_epoch == -1) { + LOG("Failed to convert SQL_DATE_STRUCT to time_t for Column ID - {}", col); + ThrowStdException("Date conversion error, check logs for details"); + } + // Calculate days since epoch + int32_t days_since_epoch = static_cast(time_since_epoch / 86400); + buffersArrow.date[col - 1][i] = days_since_epoch; + break; + } default: { std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; From 1dbd833b5004980e8ebba2ba6047d1f033b579d9 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 15 Nov 2025 00:41:06 +0100 Subject: [PATCH 14/36] working timestamp(offset) --- mssql_python/pybind/ddbc_bindings.cpp | 72 ++++++++++++++++++++------- 1 file changed, 55 insertions(+), 17 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f721967c3..834d4cfaa 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4039,6 +4039,22 @@ void ArrowArray_release(struct ArrowArray* array) { } +int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { + // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) + std::tm tm_date = {}; + tm_date.tm_year = year - 1900; // tm_year is years since 1900 + tm_date.tm_mon = month - 1; // tm_mon is 0-11 + tm_date.tm_mday = day; + + std::time_t time_since_epoch = std::mktime(&tm_date); + if (time_since_epoch == -1) { + LOG("Failed to convert SQL_DATE_STRUCT to time_t"); + ThrowStdException("Date conversion error"); + } + // Calculate days since epoch + return time_since_epoch / 86400; +} + SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { ssize_t fetchSize = 500; SQLRETURN ret; @@ -4128,7 +4144,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: - format = strdup("tsu"); + format = strdup("tsu:"); buffersArrow.ts_micro[i] = std::make_unique(fetchSize); break; case SQL_SS_TIMESTAMPOFFSET: @@ -4253,22 +4269,44 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) buffersArrow.int64[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; break; } - case SQL_TYPE_DATE: { - // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) - SQL_DATE_STRUCT sqlDate = buffers.dateBuffers[col - 1][i]; - std::tm tm_date = {}; - tm_date.tm_year = sqlDate.year - 1900; // tm_year is years since 1900 - tm_date.tm_mon = sqlDate.month - 1; // tm_mon is 0-11 - tm_date.tm_mday = sqlDate.day; - - std::time_t time_since_epoch = std::mktime(&tm_date); - if (time_since_epoch == -1) { - LOG("Failed to convert SQL_DATE_STRUCT to time_t for Column ID - {}", col); - ThrowStdException("Date conversion error, check logs for details"); - } - // Calculate days since epoch - int32_t days_since_epoch = static_cast(time_since_epoch / 86400); - buffersArrow.date[col - 1][i] = days_since_epoch; + case SQL_TYPE_DATE: + buffersArrow.date[col - 1][i] = dateAsDayCount( + buffers.dateBuffers[col - 1][i].year, + buffers.dateBuffers[col - 1][i].month, + buffers.dateBuffers[col - 1][i].day + ); + break; + + case SQL_SS_TIMESTAMPOFFSET: { + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][i]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][i] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][i]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][i] = + days * 86400 * 1000000 + + static_cast(sql_value.hour) * 3600 * 1000000 + + static_cast(sql_value.minute) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; break; } default: { From cafd1a8750d5926f5bb28f4534f5d538f9a50d24 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 15 Nov 2025 20:48:32 +0100 Subject: [PATCH 15/36] working wchar --- mssql_python/pybind/ddbc_bindings.cpp | 103 +++++++++++++++++--------- 1 file changed, 68 insertions(+), 35 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 834d4cfaa..16f35b4ef 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -162,8 +162,7 @@ struct ColumnBuffersArrow { std::vector> decimal; std::vector> valid; - std::vector> var_data; - std::vector var_data_len; + std::vector> var_data; ColumnBuffersArrow(SQLSMALLINT numCols) : @@ -180,9 +179,7 @@ struct ColumnBuffersArrow { decimal(numCols), valid(numCols), - var_data(numCols), - // initialize lengths to 0 - var_data_len(numCols, 0){} + var_data(numCols) {} }; #ifndef ARROW_C_DATA_INTERFACE @@ -4100,16 +4097,19 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_GUID: format = strdup("u"); buffersArrow.var[i] = std::make_unique(fetchSize); - buffersArrow.var_data[i] = std::make_unique(fetchSize * 42); - buffersArrow.var_data_len[i] = 42; + buffersArrow.var_data[i].resize(fetchSize * 42); + // start at offset 0 + buffersArrow.var_data[i][0] = 0; break; case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: format = strdup("z"); buffersArrow.var[i] = std::make_unique(fetchSize); - buffersArrow.var_data[i] = std::make_unique(fetchSize * 42); - buffersArrow.var_data_len[i] = 42; + buffersArrow.var_data[i].resize(fetchSize * 42); + // start at offset 0 + buffersArrow.var_data[i][0] = 0; + break; case SQL_TINYINT: format = strdup("C"); buffersArrow.uint8[i] = std::make_unique(fetchSize); @@ -4254,9 +4254,34 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); ThrowStdException("Unexpected negative data length, check logs for details"); } - assert(dataLen > 0 && "Data length must be > 0"); + assert(dataLen >= 0 && "Data length must be >= 0"); switch (dataType) { + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + auto wcharSource = &buffers.wcharBuffers[col - 1][i]; + auto start = buffersArrow.var[col - 1][i]; + auto target_vec = &buffersArrow.var_data[col - 1]; +#if defined(_WIN32) + // Convert wide string + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLen, NULL, 0, NULL, NULL); + while (target_vec->size() < start + dataLenConverted) { + target_vec->resize(target_vec->size() * 2); + } + WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLen, &(*target_vec)[start], dataLenConverted, NULL, NULL); + buffersArrow.var[col - 1][i + 1] = start + dataLenConverted; +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLen)); + std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); + buffersArrow.var[col - 1][i + 1] = start + utf8str.size(); + // debug print results + std::cout << "UTF-8 string: " << utf8str << " " << utf8str.size() << std::endl; +#endif + break; + } case SQL_INTEGER: { buffersArrow.int32[col - 1][i] = buffers.intBuffers[col - 1][i]; break; @@ -4342,16 +4367,6 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) SQLSMALLINT dataType = columnMeta["DataType"].cast(); auto arrow_array_col_buffers = new const void* [3]; memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); - auto arrow_array_col = new ArrowArray({ - .length = static_cast(numRowsFetched), - .null_count = 0, - .offset = 0, - .n_buffers = 2, - .n_children = 0, - .buffers = arrow_array_col_buffers, - .children = nullptr, - .release = ArrowArray_release, - }); // Allocate new memory and copy the data switch (dataType) { case SQL_CHAR: @@ -4364,50 +4379,57 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_GUID: case SQL_BINARY: case SQL_VARBINARY: - case SQL_LONGVARBINARY: - arrow_array_col->buffers[1] = buffersArrow.var[col].release(); - arrow_array_col->buffers[2] = buffersArrow.var_data[col].release(); - break; + case SQL_LONGVARBINARY: { + assert(buffersArrow.var[col][0] == 0); + // length of string at index i is the difference between values at i and i+1 + // so total length is value at index numRowsFetched + auto data_buf_len_total = buffersArrow.var[col][numRowsFetched]; + uint8_t* dataBuffer = new uint8_t[data_buf_len_total]; + std::memcpy(dataBuffer, buffersArrow.var_data[col].data(), data_buf_len_total); + arrow_array_col_buffers[2] = dataBuffer; + arrow_array_col_buffers[1] = buffersArrow.var[col].release(); + } + break; case SQL_TINYINT: - arrow_array_col->buffers[1] = buffersArrow.uint8[col].release(); + arrow_array_col_buffers[1] = buffersArrow.uint8[col].release(); break; case SQL_SMALLINT: - arrow_array_col->buffers[1] = buffersArrow.int16[col].release(); + arrow_array_col_buffers[1] = buffersArrow.int16[col].release(); break; case SQL_INTEGER: - arrow_array_col->buffers[1] = buffersArrow.int32[col].release(); + arrow_array_col_buffers[1] = buffersArrow.int32[col].release(); break; case SQL_BIGINT: - arrow_array_col->buffers[1] = buffersArrow.int64[col].release(); + arrow_array_col_buffers[1] = buffersArrow.int64[col].release(); break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: - arrow_array_col->buffers[1] = buffersArrow.float64[col].release(); + arrow_array_col_buffers[1] = buffersArrow.float64[col].release(); break; case SQL_DECIMAL: case SQL_NUMERIC: { - arrow_array_col->buffers[1] = buffersArrow.decimal[col].release(); + arrow_array_col_buffers[1] = buffersArrow.decimal[col].release(); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: - arrow_array_col->buffers[1] = buffersArrow.ts_micro[col].release(); + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); break; case SQL_SS_TIMESTAMPOFFSET: - arrow_array_col->buffers[1] = buffersArrow.ts_micro[col].release(); + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); break; case SQL_TYPE_DATE: - arrow_array_col->buffers[1] = buffersArrow.date[col].release(); + arrow_array_col_buffers[1] = buffersArrow.date[col].release(); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: - arrow_array_col->buffers[1] = buffersArrow.time_nano[col].release(); + arrow_array_col_buffers[1] = buffersArrow.time_nano[col].release(); break; case SQL_BIT: - arrow_array_col->buffers[1] = buffersArrow.bit[col].release(); + arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); break; default: { std::wstring columnName = columnMeta["ColumnName"].cast(); @@ -4420,6 +4442,17 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) } } + auto arrow_array_col = new ArrowArray({ + .length = static_cast(numRowsFetched), + .null_count = 0, + .offset = 0, + .n_buffers = arrow_array_col_buffers[2] ? 3 : 2, + .n_children = 0, + .buffers = arrow_array_col_buffers, + .children = nullptr, + .release = ArrowArray_release, + }); + arrow_array_col->buffers[0] = buffersArrow.valid[col].release(); arrow_array_batch->children[col] = arrow_array_col; } From 382115659df7c607dc2300e3da987a2788a8b5c2 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 15 Nov 2025 20:54:27 +0100 Subject: [PATCH 16/36] add placeholder asserts --- mssql_python/pybind/ddbc_bindings.cpp | 78 ++++++++++++++++++--------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 16f35b4ef..5d15ce992 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4257,6 +4257,11 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) assert(dataLen >= 0 && "Data length must be >= 0"); switch (dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + assert(0 && "TODO"); + break; case SQL_SS_XML: case SQL_WCHAR: case SQL_WVARCHAR: @@ -4282,28 +4287,40 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) #endif break; } - case SQL_INTEGER: { + case SQL_GUID: + assert(0 && "TODO"); + break; + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + assert(0 && "TODO"); + break; + case SQL_TINYINT: + assert(0 && "TODO"); + break; + case SQL_SMALLINT: + assert(0 && "TODO"); + break; + case SQL_INTEGER: buffersArrow.int32[col - 1][i] = buffers.intBuffers[col - 1][i]; break; - } - case SQL_DOUBLE: { + case SQL_BIGINT: + buffersArrow.int64[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: buffersArrow.float64[col - 1][i] = buffers.doubleBuffers[col - 1][i]; break; - } - case SQL_BIGINT: { - buffersArrow.int64[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; + case SQL_DECIMAL: + case SQL_NUMERIC: { + assert(0 && "TODO"); break; } - case SQL_TYPE_DATE: - buffersArrow.date[col - 1][i] = dateAsDayCount( - buffers.dateBuffers[col - 1][i].year, - buffers.dateBuffers[col - 1][i].month, - buffers.dateBuffers[col - 1][i].day - ); - break; - - case SQL_SS_TIMESTAMPOFFSET: { - DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][i]; + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][i]; int64_t days = dateAsDayCount( sql_value.year, sql_value.month, @@ -4311,16 +4328,14 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) ); buffersArrow.ts_micro[col - 1][i] = days * 86400 * 1000000 + - (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + - (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + static_cast(sql_value.hour) * 3600 * 1000000 + + static_cast(sql_value.minute) * 60 * 1000000 + static_cast(sql_value.second) * 1000000 + static_cast(sql_value.fraction) / 1000; break; } - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: { - SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][i]; + case SQL_SS_TIMESTAMPOFFSET: { + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][i]; int64_t days = dateAsDayCount( sql_value.year, sql_value.month, @@ -4328,12 +4343,27 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) ); buffersArrow.ts_micro[col - 1][i] = days * 86400 * 1000000 + - static_cast(sql_value.hour) * 3600 * 1000000 + - static_cast(sql_value.minute) * 60 * 1000000 + + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + static_cast(sql_value.second) * 1000000 + static_cast(sql_value.fraction) / 1000; break; } + case SQL_TYPE_DATE: + buffersArrow.date[col - 1][i] = dateAsDayCount( + buffers.dateBuffers[col - 1][i].year, + buffers.dateBuffers[col - 1][i].month, + buffers.dateBuffers[col - 1][i].day + ); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + assert(0 && "TODO"); + break; + case SQL_BIT: + assert(0 && "TODO"); + break; default: { std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; From 8454e9fecc057759f3fc11cb4122ffdc1f29a1fb Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 15 Nov 2025 21:04:03 +0100 Subject: [PATCH 17/36] implement char/binary --- mssql_python/pybind/ddbc_bindings.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 5d15ce992..5ca049eae 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4257,11 +4257,21 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) assert(dataLen >= 0 && "Data length must be >= 0"); switch (dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: case SQL_CHAR: case SQL_VARCHAR: - case SQL_LONGVARCHAR: - assert(0 && "TODO"); - break; + case SQL_LONGVARCHAR: { + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][i]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][i], dataLen); + buffersArrow.var[col - 1][i + 1] = start + dataLen; + break; + } case SQL_SS_XML: case SQL_WCHAR: case SQL_WVARCHAR: @@ -4290,11 +4300,6 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_GUID: assert(0 && "TODO"); break; - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: - assert(0 && "TODO"); - break; case SQL_TINYINT: assert(0 && "TODO"); break; From 9889d4a838f2242a2f7164f3014b18ac511113b3 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 15 Nov 2025 21:26:53 +0100 Subject: [PATCH 18/36] add guids --- mssql_python/pybind/ddbc_bindings.cpp | 30 +++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 5ca049eae..1db655662 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4297,9 +4297,35 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) #endif break; } - case SQL_GUID: - assert(0 && "TODO"); + case SQL_GUID: { + // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") + // Each GUID is exactly 36 bytes in UTF-8 + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][i]; + + // Ensure buffer has space for the GUID string + null terminator + while (target_vec->size() < start + 37) { + target_vec->resize(target_vec->size() * 2); + } + + // Get the GUID from the buffer + const SQLGUID& guidValue = buffers.guidBuffers[col - 1][i]; + + // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + snprintf(reinterpret_cast(&target_vec->data()[start]), 37, + "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + guidValue.Data1, + guidValue.Data2, + guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], + guidValue.Data4[2], guidValue.Data4[3], + guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); + + // Update offset for next row, ignoring null terminator + buffersArrow.var[col - 1][i + 1] = start + 36; break; + } case SQL_TINYINT: assert(0 && "TODO"); break; From 0e00daee737c841411e1ea77cfbe99eb5ffbf2fb Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 15 Nov 2025 23:52:46 +0100 Subject: [PATCH 19/36] add bit, time --- mssql_python/cursor.py | 2 +- mssql_python/pybind/ddbc_bindings.cpp | 43 ++++++++++++++++++++------- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 90d26681d..0bf8f8a3b 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2208,7 +2208,7 @@ def fetch_arrow_batch(self) -> Any: ) from e capsules = [] ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules) - assert ret == 0 + assert ret in (0, 1), ret schema_capsule = capsules[0] array_capsule = capsules[1] batch = pa.RecordBatch._import_from_c_capsule(schema_capsule, array_capsule) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 1db655662..454e8cc5a 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -158,7 +158,7 @@ struct ColumnBuffersArrow { std::vector> var; std::vector> date; std::vector> ts_micro; - std::vector> time_nano; + std::vector> time_second; std::vector> decimal; std::vector> valid; @@ -175,7 +175,7 @@ struct ColumnBuffersArrow { var(numCols), date(numCols), ts_micro(numCols), - time_nano(numCols), + time_second(numCols), decimal(numCols), valid(numCols), @@ -4158,8 +4158,8 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: - format = strdup("ttu"); - buffersArrow.time_nano[i] = std::make_unique(fetchSize); + format = strdup("tts"); + buffersArrow.time_second[i] = std::make_unique(fetchSize); break; case SQL_BIT: format = strdup("b"); @@ -4327,10 +4327,10 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) break; } case SQL_TINYINT: - assert(0 && "TODO"); + buffersArrow.uint8[col - 1][i] = buffers.charBuffers[col - 1][i]; break; case SQL_SMALLINT: - assert(0 && "TODO"); + buffersArrow.int16[col - 1][i] = buffers.smallIntBuffers[col - 1][i]; break; case SQL_INTEGER: buffersArrow.int32[col - 1][i] = buffers.intBuffers[col - 1][i]; @@ -4389,12 +4389,33 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) break; case SQL_TIME: case SQL_TYPE_TIME: - case SQL_SS_TIME2: - assert(0 && "TODO"); + case SQL_SS_TIME2: { + // TODO wrong ctype for SQL_SS_TIME2 + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][i]; + buffersArrow.time_second[col - 1][i] = + static_cast(timeValue.hour) * 3600 + + static_cast(timeValue.minute) * 60 + + static_cast(timeValue.second); break; - case SQL_BIT: - assert(0 && "TODO"); + } + case SQL_BIT: { + // SQL_BIT is stored as a single bit in Arrow's bitmap format + // Get the boolean value from the buffer + bool bitValue = buffers.charBuffers[col - 1][i] != 0; + + // Set the bit in the Arrow bitmap + size_t byteIndex = i / 8; + size_t bitIndex = i % 8; + + if (bitValue) { + // Set bit to 1 + buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); + } else { + // Clear bit to 0 + buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); + } break; + } default: { std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; @@ -4487,7 +4508,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: - arrow_array_col_buffers[1] = buffersArrow.time_nano[col].release(); + arrow_array_col_buffers[1] = buffersArrow.time_second[col].release(); break; case SQL_BIT: arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); From e86ce43e0e6736a399e6114487f43cbb5f24334b Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 16 Nov 2025 00:20:39 +0100 Subject: [PATCH 20/36] fix string length issues --- mssql_python/pybind/ddbc_bindings.cpp | 31 ++++++++++++++++++++------- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 454e8cc5a..3c53ecf40 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4237,6 +4237,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) for (SQLUSMALLINT col = 1; col <= numCols; col++) { auto columnMeta = columnNames[col - 1].cast(); SQLSMALLINT dataType = columnMeta["DataType"].cast(); + SQLULEN columnSize = columnMeta["ColumnSize"].cast(); SQLLEN dataLen = buffers.indicators[col - 1][i]; // TODO: variable length data needs special handling, this logic wont suffice @@ -4259,16 +4260,29 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) switch (dataType) { case SQL_BINARY: case SQL_VARBINARY: - case SQL_LONGVARBINARY: + case SQL_LONGVARBINARY: { + uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][i]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][i * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][i + 1] = start + dataLen; + break; + } case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { + uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; auto target_vec = &buffersArrow.var_data[col - 1]; auto start = buffersArrow.var[col - 1][i]; while (target_vec->size() < start + dataLen) { target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][i], dataLen); + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][i * fetchBufferSize], dataLen); buffersArrow.var[col - 1][i + 1] = start + dataLen; break; } @@ -4276,24 +4290,25 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - auto wcharSource = &buffers.wcharBuffers[col - 1][i]; + // uint64_t fetchBufferSize = (columnSize + 1) * sizeof(SQLWCHAR); // +1 for null terminator + assert(dataLen % sizeof(SQLWCHAR) == 0); + auto dataLenW = dataLen / sizeof(SQLWCHAR); + auto wcharSource = &buffers.wcharBuffers[col - 1][i * (columnSize + 1)]; auto start = buffersArrow.var[col - 1][i]; auto target_vec = &buffersArrow.var_data[col - 1]; #if defined(_WIN32) // Convert wide string - int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLen, NULL, 0, NULL, NULL); + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); while (target_vec->size() < start + dataLenConverted) { target_vec->resize(target_vec->size() * 2); } - WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLen, &(*target_vec)[start], dataLenConverted, NULL, NULL); + WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); buffersArrow.var[col - 1][i + 1] = start + dataLenConverted; #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 - std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLen)); + std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); buffersArrow.var[col - 1][i + 1] = start + utf8str.size(); - // debug print results - std::cout << "UTF-8 string: " << utf8str << " " << utf8str.size() << std::endl; #endif break; } From 99d4bb7a4bcbded7ec0f377edfa9a4e893643ed0 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Thu, 20 Nov 2025 21:24:10 +0100 Subject: [PATCH 21/36] Adapt LOG change --- mssql_python/pybind/ddbc_bindings.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 3c53ecf40..3997ad76e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4170,7 +4170,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) std::ostringstream errorString; errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() << ", Type - " << dataType << ", column ID - " << (i + 1); - LOG(errorString.str()); + LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); break; } @@ -4436,7 +4436,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) std::ostringstream errorString; errorString << "Unsupported data type for column - " << columnName.c_str() << ", Type - " << dataType << ", column ID - " << col; - LOG(errorString.str()); + LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); break; } @@ -4533,7 +4533,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) std::ostringstream errorString; errorString << "Unsupported data type for column - " << columnName.c_str() << ", Type - " << dataType << ", column ID - " << (col + 1); - LOG(errorString.str()); + LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); break; } From d641970e3409f9f4d12ad12cfb1cb5d1bad68d80 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 21 Nov 2025 14:46:34 +0100 Subject: [PATCH 22/36] fix var nulls --- mssql_python/pybind/ddbc_bindings.cpp | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 3997ad76e..fd4b34b55 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4249,6 +4249,24 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) size_t bytePos = i / 8; size_t bitPos = i % 8; buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + switch (dataType) + { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + buffersArrow.var[col - 1][i + 1] = buffersArrow.var[col - 1][i]; + break; + default: + break; + } continue; } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception From 46047f9492166231900a4e6bf77172901898af2d Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:15:23 +0100 Subject: [PATCH 23/36] Add numeric --- mssql_python/pybind/ddbc_bindings.cpp | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index fd4b34b55..e2f9aca6f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4240,7 +4240,6 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) SQLULEN columnSize = columnMeta["ColumnSize"].cast(); SQLLEN dataLen = buffers.indicators[col - 1][i]; - // TODO: variable length data needs special handling, this logic wont suffice // This value indicates that the driver cannot determine the length of the data if (dataLen == SQL_NO_TOTAL) { assert(false && "Is this actually possible?"); @@ -4249,6 +4248,9 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) size_t bytePos = i / 8; size_t bitPos = i % 8; buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + + // Value buffer for variable length data types needs to be set appropriately + // as it will be used by the next non null value switch (dataType) { case SQL_CHAR: @@ -4308,7 +4310,6 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - // uint64_t fetchBufferSize = (columnSize + 1) * sizeof(SQLWCHAR); // +1 for null terminator assert(dataLen % sizeof(SQLWCHAR) == 0); auto dataLenW = dataLen / sizeof(SQLWCHAR); auto wcharSource = &buffers.wcharBuffers[col - 1][i * (columnSize + 1)]; @@ -4378,7 +4379,20 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) break; case SQL_DECIMAL: case SQL_NUMERIC: { - assert(0 && "TODO"); + assert(dataLen <= MAX_DIGITS_IN_NUMERIC); + __int128_t decimalValue = 0; + auto start = i * MAX_DIGITS_IN_NUMERIC; + for (SQLULEN idx = start; idx < start + dataLen; idx++) { + char digitChar = buffers.charBuffers[col - 1][idx]; + if (digitChar == '-') { + decimalValue = -decimalValue; + } else if (digitChar >= '0' && digitChar <= '9') { + decimalValue = decimalValue * 10 + (digitChar - '0'); + } + std::cout << idx << ":" << digitChar << " "; + } + std::cout << std::endl; + buffersArrow.decimal[col - 1][i] = decimalValue; break; } case SQL_TIMESTAMP: From 087fcc384c7d15344db3589b3212f817c7d56297 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Fri, 21 Nov 2025 17:07:36 +0100 Subject: [PATCH 24/36] Separate arrowBatchSize and fetchSize --- mssql_python/pybind/ddbc_bindings.cpp | 517 +++++++++++++------------- 1 file changed, 260 insertions(+), 257 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index e2f9aca6f..ec3d4d898 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4053,7 +4053,8 @@ int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) } SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { - ssize_t fetchSize = 500; + ssize_t arrowBatchSize = 500; + ssize_t fetchSize = 1; SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -4096,8 +4097,8 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_WLONGVARCHAR: case SQL_GUID: format = strdup("u"); - buffersArrow.var[i] = std::make_unique(fetchSize); - buffersArrow.var_data[i].resize(fetchSize * 42); + buffersArrow.var[i] = std::make_unique(arrowBatchSize); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); // start at offset 0 buffersArrow.var_data[i][0] = 0; break; @@ -4105,32 +4106,32 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_VARBINARY: case SQL_LONGVARBINARY: format = strdup("z"); - buffersArrow.var[i] = std::make_unique(fetchSize); - buffersArrow.var_data[i].resize(fetchSize * 42); + buffersArrow.var[i] = std::make_unique(arrowBatchSize); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); // start at offset 0 buffersArrow.var_data[i][0] = 0; break; case SQL_TINYINT: format = strdup("C"); - buffersArrow.uint8[i] = std::make_unique(fetchSize); + buffersArrow.uint8[i] = std::make_unique(arrowBatchSize); break; case SQL_SMALLINT: format = strdup("s"); - buffersArrow.int16[i] = std::make_unique(fetchSize); + buffersArrow.int16[i] = std::make_unique(arrowBatchSize); break; case SQL_INTEGER: format = strdup("i"); - buffersArrow.int32[i] = std::make_unique(fetchSize); + buffersArrow.int32[i] = std::make_unique(arrowBatchSize); break; case SQL_BIGINT: format = strdup("l"); - buffersArrow.int64[i] = std::make_unique(fetchSize); + buffersArrow.int64[i] = std::make_unique(arrowBatchSize); break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: format = strdup("g"); - buffersArrow.float64[i] = std::make_unique(fetchSize); + buffersArrow.float64[i] = std::make_unique(arrowBatchSize); break; case SQL_DECIMAL: case SQL_NUMERIC: { @@ -4138,32 +4139,32 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); std::string formatStr = formatStream.str(); format = strdup(formatStr.c_str()); - buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(fetchSize); + buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(arrowBatchSize); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: format = strdup("tsu:"); - buffersArrow.ts_micro[i] = std::make_unique(fetchSize); + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); break; case SQL_SS_TIMESTAMPOFFSET: format = strdup("tsu:+00:00"); - buffersArrow.ts_micro[i] = std::make_unique(fetchSize); + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); break; case SQL_TYPE_DATE: format = strdup("tdD"); - buffersArrow.date[i] = std::make_unique(fetchSize); + buffersArrow.date[i] = std::make_unique(arrowBatchSize); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: format = strdup("tts"); - buffersArrow.time_second[i] = std::make_unique(fetchSize); + buffersArrow.time_second[i] = std::make_unique(arrowBatchSize); break; case SQL_BIT: format = strdup("b"); - buffersArrow.bit[i] = std::make_unique((fetchSize + 7) / 8); + buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); break; default: std::wstring columnName = colMeta["ColumnName"].cast(); @@ -4182,9 +4183,9 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) }); batch_children[i] = arrow_schema; - buffersArrow.valid[i] = std::make_unique((fetchSize + 7) / 8); + buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); // Initialize validity bitmap to all valid - std::memset(buffersArrow.valid[i].get(), 0xFF, (fetchSize + 7) / 8); + std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); } assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); @@ -4220,259 +4221,261 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); + size_t idxRowArrow = 0; + assert(arrowBatchSize % fetchSize == 0); + while (idxRowArrow < arrowBatchSize) { + ret = SQLFetch_ptr(hStmt); + if (ret == SQL_NO_DATA) { + ret = SQL_SUCCESS; // Normal completion + break; + } + if (!SQL_SUCCEEDED(ret)) { + LOG("Error while fetching rows in batches"); + return ret; + } + // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. + // It'll be populated by SQLFetch + assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); + for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { + for (SQLUSMALLINT col = 1; col <= numCols; col++) { + auto columnMeta = columnNames[col - 1].cast(); + SQLSMALLINT dataType = columnMeta["DataType"].cast(); + SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; + + // This value indicates that the driver cannot determine the length of the data + if (dataLen == SQL_NO_TOTAL) { + assert(false && "Is this actually possible?"); + } else if (dataLen == SQL_NULL_DATA) { + // Mark as null in validity bitmap + size_t bytePos = idxRowArrow / 8; + size_t bitPos = idxRowArrow % 8; + buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + + // Value buffer for variable length data types needs to be set appropriately + // as it will be used by the next non null value + switch (dataType) + { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + buffersArrow.var[col - 1][idxRowArrow + 1] = buffersArrow.var[col - 1][idxRowArrow]; + break; + default: + break; + } + continue; + } else if (dataLen < 0) { + // Negative value is unexpected, log column index, SQL type & raise exception + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); + ThrowStdException("Unexpected negative data length, check logs for details"); + } + assert(dataLen >= 0 && "Data length must be >= 0"); - - ret = SQLFetchScroll_ptr(hStmt, SQL_FETCH_NEXT, 0); - if (ret == SQL_NO_DATA) { - LOG("No data to fetch"); - return ret; - } - if (!SQL_SUCCEEDED(ret)) { - LOG("Error while fetching rows in batches"); - return ret; - } - // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. It'll be populated by - // SQLFetchScroll - for (SQLULEN i = 0; i < numRowsFetched; i++) { - for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto columnMeta = columnNames[col - 1].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); - SQLLEN dataLen = buffers.indicators[col - 1][i]; + switch (dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } - // This value indicates that the driver cannot determine the length of the data - if (dataLen == SQL_NO_TOTAL) { - assert(false && "Is this actually possible?"); - } else if (dataLen == SQL_NULL_DATA) { - // Mark as null in validity bitmap - size_t bytePos = i / 8; - size_t bitPos = i % 8; - buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); - - // Value buffer for variable length data types needs to be set appropriately - // as it will be used by the next non null value - switch (dataType) - { + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } case SQL_CHAR: case SQL_VARCHAR: - case SQL_LONGVARCHAR: + case SQL_LONGVARCHAR: { + uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } case SQL_SS_XML: case SQL_WCHAR: case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: - case SQL_GUID: - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: - buffersArrow.var[col - 1][i + 1] = buffersArrow.var[col - 1][i]; - break; - default: + case SQL_WLONGVARCHAR: { + assert(dataLen % sizeof(SQLWCHAR) == 0); + auto dataLenW = dataLen / sizeof(SQLWCHAR); + auto wcharSource = &buffers.wcharBuffers[col - 1][idxRowSql * (columnSize + 1)]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &buffersArrow.var_data[col - 1]; + #if defined(_WIN32) + // Convert wide string + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); + while (target_vec->size() < start + dataLenConverted) { + target_vec->resize(target_vec->size() * 2); + } + WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); + buffersArrow.var[col - 1][i + 1] = start + dataLenConverted; + #else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); + std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); + #endif break; - } - continue; - } else if (dataLen < 0) { - // Negative value is unexpected, log column index, SQL type & raise exception - LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); - ThrowStdException("Unexpected negative data length, check logs for details"); - } - assert(dataLen >= 0 && "Data length must be >= 0"); - - switch (dataType) { - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: { - uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][i]; - while (target_vec->size() < start + dataLen) { - target_vec->resize(target_vec->size() * 2); } + case SQL_GUID: { + // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") + // Each GUID is exactly 36 bytes in UTF-8 + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + + // Ensure buffer has space for the GUID string + null terminator + while (target_vec->size() < start + 37) { + target_vec->resize(target_vec->size() * 2); + } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][i * fetchBufferSize], dataLen); - buffersArrow.var[col - 1][i + 1] = start + dataLen; - break; - } - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: { - uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][i]; - while (target_vec->size() < start + dataLen) { - target_vec->resize(target_vec->size() * 2); + // Get the GUID from the buffer + const SQLGUID& guidValue = buffers.guidBuffers[col - 1][idxRowSql]; + + // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + snprintf(reinterpret_cast(&target_vec->data()[start]), 37, + "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + guidValue.Data1, + guidValue.Data2, + guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], + guidValue.Data4[2], guidValue.Data4[3], + guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); + + // Update offset for next row, ignoring null terminator + buffersArrow.var[col - 1][idxRowArrow + 1] = start + 36; + break; } - - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][i * fetchBufferSize], dataLen); - buffersArrow.var[col - 1][i + 1] = start + dataLen; - break; - } - case SQL_SS_XML: - case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: { - assert(dataLen % sizeof(SQLWCHAR) == 0); - auto dataLenW = dataLen / sizeof(SQLWCHAR); - auto wcharSource = &buffers.wcharBuffers[col - 1][i * (columnSize + 1)]; - auto start = buffersArrow.var[col - 1][i]; - auto target_vec = &buffersArrow.var_data[col - 1]; -#if defined(_WIN32) - // Convert wide string - int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); - while (target_vec->size() < start + dataLenConverted) { - target_vec->resize(target_vec->size() * 2); + case SQL_TINYINT: + buffersArrow.uint8[col - 1][idxRowArrow] = buffers.charBuffers[col - 1][idxRowSql]; + break; + case SQL_SMALLINT: + buffersArrow.int16[col - 1][idxRowArrow] = buffers.smallIntBuffers[col - 1][idxRowSql]; + break; + case SQL_INTEGER: + buffersArrow.int32[col - 1][idxRowArrow] = buffers.intBuffers[col - 1][idxRowSql]; + break; + case SQL_BIGINT: + buffersArrow.int64[col - 1][idxRowArrow] = buffers.bigIntBuffers[col - 1][idxRowSql]; + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + buffersArrow.float64[col - 1][idxRowArrow] = buffers.doubleBuffers[col - 1][idxRowSql]; + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + assert(dataLen <= MAX_DIGITS_IN_NUMERIC); + __int128_t decimalValue = 0; + auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; + for (SQLULEN idx = start; idx < start + dataLen; idx++) { + char digitChar = buffers.charBuffers[col - 1][idx]; + if (digitChar == '-') { + decimalValue = -decimalValue; + } else if (digitChar >= '0' && digitChar <= '9') { + decimalValue = decimalValue * 10 + (digitChar - '0'); + } + } + buffersArrow.decimal[col - 1][idxRowArrow] = decimalValue; + break; } - WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); - buffersArrow.var[col - 1][i + 1] = start + dataLenConverted; -#else - // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 - std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); - std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); - buffersArrow.var[col - 1][i + 1] = start + utf8str.size(); -#endif - break; - } - case SQL_GUID: { - // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") - // Each GUID is exactly 36 bytes in UTF-8 - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][i]; - - // Ensure buffer has space for the GUID string + null terminator - while (target_vec->size() < start + 37) { - target_vec->resize(target_vec->size() * 2); + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + static_cast(sql_value.hour) * 3600 * 1000000 + + static_cast(sql_value.minute) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; } - - // Get the GUID from the buffer - const SQLGUID& guidValue = buffers.guidBuffers[col - 1][i]; - - // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx - snprintf(reinterpret_cast(&target_vec->data()[start]), 37, - "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", - guidValue.Data1, - guidValue.Data2, - guidValue.Data3, - guidValue.Data4[0], guidValue.Data4[1], - guidValue.Data4[2], guidValue.Data4[3], - guidValue.Data4[4], guidValue.Data4[5], - guidValue.Data4[6], guidValue.Data4[7]); - - // Update offset for next row, ignoring null terminator - buffersArrow.var[col - 1][i + 1] = start + 36; - break; - } - case SQL_TINYINT: - buffersArrow.uint8[col - 1][i] = buffers.charBuffers[col - 1][i]; - break; - case SQL_SMALLINT: - buffersArrow.int16[col - 1][i] = buffers.smallIntBuffers[col - 1][i]; - break; - case SQL_INTEGER: - buffersArrow.int32[col - 1][i] = buffers.intBuffers[col - 1][i]; - break; - case SQL_BIGINT: - buffersArrow.int64[col - 1][i] = buffers.bigIntBuffers[col - 1][i]; - break; - case SQL_REAL: - case SQL_FLOAT: - case SQL_DOUBLE: - buffersArrow.float64[col - 1][i] = buffers.doubleBuffers[col - 1][i]; - break; - case SQL_DECIMAL: - case SQL_NUMERIC: { - assert(dataLen <= MAX_DIGITS_IN_NUMERIC); - __int128_t decimalValue = 0; - auto start = i * MAX_DIGITS_IN_NUMERIC; - for (SQLULEN idx = start; idx < start + dataLen; idx++) { - char digitChar = buffers.charBuffers[col - 1][idx]; - if (digitChar == '-') { - decimalValue = -decimalValue; - } else if (digitChar >= '0' && digitChar <= '9') { - decimalValue = decimalValue * 10 + (digitChar - '0'); + case SQL_SS_TIMESTAMPOFFSET: { + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_TYPE_DATE: + buffersArrow.date[col - 1][idxRowArrow] = dateAsDayCount( + buffers.dateBuffers[col - 1][idxRowSql].year, + buffers.dateBuffers[col - 1][idxRowSql].month, + buffers.dateBuffers[col - 1][idxRowSql].day + ); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + // TODO wrong ctype for SQL_SS_TIME2 + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][idxRowSql]; + buffersArrow.time_second[col - 1][idxRowArrow] = + static_cast(timeValue.hour) * 3600 + + static_cast(timeValue.minute) * 60 + + static_cast(timeValue.second); + break; + } + case SQL_BIT: { + // SQL_BIT is stored as a single bit in Arrow's bitmap format + // Get the boolean value from the buffer + bool bitValue = buffers.charBuffers[col - 1][idxRowSql] != 0; + + // Set the bit in the Arrow bitmap + size_t byteIndex = idxRowArrow / 8; + size_t bitIndex = idxRowArrow % 8; + + if (bitValue) { + // Set bit to 1 + buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); + } else { + // Clear bit to 0 + buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); } - std::cout << idx << ":" << digitChar << " "; + break; } - std::cout << std::endl; - buffersArrow.decimal[col - 1][i] = decimalValue; - break; - } - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: { - SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][i]; - int64_t days = dateAsDayCount( - sql_value.year, - sql_value.month, - sql_value.day - ); - buffersArrow.ts_micro[col - 1][i] = - days * 86400 * 1000000 + - static_cast(sql_value.hour) * 3600 * 1000000 + - static_cast(sql_value.minute) * 60 * 1000000 + - static_cast(sql_value.second) * 1000000 + - static_cast(sql_value.fraction) / 1000; - break; - } - case SQL_SS_TIMESTAMPOFFSET: { - DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][i]; - int64_t days = dateAsDayCount( - sql_value.year, - sql_value.month, - sql_value.day - ); - buffersArrow.ts_micro[col - 1][i] = - days * 86400 * 1000000 + - (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + - (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + - static_cast(sql_value.second) * 1000000 + - static_cast(sql_value.fraction) / 1000; - break; - } - case SQL_TYPE_DATE: - buffersArrow.date[col - 1][i] = dateAsDayCount( - buffers.dateBuffers[col - 1][i].year, - buffers.dateBuffers[col - 1][i].month, - buffers.dateBuffers[col - 1][i].day - ); - break; - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: { - // TODO wrong ctype for SQL_SS_TIME2 - const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][i]; - buffersArrow.time_second[col - 1][i] = - static_cast(timeValue.hour) * 3600 + - static_cast(timeValue.minute) * 60 + - static_cast(timeValue.second); - break; - } - case SQL_BIT: { - // SQL_BIT is stored as a single bit in Arrow's bitmap format - // Get the boolean value from the buffer - bool bitValue = buffers.charBuffers[col - 1][i] != 0; - - // Set the bit in the Arrow bitmap - size_t byteIndex = i / 8; - size_t bitIndex = i % 8; - - if (bitValue) { - // Set bit to 1 - buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); - } else { - // Clear bit to 0 - buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); + default: { + std::wstring columnName = columnMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << col; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; } - break; - } - default: { - std::wstring columnName = columnMeta["ColumnName"].cast(); - std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; - LOG(errorString.str().c_str()); - ThrowStdException(errorString.str()); - break; } } + idxRowArrow++; } } @@ -4481,7 +4484,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) auto arrow_array_batch_buffers = new const void* [3]; memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); auto arrow_array_batch = new ArrowArray({ - .length = static_cast(numRowsFetched), + .length = static_cast(idxRowArrow), .n_buffers = 1, .n_children = numCols, .buffers = arrow_array_batch_buffers, @@ -4511,8 +4514,8 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) case SQL_LONGVARBINARY: { assert(buffersArrow.var[col][0] == 0); // length of string at index i is the difference between values at i and i+1 - // so total length is value at index numRowsFetched - auto data_buf_len_total = buffersArrow.var[col][numRowsFetched]; + // so total length is value at index idxRowArrow + auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; uint8_t* dataBuffer = new uint8_t[data_buf_len_total]; std::memcpy(dataBuffer, buffersArrow.var_data[col].data(), data_buf_len_total); arrow_array_col_buffers[2] = dataBuffer; @@ -4572,7 +4575,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) } auto arrow_array_col = new ArrowArray({ - .length = static_cast(numRowsFetched), + .length = static_cast(idxRowArrow), .null_count = 0, .offset = 0, .n_buffers = arrow_array_col_buffers[2] ? 3 : 2, From 506bfa7e9840889dc6d11def3dd5954d300b83ef Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:30:40 +0100 Subject: [PATCH 25/36] Add Lob support to arrow fetch --- mssql_python/pybind/ddbc_bindings.cpp | 313 +++++++++++++++++++++++++- 1 file changed, 303 insertions(+), 10 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index ec3d4d898..6413b9b2e 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -3993,6 +3993,108 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } +// GetDataVar - Progressively fetches variable-length column data using SQLGetData. +// +// Calls SQLGetData repeatedly, reallocating the buffer as needed, until all data is retrieved. +// Handles both fixed-size and unknown-size (SQL_NO_TOTAL) responses from the driver. +// +// @param hStmt: Statement handle +// @param colNumber: 1-based column index +// @param cType: SQL C data type (SQL_C_CHAR, SQL_C_WCHAR, or SQL_C_BINARY) +// @param dataVec: Reference to vector that will hold the fetched data (will be resized as needed) +// @param indicator: Pointer to indicator value (SQL_NULL_DATA for NULL, or data length) +// +// @return SQLRETURN: SQL_SUCCESS on success, or error code on failure +template +SQLRETURN GetDataVar(SQLHSTMT hStmt, + SQLUSMALLINT colNumber, + SQLSMALLINT cType, + std::vector& dataVec, + SQLLEN* indicator) { + if (!SQLGetData_ptr) { + ThrowStdException("SQLGetData function not loaded"); + } + + size_t start = 0; + size_t end = 0; + + // Determine null terminator size based on data type + size_t sizeNullTerminator = 0; + switch (cType) { + case SQL_C_WCHAR: + case SQL_C_CHAR: + sizeNullTerminator = 1; + break; + case SQL_C_BINARY: + sizeNullTerminator = 0; + break; + default: + ThrowStdException("GetDataVar only supports SQL_C_CHAR, SQL_C_WCHAR, and SQL_C_BINARY"); + } + + // Ensure initial buffer has space for at least the null terminator + if (dataVec.size() < sizeNullTerminator) { + dataVec.resize(sizeNullTerminator); + } + + while (true) { + SQLLEN localInd = 0; + SQLRETURN ret = SQLGetData_ptr( + hStmt, + colNumber, + cType, + reinterpret_cast(dataVec.data() + start), + sizeof(T) * (dataVec.size() - start), // Available buffer size from start position + &localInd + ); + + // Handle NULL data + if (localInd == SQL_NULL_DATA) { + *indicator = SQL_NULL_DATA; + return SQL_SUCCESS; + } + + // Check for errors (excluding SQL_SUCCESS_WITH_INFO which means more data available) + if (ret == SQL_ERROR || ret == SQL_INVALID_HANDLE) { + return ret; + } + + // SQL_SUCCESS or SQL_NO_DATA means we got all the data + if (ret == SQL_SUCCESS || ret == SQL_NO_DATA) { + if (localInd >= 0) { + *indicator = static_cast(start) * sizeof(T) + localInd; + } else { + *indicator = localInd; // Preserve SQL_NO_TOTAL or other negative values + } + break; + } + + // SQL_SUCCESS_WITH_INFO means buffer was too small, need to continue fetching + if (ret == SQL_SUCCESS_WITH_INFO) { + // Determine how much more space we need + if (localInd < 0) { + // SQL_NO_TOTAL: driver doesn't know total size, double the buffer + end = dataVec.size() * 2; + } else { + // Driver returned total size: allocate exactly what we need + assert(localInd % sizeof(T) == 0); + end = start + static_cast(localInd) / sizeof(T) + sizeNullTerminator; + } + + // The next read starts where the null terminator would have been placed + start = dataVec.size() - sizeNullTerminator; + + // Resize buffer for next iteration + dataVec.resize(end); + } else { + // Unexpected return code + return ret; + } + } + + return SQL_SUCCESS; +} + void ArrowSchema_release(struct ArrowSchema* schema) { assert (schema != nullptr); assert (schema->release != nullptr); @@ -4054,7 +4156,7 @@ int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { ssize_t arrowBatchSize = 500; - ssize_t fetchSize = 1; + ssize_t fetchSize = arrowBatchSize; SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count @@ -4069,7 +4171,7 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) } auto batch_children = new ArrowSchema* [numCols]; - std::vector lobColumns; + bool hasLobColumns = false; ColumnBuffersArrow buffersArrow(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { @@ -4081,7 +4183,8 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { - lobColumns.push_back(i + 1); // 1-based + hasLobColumns = true; + fetchSize = 1; // LOBs require row-by-row fetch } std::string columnName = colMeta["ColumnName"].cast(); @@ -4188,8 +4291,6 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); } - assert(lobColumns.empty() && "Arrow batch fetch does not support LOB columns yet"); - auto arrow_schema_batch = new ArrowSchema({ .format = strdup("+s"), .name = strdup(""), @@ -4209,11 +4310,13 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) // Initialize column buffers ColumnBuffers buffers(numCols, fetchSize); - // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); - if (!SQL_SUCCEEDED(ret)) { - LOG("Error when binding columns"); - return ret; + if (!hasLobColumns) { + // Bind columns + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error when binding columns"); + return ret; + } } SQLULEN numRowsFetched; @@ -4241,6 +4344,196 @@ SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) auto columnMeta = columnNames[col - 1].cast(); SQLSMALLINT dataType = columnMeta["DataType"].cast(); SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + + if (hasLobColumns) { + assert(idxRowSql == 0 && "GetData only works one row at a time"); + + switch(dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + GetDataVar( + hStmt, + col, + SQL_C_BINARY, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_CHAR, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_WCHAR, + buffers.wcharBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_INTEGER: { + buffers.intBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SLONG, + buffers.intBuffers[col - 1].data(), + sizeof(SQLINTEGER), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SMALLINT: { + buffers.smallIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SSHORT, + buffers.smallIntBuffers[col - 1].data(), + sizeof(SQLSMALLINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TINYINT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TINYINT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_BIT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_REAL: { + buffers.realBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_FLOAT, + buffers.realBuffers[col - 1].data(), + sizeof(SQLREAL), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DECIMAL: + case SQL_NUMERIC: { + buffers.charBuffers[col - 1].resize(MAX_DIGITS_IN_NUMERIC); + SQLGetData_ptr( + hStmt, col, SQL_C_CHAR, + buffers.charBuffers[col - 1].data(), + MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DOUBLE: + case SQL_FLOAT: { + buffers.doubleBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_DOUBLE, + buffers.doubleBuffers[col - 1].data(), + sizeof(SQLDOUBLE), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + buffers.timestampBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[col - 1].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIGINT: { + buffers.bigIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SBIGINT, + buffers.bigIntBuffers[col - 1].data(), + sizeof(SQLBIGINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TYPE_DATE: { + buffers.dateBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_DATE, + buffers.dateBuffers[col - 1].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + buffers.timeBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIME, + buffers.timeBuffers[col - 1].data(), + sizeof(SQL_TIME_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_GUID: { + buffers.guidBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_GUID, + buffers.guidBuffers[col - 1].data(), + sizeof(SQLGUID), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + buffers.datetimeoffsetBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[col - 1].data(), + sizeof(DateTimeOffset), + buffers.indicators[col - 1].data() + ); + break; + } + default: { + std::wstring columnName = columnMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << col; + LOG("SQLGetData: %s", errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; // This value indicates that the driver cannot determine the length of the data From 884763047bbecaac6ecd858d34753a963f30009b Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:35:29 +0100 Subject: [PATCH 26/36] Parameterize batch length --- mssql_python/cursor.py | 4 ++-- mssql_python/pybind/ddbc_bindings.cpp | 7 +++++-- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 0bf8f8a3b..c9c6aec86 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2195,7 +2195,7 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e - def fetch_arrow_batch(self) -> Any: + def fetch_arrow_batch(self, batch_length: int) -> Any: self._check_closed() # Check if the cursor is closed if not self._has_result_set and self.description: self._reset_rownumber() @@ -2207,7 +2207,7 @@ def fetch_arrow_batch(self) -> Any: "pyarrow is required for fetch_arrow_batch(). Please install pyarrow." ) from e capsules = [] - ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules) + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, batch_length) assert ret in (0, 1), ret schema_capsule = capsules[0] array_capsule = capsules[1] diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 6413b9b2e..d63ae97ff 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4154,8 +4154,11 @@ int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) return time_since_epoch / 86400; } -SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules) { - ssize_t arrowBatchSize = 500; +SQLRETURN FetchArrowBatch_wrap( + SqlHandlePtr StatementHandle, + py::list& capsules, + ssize_t arrowBatchSize +) { ssize_t fetchSize = arrowBatchSize; SQLRETURN ret; SQLHSTMT hStmt = StatementHandle->get(); From 1e89057f6c1bc4b4d4765d24b6684f1c3c0b3674 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 22 Nov 2025 17:39:58 +0100 Subject: [PATCH 27/36] Some fixes around length 0 behavior --- mssql_python/pybind/ddbc_bindings.cpp | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index d63ae97ff..426bd20fd 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4164,6 +4164,9 @@ SQLRETURN FetchArrowBatch_wrap( SQLHSTMT hStmt = StatementHandle->get(); // Retrieve column count SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); + if (numCols <= 0) { + ThrowStdException("No active result set. Cannot fetch Arrow batch."); + } // Retrieve column metadata py::list columnNames; @@ -4203,19 +4206,19 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: case SQL_GUID: format = strdup("u"); - buffersArrow.var[i] = std::make_unique(arrowBatchSize); + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); buffersArrow.var_data[i].resize(arrowBatchSize * 42); // start at offset 0 - buffersArrow.var_data[i][0] = 0; + buffersArrow.var[i][0] = 0; break; case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: format = strdup("z"); - buffersArrow.var[i] = std::make_unique(arrowBatchSize); + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); buffersArrow.var_data[i].resize(arrowBatchSize * 42); // start at offset 0 - buffersArrow.var_data[i][0] = 0; + buffersArrow.var[i][0] = 0; break; case SQL_TINYINT: format = strdup("C"); @@ -4321,14 +4324,14 @@ SQLRETURN FetchArrowBatch_wrap( return ret; } } - + SQLULEN numRowsFetched; SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); size_t idxRowArrow = 0; - assert(arrowBatchSize % fetchSize == 0); + assert(fetchSize == 0 || arrowBatchSize % fetchSize == 0); while (idxRowArrow < arrowBatchSize) { ret = SQLFetch_ptr(hStmt); if (ret == SQL_NO_DATA) { From 5acffd9af5054ca73741c5f996dd3fb1fa1efaa0 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 23 Nov 2025 00:11:22 +0100 Subject: [PATCH 28/36] Use vector instead of py dict for hot loop, tweak fetchSize --- mssql_python/pybind/ddbc_bindings.cpp | 46 +++++++++++++++++++-------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 426bd20fd..f9afbb1d3 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4179,11 +4179,16 @@ SQLRETURN FetchArrowBatch_wrap( auto batch_children = new ArrowSchema* [numCols]; bool hasLobColumns = false; + std::vector dataTypes(numCols); + std::vector columnSizes(numCols); + ColumnBuffersArrow buffersArrow(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { auto colMeta = columnNames[i].cast(); SQLSMALLINT dataType = colMeta["DataType"].cast(); SQLULEN columnSize = colMeta["ColumnSize"].cast(); + dataTypes[i] = dataType; + columnSizes[i] = columnSize; if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || @@ -4313,6 +4318,20 @@ SQLRETURN FetchArrowBatch_wrap( }); capsules.append(caps); + if (fetchSize > 1) { + // An overly large fetch size doesn't seem to help performance + SQLSMALLINT searchStart = 64; + if (arrowBatchSize < 64) { + searchStart = static_cast(arrowBatchSize); + } + for (SQLSMALLINT maybeNewSize = searchStart; maybeNewSize >= 1; maybeNewSize -= 1) { + if (arrowBatchSize % maybeNewSize == 0) { + fetchSize = maybeNewSize; + break; + } + } + } + // Initialize column buffers ColumnBuffers buffers(numCols, fetchSize); @@ -4331,7 +4350,11 @@ SQLRETURN FetchArrowBatch_wrap( size_t idxRowArrow = 0; + // arrowBatchSize % fetchSize == 0 ensures that any followup (even non-arrow) fetches + // start with a fresh batch assert(fetchSize == 0 || arrowBatchSize % fetchSize == 0); + assert(fetchSize <= arrowBatchSize); + while (idxRowArrow < arrowBatchSize) { ret = SQLFetch_ptr(hStmt); if (ret == SQL_NO_DATA) { @@ -4347,9 +4370,8 @@ SQLRETURN FetchArrowBatch_wrap( assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto columnMeta = columnNames[col - 1].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); - SQLULEN columnSize = columnMeta["ColumnSize"].cast(); + auto dataType = dataTypes[col - 1]; + auto columnSize = columnSizes[col - 1]; if (hasLobColumns) { assert(idxRowSql == 0 && "GetData only works one row at a time"); @@ -4529,10 +4551,9 @@ SQLRETURN FetchArrowBatch_wrap( break; } default: { - std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; LOG("SQLGetData: %s", errorString.str().c_str()); ThrowStdException(errorString.str()); break; @@ -4764,10 +4785,9 @@ SQLRETURN FetchArrowBatch_wrap( break; } default: { - std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << col; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); break; @@ -4794,8 +4814,7 @@ SQLRETURN FetchArrowBatch_wrap( arrow_array_batch->buffers[1] = new int[1]; for (SQLUSMALLINT col = 0; col < numCols; col++) { - auto columnMeta = columnNames[col].cast(); - SQLSMALLINT dataType = columnMeta["DataType"].cast(); + auto dataType = dataTypes[col]; auto arrow_array_col_buffers = new const void* [3]; memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); // Allocate new memory and copy the data @@ -4863,10 +4882,9 @@ SQLRETURN FetchArrowBatch_wrap( arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); break; default: { - std::wstring columnName = columnMeta["ColumnName"].cast(); std::ostringstream errorString; - errorString << "Unsupported data type for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << (col + 1); + errorString << "Unsupported data type for column ID - " << (col + 1) + << ", Type - " << dataType; LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); break; From dbfc26af2e68deb349bc79d8e59c2b8058f7b916 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 23 Nov 2025 20:00:07 +0100 Subject: [PATCH 29/36] tweak fetchSize calculation --- mssql_python/pybind/ddbc_bindings.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f9afbb1d3..5a872a66b 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4195,7 +4195,9 @@ SQLRETURN FetchArrowBatch_wrap( dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { hasLobColumns = true; - fetchSize = 1; // LOBs require row-by-row fetch + if (fetchSize > 1) { + fetchSize = 1; // LOBs require row-by-row fetch + } } std::string columnName = colMeta["ColumnName"].cast(); @@ -4335,7 +4337,7 @@ SQLRETURN FetchArrowBatch_wrap( // Initialize column buffers ColumnBuffers buffers(numCols, fetchSize); - if (!hasLobColumns) { + if (!hasLobColumns && fetchSize > 0) { // Bind columns ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); if (!SQL_SUCCEEDED(ret)) { From 0e9f3eec169d9f3717c194678b5fb586b7b03347 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 23 Nov 2025 21:58:03 +0100 Subject: [PATCH 30/36] Transfer ownership to arrow at the end of the function --- mssql_python/pybind/ddbc_bindings.cpp | 108 +++++++++++++++----------- 1 file changed, 62 insertions(+), 46 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 5a872a66b..a4e829b70 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4176,11 +4176,12 @@ SQLRETURN FetchArrowBatch_wrap( return ret; } - auto batch_children = new ArrowSchema* [numCols]; bool hasLobColumns = false; std::vector dataTypes(numCols); std::vector columnSizes(numCols); + std::vector> columnFormats(numCols); + std::vector> columnNamesCStr(numCols); ColumnBuffersArrow buffersArrow(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { @@ -4201,8 +4202,11 @@ SQLRETURN FetchArrowBatch_wrap( } std::string columnName = colMeta["ColumnName"].cast(); + size_t nameLen = columnName.length() + 1; + columnNamesCStr[i] = std::make_unique(nameLen); + std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); - char* format = nullptr; + const char* format = nullptr; switch(dataType) { case SQL_CHAR: case SQL_VARCHAR: @@ -4212,7 +4216,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WVARCHAR: case SQL_WLONGVARCHAR: case SQL_GUID: - format = strdup("u"); + format = "u"; buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); buffersArrow.var_data[i].resize(arrowBatchSize * 42); // start at offset 0 @@ -4221,32 +4225,32 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: - format = strdup("z"); + format = "z"; buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); buffersArrow.var_data[i].resize(arrowBatchSize * 42); // start at offset 0 buffersArrow.var[i][0] = 0; break; case SQL_TINYINT: - format = strdup("C"); + format = "C"; buffersArrow.uint8[i] = std::make_unique(arrowBatchSize); break; case SQL_SMALLINT: - format = strdup("s"); + format = "s"; buffersArrow.int16[i] = std::make_unique(arrowBatchSize); break; case SQL_INTEGER: - format = strdup("i"); + format = "i"; buffersArrow.int32[i] = std::make_unique(arrowBatchSize); break; case SQL_BIGINT: - format = strdup("l"); + format = "l"; buffersArrow.int64[i] = std::make_unique(arrowBatchSize); break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: - format = strdup("g"); + format = "g"; buffersArrow.float64[i] = std::make_unique(arrowBatchSize); break; case SQL_DECIMAL: @@ -4254,32 +4258,35 @@ SQLRETURN FetchArrowBatch_wrap( std::ostringstream formatStream; formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); std::string formatStr = formatStream.str(); - format = strdup(formatStr.c_str()); + size_t formatLen = formatStr.length() + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), formatStr.c_str(), formatLen); + format = columnFormats[i].get(); buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(arrowBatchSize); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: - format = strdup("tsu:"); + format = "tsu:"; buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); break; case SQL_SS_TIMESTAMPOFFSET: - format = strdup("tsu:+00:00"); + format = "tsu:+00:00"; buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); break; case SQL_TYPE_DATE: - format = strdup("tdD"); + format = "tdD"; buffersArrow.date[i] = std::make_unique(arrowBatchSize); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: - format = strdup("tts"); + format = "tts"; buffersArrow.time_second[i] = std::make_unique(arrowBatchSize); break; case SQL_BIT: - format = strdup("b"); + format = "b"; buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); break; default: @@ -4292,34 +4299,18 @@ SQLRETURN FetchArrowBatch_wrap( break; } - auto arrow_schema = new ArrowSchema({ - .format = format, - .name = strdup(columnName.c_str()), - .release = ArrowSchema_release, - }); - batch_children[i] = arrow_schema; + // Store format string if not already stored (for non-decimal types) + if (!columnFormats[i]) { + size_t formatLen = std::strlen(format) + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), format, formatLen); + } buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); // Initialize validity bitmap to all valid std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); } - auto arrow_schema_batch = new ArrowSchema({ - .format = strdup("+s"), - .name = strdup(""), - .n_children = numCols, - .children = batch_children, - .release = ArrowSchema_release, - }); - auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { - auto arrow_schema = static_cast(ptr); - if (arrow_schema->release) { - arrow_schema->release(arrow_schema); - } - delete arrow_schema; - }); - capsules.append(caps); - if (fetchSize > 1) { // An overly large fetch size doesn't seem to help performance SQLSMALLINT searchStart = 64; @@ -4640,7 +4631,7 @@ SQLRETURN FetchArrowBatch_wrap( auto wcharSource = &buffers.wcharBuffers[col - 1][idxRowSql * (columnSize + 1)]; auto start = buffersArrow.var[col - 1][idxRowArrow]; auto target_vec = &buffersArrow.var_data[col - 1]; - #if defined(_WIN32) +#if defined(_WIN32) // Convert wide string int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); while (target_vec->size() < start + dataLenConverted) { @@ -4648,12 +4639,12 @@ SQLRETURN FetchArrowBatch_wrap( } WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); buffersArrow.var[col - 1][i + 1] = start + dataLenConverted; - #else +#else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); - #endif +#endif break; } case SQL_GUID: { @@ -4800,7 +4791,37 @@ SQLRETURN FetchArrowBatch_wrap( } } + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + + // Transfer ownerhip of buffers to Arrow structures + // Exceptions beyond this point would cause memory leaks + auto batch_children = new ArrowSchema* [numCols]; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto arrow_schema = new ArrowSchema({ + .format = columnFormats[i].release(), + .name = columnNamesCStr[i].release(), + .release = ArrowSchema_release, + }); + batch_children[i] = arrow_schema; + } + auto arrow_schema_batch = new ArrowSchema({ + .format = strdup("+s"), + .name = strdup(""), + .n_children = numCols, + .children = batch_children, + .release = ArrowSchema_release, + }); + auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { + auto arrow_schema = static_cast(ptr); + if (arrow_schema->release) { + arrow_schema->release(arrow_schema); + } + delete arrow_schema; + }); + capsules.append(caps); auto arrow_array_batch_buffers = new const void* [3]; memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); @@ -4812,7 +4833,7 @@ SQLRETURN FetchArrowBatch_wrap( .children = new ArrowArray* [numCols], .release = ArrowArray_release, }); - // dummy buffer + // Necessary dummy buffer arrow_array_batch->buffers[1] = new int[1]; for (SQLUSMALLINT col = 0; col < numCols; col++) { @@ -4916,11 +4937,6 @@ SQLRETURN FetchArrowBatch_wrap( delete arrow_array; })); - - // Reset attributes before returning to avoid using stack pointers later - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); - SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); - return ret; } From 2d2064074d90fa151f7c8a49dca7334f70192cba Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 23 Nov 2025 22:47:37 +0100 Subject: [PATCH 31/36] Add functions for arrow table/reader, rename fetch_arrow_batch to arrow_batch --- mssql_python/cursor.py | 57 ++++++++++++++++++++++++++++++++++++------ 1 file changed, 49 insertions(+), 8 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index c9c6aec86..8edb7b362 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -25,7 +25,10 @@ from mssql_python import get_settings if TYPE_CHECKING: + import pyarrow # type: ignore from mssql_python.connection import Connection +else: + pyarrow = None # Constants for string handling MAX_INLINE_CHAR: int = ( @@ -2195,25 +2198,64 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e - def fetch_arrow_batch(self, batch_length: int) -> Any: + def arrow_batch(self, batch_size: int=8192) -> "pyarrow.RecordBatch": self._check_closed() # Check if the cursor is closed if not self._has_result_set and self.description: self._reset_rownumber() try: - import pyarrow as pa + import pyarrow except ImportError as e: raise ImportError( - "pyarrow is required for fetch_arrow_batch(). Please install pyarrow." + "pyarrow is required for arrow_batch(). Please install pyarrow." ) from e + capsules = [] - ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, batch_length) + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, batch_size) assert ret in (0, 1), ret - schema_capsule = capsules[0] - array_capsule = capsules[1] - batch = pa.RecordBatch._import_from_c_capsule(schema_capsule, array_capsule) + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) return batch + def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow(). Please install pyarrow." + ) from e + + assert batch_size > 0 + batches: list["pyarrow.RecordBatch"] = [] + while True: + batch = self.arrow_batch(batch_size) + if batch.num_rows < batch_size: + if not batches or batch.num_rows > 0: + batches.append(batch) + break + batches.append(batch) + return pyarrow.Table.from_batches(batches, schema=batches[0].schema) + + def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": + """ + Fetch the result as a pyarrow RecordBatchReader. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for fetch_record_batch(). Please install pyarrow." + ) from e + + # Fetch schema without advancing cursor + schema_batch = self.arrow_batch(0) + schema = schema_batch.schema + + def batch_generator(): + while len(batch := self.arrow_batch(batch_size)) > 0: + yield batch + + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. @@ -2389,7 +2431,6 @@ def __del__(self): Destructor to ensure the cursor is closed when it is no longer needed. This is a safety net to ensure resources are cleaned up even if close() was not called explicitly. - If the cursor is already closed, it will not raise an exception during cleanup. """ if "closed" not in self.__dict__ or not self.closed: try: From a4474b012f8f6bb8d3bccfe502a57ac9ba5dfad8 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Mon, 24 Nov 2025 15:59:37 +0100 Subject: [PATCH 32/36] Update Docstrings --- mssql_python/cursor.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 8edb7b362..7d2ac00ef 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2199,6 +2199,16 @@ def fetchall(self) -> List[Row]: raise e def arrow_batch(self, batch_size: int=8192) -> "pyarrow.RecordBatch": + """ + Fetch a single pyarrow Record Batch of the specified size from the + query result set. + + Args: + batch_size: Maximum number of rows to fetch in the Record Batch. + + Returns: + A pyarrow RecordBatch object containing up to batch_size rows. + """ self._check_closed() # Check if the cursor is closed if not self._has_result_set and self.description: self._reset_rownumber() @@ -2217,6 +2227,15 @@ def arrow_batch(self, batch_size: int=8192) -> "pyarrow.RecordBatch": return batch def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": + """ + Fetch the entire result as a pyarrow Table. + + Args: + batch_size: Size of the Record Batches which make up the Table. + + Returns: + A pyarrow Table containing all remaining rows from the result set. + """ try: import pyarrow except ImportError as e: @@ -2237,7 +2256,15 @@ def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": """ - Fetch the result as a pyarrow RecordBatchReader. + Fetch the result as a pyarrow RecordBatchReader, which yields Record + Batches of the specified size until the current result set is + exhausted. + + Args: + batch_size: Size of the Record Batches produced by the reader. + + Returns: + A pyarrow RecordBatchReader for the result set. """ try: import pyarrow From 1d3438f129bb48119e913aa17102e380f8757d65 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:44:49 +0100 Subject: [PATCH 33/36] Undo accidental changes --- mssql_python/cursor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 7d2ac00ef..9fc07ccdb 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2458,10 +2458,10 @@ def __del__(self): Destructor to ensure the cursor is closed when it is no longer needed. This is a safety net to ensure resources are cleaned up even if close() was not called explicitly. + If the cursor is already closed, it will not raise an exception during cleanup. """ if "closed" not in self.__dict__ or not self.closed: try: - assert self is not None self.close() except Exception as e: # pylint: disable=broad-exception-caught # Don't raise an exception in __del__, just log it From 6c683d5862af7162d53e0212b976b28174a08f1a Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 25 Nov 2025 17:54:50 +0100 Subject: [PATCH 34/36] check_error instead of ret assert + handle negative/zero batch_size --- mssql_python/cursor.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 9fc07ccdb..877723fae 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2221,8 +2221,9 @@ def arrow_batch(self, batch_size: int=8192) -> "pyarrow.RecordBatch": ) from e capsules = [] - ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, batch_size) - assert ret in (0, 1), ret + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0)) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) return batch @@ -2243,11 +2244,10 @@ def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": "pyarrow is required for arrow(). Please install pyarrow." ) from e - assert batch_size > 0 batches: list["pyarrow.RecordBatch"] = [] while True: batch = self.arrow_batch(batch_size) - if batch.num_rows < batch_size: + if batch.num_rows < batch_size or batch_size <= 0: if not batches or batch.num_rows > 0: batches.append(batch) break From a8a4bf33f9a0cd5c3b2d771895886d7c4e198519 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 25 Nov 2025 18:42:27 +0100 Subject: [PATCH 35/36] apply AI suggestions --- mssql_python/cursor.py | 4 ++-- mssql_python/pybind/ddbc_bindings.cpp | 13 +++++-------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 877723fae..265131655 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2270,7 +2270,7 @@ def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": import pyarrow except ImportError as e: raise ImportError( - "pyarrow is required for fetch_record_batch(). Please install pyarrow." + "pyarrow is required for arrow_reader(). Please install pyarrow." ) from e # Fetch schema without advancing cursor @@ -2278,7 +2278,7 @@ def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": schema = schema_batch.schema def batch_generator(): - while len(batch := self.arrow_batch(batch_size)) > 0: + while (batch := self.arrow_batch(batch_size)).num_rows > 0: yield batch return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index a4e829b70..815c4633f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4556,10 +4556,7 @@ SQLRETURN FetchArrowBatch_wrap( SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; - // This value indicates that the driver cannot determine the length of the data - if (dataLen == SQL_NO_TOTAL) { - assert(false && "Is this actually possible?"); - } else if (dataLen == SQL_NULL_DATA) { + if (dataLen == SQL_NULL_DATA) { // Mark as null in validity bitmap size_t bytePos = idxRowArrow / 8; size_t bitPos = idxRowArrow % 8; @@ -4589,9 +4586,8 @@ SQLRETURN FetchArrowBatch_wrap( } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); - ThrowStdException("Unexpected negative data length, check logs for details"); + ThrowStdException("Unexpected negative data length."); } - assert(dataLen >= 0 && "Data length must be >= 0"); switch (dataType) { case SQL_BINARY: @@ -4638,7 +4634,7 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); - buffersArrow.var[col - 1][i + 1] = start + dataLenConverted; + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLenConverted; #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); @@ -4751,7 +4747,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { - // TODO wrong ctype for SQL_SS_TIME2 + // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. + // To fully support SQL_SS_TIME2, the corresponding c-type should be used. const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][idxRowSql]; buffersArrow.time_second[col - 1][idxRowArrow] = static_cast(timeValue.hour) * 3600 + From 9b086dd5f3b8e9e13487159799fa8360afb97426 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 25 Nov 2025 19:58:37 +0100 Subject: [PATCH 36/36] Apply black formatting --- mssql_python/cursor.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 265131655..46d3f4539 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -2198,7 +2198,7 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e - def arrow_batch(self, batch_size: int=8192) -> "pyarrow.RecordBatch": + def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch": """ Fetch a single pyarrow Record Batch of the specified size from the query result set. @@ -2240,9 +2240,7 @@ def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": try: import pyarrow except ImportError as e: - raise ImportError( - "pyarrow is required for arrow(). Please install pyarrow." - ) from e + raise ImportError("pyarrow is required for arrow(). Please install pyarrow.") from e batches: list["pyarrow.RecordBatch"] = [] while True: @@ -2276,11 +2274,11 @@ def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": # Fetch schema without advancing cursor schema_batch = self.arrow_batch(0) schema = schema_batch.schema - + def batch_generator(): while (batch := self.arrow_batch(batch_size)).num_rows > 0: yield batch - + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) def nextset(self) -> Union[bool, None]: