From 257f7c7ce2b90384c2be54a001e384deec0a565e Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 21:06:07 +0100 Subject: [PATCH 01/11] Add arrow fetch support --- mssql_python/cursor.py | 86 +++ mssql_python/pybind/ddbc_bindings.cpp | 1028 +++++++++++++++++++++++++ requirements.txt | 1 + tests/test_004_cursor.py | 246 ++++++ 4 files changed, 1361 insertions(+) diff --git a/mssql_python/cursor.py b/mssql_python/cursor.py index fd9d7b32..a95e5cbf 100644 --- a/mssql_python/cursor.py +++ b/mssql_python/cursor.py @@ -25,7 +25,10 @@ from mssql_python import get_settings if TYPE_CHECKING: + import pyarrow # type: ignore from mssql_python.connection import Connection +else: + pyarrow = None # Constants for string handling MAX_INLINE_CHAR: int = ( @@ -2198,6 +2201,89 @@ def fetchall(self) -> List[Row]: # On error, don't increment rownumber - rethrow the error raise e + def arrow_batch(self, batch_size: int = 8192) -> "pyarrow.RecordBatch": + """ + Fetch a single pyarrow Record Batch of the specified size from the + query result set. + + Args: + batch_size: Maximum number of rows to fetch in the Record Batch. + + Returns: + A pyarrow RecordBatch object containing up to batch_size rows. + """ + self._check_closed() # Check if the cursor is closed + if not self._has_result_set and self.description: + self._reset_rownumber() + + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_batch(). Please install pyarrow." + ) from e + + capsules = [] + ret = ddbc_bindings.DDBCSQLFetchArrowBatch(self.hstmt, capsules, max(batch_size, 0)) + check_error(ddbc_sql_const.SQL_HANDLE_STMT.value, self.hstmt, ret) + + batch = pyarrow.RecordBatch._import_from_c_capsule(*capsules) + return batch + + def arrow(self, batch_size: int = 8192) -> "pyarrow.Table": + """ + Fetch the entire result as a pyarrow Table. + + Args: + batch_size: Size of the Record Batches which make up the Table. + + Returns: + A pyarrow Table containing all remaining rows from the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError("pyarrow is required for arrow(). Please install pyarrow.") from e + + batches: list["pyarrow.RecordBatch"] = [] + while True: + batch = self.arrow_batch(batch_size) + if batch.num_rows < batch_size or batch_size <= 0: + if not batches or batch.num_rows > 0: + batches.append(batch) + break + batches.append(batch) + return pyarrow.Table.from_batches(batches, schema=batches[0].schema) + + def arrow_reader(self, batch_size: int = 8192) -> "pyarrow.RecordBatchReader": + """ + Fetch the result as a pyarrow RecordBatchReader, which yields Record + Batches of the specified size until the current result set is + exhausted. + + Args: + batch_size: Size of the Record Batches produced by the reader. + + Returns: + A pyarrow RecordBatchReader for the result set. + """ + try: + import pyarrow + except ImportError as e: + raise ImportError( + "pyarrow is required for arrow_reader(). Please install pyarrow." + ) from e + + # Fetch schema without advancing cursor + schema_batch = self.arrow_batch(0) + schema = schema_batch.schema + + def batch_generator(): + while (batch := self.arrow_batch(batch_size)).num_rows > 0: + yield batch + + return pyarrow.RecordBatchReader.from_batches(schema, batch_generator()) + def nextset(self) -> Union[bool, None]: """ Skip to the next available result set. diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 31cdc514..c7ef91cd 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -157,6 +157,83 @@ struct NumericData { } }; +// Struct to hold data buffers and indicators for each column +struct ColumnBuffersArrow { + std::vector> uint8; + std::vector> int16; + std::vector> int32; + std::vector> int64; + std::vector> float64; + std::vector> bit; + std::vector> var; + std::vector> date; + std::vector> ts_micro; + std::vector> time_second; + std::vector> decimal; + + std::vector> valid; + std::vector> var_data; + + ColumnBuffersArrow(SQLSMALLINT numCols) + : + uint8(numCols), + int16(numCols), + int32(numCols), + int64(numCols), + float64(numCols), + bit(numCols), + var(numCols), + date(numCols), + ts_micro(numCols), + time_second(numCols), + decimal(numCols), + + valid(numCols), + var_data(numCols) {} +}; + +#ifndef ARROW_C_DATA_INTERFACE +#define ARROW_C_DATA_INTERFACE + +#define ARROW_FLAG_DICTIONARY_ORDERED 1 +#define ARROW_FLAG_NULLABLE 2 +#define ARROW_FLAG_MAP_KEYS_SORTED 4 + +struct ArrowSchema { + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + void* private_data; +}; + +struct ArrowArray { + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + void* private_data; +}; + +#endif // ARROW_C_DATA_INTERFACE + //------------------------------------------------------------------------------------------------- // Function pointer initialization //------------------------------------------------------------------------------------------------- @@ -3926,6 +4003,956 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch return ret; } +// GetDataVar - Progressively fetches variable-length column data using SQLGetData. +// +// Calls SQLGetData repeatedly, reallocating the buffer as needed, until all data is retrieved. +// Handles both fixed-size and unknown-size (SQL_NO_TOTAL) responses from the driver. +// +// @param hStmt: Statement handle +// @param colNumber: 1-based column index +// @param cType: SQL C data type (SQL_C_CHAR, SQL_C_WCHAR, or SQL_C_BINARY) +// @param dataVec: Reference to vector that will hold the fetched data (will be resized as needed) +// @param indicator: Pointer to indicator value (SQL_NULL_DATA for NULL, or data length) +// +// @return SQLRETURN: SQL_SUCCESS on success, or error code on failure +template +SQLRETURN GetDataVar(SQLHSTMT hStmt, + SQLUSMALLINT colNumber, + SQLSMALLINT cType, + std::vector& dataVec, + SQLLEN* indicator) { + if (!SQLGetData_ptr) { + ThrowStdException("SQLGetData function not loaded"); + } + + size_t start = 0; + size_t end = 0; + + // Determine null terminator size based on data type + size_t sizeNullTerminator = 0; + switch (cType) { + case SQL_C_WCHAR: + case SQL_C_CHAR: + sizeNullTerminator = 1; + break; + case SQL_C_BINARY: + sizeNullTerminator = 0; + break; + default: + ThrowStdException("GetDataVar only supports SQL_C_CHAR, SQL_C_WCHAR, and SQL_C_BINARY"); + } + + // Ensure initial buffer has space for at least the null terminator + if (dataVec.size() < sizeNullTerminator) { + dataVec.resize(sizeNullTerminator); + } + + while (true) { + SQLLEN localInd = 0; + SQLRETURN ret = SQLGetData_ptr( + hStmt, + colNumber, + cType, + reinterpret_cast(dataVec.data() + start), + sizeof(T) * (dataVec.size() - start), // Available buffer size from start position + &localInd + ); + + // Handle NULL data + if (localInd == SQL_NULL_DATA) { + *indicator = SQL_NULL_DATA; + return SQL_SUCCESS; + } + + // Check for errors (excluding SQL_SUCCESS_WITH_INFO which means more data available) + if (ret == SQL_ERROR || ret == SQL_INVALID_HANDLE) { + return ret; + } + + // SQL_SUCCESS or SQL_NO_DATA means we got all the data + if (ret == SQL_SUCCESS || ret == SQL_NO_DATA) { + if (localInd >= 0) { + *indicator = static_cast(start) * sizeof(T) + localInd; + } else { + *indicator = localInd; // Preserve SQL_NO_TOTAL or other negative values + } + break; + } + + // SQL_SUCCESS_WITH_INFO means buffer was too small, need to continue fetching + if (ret == SQL_SUCCESS_WITH_INFO) { + // Determine how much more space we need + if (localInd < 0) { + // SQL_NO_TOTAL: driver doesn't know total size, double the buffer + end = dataVec.size() * 2; + } else { + // Driver returned total size: allocate exactly what we need + assert(localInd % sizeof(T) == 0); + end = start + static_cast(localInd) / sizeof(T) + sizeNullTerminator; + } + + // The next read starts where the null terminator would have been placed + start = dataVec.size() - sizeNullTerminator; + + // Resize buffer for next iteration + dataVec.resize(end); + } else { + // Unexpected return code + return ret; + } + } + + return SQL_SUCCESS; +} + +void ArrowSchema_release(struct ArrowSchema* schema) { + assert (schema != nullptr); + assert (schema->release != nullptr); + schema->release = nullptr; + delete[] schema->name; + for (int i = 0; i < schema->n_children; i++) { + assert (schema->children != nullptr); + if (schema->children[i]) { + schema->children[i]->release(schema->children[i]); + delete schema->children[i]; + } + } + delete[] schema->children; + delete[] schema->format; +} + +void ArrowArray_release(struct ArrowArray* array) { + assert (array != nullptr); + assert (array->release != nullptr); + array->release = nullptr; + + uint32_t buffers_freed = 0; + uint32_t current_buffer = 0; + while (buffers_freed < array->n_buffers) { + if (array->buffers[current_buffer]) { + free((void*)array->buffers[current_buffer]); + buffers_freed++; + } + current_buffer++; + assert (current_buffer <= 3); + } + delete[] array->buffers; + + for (int i = 0; i < array->n_children; i++) { + assert (array->children != nullptr); + assert (array->children[i] != nullptr); + array->children[i]->release(array->children[i]); + delete array->children[i]; + } + delete[] array->children; + +} + +int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { + // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) + std::tm tm_date = {}; + tm_date.tm_year = year - 1900; // tm_year is years since 1900 + tm_date.tm_mon = month - 1; // tm_mon is 0-11 + tm_date.tm_mday = day; + + std::time_t time_since_epoch = std::mktime(&tm_date); + if (time_since_epoch == -1) { + LOG("Failed to convert SQL_DATE_STRUCT to time_t"); + ThrowStdException("Date conversion error"); + } + // Calculate days since epoch + return time_since_epoch / 86400; +} + +SQLRETURN FetchArrowBatch_wrap( + SqlHandlePtr StatementHandle, + py::list& capsules, + ssize_t arrowBatchSize +) { + ssize_t fetchSize = arrowBatchSize; + SQLRETURN ret; + SQLHSTMT hStmt = StatementHandle->get(); + // Retrieve column count + SQLSMALLINT numCols = SQLNumResultCols_wrap(StatementHandle); + if (numCols <= 0) { + ThrowStdException("No active result set. Cannot fetch Arrow batch."); + } + + // Retrieve column metadata + py::list columnNames; + ret = SQLDescribeCol_wrap(StatementHandle, columnNames); + if (!SQL_SUCCEEDED(ret)) { + LOG("Failed to get column descriptions"); + return ret; + } + + bool hasLobColumns = false; + + std::vector dataTypes(numCols); + std::vector columnSizes(numCols); + std::vector columnNullable(numCols); + std::vector> columnFormats(numCols); + std::vector> columnNamesCStr(numCols); + + ColumnBuffersArrow buffersArrow(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto colMeta = columnNames[i].cast(); + SQLSMALLINT dataType = colMeta["DataType"].cast(); + SQLULEN columnSize = colMeta["ColumnSize"].cast(); + SQLSMALLINT nullable = colMeta["Nullable"].cast(); + dataTypes[i] = dataType; + columnSizes[i] = columnSize; + columnNullable[i] = (nullable != SQL_NO_NULLS); + + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || + dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || + dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML) && + (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { + hasLobColumns = true; + if (fetchSize > 1) { + fetchSize = 1; // LOBs require row-by-row fetch + } + } + + std::string columnName = colMeta["ColumnName"].cast(); + size_t nameLen = columnName.length() + 1; + columnNamesCStr[i] = std::make_unique(nameLen); + std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); + + const char* format = nullptr; + switch(dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + format = "u"; + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); + // start at offset 0 + buffersArrow.var[i][0] = 0; + break; + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + format = "z"; + buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); + buffersArrow.var_data[i].resize(arrowBatchSize * 42); + // start at offset 0 + buffersArrow.var[i][0] = 0; + break; + case SQL_TINYINT: + format = "C"; + buffersArrow.uint8[i] = std::make_unique(arrowBatchSize); + break; + case SQL_SMALLINT: + format = "s"; + buffersArrow.int16[i] = std::make_unique(arrowBatchSize); + break; + case SQL_INTEGER: + format = "i"; + buffersArrow.int32[i] = std::make_unique(arrowBatchSize); + break; + case SQL_BIGINT: + format = "l"; + buffersArrow.int64[i] = std::make_unique(arrowBatchSize); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + format = "g"; + buffersArrow.float64[i] = std::make_unique(arrowBatchSize); + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + std::ostringstream formatStream; + formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); + std::string formatStr = formatStream.str(); + size_t formatLen = formatStr.length() + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), formatStr.c_str(), formatLen); + format = columnFormats[i].get(); + buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(arrowBatchSize); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + format = "tsu:"; + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + break; + case SQL_SS_TIMESTAMPOFFSET: + format = "tsu:+00:00"; + buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + break; + case SQL_TYPE_DATE: + format = "tdD"; + buffersArrow.date[i] = std::make_unique(arrowBatchSize); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + format = "tts"; + buffersArrow.time_second[i] = std::make_unique(arrowBatchSize); + break; + case SQL_BIT: + format = "b"; + buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); + break; + default: + std::wstring columnName = colMeta["ColumnName"].cast(); + std::ostringstream errorString; + errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() + << ", Type - " << dataType << ", column ID - " << (i + 1); + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + + // Store format string if not already stored (for non-decimal types) + if (!columnFormats[i]) { + size_t formatLen = std::strlen(format) + 1; + columnFormats[i] = std::make_unique(formatLen); + std::memcpy(columnFormats[i].get(), format, formatLen); + } + + buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); + // Initialize validity bitmap to all valid + std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); + } + + if (fetchSize > 1) { + // An overly large fetch size doesn't seem to help performance + SQLSMALLINT searchStart = 64; + if (arrowBatchSize < 64) { + searchStart = static_cast(arrowBatchSize); + } + for (SQLSMALLINT maybeNewSize = searchStart; maybeNewSize >= 1; maybeNewSize -= 1) { + if (arrowBatchSize % maybeNewSize == 0) { + fetchSize = maybeNewSize; + break; + } + } + } + + // Initialize column buffers + ColumnBuffers buffers(numCols, fetchSize); + + if (!hasLobColumns && fetchSize > 0) { + // Bind columns + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error when binding columns"); + return ret; + } + } + + SQLULEN numRowsFetched; + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); + + + size_t idxRowArrow = 0; + // arrowBatchSize % fetchSize == 0 ensures that any followup (even non-arrow) fetches + // start with a fresh batch + assert(fetchSize == 0 || arrowBatchSize % fetchSize == 0); + assert(fetchSize <= arrowBatchSize); + + while (idxRowArrow < arrowBatchSize) { + ret = SQLFetch_ptr(hStmt); + if (ret == SQL_NO_DATA) { + ret = SQL_SUCCESS; // Normal completion + break; + } + if (!SQL_SUCCEEDED(ret)) { + LOG("Error while fetching rows in batches"); + return ret; + } + // numRowsFetched is the SQL_ATTR_ROWS_FETCHED_PTR attribute. + // It'll be populated by SQLFetch + assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); + for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { + for (SQLUSMALLINT col = 1; col <= numCols; col++) { + auto dataType = dataTypes[col - 1]; + auto columnSize = columnSizes[col - 1]; + + if (hasLobColumns) { + assert(idxRowSql == 0 && "GetData only works one row at a time"); + + switch(dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + GetDataVar( + hStmt, + col, + SQL_C_BINARY, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_CHAR, + buffers.charBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + GetDataVar( + hStmt, + col, + SQL_C_WCHAR, + buffers.wcharBuffers[col - 1], + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_INTEGER: { + buffers.intBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SLONG, + buffers.intBuffers[col - 1].data(), + sizeof(SQLINTEGER), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SMALLINT: { + buffers.smallIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SSHORT, + buffers.smallIntBuffers[col - 1].data(), + sizeof(SQLSMALLINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TINYINT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TINYINT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIT: { + buffers.charBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_BIT, + buffers.charBuffers[col - 1].data(), + sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_REAL: { + buffers.realBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_FLOAT, + buffers.realBuffers[col - 1].data(), + sizeof(SQLREAL), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DECIMAL: + case SQL_NUMERIC: { + buffers.charBuffers[col - 1].resize(MAX_DIGITS_IN_NUMERIC); + SQLGetData_ptr( + hStmt, col, SQL_C_CHAR, + buffers.charBuffers[col - 1].data(), + MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_DOUBLE: + case SQL_FLOAT: { + buffers.doubleBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_DOUBLE, + buffers.doubleBuffers[col - 1].data(), + sizeof(SQLDOUBLE), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + buffers.timestampBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[col - 1].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_BIGINT: { + buffers.bigIntBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SBIGINT, + buffers.bigIntBuffers[col - 1].data(), + sizeof(SQLBIGINT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TYPE_DATE: { + buffers.dateBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_DATE, + buffers.dateBuffers[col - 1].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + buffers.timeBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_TYPE_TIME, + buffers.timeBuffers[col - 1].data(), + sizeof(SQL_TIME_STRUCT), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_GUID: { + buffers.guidBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_GUID, + buffers.guidBuffers[col - 1].data(), + sizeof(SQLGUID), + buffers.indicators[col - 1].data() + ); + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + buffers.datetimeoffsetBuffers[col - 1].resize(1); + SQLGetData_ptr( + hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[col - 1].data(), + sizeof(DateTimeOffset), + buffers.indicators[col - 1].data() + ); + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; + LOG("SQLGetData: %s", errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + + SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; + + if (dataLen == SQL_NULL_DATA) { + // Mark as null in validity bitmap + size_t bytePos = idxRowArrow / 8; + size_t bitPos = idxRowArrow % 8; + buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + + // Value buffer for variable length data types needs to be set appropriately + // as it will be used by the next non null value + switch (dataType) + { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: + buffersArrow.var[col - 1][idxRowArrow + 1] = buffersArrow.var[col - 1][idxRowArrow]; + break; + default: + break; + } + continue; + } else if (dataLen < 0) { + // Negative value is unexpected, log column index, SQL type & raise exception + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); + ThrowStdException("Unexpected negative data length."); + } + + switch (dataType) { + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: { + uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + while (target_vec->size() < start + dataLen) { + target_vec->resize(target_vec->size() * 2); + } + + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + break; + } + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: { + assert(dataLen % sizeof(SQLWCHAR) == 0); + auto dataLenW = dataLen / sizeof(SQLWCHAR); + auto wcharSource = &buffers.wcharBuffers[col - 1][idxRowSql * (columnSize + 1)]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &buffersArrow.var_data[col - 1]; +#if defined(_WIN32) + // Convert wide string + int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); + while (target_vec->size() < start + dataLenConverted) { + target_vec->resize(target_vec->size() * 2); + } + WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLenConverted; +#else + // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 + std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); + std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); + buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); +#endif + break; + } + case SQL_GUID: { + // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") + // Each GUID is exactly 36 bytes in UTF-8 + auto target_vec = &buffersArrow.var_data[col - 1]; + auto start = buffersArrow.var[col - 1][idxRowArrow]; + + // Ensure buffer has space for the GUID string + null terminator + while (target_vec->size() < start + 37) { + target_vec->resize(target_vec->size() * 2); + } + + // Get the GUID from the buffer + const SQLGUID& guidValue = buffers.guidBuffers[col - 1][idxRowSql]; + + // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx + snprintf(reinterpret_cast(&target_vec->data()[start]), 37, + "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + guidValue.Data1, + guidValue.Data2, + guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], + guidValue.Data4[2], guidValue.Data4[3], + guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); + + // Update offset for next row, ignoring null terminator + buffersArrow.var[col - 1][idxRowArrow + 1] = start + 36; + break; + } + case SQL_TINYINT: + buffersArrow.uint8[col - 1][idxRowArrow] = buffers.charBuffers[col - 1][idxRowSql]; + break; + case SQL_SMALLINT: + buffersArrow.int16[col - 1][idxRowArrow] = buffers.smallIntBuffers[col - 1][idxRowSql]; + break; + case SQL_INTEGER: + buffersArrow.int32[col - 1][idxRowArrow] = buffers.intBuffers[col - 1][idxRowSql]; + break; + case SQL_BIGINT: + buffersArrow.int64[col - 1][idxRowArrow] = buffers.bigIntBuffers[col - 1][idxRowSql]; + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + buffersArrow.float64[col - 1][idxRowArrow] = buffers.doubleBuffers[col - 1][idxRowSql]; + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + assert(dataLen <= MAX_DIGITS_IN_NUMERIC); + __int128_t decimalValue = 0; + auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; + int sign = 1; + for (SQLULEN idx = start; idx < start + dataLen; idx++) { + char digitChar = buffers.charBuffers[col - 1][idx]; + if (digitChar == '-') { + sign = -1; + } else if (digitChar >= '0' && digitChar <= '9') { + decimalValue = decimalValue * 10 + (digitChar - '0'); + } + } + buffersArrow.decimal[col - 1][idxRowArrow] = decimalValue * sign; + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: { + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + static_cast(sql_value.hour) * 3600 * 1000000 + + static_cast(sql_value.minute) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_SS_TIMESTAMPOFFSET: { + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][idxRowSql]; + int64_t days = dateAsDayCount( + sql_value.year, + sql_value.month, + sql_value.day + ); + buffersArrow.ts_micro[col - 1][idxRowArrow] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + static_cast(sql_value.second) * 1000000 + + static_cast(sql_value.fraction) / 1000; + break; + } + case SQL_TYPE_DATE: + buffersArrow.date[col - 1][idxRowArrow] = dateAsDayCount( + buffers.dateBuffers[col - 1][idxRowSql].year, + buffers.dateBuffers[col - 1][idxRowSql].month, + buffers.dateBuffers[col - 1][idxRowSql].day + ); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: { + // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. + // To fully support SQL_SS_TIME2, the corresponding c-type should be used. + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][idxRowSql]; + buffersArrow.time_second[col - 1][idxRowArrow] = + static_cast(timeValue.hour) * 3600 + + static_cast(timeValue.minute) * 60 + + static_cast(timeValue.second); + break; + } + case SQL_BIT: { + // SQL_BIT is stored as a single bit in Arrow's bitmap format + // Get the boolean value from the buffer + bool bitValue = buffers.charBuffers[col - 1][idxRowSql] != 0; + + // Set the bit in the Arrow bitmap + size_t byteIndex = idxRowArrow / 8; + size_t bitIndex = idxRowArrow % 8; + + if (bitValue) { + // Set bit to 1 + buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); + } else { + // Clear bit to 0 + buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); + } + break; + } + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << col + << ", Type - " << dataType; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + } + idxRowArrow++; + } + } + + // Reset attributes before returning to avoid using stack pointers later + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); + SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); + + // Transfer ownerhip of buffers to Arrow structures + // Exceptions beyond this point would cause memory leaks + auto batch_children = new ArrowSchema* [numCols]; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto arrow_schema = new ArrowSchema({ + .format = columnFormats[i].release(), + .name = columnNamesCStr[i].release(), + .flags = columnNullable[i] ? 2 : 0, // ARROW_FLAG_NULLABLE + .release = ArrowSchema_release, + }); + batch_children[i] = arrow_schema; + } + + auto arrow_schema_batch = new ArrowSchema({ + .format = strdup("+s"), + .name = strdup(""), + .n_children = numCols, + .children = batch_children, + .release = ArrowSchema_release, + }); + auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { + auto arrow_schema = static_cast(ptr); + if (arrow_schema->release) { + arrow_schema->release(arrow_schema); + } + delete arrow_schema; + }); + capsules.append(caps); + + auto arrow_array_batch_buffers = new const void* [3]; + memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); + auto arrow_array_batch = new ArrowArray({ + .length = static_cast(idxRowArrow), + .n_buffers = 1, + .n_children = numCols, + .buffers = arrow_array_batch_buffers, + .children = new ArrowArray* [numCols], + .release = ArrowArray_release, + }); + // Necessary dummy buffer + arrow_array_batch->buffers[1] = new int[1]; + + for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto dataType = dataTypes[col]; + auto arrow_array_col_buffers = new const void* [3]; + memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); + // Allocate new memory and copy the data + switch (dataType) { + case SQL_CHAR: + case SQL_VARCHAR: + case SQL_LONGVARCHAR: + case SQL_SS_XML: + case SQL_WCHAR: + case SQL_WVARCHAR: + case SQL_WLONGVARCHAR: + case SQL_GUID: + case SQL_BINARY: + case SQL_VARBINARY: + case SQL_LONGVARBINARY: { + assert(buffersArrow.var[col][0] == 0); + // length of string at index i is the difference between values at i and i+1 + // so total length is value at index idxRowArrow + auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; + uint8_t* dataBuffer = new uint8_t[data_buf_len_total]; + std::memcpy(dataBuffer, buffersArrow.var_data[col].data(), data_buf_len_total); + arrow_array_col_buffers[2] = dataBuffer; + arrow_array_col_buffers[1] = buffersArrow.var[col].release(); + } + break; + case SQL_TINYINT: + arrow_array_col_buffers[1] = buffersArrow.uint8[col].release(); + break; + case SQL_SMALLINT: + arrow_array_col_buffers[1] = buffersArrow.int16[col].release(); + break; + case SQL_INTEGER: + arrow_array_col_buffers[1] = buffersArrow.int32[col].release(); + break; + case SQL_BIGINT: + arrow_array_col_buffers[1] = buffersArrow.int64[col].release(); + break; + case SQL_REAL: + case SQL_FLOAT: + case SQL_DOUBLE: + arrow_array_col_buffers[1] = buffersArrow.float64[col].release(); + break; + case SQL_DECIMAL: + case SQL_NUMERIC: { + arrow_array_col_buffers[1] = buffersArrow.decimal[col].release(); + break; + } + case SQL_TIMESTAMP: + case SQL_TYPE_TIMESTAMP: + case SQL_DATETIME: + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_SS_TIMESTAMPOFFSET: + arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + break; + case SQL_TYPE_DATE: + arrow_array_col_buffers[1] = buffersArrow.date[col].release(); + break; + case SQL_TIME: + case SQL_TYPE_TIME: + case SQL_SS_TIME2: + arrow_array_col_buffers[1] = buffersArrow.time_second[col].release(); + break; + case SQL_BIT: + arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); + break; + default: { + std::ostringstream errorString; + errorString << "Unsupported data type for column ID - " << (col + 1) + << ", Type - " << dataType; + LOG(errorString.str().c_str()); + ThrowStdException(errorString.str()); + break; + } + } + + auto arrow_array_col = new ArrowArray({ + .length = static_cast(idxRowArrow), + .null_count = 0, + .offset = 0, + .n_buffers = arrow_array_col_buffers[2] ? 3 : 2, + .n_children = 0, + .buffers = arrow_array_col_buffers, + .children = nullptr, + .release = ArrowArray_release, + }); + + arrow_array_col->buffers[0] = buffersArrow.valid[col].release(); + arrow_array_batch->children[col] = arrow_array_col; + } + + capsules.append(py::capsule((void*)arrow_array_batch, "arrow_array", [](void* ptr) { + auto arrow_array = static_cast(ptr); + if (arrow_array->release) { + arrow_array->release(arrow_array); + } + delete arrow_array; + })); + + return ret; +} + + // FetchAll_wrap - Fetches all rows of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be @@ -4232,6 +5259,7 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchMany", &FetchMany_wrap, py::arg("StatementHandle"), py::arg("rows"), py::arg("fetchSize") = 1, "Fetch many rows from the result set"); m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set"); + m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, "Fetch an arrow batch of given length from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, diff --git a/requirements.txt b/requirements.txt index 0951f7d0..4cd60771 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ pytest-cov coverage unittest-xml-reporting psutil +pyarrow # Build dependencies pybind11 diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index a37b2b6a..5f3cbdd3 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -18,6 +18,11 @@ import re from conftest import is_azure_sql_connection +try: + import pyarrow as pa +except ImportError: + pa = None + # Setup test table TEST_TABLE = """ @@ -14764,3 +14769,244 @@ def test_close(db_connection): pytest.fail(f"Cursor close test failed: {e}") finally: cursor = db_connection.cursor() + + +def get_arrow_test_data(include_lobs: bool, batch_length: int): + arrow_test_data = [ + (pa.uint8(), "tinyint", [1, 2, None, 4, 5, 0, 2**8 - 1]), + (pa.int16(), "smallint", [1, 2, None, 4, 5, -(2**15), 2**15 - 1]), + (pa.int32(), "int", [1, 2, None, 4, 5, 0, -(2**31), 2**31 - 1]), + (pa.int64(), "bigint", [1, 2, None, 4, 5, 0, -(2**63), 2**63 - 1]), + (pa.float64(), "float", [1.0, 2.5, None, 4.25, 5.125]), + ( + pa.decimal128(precision=10, scale=2), + "decimal(10, 2)", + [ + decimal.Decimal("1.23"), + None, + decimal.Decimal("0.25"), + decimal.Decimal("-99999999.99"), + decimal.Decimal("99999999.99"), + ], + ), + ( + pa.decimal128(precision=38, scale=10), + "decimal(38, 10)", + [ + decimal.Decimal("1.1234567890"), + None, + decimal.Decimal("0"), + decimal.Decimal("1.0000000001"), + decimal.Decimal("-9999999999999999999999999999.9999999999"), + decimal.Decimal("9999999999999999999999999999.9999999999"), + ], + ), + (pa.bool_(), "bit", [True, None, False]), + (pa.binary(), "binary(9)", [b"asdfghjkl", None, b"lkjhgfdsa"]), + (pa.string(), "varchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.string(), "nvarchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.date32(), "date", [date(1, 1, 1), None, date(2345, 12, 31), date(9999, 12, 31)]), + ( + pa.time32("s"), + "time(0)", + [time(12, 0, 5, 0), None, time(23, 59, 59, 0), time(0, 0, 0, 0)], + ), + ( + pa.time32("s"), + "time(7)", + [time(12, 0, 5, 0), None, time(23, 59, 59, 0), time(0, 0, 0, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(0)", + [datetime(2025, 1, 1, 12, 0, 5, 0), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(3)", + [datetime(2025, 1, 1, 12, 0, 5, 123_000), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(6)", + [datetime(2025, 1, 1, 12, 0, 5, 123_456), None, datetime(2345, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(7)", + [datetime(2025, 1, 1, 12, 0, 5, 123_456), None, datetime(2145, 12, 31, 23, 59, 59, 0)], + ), + ( + pa.timestamp("us"), + "datetime2(2)", + [datetime(2025, 1, 1, 12, 0, 5, 0), None, datetime(2145, 12, 31, 23, 59, 59, 0)], + ), + ] + + if include_lobs: + arrow_test_data += [ + (pa.string(), "nvarchar(max)", ["hey", None, "ho"]), + (pa.string(), "varchar(max)", ["hey", None, "ho"]), + (pa.binary(), "varbinary(max)", [b"hey", None, b"ho"]), + ] + + for ix in range(len(arrow_test_data)): + while True: + T, sql_type, vals = arrow_test_data[ix] + if len(vals) >= batch_length: + arrow_test_data[ix] = (T, sql_type, vals[:batch_length]) + break + arrow_test_data[ix] = (T, sql_type, vals + vals) + + return arrow_test_data + + +def _test_arrow_test_data(cursor: mssql_python.Cursor, arrow_test_data, fetch_length=500): + cols = [] + for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): + rows = [] + for value in values: + if type(value) is bool: + value = int(value) + if type(value) is bytes: + value = value.decode() + if value is None: + value = "null" + else: + value = f"'{value}'" + rows.append(f"col_{i_col} = cast({value} as {sql_type})") + cols.append(rows) + + selects = [] + for row in zip(*cols): + selects.append(f"select {', '.join(col for col in row)}") + full_query = "\nunion all\n".join(selects) + ret = cursor.execute(full_query).arrow_batch(fetch_length) + for i_col, col in enumerate(ret): + for i_row, (v_expected, v_actual) in enumerate( + zip(arrow_test_data[i_col][2][:fetch_length], col.to_pylist(), strict=True) + ): + assert ( + v_expected == v_actual + ), f"Mismatch in column {i_col}, row {i_row}: expected {v_expected}, got {v_actual}" + for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): + field = ret.schema.field(i_col) + assert ( + field.name == f"col_{i_col}" + ), f"Column {i_col} name mismatch: expected col_{i_col}, got {field.name}" + assert field.type.equals( + pa_type + ), f"Column {i_col} type mismatch: expected {pa_type}, got {field.type}" + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_lob_wide(cursor: mssql_python.Cursor): + "Take the SQLGetData branch for a wide table." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + _test_arrow_test_data(cursor, arrow_test_data) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_nolob_wide(cursor: mssql_python.Cursor): + "Test the SQLBindData branch for a wide table." + arrow_test_data = get_arrow_test_data(include_lobs=False, batch_length=123) + _test_arrow_test_data(cursor, arrow_test_data) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_single_column(cursor: mssql_python.Cursor): + "Test each datatype as a single column fetch." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + for col_data in arrow_test_data: + _test_arrow_test_data(cursor, [col_data]) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_empty_fetch(cursor: mssql_python.Cursor): + "Test each datatype as a single column fetch of length 0." + arrow_test_data = get_arrow_test_data(include_lobs=True, batch_length=123) + for col_data in arrow_test_data: + _test_arrow_test_data(cursor, [col_data], fetch_length=0) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_table_batchsize_negative(cursor: mssql_python.Cursor): + tbl = cursor.execute("select 1 a").arrow(batch_size=-42) + assert type(tbl) is pa.Table + assert tbl.num_rows == 0 + assert tbl.num_columns == 1 + assert cursor.fetchone()[0] == 1 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_empty_result_set(cursor: mssql_python.Cursor): + "Test fetching from an empty result set." + cursor.execute("select 1 where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 1 + cursor.execute("select cast(N'' as nvarchar(max)) where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 1 + cursor.execute("select 1, cast(N'' as nvarchar(max)) where 1 = 0") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 0 + assert batch.num_columns == 2 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_no_result_set(cursor: mssql_python.Cursor): + "Test fetching when there is no result set." + cursor.execute("declare @a int") + with pytest.raises(Exception, match=".*No active result set.*"): + cursor.arrow_batch(10) + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_datetimeoffset(cursor: mssql_python.Cursor): + "Datetimeoffset converts correctly to utc" + cursor.execute( + "declare @dt datetimeoffset(0) = '2345-02-03 12:34:56 +00:00';\n" + "select @dt, @dt at time zone 'Pacific Standard Time';\n" + ) + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 2 + for col in batch.columns: + assert pa.types.is_timestamp(col.type) + assert col.type.tz == "+00:00", col.type.tz + assert col.to_pylist() == [ + datetime(2345, 2, 3, 12, 34, 56, tzinfo=timezone.utc), + ] + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_schema_nullable(cursor: mssql_python.Cursor): + "Test that the schema is nullable." + cursor.execute("select 1 a, null b") + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 2 + assert not batch.schema.field(0).nullable + assert batch.schema.field(1).nullable + assert batch.schema.field(0).name == "a" + assert batch.schema.field(1).name == "b" + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_table(cursor: mssql_python.Cursor): + tbl = cursor.execute("select top 11 1 a from sys.objects").arrow(batch_size=5) + assert type(tbl) is pa.Table + assert tbl.num_rows == 11 + assert tbl.num_columns == 1 + assert [len(b) for b in tbl.to_batches()] == [5, 5, 1] + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_reader(cursor: mssql_python.Cursor): + reader = cursor.execute("select top 11 1 a from sys.objects").arrow_reader(batch_size=4) + assert type(reader) is pa.RecordBatchReader + batches = list(reader) + assert [len(b) for b in batches] == [4, 4, 3] + assert sum(len(b) for b in batches) == 11 From 43284523e80920cac14a8a1bf9d16ab5ae201a1c Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:32:04 +0100 Subject: [PATCH 02/11] Copilot suggestion: Fix typo Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index c7ef91cd..e561ac79 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4806,7 +4806,7 @@ SQLRETURN FetchArrowBatch_wrap( SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); - // Transfer ownerhip of buffers to Arrow structures + // Transfer ownership of buffers to Arrow structures // Exceptions beyond this point would cause memory leaks auto batch_children = new ArrowSchema* [numCols]; for (SQLSMALLINT i = 0; i < numCols; i++) { From 16341995c92bddb31359fb882d27a4203ac13b86 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:34:37 +0100 Subject: [PATCH 03/11] Copilot suggestion: Fix missing buffer resize Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index e561ac79..caf3bcaf 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4651,6 +4651,9 @@ SQLRETURN FetchArrowBatch_wrap( #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); + while (target_vec->size() < start + utf8str.size()) { + target_vec->resize(target_vec->size() * 2); + } std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); #endif From eb08a933bc610bde74210d4d34c1440ec1f6f71d Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:35:57 +0100 Subject: [PATCH 04/11] Copilot suggestion: Initialize bool value buffer Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index caf3bcaf..bc654754 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4301,6 +4301,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BIT: format = "b"; buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); + std::memset(buffersArrow.bit[i].get(), 0, (arrowBatchSize + 7) / 8); break; default: std::wstring columnName = colMeta["ColumnName"].cast(); From c61c6bb6a5949a7a14ac5e2082e0e4d8bacdfe5a Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:39:48 +0100 Subject: [PATCH 05/11] Add test for long data --- tests/test_004_cursor.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 5f3cbdd3..1b02cd77 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -15010,3 +15010,14 @@ def test_arrow_reader(cursor: mssql_python.Cursor): batches = list(reader) assert [len(b) for b in batches] == [4, 4, 3] assert sum(len(b) for b in batches) == 11 + + +@pytest.mark.skipif(pa is None, reason="pyarrow is not installed") +def test_arrow_long_string(cursor: mssql_python.Cursor): + "Make sure resizing the data buffer works" + long_string = "A" * 100000 # 100k characters + cursor.execute("select cast(? as nvarchar(max))", (long_string,)) + batch = cursor.arrow_batch(10) + assert batch.num_rows == 1 + assert batch.num_columns == 1 + assert batch.column(0).to_pylist() == [long_string] From 790d94d1b94fc97dc7ba28da47e2af658ecbb607 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:45:32 +0100 Subject: [PATCH 06/11] Copilot suggestion: Uppercase uuids --- mssql_python/pybind/ddbc_bindings.cpp | 2 +- tests/test_004_cursor.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index bc654754..9ea98b84 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4676,7 +4676,7 @@ SQLRETURN FetchArrowBatch_wrap( // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx snprintf(reinterpret_cast(&target_vec->data()[start]), 37, - "%08x-%04x-%04x-%02x%02x-%02x%02x%02x%02x%02x%02x", + "%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X", guidValue.Data1, guidValue.Data2, guidValue.Data3, diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 1b02cd77..47981817 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -14805,6 +14805,7 @@ def get_arrow_test_data(include_lobs: bool, batch_length: int): (pa.binary(), "binary(9)", [b"asdfghjkl", None, b"lkjhgfdsa"]), (pa.string(), "varchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), (pa.string(), "nvarchar(100)", ["asdfghjkl", None, "lkjhgfdsa"]), + (pa.string(), "uniqueidentifier", ["58185E0D-3A91-44D8-BC46-7107217E0A6D", None]), (pa.date32(), "date", [date(1, 1, 1), None, date(2345, 12, 31), date(9999, 12, 31)]), ( pa.time32("s"), From 322c9a71a53c2e0c870f19c51c374cee71e6d9d8 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sun, 30 Nov 2025 22:52:35 +0100 Subject: [PATCH 07/11] Copilot suggestion: use new for batch schema format/name Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- mssql_python/pybind/ddbc_bindings.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 9ea98b84..dc9e3aea 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4824,8 +4824,8 @@ SQLRETURN FetchArrowBatch_wrap( } auto arrow_schema_batch = new ArrowSchema({ - .format = strdup("+s"), - .name = strdup(""), + .format = []{ char* f = new char[3]; std::strcpy(f, "+s"); return f; }(), + .name = []{ char* n = new char[1]; n[0] = '\0'; return n; }(), .n_children = numCols, .children = batch_children, .release = ArrowSchema_release, From e274d1a4fd1133a12f6a69222456ec9187ff375f Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Tue, 2 Dec 2025 00:46:30 +0100 Subject: [PATCH 08/11] Replace free calls in release callbacks with unique pointers tracked by private_data --- mssql_python/pybind/ddbc_bindings.cpp | 231 +++++++++++++++++--------- 1 file changed, 151 insertions(+), 80 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index dc9e3aea..4a46efd9 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -192,6 +192,28 @@ struct ColumnBuffersArrow { var_data(numCols) {} }; +struct ArrowArrayPrivateData { + std::unique_ptr buffer_uint8; + std::unique_ptr buffer_int16; + std::unique_ptr buffer_int32; + std::unique_ptr buffer_int64; + std::unique_ptr buffer_float64; + std::unique_ptr buffer_bit; + std::unique_ptr buffer_var; + std::unique_ptr buffer_date; + std::unique_ptr buffer_ts_micro; + std::unique_ptr buffer_time_second; + std::unique_ptr<__int128_t[]> buffer_decimal; + + std::unique_ptr buffer_valid; + std::unique_ptr buffer_var_data; +}; + +struct ArrowSchemaPrivateData { + std::unique_ptr name; + std::unique_ptr format; +}; + #ifndef ARROW_C_DATA_INTERFACE #define ARROW_C_DATA_INTERFACE @@ -212,7 +234,8 @@ struct ArrowSchema { // Release callback void (*release)(struct ArrowSchema*); // Opaque producer-specific data - void* private_data; + // Only our child-arrays will set this, so we can give it the correct type + ArrowSchemaPrivateData* private_data; }; struct ArrowArray { @@ -229,7 +252,8 @@ struct ArrowArray { // Release callback void (*release)(struct ArrowArray*); // Opaque producer-specific data - void* private_data; + // Only our child-arrays will set this, so we can give it the correct type + ArrowArrayPrivateData* private_data; }; #endif // ARROW_C_DATA_INTERFACE @@ -4021,10 +4045,6 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, SQLSMALLINT cType, std::vector& dataVec, SQLLEN* indicator) { - if (!SQLGetData_ptr) { - ThrowStdException("SQLGetData function not loaded"); - } - size_t start = 0; size_t end = 0; @@ -4105,49 +4125,6 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, return SQL_SUCCESS; } -void ArrowSchema_release(struct ArrowSchema* schema) { - assert (schema != nullptr); - assert (schema->release != nullptr); - schema->release = nullptr; - delete[] schema->name; - for (int i = 0; i < schema->n_children; i++) { - assert (schema->children != nullptr); - if (schema->children[i]) { - schema->children[i]->release(schema->children[i]); - delete schema->children[i]; - } - } - delete[] schema->children; - delete[] schema->format; -} - -void ArrowArray_release(struct ArrowArray* array) { - assert (array != nullptr); - assert (array->release != nullptr); - array->release = nullptr; - - uint32_t buffers_freed = 0; - uint32_t current_buffer = 0; - while (buffers_freed < array->n_buffers) { - if (array->buffers[current_buffer]) { - free((void*)array->buffers[current_buffer]); - buffers_freed++; - } - current_buffer++; - assert (current_buffer <= 3); - } - delete[] array->buffers; - - for (int i = 0; i < array->n_children; i++) { - assert (array->children != nullptr); - assert (array->children[i] != nullptr); - array->children[i]->release(array->children[i]); - delete array->children[i]; - } - delete[] array->children; - -} - int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) { // Convert SQL_DATE_STRUCT to Arrow Date32 (days since epoch) std::tm tm_date = {}; @@ -4160,6 +4137,8 @@ int32_t dateAsDayCount(SQLUSMALLINT year, SQLUSMALLINT month, SQLUSMALLINT day) LOG("Failed to convert SQL_DATE_STRUCT to time_t"); ThrowStdException("Date conversion error"); } + // Sanity check against timezone issues. Since we only provide the date, this has to be true + assert(time_since_epoch % 86400 == 0); // Calculate days since epoch return time_since_epoch / 86400; } @@ -4219,7 +4198,7 @@ SQLRETURN FetchArrowBatch_wrap( columnNamesCStr[i] = std::make_unique(nameLen); std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); - const char* format = nullptr; + std::string format = ""; switch(dataType) { case SQL_CHAR: case SQL_VARCHAR: @@ -4315,9 +4294,9 @@ SQLRETURN FetchArrowBatch_wrap( // Store format string if not already stored (for non-decimal types) if (!columnFormats[i]) { - size_t formatLen = std::strlen(format) + 1; + size_t formatLen = format.length() + 1; columnFormats[i] = std::make_unique(formatLen); - std::memcpy(columnFormats[i].get(), format, formatLen); + std::memcpy(columnFormats[i].get(), format.c_str(), formatLen); } buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); @@ -4812,24 +4791,64 @@ SQLRETURN FetchArrowBatch_wrap( // Transfer ownership of buffers to Arrow structures // Exceptions beyond this point would cause memory leaks - auto batch_children = new ArrowSchema* [numCols]; + + auto batch_children = new ArrowSchema*[numCols]; + for (SQLSMALLINT i = 0; i < numCols; i++) { + auto col_private_data = new ArrowSchemaPrivateData(); + col_private_data->format = std::move(columnFormats[i]); + col_private_data->name = std::move(columnNamesCStr[i]); + auto arrow_schema = new ArrowSchema({ - .format = columnFormats[i].release(), - .name = columnNamesCStr[i].release(), - .flags = columnNullable[i] ? 2 : 0, // ARROW_FLAG_NULLABLE - .release = ArrowSchema_release, + .format = col_private_data->format.get(), + .name = col_private_data->name.get(), + .metadata = nullptr, + .flags = static_cast(columnNullable[i] ? ARROW_FLAG_NULLABLE : 0), + .n_children = 0, + .children = nullptr, + .dictionary = nullptr, + .release = [](ArrowSchema* schema) { + assert(schema != nullptr); + assert(schema->release != nullptr); + assert(schema->private_data != nullptr); + assert(schema->children == nullptr && schema->n_children == 0); + delete schema->private_data; // Frees format and name + schema->release = nullptr; + }, + .private_data = col_private_data, }); batch_children[i] = arrow_schema; } auto arrow_schema_batch = new ArrowSchema({ - .format = []{ char* f = new char[3]; std::strcpy(f, "+s"); return f; }(), - .name = []{ char* n = new char[1]; n[0] = '\0'; return n; }(), + .format = "+s", + .name = "", + .metadata = nullptr, + .flags = 0, .n_children = numCols, .children = batch_children, - .release = ArrowSchema_release, + .dictionary = nullptr, + .release = [](ArrowSchema* schema) { + // format and name are string literals, no need to free + assert(schema != nullptr); + assert(schema->release != nullptr); + assert(schema->private_data == nullptr); + assert(schema->children != nullptr); + assert(schema->n_children > 0); + for (int64_t i = 0; i < schema->n_children; ++i) { + if (schema->children[i]) { + if (schema->children[i]->release) { + schema->children[i]->release(schema->children[i]); + } + delete schema->children[i]; + } + } + delete[] schema->children; + schema->release = nullptr; + }, + .private_data = nullptr, }); + auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { auto arrow_schema = static_cast(ptr); if (arrow_schema->release) { @@ -4841,21 +4860,48 @@ SQLRETURN FetchArrowBatch_wrap( auto arrow_array_batch_buffers = new const void* [3]; memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); + // Necessary dummy buffer, pyarrow will error without it + arrow_array_batch_buffers[1] = new uint8_t[1]{0}; auto arrow_array_batch = new ArrowArray({ .length = static_cast(idxRowArrow), + // only the non null dummy buffer counts .n_buffers = 1, .n_children = numCols, .buffers = arrow_array_batch_buffers, .children = new ArrowArray* [numCols], - .release = ArrowArray_release, + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data == nullptr); + assert(array->release != nullptr); + assert(array->children != nullptr); + assert(array->n_children > 0); + for (int64_t i = 0; i < array->n_children; ++i) { + if (array->children[i]) { + if (array->children[i]->release) { + array->children[i]->release(array->children[i]); + } + delete array->children[i]; + } + } + delete[] array->children; + assert(array->buffers != nullptr); + assert(array->n_buffers == 1); + assert(array->buffers[0] == nullptr); + assert(array->buffers[1] != nullptr); + assert(array->buffers[2] == nullptr); + // Delete dummy buffer + delete[] const_cast(static_cast(array->buffers[1])); + + delete[] array->buffers; + array->release = nullptr; + }, }); - // Necessary dummy buffer - arrow_array_batch->buffers[1] = new int[1]; for (SQLUSMALLINT col = 0; col < numCols; col++) { auto dataType = dataTypes[col]; auto arrow_array_col_buffers = new const void* [3]; memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); + auto private_data = new ArrowArrayPrivateData(); // Allocate new memory and copy the data switch (dataType) { case SQL_CHAR: @@ -4873,52 +4919,65 @@ SQLRETURN FetchArrowBatch_wrap( // length of string at index i is the difference between values at i and i+1 // so total length is value at index idxRowArrow auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; - uint8_t* dataBuffer = new uint8_t[data_buf_len_total]; - std::memcpy(dataBuffer, buffersArrow.var_data[col].data(), data_buf_len_total); - arrow_array_col_buffers[2] = dataBuffer; - arrow_array_col_buffers[1] = buffersArrow.var[col].release(); + auto dataBuffer = std::make_unique(data_buf_len_total); + std::memcpy(dataBuffer.get(), buffersArrow.var_data[col].data(), data_buf_len_total); + private_data->buffer_var_data = std::move(dataBuffer); + arrow_array_col_buffers[2] = private_data->buffer_var_data.get(); + private_data->buffer_var = std::move(buffersArrow.var[col]); + arrow_array_col_buffers[1] = private_data->buffer_var.get(); } break; case SQL_TINYINT: - arrow_array_col_buffers[1] = buffersArrow.uint8[col].release(); + private_data->buffer_uint8 = std::move(buffersArrow.uint8[col]); + arrow_array_col_buffers[1] = private_data->buffer_uint8.get(); break; case SQL_SMALLINT: - arrow_array_col_buffers[1] = buffersArrow.int16[col].release(); + private_data->buffer_int16 = std::move(buffersArrow.int16[col]); + arrow_array_col_buffers[1] = private_data->buffer_int16.get(); break; case SQL_INTEGER: - arrow_array_col_buffers[1] = buffersArrow.int32[col].release(); + private_data->buffer_int32 = std::move(buffersArrow.int32[col]); + arrow_array_col_buffers[1] = private_data->buffer_int32.get(); break; case SQL_BIGINT: - arrow_array_col_buffers[1] = buffersArrow.int64[col].release(); + private_data->buffer_int64 = std::move(buffersArrow.int64[col]); + arrow_array_col_buffers[1] = private_data->buffer_int64.get(); break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: - arrow_array_col_buffers[1] = buffersArrow.float64[col].release(); + private_data->buffer_float64 = std::move(buffersArrow.float64[col]); + arrow_array_col_buffers[1] = private_data->buffer_float64.get(); break; case SQL_DECIMAL: case SQL_NUMERIC: { - arrow_array_col_buffers[1] = buffersArrow.decimal[col].release(); + private_data->buffer_decimal = std::move(buffersArrow.decimal[col]); + arrow_array_col_buffers[1] = private_data->buffer_decimal.get(); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: - arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); + arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); break; case SQL_SS_TIMESTAMPOFFSET: - arrow_array_col_buffers[1] = buffersArrow.ts_micro[col].release(); + private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); + arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); break; case SQL_TYPE_DATE: - arrow_array_col_buffers[1] = buffersArrow.date[col].release(); + private_data->buffer_date = std::move(buffersArrow.date[col]); + arrow_array_col_buffers[1] = private_data->buffer_date.get(); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: - arrow_array_col_buffers[1] = buffersArrow.time_second[col].release(); + private_data->buffer_time_second = std::move(buffersArrow.time_second[col]); + arrow_array_col_buffers[1] = private_data->buffer_time_second.get(); break; case SQL_BIT: - arrow_array_col_buffers[1] = buffersArrow.bit[col].release(); + private_data->buffer_bit = std::move(buffersArrow.bit[col]); + arrow_array_col_buffers[1] = private_data->buffer_bit.get(); break; default: { std::ostringstream errorString; @@ -4938,10 +4997,22 @@ SQLRETURN FetchArrowBatch_wrap( .n_children = 0, .buffers = arrow_array_col_buffers, .children = nullptr, - .release = ArrowArray_release, + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data != nullptr); + assert(array->release != nullptr); + assert(array->children == nullptr); + assert(array->n_children == 0); + delete array->private_data; // Frees all buffer entries + assert(array->buffers != nullptr); + delete[] array->buffers; + array->release = nullptr; + }, + .private_data = private_data, }); - arrow_array_col->buffers[0] = buffersArrow.valid[col].release(); + private_data->buffer_valid = std::move(buffersArrow.valid[col]); + arrow_array_col->buffers[0] = private_data->buffer_valid.get(); arrow_array_batch->children[col] = arrow_array_col; } From e1d08c5529e2992b7c20a218a21ab4c419ac6ac2 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:08:10 +0100 Subject: [PATCH 09/11] Eliminate potential memory leaks on allocation failures when transferring ownership to arrow --- mssql_python/pybind/ddbc_bindings.cpp | 640 ++++++++++++-------------- 1 file changed, 289 insertions(+), 351 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index 4a46efd9..b14752e7 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -157,56 +157,31 @@ struct NumericData { } }; -// Struct to hold data buffers and indicators for each column -struct ColumnBuffersArrow { - std::vector> uint8; - std::vector> int16; - std::vector> int32; - std::vector> int64; - std::vector> float64; - std::vector> bit; - std::vector> var; - std::vector> date; - std::vector> ts_micro; - std::vector> time_second; - std::vector> decimal; - - std::vector> valid; - std::vector> var_data; - - ColumnBuffersArrow(SQLSMALLINT numCols) - : - uint8(numCols), - int16(numCols), - int32(numCols), - int64(numCols), - float64(numCols), - bit(numCols), - var(numCols), - date(numCols), - ts_micro(numCols), - time_second(numCols), - decimal(numCols), - - valid(numCols), - var_data(numCols) {} -}; - struct ArrowArrayPrivateData { - std::unique_ptr buffer_uint8; - std::unique_ptr buffer_int16; - std::unique_ptr buffer_int32; - std::unique_ptr buffer_int64; - std::unique_ptr buffer_float64; - std::unique_ptr buffer_bit; - std::unique_ptr buffer_var; - std::unique_ptr buffer_date; - std::unique_ptr buffer_ts_micro; - std::unique_ptr buffer_time_second; - std::unique_ptr<__int128_t[]> buffer_decimal; - - std::unique_ptr buffer_valid; - std::unique_ptr buffer_var_data; + std::unique_ptr valid; + + std::unique_ptr uint8Val; + std::unique_ptr int16Val; + std::unique_ptr int32Val; + std::unique_ptr int64Val; + std::unique_ptr float64Val; + std::unique_ptr bitVal; + std::unique_ptr varVal; + std::unique_ptr dateVal; + std::unique_ptr tsMicroVal; + std::unique_ptr timeSecondVal; + std::unique_ptr<__int128_t[]> decimalVal; + + std::vector varData; + + // first buffer will be the valid bitmap + // second buffer will be one of the value buffers above + // third buffer will be the varData buffer for variable length types + std::array buffers; + + // Points to one of the typed *Val buffers above. Since the buffer pointers + // don't change, this can be set once during batch initialization. + void* ptrValueBuffer; }; struct ArrowSchemaPrivateData { @@ -4170,15 +4145,20 @@ SQLRETURN FetchArrowBatch_wrap( std::vector dataTypes(numCols); std::vector columnSizes(numCols); std::vector columnNullable(numCols); - std::vector> columnFormats(numCols); - std::vector> columnNamesCStr(numCols); + std::vector columnVarLen(numCols, false); - ColumnBuffersArrow buffersArrow(numCols); + std::vector> arrowArrayPrivateData(numCols); + std::vector> arrowSchemaPrivateData(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayPrivateData[i] = std::make_unique(); + auto& arrowColumnProducer = arrowArrayPrivateData[i]; + arrowSchemaPrivateData[i] = std::make_unique(); + auto colMeta = columnNames[i].cast(); SQLSMALLINT dataType = colMeta["DataType"].cast(); SQLULEN columnSize = colMeta["ColumnSize"].cast(); SQLSMALLINT nullable = colMeta["Nullable"].cast(); + dataTypes[i] = dataType; columnSizes[i] = columnSize; columnNullable[i] = (nullable != SQL_NO_NULLS); @@ -4195,8 +4175,8 @@ SQLRETURN FetchArrowBatch_wrap( std::string columnName = colMeta["ColumnName"].cast(); size_t nameLen = columnName.length() + 1; - columnNamesCStr[i] = std::make_unique(nameLen); - std::memcpy(columnNamesCStr[i].get(), columnName.c_str(), nameLen); + arrowSchemaPrivateData[i]->name = std::make_unique(nameLen); + std::memcpy(arrowSchemaPrivateData[i]->name.get(), columnName.c_str(), nameLen); std::string format = ""; switch(dataType) { @@ -4209,41 +4189,50 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: case SQL_GUID: format = "u"; - buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); - buffersArrow.var_data[i].resize(arrowBatchSize * 42); + arrowColumnProducer->varVal = std::make_unique(arrowBatchSize + 1); + arrowColumnProducer->varData.resize(arrowBatchSize * 42); + columnVarLen[i] = true; // start at offset 0 - buffersArrow.var[i][0] = 0; + arrowColumnProducer->varVal[0] = 0; + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->varVal.get(); break; case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: format = "z"; - buffersArrow.var[i] = std::make_unique(arrowBatchSize + 1); - buffersArrow.var_data[i].resize(arrowBatchSize * 42); + arrowColumnProducer->varVal = std::make_unique(arrowBatchSize + 1); + arrowColumnProducer->varData.resize(arrowBatchSize * 42); + columnVarLen[i] = true; // start at offset 0 - buffersArrow.var[i][0] = 0; + arrowColumnProducer->varVal[0] = 0; + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->varVal.get(); break; case SQL_TINYINT: format = "C"; - buffersArrow.uint8[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->uint8Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->uint8Val.get(); break; case SQL_SMALLINT: format = "s"; - buffersArrow.int16[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->int16Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int16Val.get(); break; case SQL_INTEGER: format = "i"; - buffersArrow.int32[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->int32Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int32Val.get(); break; case SQL_BIGINT: format = "l"; - buffersArrow.int64[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->int64Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->int64Val.get(); break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: format = "g"; - buffersArrow.float64[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->float64Val = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->float64Val.get(); break; case SQL_DECIMAL: case SQL_NUMERIC: { @@ -4251,36 +4240,42 @@ SQLRETURN FetchArrowBatch_wrap( formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); std::string formatStr = formatStream.str(); size_t formatLen = formatStr.length() + 1; - columnFormats[i] = std::make_unique(formatLen); - std::memcpy(columnFormats[i].get(), formatStr.c_str(), formatLen); - format = columnFormats[i].get(); - buffersArrow.decimal[i] = std::make_unique<__int128_t[]>(arrowBatchSize); + arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); + std::memcpy(arrowSchemaPrivateData[i]->format.get(), formatStr.c_str(), formatLen); + format = arrowSchemaPrivateData[i]->format.get(); + arrowColumnProducer->decimalVal = std::make_unique<__int128_t[]>(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->decimalVal.get(); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: format = "tsu:"; - buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->tsMicroVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->tsMicroVal.get(); break; case SQL_SS_TIMESTAMPOFFSET: format = "tsu:+00:00"; - buffersArrow.ts_micro[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->tsMicroVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->tsMicroVal.get(); break; case SQL_TYPE_DATE: format = "tdD"; - buffersArrow.date[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->dateVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->dateVal.get(); break; case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: format = "tts"; - buffersArrow.time_second[i] = std::make_unique(arrowBatchSize); + arrowColumnProducer->timeSecondVal = std::make_unique(arrowBatchSize); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->timeSecondVal.get(); break; case SQL_BIT: format = "b"; - buffersArrow.bit[i] = std::make_unique((arrowBatchSize + 7) / 8); - std::memset(buffersArrow.bit[i].get(), 0, (arrowBatchSize + 7) / 8); + arrowColumnProducer->bitVal = std::make_unique((arrowBatchSize + 7) / 8); + std::memset(arrowColumnProducer->bitVal.get(), 0, (arrowBatchSize + 7) / 8); + arrowColumnProducer->ptrValueBuffer = arrowColumnProducer->bitVal.get(); break; default: std::wstring columnName = colMeta["ColumnName"].cast(); @@ -4292,16 +4287,17 @@ SQLRETURN FetchArrowBatch_wrap( break; } - // Store format string if not already stored (for non-decimal types) - if (!columnFormats[i]) { + // Store format string if not already stored. + // For non-decimal types, format is now a static string. + if (!arrowSchemaPrivateData[i]->format) { size_t formatLen = format.length() + 1; - columnFormats[i] = std::make_unique(formatLen); - std::memcpy(columnFormats[i].get(), format.c_str(), formatLen); + arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); + std::memcpy(arrowSchemaPrivateData[i]->format.get(), format.c_str(), formatLen); } - buffersArrow.valid[i] = std::make_unique((arrowBatchSize + 7) / 8); + arrowColumnProducer->valid = std::make_unique((arrowBatchSize + 7) / 8); // Initialize validity bitmap to all valid - std::memset(buffersArrow.valid[i].get(), 0xFF, (arrowBatchSize + 7) / 8); + std::memset(arrowColumnProducer->valid.get(), 0xFF, (arrowBatchSize + 7) / 8); } if (fetchSize > 1) { @@ -4334,7 +4330,6 @@ SQLRETURN FetchArrowBatch_wrap( SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)(intptr_t)fetchSize, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, &numRowsFetched, 0); - size_t idxRowArrow = 0; // arrowBatchSize % fetchSize == 0 ensures that any followup (even non-arrow) fetches // start with a fresh batch @@ -4355,9 +4350,10 @@ SQLRETURN FetchArrowBatch_wrap( // It'll be populated by SQLFetch assert(numRowsFetched + idxRowArrow <= static_cast(arrowBatchSize)); for (SQLULEN idxRowSql = 0; idxRowSql < numRowsFetched; idxRowSql++) { - for (SQLUSMALLINT col = 1; col <= numCols; col++) { - auto dataType = dataTypes[col - 1]; - auto columnSize = columnSizes[col - 1]; + for (SQLUSMALLINT idxCol = 0; idxCol < numCols; idxCol++) { + auto& arrowColumnProducer = arrowArrayPrivateData[idxCol]; + auto dataType = dataTypes[idxCol]; + auto columnSize = columnSizes[idxCol]; if (hasLobColumns) { assert(idxRowSql == 0 && "GetData only works one row at a time"); @@ -4368,10 +4364,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_LONGVARBINARY: { GetDataVar( hStmt, - col, + idxCol + 1, SQL_C_BINARY, - buffers.charBuffers[col - 1], - buffers.indicators[col - 1].data() + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data() ); break; } @@ -4380,10 +4376,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_LONGVARCHAR: { GetDataVar( hStmt, - col, + idxCol + 1, SQL_C_CHAR, - buffers.charBuffers[col - 1], - buffers.indicators[col - 1].data() + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data() ); break; } @@ -4393,152 +4389,152 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: { GetDataVar( hStmt, - col, + idxCol + 1, SQL_C_WCHAR, - buffers.wcharBuffers[col - 1], - buffers.indicators[col - 1].data() + buffers.wcharBuffers[idxCol], + buffers.indicators[idxCol].data() ); break; } case SQL_INTEGER: { - buffers.intBuffers[col - 1].resize(1); + buffers.intBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SLONG, - buffers.intBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SLONG, + buffers.intBuffers[idxCol].data(), sizeof(SQLINTEGER), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_SMALLINT: { - buffers.smallIntBuffers[col - 1].resize(1); + buffers.smallIntBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SSHORT, - buffers.smallIntBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SSHORT, + buffers.smallIntBuffers[idxCol].data(), sizeof(SQLSMALLINT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TINYINT: { - buffers.charBuffers[col - 1].resize(1); + buffers.charBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TINYINT, - buffers.charBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TINYINT, + buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_BIT: { - buffers.charBuffers[col - 1].resize(1); + buffers.charBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_BIT, - buffers.charBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_BIT, + buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_REAL: { - buffers.realBuffers[col - 1].resize(1); + buffers.realBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_FLOAT, - buffers.realBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_FLOAT, + buffers.realBuffers[idxCol].data(), sizeof(SQLREAL), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_DECIMAL: case SQL_NUMERIC: { - buffers.charBuffers[col - 1].resize(MAX_DIGITS_IN_NUMERIC); + buffers.charBuffers[idxCol].resize(MAX_DIGITS_IN_NUMERIC); SQLGetData_ptr( - hStmt, col, SQL_C_CHAR, - buffers.charBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_CHAR, + buffers.charBuffers[idxCol].data(), MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_DOUBLE: case SQL_FLOAT: { - buffers.doubleBuffers[col - 1].resize(1); + buffers.doubleBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_DOUBLE, - buffers.doubleBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_DOUBLE, + buffers.doubleBuffers[idxCol].data(), sizeof(SQLDOUBLE), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - buffers.timestampBuffers[col - 1].resize(1); + buffers.timestampBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TYPE_TIMESTAMP, - buffers.timestampBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[idxCol].data(), sizeof(SQL_TIMESTAMP_STRUCT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_BIGINT: { - buffers.bigIntBuffers[col - 1].resize(1); + buffers.bigIntBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SBIGINT, - buffers.bigIntBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SBIGINT, + buffers.bigIntBuffers[idxCol].data(), sizeof(SQLBIGINT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TYPE_DATE: { - buffers.dateBuffers[col - 1].resize(1); + buffers.dateBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TYPE_DATE, - buffers.dateBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TYPE_DATE, + buffers.dateBuffers[idxCol].data(), sizeof(SQL_DATE_STRUCT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { - buffers.timeBuffers[col - 1].resize(1); + buffers.timeBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_TYPE_TIME, - buffers.timeBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_TYPE_TIME, + buffers.timeBuffers[idxCol].data(), sizeof(SQL_TIME_STRUCT), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_GUID: { - buffers.guidBuffers[col - 1].resize(1); + buffers.guidBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_GUID, - buffers.guidBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_GUID, + buffers.guidBuffers[idxCol].data(), sizeof(SQLGUID), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } case SQL_SS_TIMESTAMPOFFSET: { - buffers.datetimeoffsetBuffers[col - 1].resize(1); + buffers.datetimeoffsetBuffers[idxCol].resize(1); SQLGetData_ptr( - hStmt, col, SQL_C_SS_TIMESTAMPOFFSET, - buffers.datetimeoffsetBuffers[col - 1].data(), + hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[idxCol].data(), sizeof(DateTimeOffset), - buffers.indicators[col - 1].data() + buffers.indicators[idxCol].data() ); break; } default: { std::ostringstream errorString; - errorString << "Unsupported data type for column ID - " << col + errorString << "Unsupported data type for column ID - " << (idxCol + 1) << ", Type - " << dataType; LOG("SQLGetData: %s", errorString.str().c_str()); ThrowStdException(errorString.str()); @@ -4547,13 +4543,13 @@ SQLRETURN FetchArrowBatch_wrap( } } - SQLLEN dataLen = buffers.indicators[col - 1][idxRowSql]; + SQLLEN dataLen = buffers.indicators[idxCol][idxRowSql]; if (dataLen == SQL_NULL_DATA) { // Mark as null in validity bitmap size_t bytePos = idxRowArrow / 8; size_t bitPos = idxRowArrow % 8; - buffersArrow.valid[col - 1][bytePos] &= ~(1 << bitPos); + arrowColumnProducer->valid[bytePos] &= ~(1 << bitPos); // Value buffer for variable length data types needs to be set appropriately // as it will be used by the next non null value @@ -4570,7 +4566,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: - buffersArrow.var[col - 1][idxRowArrow + 1] = buffersArrow.var[col - 1][idxRowArrow]; + arrowColumnProducer->varVal[idxRowArrow + 1] = arrowColumnProducer->varVal[idxRowArrow]; break; default: break; @@ -4578,7 +4574,7 @@ SQLRETURN FetchArrowBatch_wrap( continue; } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception - LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", col, dataType, dataLen); + LOG("Unexpected negative data length. Column ID - {}, SQL Type - {}, Data Length - {}", idxCol + 1, dataType, dataLen); ThrowStdException("Unexpected negative data length."); } @@ -4587,28 +4583,28 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_VARBINARY: case SQL_LONGVARBINARY: { uint64_t fetchBufferSize = columnSize /* bytes are not null terminated */; - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; while (target_vec->size() < start + dataLen) { target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { uint64_t fetchBufferSize = columnSize + 1 /* null-termination */; - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; while (target_vec->size() < start + dataLen) { target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[col - 1][idxRowSql * fetchBufferSize], dataLen); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLen; + std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } case SQL_SS_XML: @@ -4617,9 +4613,9 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: { assert(dataLen % sizeof(SQLWCHAR) == 0); auto dataLenW = dataLen / sizeof(SQLWCHAR); - auto wcharSource = &buffers.wcharBuffers[col - 1][idxRowSql * (columnSize + 1)]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; - auto target_vec = &buffersArrow.var_data[col - 1]; + auto wcharSource = &buffers.wcharBuffers[idxCol][idxRowSql * (columnSize + 1)]; + auto start = arrowColumnProducer->varVal[idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; #if defined(_WIN32) // Convert wide string int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, NULL, 0, NULL, NULL); @@ -4627,7 +4623,7 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } WideCharToMultiByte(CP_UTF8, 0, wcharSource, dataLenW, &(*target_vec)[start], dataLenConverted, NULL, NULL); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + dataLenConverted; + arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLenConverted; #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 std::string utf8str = WideToUTF8(SQLWCHARToWString(wcharSource, dataLenW)); @@ -4635,15 +4631,15 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } std::memcpy(&(*target_vec)[start], utf8str.data(), utf8str.size()); - buffersArrow.var[col - 1][idxRowArrow + 1] = start + utf8str.size(); + arrowColumnProducer->varVal[idxRowArrow + 1] = start + utf8str.size(); #endif break; } case SQL_GUID: { // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") // Each GUID is exactly 36 bytes in UTF-8 - auto target_vec = &buffersArrow.var_data[col - 1]; - auto start = buffersArrow.var[col - 1][idxRowArrow]; + auto target_vec = &arrowColumnProducer->varData; + auto start = arrowColumnProducer->varVal[idxRowArrow]; // Ensure buffer has space for the GUID string + null terminator while (target_vec->size() < start + 37) { @@ -4651,7 +4647,7 @@ SQLRETURN FetchArrowBatch_wrap( } // Get the GUID from the buffer - const SQLGUID& guidValue = buffers.guidBuffers[col - 1][idxRowSql]; + const SQLGUID& guidValue = buffers.guidBuffers[idxCol][idxRowSql]; // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx snprintf(reinterpret_cast(&target_vec->data()[start]), 37, @@ -4665,25 +4661,25 @@ SQLRETURN FetchArrowBatch_wrap( guidValue.Data4[6], guidValue.Data4[7]); // Update offset for next row, ignoring null terminator - buffersArrow.var[col - 1][idxRowArrow + 1] = start + 36; + arrowColumnProducer->varVal[idxRowArrow + 1] = start + 36; break; } case SQL_TINYINT: - buffersArrow.uint8[col - 1][idxRowArrow] = buffers.charBuffers[col - 1][idxRowSql]; + arrowColumnProducer->uint8Val[idxRowArrow] = buffers.charBuffers[idxCol][idxRowSql]; break; case SQL_SMALLINT: - buffersArrow.int16[col - 1][idxRowArrow] = buffers.smallIntBuffers[col - 1][idxRowSql]; + arrowColumnProducer->int16Val[idxRowArrow] = buffers.smallIntBuffers[idxCol][idxRowSql]; break; case SQL_INTEGER: - buffersArrow.int32[col - 1][idxRowArrow] = buffers.intBuffers[col - 1][idxRowSql]; + arrowColumnProducer->int32Val[idxRowArrow] = buffers.intBuffers[idxCol][idxRowSql]; break; case SQL_BIGINT: - buffersArrow.int64[col - 1][idxRowArrow] = buffers.bigIntBuffers[col - 1][idxRowSql]; + arrowColumnProducer->int64Val[idxRowArrow] = buffers.bigIntBuffers[idxCol][idxRowSql]; break; case SQL_REAL: case SQL_FLOAT: case SQL_DOUBLE: - buffersArrow.float64[col - 1][idxRowArrow] = buffers.doubleBuffers[col - 1][idxRowSql]; + arrowColumnProducer->float64Val[idxRowArrow] = buffers.doubleBuffers[idxCol][idxRowSql]; break; case SQL_DECIMAL: case SQL_NUMERIC: { @@ -4692,26 +4688,26 @@ SQLRETURN FetchArrowBatch_wrap( auto start = idxRowSql * MAX_DIGITS_IN_NUMERIC; int sign = 1; for (SQLULEN idx = start; idx < start + dataLen; idx++) { - char digitChar = buffers.charBuffers[col - 1][idx]; + char digitChar = buffers.charBuffers[idxCol][idx]; if (digitChar == '-') { sign = -1; } else if (digitChar >= '0' && digitChar <= '9') { decimalValue = decimalValue * 10 + (digitChar - '0'); } } - buffersArrow.decimal[col - 1][idxRowArrow] = decimalValue * sign; + arrowColumnProducer->decimalVal[idxRowArrow] = decimalValue * sign; break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[col - 1][idxRowSql]; + SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[idxCol][idxRowSql]; int64_t days = dateAsDayCount( sql_value.year, sql_value.month, sql_value.day ); - buffersArrow.ts_micro[col - 1][idxRowArrow] = + arrowColumnProducer->tsMicroVal[idxRowArrow] = days * 86400 * 1000000 + static_cast(sql_value.hour) * 3600 * 1000000 + static_cast(sql_value.minute) * 60 * 1000000 + @@ -4720,13 +4716,13 @@ SQLRETURN FetchArrowBatch_wrap( break; } case SQL_SS_TIMESTAMPOFFSET: { - DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[col - 1][idxRowSql]; + DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[idxCol][idxRowSql]; int64_t days = dateAsDayCount( sql_value.year, sql_value.month, sql_value.day ); - buffersArrow.ts_micro[col - 1][idxRowArrow] = + arrowColumnProducer->tsMicroVal[idxRowArrow] = days * 86400 * 1000000 + (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + @@ -4735,10 +4731,10 @@ SQLRETURN FetchArrowBatch_wrap( break; } case SQL_TYPE_DATE: - buffersArrow.date[col - 1][idxRowArrow] = dateAsDayCount( - buffers.dateBuffers[col - 1][idxRowSql].year, - buffers.dateBuffers[col - 1][idxRowSql].month, - buffers.dateBuffers[col - 1][idxRowSql].day + arrowColumnProducer->dateVal[idxRowArrow] = dateAsDayCount( + buffers.dateBuffers[idxCol][idxRowSql].year, + buffers.dateBuffers[idxCol][idxRowSql].month, + buffers.dateBuffers[idxCol][idxRowSql].day ); break; case SQL_TIME: @@ -4746,8 +4742,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_SS_TIME2: { // NOTE: SQL_SS_TIME2 supports fractional seconds, but SQL_C_TYPE_TIME does not. // To fully support SQL_SS_TIME2, the corresponding c-type should be used. - const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[col - 1][idxRowSql]; - buffersArrow.time_second[col - 1][idxRowArrow] = + const SQL_TIME_STRUCT& timeValue = buffers.timeBuffers[idxCol][idxRowSql]; + arrowColumnProducer->timeSecondVal[idxRowArrow] = static_cast(timeValue.hour) * 3600 + static_cast(timeValue.minute) * 60 + static_cast(timeValue.second); @@ -4756,7 +4752,7 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BIT: { // SQL_BIT is stored as a single bit in Arrow's bitmap format // Get the boolean value from the buffer - bool bitValue = buffers.charBuffers[col - 1][idxRowSql] != 0; + bool bitValue = buffers.charBuffers[idxCol][idxRowSql] != 0; // Set the bit in the Arrow bitmap size_t byteIndex = idxRowArrow / 8; @@ -4764,16 +4760,16 @@ SQLRETURN FetchArrowBatch_wrap( if (bitValue) { // Set bit to 1 - buffersArrow.bit[col - 1][byteIndex] |= (1 << bitIndex); + arrowColumnProducer->bitVal[byteIndex] |= (1 << bitIndex); } else { // Clear bit to 0 - buffersArrow.bit[col - 1][byteIndex] &= ~(1 << bitIndex); + arrowColumnProducer->bitVal[byteIndex] &= ~(1 << bitIndex); } break; } default: { std::ostringstream errorString; - errorString << "Unsupported data type for column ID - " << col + errorString << "Unsupported data type for column ID - " << (idxCol + 1) << ", Type - " << dataType; LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); @@ -4789,19 +4785,23 @@ SQLRETURN FetchArrowBatch_wrap( SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROW_ARRAY_SIZE, (SQLPOINTER)1, 0); SQLSetStmtAttr_ptr(hStmt, SQL_ATTR_ROWS_FETCHED_PTR, NULL, 0); - // Transfer ownership of buffers to Arrow structures - // Exceptions beyond this point would cause memory leaks - - auto batch_children = new ArrowSchema*[numCols]; + // Transfer ownership of buffers to batch ArrowSchema + // First, allocate memory for the necessary structures + auto arrowSchemaBatch = std::make_unique(); + auto arrowSchemaBatchChildren = std::make_unique(numCols); + auto arrowSchemaBatchChildPointers = std::make_unique[]>(numCols); for (SQLSMALLINT i = 0; i < numCols; i++) { - auto col_private_data = new ArrowSchemaPrivateData(); - col_private_data->format = std::move(columnFormats[i]); - col_private_data->name = std::move(columnNamesCStr[i]); + arrowSchemaBatchChildPointers[i] = std::make_unique(); + } - auto arrow_schema = new ArrowSchema({ - .format = col_private_data->format.get(), - .name = col_private_data->name.get(), + // Second, transfer ownership to arrowSchemaBatch + // No unhandled exceptions until the pycapsule owns the arrowSchemaBatch to avoid memory leaks + + for (SQLSMALLINT i = 0; i < numCols; i++) { + *arrowSchemaBatchChildPointers[i] = { + .format = arrowSchemaPrivateData[i]->format.get(), + .name = arrowSchemaPrivateData[i]->name.get(), .metadata = nullptr, .flags = static_cast(columnNullable[i] ? ARROW_FLAG_NULLABLE : 0), .n_children = 0, @@ -4815,18 +4815,21 @@ SQLRETURN FetchArrowBatch_wrap( delete schema->private_data; // Frees format and name schema->release = nullptr; }, - .private_data = col_private_data, - }); - batch_children[i] = arrow_schema; + .private_data = arrowSchemaPrivateData[i].release(), + }; + } + + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowSchemaBatchChildren[i] = arrowSchemaBatchChildPointers[i].release(); } - auto arrow_schema_batch = new ArrowSchema({ + *arrowSchemaBatch = ArrowSchema{ .format = "+s", .name = "", .metadata = nullptr, .flags = 0, .n_children = numCols, - .children = batch_children, + .children = arrowSchemaBatchChildren.release(), .dictionary = nullptr, .release = [](ArrowSchema* schema) { // format and name are string literals, no need to free @@ -4847,28 +4850,79 @@ SQLRETURN FetchArrowBatch_wrap( schema->release = nullptr; }, .private_data = nullptr, - }); + }; - auto caps = py::capsule((void*)arrow_schema_batch, "arrow_schema", [](void* ptr) { - auto arrow_schema = static_cast(ptr); - if (arrow_schema->release) { - arrow_schema->release(arrow_schema); - } - delete arrow_schema; - }); - capsules.append(caps); + // Finally, transfer ownership of arrowSchemaBatch and its pointer to pycapsule + py::capsule arrowSchemaBatchCapsule; + try { + arrowSchemaBatchCapsule = py::capsule(arrowSchemaBatch.get(), "arrow_schema", [](void* ptr) { + auto arrowSchema = static_cast(ptr); + if (arrowSchema->release) { + arrowSchema->release(arrowSchema); + } + delete arrowSchema; + }); + } catch (...) { + arrowSchemaBatch->release(arrowSchemaBatch.get()); + throw; + } + arrowSchemaBatch.release(); + capsules.append(arrowSchemaBatchCapsule); + + // Transfer ownership of buffers to batch ArrowArray + // First, allocate memory for the necessary structures + auto arrowArrayBatch = std::make_unique(); + + auto arrowArrayBatchBuffers = std::make_unique(1); + arrowArrayBatchBuffers[0] = nullptr; - auto arrow_array_batch_buffers = new const void* [3]; - memset(arrow_array_batch_buffers, 0, sizeof(const void*) * 3); - // Necessary dummy buffer, pyarrow will error without it - arrow_array_batch_buffers[1] = new uint8_t[1]{0}; - auto arrow_array_batch = new ArrowArray({ + auto arrowArrayBatchChildren = std::make_unique(numCols); + auto arrowArrayBatchChildPointers = std::make_unique[]>(numCols); + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayBatchChildPointers[i] = std::make_unique(); + } + + // Second, transfer ownership to arrowArrayBatch + // No unhandled exceptions until the pycapsule owns the arrowArrayBatch to avoid memory leaks + + for (SQLUSMALLINT col = 0; col < numCols; col++) { + auto dataType = dataTypes[col]; + arrowArrayPrivateData[col]->buffers[0] = arrowArrayPrivateData[col]->valid.get(); + arrowArrayPrivateData[col]->buffers[1] = arrowArrayPrivateData[col]->ptrValueBuffer; + arrowArrayPrivateData[col]->buffers[2] = arrowArrayPrivateData[col]->varData.data(); + + *arrowArrayBatchChildPointers[col] = { + .length = static_cast(idxRowArrow), + .null_count = 0, + .offset = 0, + .n_buffers = columnVarLen[col] ? 3 : 2, + .n_children = 0, + .buffers = (const void**)arrowArrayPrivateData[col]->buffers.data(), + .children = nullptr, + .release = [](ArrowArray* array) { + assert(array != nullptr); + assert(array->private_data != nullptr); + assert(array->release != nullptr); + assert(array->children == nullptr); + assert(array->n_children == 0); + delete array->private_data; // Frees all buffer entries + assert(array->buffers != nullptr); + array->release = nullptr; + }, + .private_data = arrowArrayPrivateData[col].release(), + }; + } + + for (SQLSMALLINT i = 0; i < numCols; i++) { + arrowArrayBatchChildren[i] = arrowArrayBatchChildPointers[i].release(); + } + + *arrowArrayBatch = ArrowArray{ .length = static_cast(idxRowArrow), - // only the non null dummy buffer counts .n_buffers = 1, .n_children = numCols, - .buffers = arrow_array_batch_buffers, - .children = new ArrowArray* [numCols], + .buffers = arrowArrayBatchBuffers.release(), + .children = arrowArrayBatchChildren.release(), .release = [](ArrowArray* array) { assert(array != nullptr); assert(array->private_data == nullptr); @@ -4887,147 +4941,31 @@ SQLRETURN FetchArrowBatch_wrap( assert(array->buffers != nullptr); assert(array->n_buffers == 1); assert(array->buffers[0] == nullptr); - assert(array->buffers[1] != nullptr); - assert(array->buffers[2] == nullptr); - // Delete dummy buffer - delete[] const_cast(static_cast(array->buffers[1])); - delete[] array->buffers; array->release = nullptr; }, - }); + }; - for (SQLUSMALLINT col = 0; col < numCols; col++) { - auto dataType = dataTypes[col]; - auto arrow_array_col_buffers = new const void* [3]; - memset(arrow_array_col_buffers, 0, sizeof(const void*) * 3); - auto private_data = new ArrowArrayPrivateData(); - // Allocate new memory and copy the data - switch (dataType) { - case SQL_CHAR: - case SQL_VARCHAR: - case SQL_LONGVARCHAR: - case SQL_SS_XML: - case SQL_WCHAR: - case SQL_WVARCHAR: - case SQL_WLONGVARCHAR: - case SQL_GUID: - case SQL_BINARY: - case SQL_VARBINARY: - case SQL_LONGVARBINARY: { - assert(buffersArrow.var[col][0] == 0); - // length of string at index i is the difference between values at i and i+1 - // so total length is value at index idxRowArrow - auto data_buf_len_total = buffersArrow.var[col][idxRowArrow]; - auto dataBuffer = std::make_unique(data_buf_len_total); - std::memcpy(dataBuffer.get(), buffersArrow.var_data[col].data(), data_buf_len_total); - private_data->buffer_var_data = std::move(dataBuffer); - arrow_array_col_buffers[2] = private_data->buffer_var_data.get(); - private_data->buffer_var = std::move(buffersArrow.var[col]); - arrow_array_col_buffers[1] = private_data->buffer_var.get(); - } - break; - case SQL_TINYINT: - private_data->buffer_uint8 = std::move(buffersArrow.uint8[col]); - arrow_array_col_buffers[1] = private_data->buffer_uint8.get(); - break; - case SQL_SMALLINT: - private_data->buffer_int16 = std::move(buffersArrow.int16[col]); - arrow_array_col_buffers[1] = private_data->buffer_int16.get(); - break; - case SQL_INTEGER: - private_data->buffer_int32 = std::move(buffersArrow.int32[col]); - arrow_array_col_buffers[1] = private_data->buffer_int32.get(); - break; - case SQL_BIGINT: - private_data->buffer_int64 = std::move(buffersArrow.int64[col]); - arrow_array_col_buffers[1] = private_data->buffer_int64.get(); - break; - case SQL_REAL: - case SQL_FLOAT: - case SQL_DOUBLE: - private_data->buffer_float64 = std::move(buffersArrow.float64[col]); - arrow_array_col_buffers[1] = private_data->buffer_float64.get(); - break; - case SQL_DECIMAL: - case SQL_NUMERIC: { - private_data->buffer_decimal = std::move(buffersArrow.decimal[col]); - arrow_array_col_buffers[1] = private_data->buffer_decimal.get(); - break; - } - case SQL_TIMESTAMP: - case SQL_TYPE_TIMESTAMP: - case SQL_DATETIME: - private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); - arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); - break; - case SQL_SS_TIMESTAMPOFFSET: - private_data->buffer_ts_micro = std::move(buffersArrow.ts_micro[col]); - arrow_array_col_buffers[1] = private_data->buffer_ts_micro.get(); - break; - case SQL_TYPE_DATE: - private_data->buffer_date = std::move(buffersArrow.date[col]); - arrow_array_col_buffers[1] = private_data->buffer_date.get(); - break; - case SQL_TIME: - case SQL_TYPE_TIME: - case SQL_SS_TIME2: - private_data->buffer_time_second = std::move(buffersArrow.time_second[col]); - arrow_array_col_buffers[1] = private_data->buffer_time_second.get(); - break; - case SQL_BIT: - private_data->buffer_bit = std::move(buffersArrow.bit[col]); - arrow_array_col_buffers[1] = private_data->buffer_bit.get(); - break; - default: { - std::ostringstream errorString; - errorString << "Unsupported data type for column ID - " << (col + 1) - << ", Type - " << dataType; - LOG(errorString.str().c_str()); - ThrowStdException(errorString.str()); - break; + // Finally, transfer ownership of arrowArrayBatch and its pointer to pycapsule + py::capsule arrowArrayBatchCapsule; + try { + arrowArrayBatchCapsule = py::capsule(arrowArrayBatch.get(), "arrow_array", [](void* ptr) { + auto arrowArray = static_cast(ptr); + if (arrowArray->release) { + arrowArray->release(arrowArray); } - } - - auto arrow_array_col = new ArrowArray({ - .length = static_cast(idxRowArrow), - .null_count = 0, - .offset = 0, - .n_buffers = arrow_array_col_buffers[2] ? 3 : 2, - .n_children = 0, - .buffers = arrow_array_col_buffers, - .children = nullptr, - .release = [](ArrowArray* array) { - assert(array != nullptr); - assert(array->private_data != nullptr); - assert(array->release != nullptr); - assert(array->children == nullptr); - assert(array->n_children == 0); - delete array->private_data; // Frees all buffer entries - assert(array->buffers != nullptr); - delete[] array->buffers; - array->release = nullptr; - }, - .private_data = private_data, + delete arrowArray; }); - - private_data->buffer_valid = std::move(buffersArrow.valid[col]); - arrow_array_col->buffers[0] = private_data->buffer_valid.get(); - arrow_array_batch->children[col] = arrow_array_col; + } catch (...) { + arrowArrayBatch->release(arrowArrayBatch.get()); + throw; } - - capsules.append(py::capsule((void*)arrow_array_batch, "arrow_array", [](void* ptr) { - auto arrow_array = static_cast(ptr); - if (arrow_array->release) { - arrow_array->release(arrow_array); - } - delete arrow_array; - })); + arrowArrayBatch.release(); + capsules.append(arrowArrayBatchCapsule); return ret; } - // FetchAll_wrap - Fetches all rows of data from the result set. // // @param StatementHandle: Handle to the statement from which data is to be From f79f02943affaed62f8b6d99a2f209a84e351e6a Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:12:34 +0100 Subject: [PATCH 10/11] Check returncode for SQLGetData --- mssql_python/pybind/ddbc_bindings.cpp | 96 ++++++++++++++++++++++----- 1 file changed, 80 insertions(+), 16 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index b14752e7..cb3df321 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4362,174 +4362,238 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - GetDataVar( + ret = GetDataVar( hStmt, idxCol + 1, SQL_C_BINARY, buffers.charBuffers[idxCol], buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching LOB for column %d", idxCol + 1); + return ret; + } break; } case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - GetDataVar( + ret = GetDataVar( hStmt, idxCol + 1, SQL_C_CHAR, buffers.charBuffers[idxCol], buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching LOB for column %d", idxCol + 1); + return ret; + } break; } case SQL_SS_XML: case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - GetDataVar( + ret = GetDataVar( hStmt, idxCol + 1, SQL_C_WCHAR, buffers.wcharBuffers[idxCol], buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching binary data for column %d", idxCol + 1); + return ret; + } break; } case SQL_INTEGER: { buffers.intBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SLONG, buffers.intBuffers[idxCol].data(), sizeof(SQLINTEGER), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SLONG data for column %d", idxCol + 1); + return ret; + } break; } case SQL_SMALLINT: { buffers.smallIntBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SSHORT, buffers.smallIntBuffers[idxCol].data(), sizeof(SQLSMALLINT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SSHORT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TINYINT: { buffers.charBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TINYINT, buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TINYINT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_BIT: { buffers.charBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_BIT, buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching BIT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_REAL: { buffers.realBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_FLOAT, buffers.realBuffers[idxCol].data(), sizeof(SQLREAL), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching FLOAT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_DECIMAL: case SQL_NUMERIC: { buffers.charBuffers[idxCol].resize(MAX_DIGITS_IN_NUMERIC); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_CHAR, buffers.charBuffers[idxCol].data(), MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching CHAR data for column %d", idxCol + 1); + return ret; + } break; } case SQL_DOUBLE: case SQL_FLOAT: { buffers.doubleBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_DOUBLE, buffers.doubleBuffers[idxCol].data(), sizeof(SQLDOUBLE), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching DOUBLE data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { buffers.timestampBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, buffers.timestampBuffers[idxCol].data(), sizeof(SQL_TIMESTAMP_STRUCT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_TIMESTAMP data for column %d", idxCol + 1); + return ret; + } break; } case SQL_BIGINT: { buffers.bigIntBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SBIGINT, buffers.bigIntBuffers[idxCol].data(), sizeof(SQLBIGINT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SBIGINT data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TYPE_DATE: { buffers.dateBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TYPE_DATE, buffers.dateBuffers[idxCol].data(), sizeof(SQL_DATE_STRUCT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_DATE data for column %d", idxCol + 1); + return ret; + } break; } case SQL_TIME: case SQL_TYPE_TIME: case SQL_SS_TIME2: { buffers.timeBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_TYPE_TIME, buffers.timeBuffers[idxCol].data(), sizeof(SQL_TIME_STRUCT), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching TYPE_TIME data for column %d", idxCol + 1); + return ret; + } break; } case SQL_GUID: { buffers.guidBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_GUID, buffers.guidBuffers[idxCol].data(), sizeof(SQLGUID), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching GUID data for column %d", idxCol + 1); + return ret; + } break; } case SQL_SS_TIMESTAMPOFFSET: { buffers.datetimeoffsetBuffers[idxCol].resize(1); - SQLGetData_ptr( + ret = SQLGetData_ptr( hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, buffers.datetimeoffsetBuffers[idxCol].data(), sizeof(DateTimeOffset), buffers.indicators[idxCol].data() ); + if (!SQL_SUCCEEDED(ret)) { + LOG("Error fetching SS_TIMESTAMPOFFSET data for column %d", idxCol + 1); + return ret; + } break; } default: { From 5834c731652412b787047758957031124990c454 Mon Sep 17 00:00:00 2001 From: ffelixg <142172984+ffelixg@users.noreply.github.com> Date: Sat, 6 Dec 2025 20:31:21 +0100 Subject: [PATCH 11/11] Fix null count array attribute --- mssql_python/pybind/ddbc_bindings.cpp | 5 ++++- tests/test_004_cursor.py | 26 +++++++++++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index cb3df321..810c5b9c 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -4146,6 +4146,7 @@ SQLRETURN FetchArrowBatch_wrap( std::vector columnSizes(numCols); std::vector columnNullable(numCols); std::vector columnVarLen(numCols, false); + std::vector nullCounts(numCols, 0); std::vector> arrowArrayPrivateData(numCols); std::vector> arrowSchemaPrivateData(numCols); @@ -4635,6 +4636,8 @@ SQLRETURN FetchArrowBatch_wrap( default: break; } + + nullCounts[idxCol] += 1; continue; } else if (dataLen < 0) { // Negative value is unexpected, log column index, SQL type & raise exception @@ -4957,7 +4960,7 @@ SQLRETURN FetchArrowBatch_wrap( *arrowArrayBatchChildPointers[col] = { .length = static_cast(idxRowArrow), - .null_count = 0, + .null_count = nullCounts[col], .offset = 0, .n_buffers = columnVarLen[col] ? 3 : 2, .n_children = 0, diff --git a/tests/test_004_cursor.py b/tests/test_004_cursor.py index 47981817..c82afaa4 100644 --- a/tests/test_004_cursor.py +++ b/tests/test_004_cursor.py @@ -20,8 +20,11 @@ try: import pyarrow as pa + import pyarrow.parquet as pq + import io except ImportError: pa = None + pq = None # Setup test table @@ -14884,12 +14887,17 @@ def _test_arrow_test_data(cursor: mssql_python.Cursor, arrow_test_data, fetch_le full_query = "\nunion all\n".join(selects) ret = cursor.execute(full_query).arrow_batch(fetch_length) for i_col, col in enumerate(ret): + expected_data = arrow_test_data[i_col][2][:fetch_length] for i_row, (v_expected, v_actual) in enumerate( - zip(arrow_test_data[i_col][2][:fetch_length], col.to_pylist(), strict=True) + zip(expected_data, col.to_pylist(), strict=True) ): assert ( v_expected == v_actual ), f"Mismatch in column {i_col}, row {i_row}: expected {v_expected}, got {v_actual}" + # check that null counts match + expected_null_count = sum(1 for v in expected_data if v is None) + actual_null_count = col.null_count + assert expected_null_count == actual_null_count, (expected_null_count, actual_null_count) for i_col, (pa_type, sql_type, values) in enumerate(arrow_test_data): field = ret.schema.field(i_col) assert ( @@ -14899,6 +14907,22 @@ def _test_arrow_test_data(cursor: mssql_python.Cursor, arrow_test_data, fetch_le pa_type ), f"Column {i_col} type mismatch: expected {pa_type}, got {field.type}" + # Validate that Parquet serialization/deserialization does not detect any issues + tbl = pa.Table.from_batches([ret]) + # for some reason parquet converts seconds to milliseconds in time32 + for i_col, col in enumerate(tbl.columns): + if col.type == pa.time32("s"): + tbl = tbl.set_column( + i_col, + tbl.schema.field(i_col).name, + col.cast(pa.time32("ms")), + ) + buffer = io.BytesIO() + pq.write_table(tbl, buffer) + buffer.seek(0) + read_table = pq.read_table(buffer) + assert read_table.equals(tbl) + @pytest.mark.skipif(pa is None, reason="pyarrow is not installed") def test_arrow_lob_wide(cursor: mssql_python.Cursor):