Skip to content

Commit 6a5c592

Browse files
committed
Merged PR 5322: Refactor numeric data handling in DDBC bindings to use SQL_CHAR for better co...
Refactor numeric data handling in DDBC bindings to use SQL_CHAR for better compatibility and precision ---- #### AI description (iteration 1) #### PR Classification Code refactor to improve numeric data handling in DDBC bindings. #### PR Summary This pull request refactors the handling of numeric data in DDBC bindings to use `SQL_CHAR` for better compatibility and simplicity. - `mssql_python/pybind/ddbc_bindings.cpp`: Replaced `SQL_NUMERIC_STRUCT` with `SQLCHAR` for numeric data handling, including changes to `SQLGetData_wrap`, `FetchBatchData`, and `SQLBindColums` functions. - Defined `SQL_NUMERIC_SIZE` constant and removed `numericBuffers` from `ColumnBuffers` struct. <!-- GitOpsUserAgent=GitOps.Apps.Server.pullrequestcopilot --> Related work items: #33915
1 parent e2bb27a commit 6a5c592

File tree

1 file changed

+37
-26
lines changed

1 file changed

+37
-26
lines changed

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ using namespace pybind11::literals;
3131

3232
// This constant is not exposed via sql.h, hence define it here
3333
#define SQL_SS_TIME2 (-154)
34+
#define SQL_NUMERIC_SIZE 64
3435

3536
// Logs data to stdout only in debug builds
3637
// TODO: Handle both UTF-8 and UTF-16 strings
@@ -110,7 +111,6 @@ struct ColumnBuffers {
110111
std::vector<std::vector<SQLSMALLINT>> smallIntBuffers;
111112
std::vector<std::vector<SQLREAL>> realBuffers;
112113
std::vector<std::vector<SQLDOUBLE>> doubleBuffers;
113-
std::vector<std::vector<SQL_NUMERIC_STRUCT>> numericBuffers;
114114
std::vector<std::vector<SQL_TIMESTAMP_STRUCT>> timestampBuffers;
115115
std::vector<std::vector<SQLBIGINT>> bigIntBuffers;
116116
std::vector<std::vector<SQL_DATE_STRUCT>> dateBuffers;
@@ -125,7 +125,6 @@ struct ColumnBuffers {
125125
smallIntBuffers(numCols),
126126
realBuffers(numCols),
127127
doubleBuffers(numCols),
128-
numericBuffers(numCols),
129128
timestampBuffers(numCols),
130129
bigIntBuffers(numCols),
131130
dateBuffers(numCols),
@@ -941,16 +940,22 @@ SQLRETURN SQLGetData_wrap(intptr_t StatementHandle, SQLUSMALLINT colCount, py::l
941940
}
942941
case SQL_DECIMAL:
943942
case SQL_NUMERIC: {
944-
SQL_NUMERIC_STRUCT numericValue;
945-
ret = SQLGetData_ptr(hStmt, i, SQL_C_NUMERIC, &numericValue, sizeof(numericValue),
946-
NULL);
943+
SQLCHAR numericStr[SQL_NUMERIC_SIZE] = { 0 };
944+
SQLLEN indicator;
945+
ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, numericStr, sizeof(numericStr), &indicator);
946+
947947
if (SQL_SUCCEEDED(ret)) {
948-
row.append(NumericData(numericValue.precision, numericValue.scale,
949-
numericValue.sign,
950-
std::string(reinterpret_cast<char*>(numericValue.val),
951-
SQL_MAX_NUMERIC_LEN))
952-
.to_double());
953-
} else {
948+
try{
949+
// Convert numericStr to py::decimal.Decimal and append to row
950+
row.append(py::module_::import("decimal").attr("Decimal")(
951+
std::string(reinterpret_cast<const char*>(numericStr), indicator)));
952+
} catch (const py::error_already_set& e) {
953+
// If the conversion fails, append None
954+
DEBUG_LOG("Error converting to decimal: %s", e.what());
955+
row.append(py::none());
956+
}
957+
}
958+
else {
954959
row.append(py::none());
955960
}
956961
break;
@@ -1142,15 +1147,15 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
11421147
case SQL_CHAR:
11431148
case SQL_VARCHAR:
11441149
case SQL_LONGVARCHAR:
1145-
buffers.charBuffers[col - 1].resize(fetchSize * (columnSize));
1150+
buffers.charBuffers[col - 1].resize(fetchSize * columnSize);
11461151
ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(),
11471152
(columnSize) * sizeof(SQLCHAR),
11481153
buffers.indicators[col - 1].data());
11491154
break;
11501155
case SQL_WCHAR:
11511156
case SQL_WVARCHAR:
11521157
case SQL_WLONGVARCHAR:
1153-
buffers.wcharBuffers[col - 1].resize(fetchSize * (columnSize));
1158+
buffers.wcharBuffers[col - 1].resize(fetchSize * columnSize);
11541159
ret = SQLBindCol_ptr(hStmt, col, SQL_C_WCHAR, buffers.wcharBuffers[col - 1].data(),
11551160
(columnSize) * sizeof(SQLWCHAR),
11561161
buffers.indicators[col - 1].data());
@@ -1183,10 +1188,10 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
11831188
break;
11841189
case SQL_DECIMAL:
11851190
case SQL_NUMERIC:
1186-
buffers.numericBuffers[col - 1].resize(fetchSize);
1191+
buffers.charBuffers[col - 1].resize(fetchSize * SQL_NUMERIC_SIZE);
11871192
ret = SQLBindCol_ptr(
1188-
hStmt, col, SQL_C_NUMERIC, buffers.numericBuffers[col - 1].data(),
1189-
sizeof(SQL_NUMERIC_STRUCT), buffers.indicators[col - 1].data());
1193+
hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(),
1194+
SQL_NUMERIC_SIZE * sizeof(SQLCHAR), buffers.indicators[col - 1].data());
11901195
break;
11911196
case SQL_DOUBLE:
11921197
case SQL_FLOAT:
@@ -1332,14 +1337,17 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
13321337
break;
13331338
case SQL_DECIMAL:
13341339
case SQL_NUMERIC:
1335-
row.append(
1336-
NumericData(buffers.numericBuffers[col - 1][i].precision,
1337-
buffers.numericBuffers[col - 1][i].scale,
1338-
buffers.numericBuffers[col - 1][i].sign,
1339-
std::string(reinterpret_cast<char*>(
1340-
buffers.numericBuffers[col - 1][i].val),
1341-
SQL_MAX_NUMERIC_LEN))
1342-
.to_double());
1340+
try {
1341+
// Convert numericStr to py::decimal.Decimal and append to row
1342+
row.append(py::module_::import("decimal").attr("Decimal")(
1343+
std::string(reinterpret_cast<const char*>(
1344+
&buffers.charBuffers[col - 1][i * SQL_NUMERIC_SIZE]),
1345+
buffers.indicators[col - 1][i])));
1346+
} catch (const py::error_already_set& e) {
1347+
// Handle the exception, e.g., log the error and append py::none()
1348+
DEBUG_LOG("Error converting to decimal: %s", e.what());
1349+
row.append(py::none());
1350+
}
13431351
break;
13441352
case SQL_DOUBLE:
13451353
case SQL_FLOAT:
@@ -1443,13 +1451,16 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
14431451
rowSize += sizeof(SQLSMALLINT);
14441452
break;
14451453
case SQL_REAL:
1446-
case SQL_FLOAT:
1447-
case SQL_DECIMAL:
14481454
rowSize += sizeof(SQLREAL);
14491455
break;
14501456
case SQL_DOUBLE:
1457+
case SQL_FLOAT:
14511458
rowSize += sizeof(SQLDOUBLE);
14521459
break;
1460+
case SQL_DECIMAL:
1461+
case SQL_NUMERIC:
1462+
rowSize += SQL_NUMERIC_SIZE;
1463+
break;
14531464
case SQL_TIMESTAMP:
14541465
case SQL_TYPE_TIMESTAMP:
14551466
case SQL_DATETIME:

0 commit comments

Comments
 (0)