diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index 2889f2ca8..46d3f4539 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -25,7 +25,10 @@ from mssql_python import get_settings if TYPE_CHECKING: + import pyarrow # type: ignore from mssql_python.connection import Connection +else: + pyarrow = None # Constants for string handling MAX_INLINE_CHAR: int = ( @@ -2195,6 +2198,89 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e + def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch": + """ + Fetch a single pyarrow Record Batch of the specified size from the + query result set. + + Args: + batch_size: Maximum number of rows to fetch in the Record Batch. + + Returns: + A pyarrow RecordBatch object containing up to batch_size rows. + """ + self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() + + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_batch(). Please install pyarrow." + ) from e + + capsules = [] + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0)) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) + return batch + + def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": + """ + Fetch the entire result as a pyarrow Table. + + Args: + batch_size: Size of the Record Batches which make up the Table. + + Returns: + A pyarrow Table containing all remaining rows from the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError("pyarrow is required for arrow(). Please install pyarrow.") from e + + batches: list["pyarrow.RecordBatch"] = [] + while True: + batch = self.arrow_batch(batch_size) + if batch.num_rows < batch_size or batch_size <= 0: + if not batches or batch.num_rows > 0: + batches.append(batch) + break + batches.append(batch) + return pyarrow.Table.from_batches(batches, schema=batches[0].schema) + + def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": + """ + Fetch the result as a pyarrow RecordBatchReader, which yields Record + Batches of the specified size until the current result set is + exhausted. + + Args: + batch_size: Size of the Record Batches produced by the reader. + + Returns: + A pyarrow RecordBatchReader for the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_reader(). Please install pyarrow." + ) from e + + # Fetch schema without advancing cursor + schema_batch = self.arrow_batch(0) + schema = schema_batch.schema + + def batch_generator(): + while (batch := self.arrow_batch(batch_size)).num_rows > 0: + yield batch + + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. diff --git a/mssql_python/pybind/build.sh b/mssql_python/pybind/build.sh index 811777285..38f141cff 100755 --- a/mssql_python/pybind/build.sh +++ b/mssql_python/pybind/build.sh @@ -87,8 +87,8 @@ if [ $? -ne 0 ]; then fi # Build the project -echo "[DIAGNOSTIC] Running CMake build with: cmake --build . --config Release" -cmake --build . --config Release +echo "[DIAGNOSTIC] Running CMake build with: cmake --build . --config Debug" +cmake --build . --config Debug # Check if build succeeded if [ $? -ne 0 ]; then diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 9a8280117..815c4633f 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -147,6 +147,83 @@ struct NumericData { } }; +// Struct to hold data buffers and indicators for each column +struct ColumnBuffersArrow { + std::vector> uint8; + std::vector> int16; + std::vector> int32; + std::vector> int64; + std::vector> float64; + std::vector> bit; + std::vector> var; + std::vector> date; + std::vector> ts_micro; + std::vector> time_second; + std::vector> decimal; + + std::vector> valid; + std::vector> var_data; + + ColumnBuffersArrow(SQLSMALLINT numCols) + : + uint8(numCols), + int16(numCols), + int32(numCols), + int64(numCols), + float64(numCols), + bit(numCols), + var(numCols), + date(numCols), + ts_micro(numCols), + time_second(numCols), + decimal(numCols), + + valid(numCols), + var_data(numCols) {} +}; + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + //------------------------------------------------------------------------------------------------- // Function pointer initialization //------------------------------------------------------------------------------------------------- @@ -3916,6 +3993,951 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } +// GetDataVar - Progressively fetches variable-length column data using SQLGetData. +// +// Calls SQLGetData repeatedly, reallocating the buffer as needed, until all data is retrieved. +// Handles both fixed-size and unknown-size (SQL_NO_TOTAL) responses from the driver. +// +// @param hStmt: Statement handle +// @param colNumber: 1-based column index +// @param cType: SQL C data type (SQL_C_CHAR, SQL_C_WCHAR, or SQL_C_BINARY) +// @param dataVec: Reference to vector that will hold the fetched data (will be resized as needed) +// @param indicator: Pointer to indicator value (SQL_NULL_DATA for NULL, or data length) +// +// @return SQLRETURN: SQL_SUCCESS on success, or error code on failure +template +SQLRETURN GetDataVar(SQLHSTMT hStmt, + SQLUSMALLINT colNumber, + SQLSMALLINT cType, + std::vector& dataVec, + SQLLEN* indicator) { + if (!SQLGetData_ptr) { + ThrowStdException("SQLGetData function not loaded"); + } + + size_t start = 0; + size_t end = 0; + + // Determine null terminator size based on data type + size_t sizeNullTerminator = 0; + switch (cType) { + case SQL_C_WCHAR: + case SQL_C_CHAR: + sizeNullTerminator = 1; + break; + case SQL_C_BINARY: + sizeNullTerminator = 0; + break; + default: + ThrowStdException("GetDataVar only supports SQL_C_CHAR, SQL_C_WCHAR, and SQL_C_BINARY"); + } + + // Ensure initial buffer has space for at least the null terminator + if (dataVec.size() < sizeNullTerminator) { + dataVec.resize(sizeNullTerminator); + } + + while (true) { + SQLLEN localInd = 0; + SQLRETURN ret = SQLGetData_ptr( + hStmt, + colNumber, + cType, + reinterpret_cast(dataVec.data() + start), + sizeof(T) * (dataVec.size() - start), // Available buffer size from start position + &localInd + ); + + // Handle NULL data + if (localInd == SQL_NULL_DATA) { + *indicator = SQL_NULL_DATA; + return SQL_SUCCESS; + } + + // Check for errors (excluding SQL_SUCCESS_WITH_INFO which means more data available) + if (ret == SQL_ERROR || ret == SQL_INVALID_HANDLE) { + return ret; + } + + // SQL_SUCCESS or SQL_NO_DATA means we got all the data + if (ret == SQL_SUCCESS || ret == SQL_NO_DATA) { + if (localInd >= 0) { + *indicator = static_cast(start) * sizeof(T) + localInd; + } else { + *indicator = localInd; // Preserve SQL_NO_TOTAL or other negative values + } + break; + } + + // SQL_SUCCESS_WITH_INFO means buffer was too small, need to continue fetching + if (ret == SQL_SUCCESS_WITH_INFO) { + // Determine how much more space we need + if (localInd < 0) { + // SQL_NO_TOTAL: driver doesn't know total size, double the buffer + end = dataVec.size() * 2; + } else { + // Driver returned total size: allocate exactly what we need + assert(localInd % sizeof(T) == 0); + end = start + static_cast(localInd) / sizeof(T) + sizeNullTerminator; + } + + // The next read starts where the null terminator would have been placed + start = dataVec.size() - sizeNullTerminator; + + // Resize buffer for next iteration + dataVec.resize(end); + } else { + // Unexpected return code + return ret; + } + } + + return SQL_SUCCESS; +} + +void ArrowSchema_release(struct ArrowSchema* schema) { + assert (schema != nullptr); + assert (schema->release != nullptr); + schema->release = nullptr; + delete[] schema->name; + for (int i = 0; i < schema->n_children; i++) { + assert (schema->children != nullptr); + if (schema->children[i]) { + schema->children[i]->release(schema->children[i]); + delete schema->children[i]; + } + } + delete[] schema->children; + delete[] schema->format; +} + +void ArrowArray_release(struct ArrowArray* array) { + assert (array != nullptr); + assert (array->release != nullptr); + array->release = nullptr; + + uint32_t buffers_freed = 0; + uint32_t current_buffer = 0; + while (buffers_freed < array->n_buffers) { + if (array->buffers[current_buffer]) { + free((void*)array->buffers[current_buffer]); + buffers_freed++; + } + current_buffer++; + assert (current_buffer <= 3); + } + delete[] array->buffers; + + for (int i = 0; i < array->n_children; i++) { + assert (array->children != nullptr); + assert (array->children[i] != nullptr); + array->children[i]->release(array->children[i]); + delete array->children[i]; + } + delete[] array->children; + +} + +int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { + // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) + std::tm tm_date = {}; + tm_date.tm_year = year - 1900; // tm_year is years since 1900 + tm_date.tm_mon = month - 1; // tm_mon is 0-11 + tm_date.tm_mday = day; + + std::time_t time_since_epoch = std::mktime(&tm_date); + if (time_since_epoch == -1) { + LOG("Failed to convert SQL_DATE_STRUCT to time_t"); + ThrowStdException("Date conversion error"); + } + // Calculate days since epoch + return time_since_epoch / 86400; +} + +SQLRETURN FetchArrowBatch_wrap( + SqlHandlePtr StatementHandle, + py::list& capsules, + ssize_t arrowBatchSize +) { + ssize_t fetchSize = arrowBatchSize; + SQLRETURN ret; + SQLHSTMT hStmt = StatementHandle->get(); + // Retrieve column count + SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); + if (numCols <= 0) { + ThrowStdException("No active result set. Cannot fetch Arrow batch."); + } + + // Retrieve column metadata + py::list columnNames; + ret = SQLDescribeCol_wrap(StatementHandle, columnNames); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to get column descriptions"); + return ret; + } + + bool hasLobColumns = false; + + std::vector dataTypes(numCols); + std::vector columnSizes(numCols); + std::vector> columnFormats(numCols); + std::vector> columnNamesCStr(numCols); + + ColumnBuffersArrow buffersArrow(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto colMeta = columnNames[i].cast(); + SQLSMALLINT dataType = colMeta["DataType"].cast(); + SQLULEN columnSize = colMeta["ColumnSize"].cast(); + dataTypes[i] = dataType; + columnSizes[i] = columnSize; + + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { + hasLobColumns = true; + if (fetchSize > 1) { + fetchSize = 1; // LOBs require row-by-row fetch + } + } + + std::string columnName = colMeta["ColumnName"].cast(); + size_t nameLen = columnName.length() + 1; + columnNamesCStr[i] = std::make_unique(nameLen); + std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); + + const char* format = nullptr; + switch(dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + format = "u"; + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); + // start at offset 0 + buffersArrow.var[i][0] = 0; + break; + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + format = "z"; + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); + // start at offset 0 + buffersArrow.var[i][0] = 0; + break; + case SQL_TINYINT: + format = "C"; + buffersArrow.uint8[i] = std::make_unique(arrowBatchSize); + break; + case SQL_SMALLINT: + format = "s"; + buffersArrow.int16[i] = std::make_unique(arrowBatchSize); + break; + case SQL_INTEGER: + format = "i"; + buffersArrow.int32[i] = std::make_unique(arrowBatchSize); + break; + case SQL_BIGINT: + format = "l"; + buffersArrow.int64[i] = std::make_unique(arrowBatchSize); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + format = "g"; + buffersArrow.float64[i] = std::make_unique(arrowBatchSize); + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + std::ostringstream formatStream; + formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); + std::string formatStr = formatStream.str(); + size_t formatLen = formatStr.length() + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), formatStr.c_str(), formatLen); + format = columnFormats[i].get(); + buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(arrowBatchSize); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + format = "tsu:"; + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + break; + case SQL_SS_TIMESTAMPOFFSET: + format = "tsu:+00:00"; + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + break; + case SQL_TYPE_DATE: + format = "tdD"; + buffersArrow.date[i] = std::make_unique(arrowBatchSize); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + format = "tts"; + buffersArrow.time_second[i] = std::make_unique(arrowBatchSize); + break; + case SQL_BIT: + format = "b"; + buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); + break; + default: + std::wstring columnName = colMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << (i + 1); + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + + // Store format string if not already stored (for non-decimal types) + if (!columnFormats[i]) { + size_t formatLen = std::strlen(format) + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), format, formatLen); + } + + buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); + // Initialize validity bitmap to all valid + std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); + } + + if (fetchSize > 1) { + // An overly large fetch size doesn't seem to help performance + SQLSMALLINT searchStart = 64; + if (arrowBatchSize < 64) { + searchStart = static_cast(arrowBatchSize); + } + for (SQLSMALLINT maybeNewSize = searchStart; maybeNewSize >= 1; maybeNewSize -= 1) { + if (arrowBatchSize % maybeNewSize == 0) { + fetchSize = maybeNewSize; + break; + } + } + } + + // Initialize column buffers + ColumnBuffers buffers(numCols, fetchSize); + + if (!hasLobColumns && fetchSize > 0) { + // Bind columns + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error when binding columns"); + return ret; + } + } + + SQLULEN numRowsFetched; + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); + + + size_t idxRowArrow = 0; + // arrowBatchSize % fetchSize == 0 ensures that any followup (even non-arrow) fetches + // start with a fresh batch + assert(fetchSize == 0 || arrowBatchSize % fetchSize == 0); + assert(fetchSize <= arrowBatchSize); + + while (idxRowArrow < arrowBatchSize) { + ret = SQLFetch_ptr(hStmt); + if (ret == SQL_NO_DATA) { + ret = SQL_SUCCESS; // Normal completion + break; + } + if (!SQL_SUCCEEDED(ret)) { + LOG("Error while fetching rows in batches"); + return ret; + } + // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. + // It'll be populated by SQLFetch + assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); + for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { + for (SQLUSMALLINT col = 1; col <= numCols; col++) { + auto dataType = dataTypes[col - 1]; + auto columnSize = columnSizes[col - 1]; + + if (hasLobColumns) { + assert(idxRowSql == 0 && "GetData only works one row at a time"); + + switch(dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + GetDataVar( + hStmt, + col, + SQL_C_BINARY, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_CHAR, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_WCHAR, + buffers.wcharBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_INTEGER: { + buffers.intBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SLONG, + buffers.intBuffers[col - 1].data(), + sizeof(SQLINTEGER), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SMALLINT: { + buffers.smallIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SSHORT, + buffers.smallIntBuffers[col - 1].data(), + sizeof(SQLSMALLINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TINYINT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TINYINT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_BIT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_REAL: { + buffers.realBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_FLOAT, + buffers.realBuffers[col - 1].data(), + sizeof(SQLREAL), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DECIMAL: + case SQL_NUMERIC: { + buffers.charBuffers[col - 1].resize(MAX_DIGITS_IN_NUMERIC); + SQLGetData_ptr( + hStmt, col, SQL_C_CHAR, + buffers.charBuffers[col - 1].data(), + MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DOUBLE: + case SQL_FLOAT: { + buffers.doubleBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_DOUBLE, + buffers.doubleBuffers[col - 1].data(), + sizeof(SQLDOUBLE), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + buffers.timestampBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[col - 1].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIGINT: { + buffers.bigIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SBIGINT, + buffers.bigIntBuffers[col - 1].data(), + sizeof(SQLBIGINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TYPE_DATE: { + buffers.dateBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_DATE, + buffers.dateBuffers[col - 1].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + buffers.timeBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIME, + buffers.timeBuffers[col - 1].data(), + sizeof(SQL_TIME_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_GUID: { + buffers.guidBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_GUID, + buffers.guidBuffers[col - 1].data(), + sizeof(SQLGUID), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + buffers.datetimeoffsetBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[col - 1].data(), + sizeof(DateTimeOffset), + buffers.indicators[col - 1].data() + ); + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; + LOG("SQLGetData: %s", errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + + SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; + + if (dataLen == SQL_NULL_DATA) { + // Mark as null in validity bitmap + size_t bytePos = idxRowArrow / 8; + size_t bitPos = idxRowArrow % 8; + buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + + // Value buffer for variable length data types needs to be set appropriately + // as it will be used by the next non null value + switch (dataType) + { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + buffersArrow.var[col - 1][idxRowArrow + 1] = buffersArrow.var[col - 1][idxRowArrow]; + break; + default: + break; + } + continue; + } else if (dataLen < 0) { + // Negative value is unexpected, log column index, SQL type & raise exception + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); + ThrowStdException("Unexpected negative data length."); + } + + switch (dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + assert(dataLen % sizeof(SQLWCHAR) == 0); + auto dataLenW = dataLen / sizeof(SQLWCHAR); + auto wcharSource = &buffers.wcharBuffers[col - 1][idxRowSql * (columnSize + 1)]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &buffersArrow.var_data[col - 1]; +#if defined(_WIN32) + // Convert wide string + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); + while (target_vec->size() < start + dataLenConverted) { + target_vec->resize(target_vec->size() * 2); + } + WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLenConverted; +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); + std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); +#endif + break; + } + case SQL_GUID: { + // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") + // Each GUID is exactly 36 bytes in UTF-8 + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + + // Ensure buffer has space for the GUID string + null terminator + while (target_vec->size() < start + 37) { + target_vec->resize(target_vec->size() * 2); + } + + // Get the GUID from the buffer + const SQLGUID& guidValue = buffers.guidBuffers[col - 1][idxRowSql]; + + // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + snprintf(reinterpret_cast(&target_vec->data()[start]), 37, + "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + guidValue.Data1, + guidValue.Data2, + guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], + guidValue.Data4[2], guidValue.Data4[3], + guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); + + // Update offset for next row, ignoring null terminator + buffersArrow.var[col - 1][idxRowArrow + 1] = start + 36; + break; + } + case SQL_TINYINT: + buffersArrow.uint8[col - 1][idxRowArrow] = buffers.charBuffers[col - 1][idxRowSql]; + break; + case SQL_SMALLINT: + buffersArrow.int16[col - 1][idxRowArrow] = buffers.smallIntBuffers[col - 1][idxRowSql]; + break; + case SQL_INTEGER: + buffersArrow.int32[col - 1][idxRowArrow] = buffers.intBuffers[col - 1][idxRowSql]; + break; + case SQL_BIGINT: + buffersArrow.int64[col - 1][idxRowArrow] = buffers.bigIntBuffers[col - 1][idxRowSql]; + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + buffersArrow.float64[col - 1][idxRowArrow] = buffers.doubleBuffers[col - 1][idxRowSql]; + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + assert(dataLen <= MAX_DIGITS_IN_NUMERIC); + __int128_t decimalValue = 0; + auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; + for (SQLULEN idx = start; idx < start + dataLen; idx++) { + char digitChar = buffers.charBuffers[col - 1][idx]; + if (digitChar == '-') { + decimalValue = -decimalValue; + } else if (digitChar >= '0' && digitChar <= '9') { + decimalValue = decimalValue * 10 + (digitChar - '0'); + } + } + buffersArrow.decimal[col - 1][idxRowArrow] = decimalValue; + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + static_cast(sql_value.hour) * 3600 * 1000000 + + static_cast(sql_value.minute) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_TYPE_DATE: + buffersArrow.date[col - 1][idxRowArrow] = dateAsDayCount( + buffers.dateBuffers[col - 1][idxRowSql].year, + buffers.dateBuffers[col - 1][idxRowSql].month, + buffers.dateBuffers[col - 1][idxRowSql].day + ); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. + // To fully support SQL_SS_TIME2, the corresponding c-type should be used. + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][idxRowSql]; + buffersArrow.time_second[col - 1][idxRowArrow] = + static_cast(timeValue.hour) * 3600 + + static_cast(timeValue.minute) * 60 + + static_cast(timeValue.second); + break; + } + case SQL_BIT: { + // SQL_BIT is stored as a single bit in Arrow's bitmap format + // Get the boolean value from the buffer + bool bitValue = buffers.charBuffers[col - 1][idxRowSql] != 0; + + // Set the bit in the Arrow bitmap + size_t byteIndex = idxRowArrow / 8; + size_t bitIndex = idxRowArrow % 8; + + if (bitValue) { + // Set bit to 1 + buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); + } else { + // Clear bit to 0 + buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); + } + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + idxRowArrow++; + } + } + + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + + // Transfer ownerhip of buffers to Arrow structures + // Exceptions beyond this point would cause memory leaks + auto batch_children = new ArrowSchema* [numCols]; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto arrow_schema = new ArrowSchema({ + .format = columnFormats[i].release(), + .name = columnNamesCStr[i].release(), + .release = ArrowSchema_release, + }); + batch_children[i] = arrow_schema; + } + + auto arrow_schema_batch = new ArrowSchema({ + .format = strdup("+s"), + .name = strdup(""), + .n_children = numCols, + .children = batch_children, + .release = ArrowSchema_release, + }); + auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { + auto arrow_schema = static_cast(ptr); + if (arrow_schema->release) { + arrow_schema->release(arrow_schema); + } + delete arrow_schema; + }); + capsules.append(caps); + + auto arrow_array_batch_buffers = new const void* [3]; + memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); + auto arrow_array_batch = new ArrowArray({ + .length = static_cast(idxRowArrow), + .n_buffers = 1, + .n_children = numCols, + .buffers = arrow_array_batch_buffers, + .children = new ArrowArray* [numCols], + .release = ArrowArray_release, + }); + // Necessary dummy buffer + arrow_array_batch->buffers[1] = new int[1]; + + for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto dataType = dataTypes[col]; + auto arrow_array_col_buffers = new const void* [3]; + memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); + // Allocate new memory and copy the data + switch (dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + assert(buffersArrow.var[col][0] == 0); + // length of string at index i is the difference between values at i and i+1 + // so total length is value at index idxRowArrow + auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; + uint8_t* dataBuffer = new uint8_t[data_buf_len_total]; + std::memcpy(dataBuffer, buffersArrow.var_data[col].data(), data_buf_len_total); + arrow_array_col_buffers[2] = dataBuffer; + arrow_array_col_buffers[1] = buffersArrow.var[col].release(); + } + break; + case SQL_TINYINT: + arrow_array_col_buffers[1] = buffersArrow.uint8[col].release(); + break; + case SQL_SMALLINT: + arrow_array_col_buffers[1] = buffersArrow.int16[col].release(); + break; + case SQL_INTEGER: + arrow_array_col_buffers[1] = buffersArrow.int32[col].release(); + break; + case SQL_BIGINT: + arrow_array_col_buffers[1] = buffersArrow.int64[col].release(); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + arrow_array_col_buffers[1] = buffersArrow.float64[col].release(); + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + arrow_array_col_buffers[1] = buffersArrow.decimal[col].release(); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_SS_TIMESTAMPOFFSET: + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_TYPE_DATE: + arrow_array_col_buffers[1] = buffersArrow.date[col].release(); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + arrow_array_col_buffers[1] = buffersArrow.time_second[col].release(); + break; + case SQL_BIT: + arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); + break; + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << (col + 1) + << ", Type - " << dataType; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + + auto arrow_array_col = new ArrowArray({ + .length = static_cast(idxRowArrow), + .null_count = 0, + .offset = 0, + .n_buffers = arrow_array_col_buffers[2] ? 3 : 2, + .n_children = 0, + .buffers = arrow_array_col_buffers, + .children = nullptr, + .release = ArrowArray_release, + }); + + arrow_array_col->buffers[0] = buffersArrow.valid[col].release(); + arrow_array_batch->children[col] = arrow_array_col; + } + + capsules.append(py::capsule((void*)arrow_array_batch, "arrow_array", [](void* ptr) { + auto arrow_array = static_cast(ptr); + if (arrow_array->release) { + arrow_array->release(arrow_array); + } + delete arrow_array; + })); + + return ret; +} + + // FetchAll_wrap - Fetches all rows of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be @@ -4222,6 +5244,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), py::arg("fetchSize") = 1, "Fetch many rows from the result set"); m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); + m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, "Fetch an arrow batch of given length from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords,