@@ -31,6 +31,7 @@ using namespace pybind11::literals;
3131
3232// This constant is not exposed via sql.h, hence define it here
3333#define SQL_SS_TIME2 (-154 )
34+ #define SQL_NUMERIC_SIZE 64
3435
3536// Logs data to stdout only in debug builds
3637// TODO: Handle both UTF-8 and UTF-16 strings
@@ -110,7 +111,6 @@ struct ColumnBuffers {
110111 std::vector<std::vector<SQLSMALLINT>> smallIntBuffers;
111112 std::vector<std::vector<SQLREAL>> realBuffers;
112113 std::vector<std::vector<SQLDOUBLE>> doubleBuffers;
113- std::vector<std::vector<SQL_NUMERIC_STRUCT>> numericBuffers;
114114 std::vector<std::vector<SQL_TIMESTAMP_STRUCT>> timestampBuffers;
115115 std::vector<std::vector<SQLBIGINT>> bigIntBuffers;
116116 std::vector<std::vector<SQL_DATE_STRUCT>> dateBuffers;
@@ -125,7 +125,6 @@ struct ColumnBuffers {
125125 smallIntBuffers(numCols),
126126 realBuffers(numCols),
127127 doubleBuffers(numCols),
128- numericBuffers(numCols),
129128 timestampBuffers(numCols),
130129 bigIntBuffers(numCols),
131130 dateBuffers(numCols),
@@ -941,16 +940,22 @@ SQLRETURN SQLGetData_wrap(intptr_t StatementHandle, SQLUSMALLINT colCount, py::l
941940 }
942941 case SQL_DECIMAL:
943942 case SQL_NUMERIC: {
944- SQL_NUMERIC_STRUCT numericValue;
945- ret = SQLGetData_ptr (hStmt, i, SQL_C_NUMERIC, &numericValue, sizeof (numericValue),
946- NULL );
943+ SQLCHAR numericStr[SQL_NUMERIC_SIZE] = { 0 };
944+ SQLLEN indicator;
945+ ret = SQLGetData_ptr (hStmt, i, SQL_C_CHAR, numericStr, sizeof (numericStr), &indicator);
946+
947947 if (SQL_SUCCEEDED (ret)) {
948- row.append (NumericData (numericValue.precision , numericValue.scale ,
949- numericValue.sign ,
950- std::string (reinterpret_cast <char *>(numericValue.val ),
951- SQL_MAX_NUMERIC_LEN))
952- .to_double ());
953- } else {
948+ try {
949+ // Convert numericStr to py::decimal.Decimal and append to row
950+ row.append (py::module_::import (" decimal" ).attr (" Decimal" )(
951+ std::string (reinterpret_cast <const char *>(numericStr), indicator)));
952+ } catch (const py::error_already_set& e) {
953+ // If the conversion fails, append None
954+ DEBUG_LOG (" Error converting to decimal: %s" , e.what ());
955+ row.append (py::none ());
956+ }
957+ }
958+ else {
954959 row.append (py::none ());
955960 }
956961 break ;
@@ -1142,15 +1147,15 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
11421147 case SQL_CHAR:
11431148 case SQL_VARCHAR:
11441149 case SQL_LONGVARCHAR:
1145- buffers.charBuffers [col - 1 ].resize (fetchSize * ( columnSize) );
1150+ buffers.charBuffers [col - 1 ].resize (fetchSize * columnSize);
11461151 ret = SQLBindCol_ptr (hStmt, col, SQL_C_CHAR, buffers.charBuffers [col - 1 ].data (),
11471152 (columnSize) * sizeof (SQLCHAR),
11481153 buffers.indicators [col - 1 ].data ());
11491154 break ;
11501155 case SQL_WCHAR:
11511156 case SQL_WVARCHAR:
11521157 case SQL_WLONGVARCHAR:
1153- buffers.wcharBuffers [col - 1 ].resize (fetchSize * ( columnSize) );
1158+ buffers.wcharBuffers [col - 1 ].resize (fetchSize * columnSize);
11541159 ret = SQLBindCol_ptr (hStmt, col, SQL_C_WCHAR, buffers.wcharBuffers [col - 1 ].data (),
11551160 (columnSize) * sizeof (SQLWCHAR),
11561161 buffers.indicators [col - 1 ].data ());
@@ -1183,10 +1188,10 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
11831188 break ;
11841189 case SQL_DECIMAL:
11851190 case SQL_NUMERIC:
1186- buffers.numericBuffers [col - 1 ].resize (fetchSize);
1191+ buffers.charBuffers [col - 1 ].resize (fetchSize * SQL_NUMERIC_SIZE );
11871192 ret = SQLBindCol_ptr (
1188- hStmt, col, SQL_C_NUMERIC , buffers.numericBuffers [col - 1 ].data (),
1189- sizeof (SQL_NUMERIC_STRUCT ), buffers.indicators [col - 1 ].data ());
1193+ hStmt, col, SQL_C_CHAR , buffers.charBuffers [col - 1 ].data (),
1194+ SQL_NUMERIC_SIZE * sizeof (SQLCHAR ), buffers.indicators [col - 1 ].data ());
11901195 break ;
11911196 case SQL_DOUBLE:
11921197 case SQL_FLOAT:
@@ -1332,14 +1337,17 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
13321337 break ;
13331338 case SQL_DECIMAL:
13341339 case SQL_NUMERIC:
1335- row.append (
1336- NumericData (buffers.numericBuffers [col - 1 ][i].precision ,
1337- buffers.numericBuffers [col - 1 ][i].scale ,
1338- buffers.numericBuffers [col - 1 ][i].sign ,
1339- std::string (reinterpret_cast <char *>(
1340- buffers.numericBuffers [col - 1 ][i].val ),
1341- SQL_MAX_NUMERIC_LEN))
1342- .to_double ());
1340+ try {
1341+ // Convert numericStr to py::decimal.Decimal and append to row
1342+ row.append (py::module_::import (" decimal" ).attr (" Decimal" )(
1343+ std::string (reinterpret_cast <const char *>(
1344+ &buffers.charBuffers [col - 1 ][i * SQL_NUMERIC_SIZE]),
1345+ buffers.indicators [col - 1 ][i])));
1346+ } catch (const py::error_already_set& e) {
1347+ // Handle the exception, e.g., log the error and append py::none()
1348+ DEBUG_LOG (" Error converting to decimal: %s" , e.what ());
1349+ row.append (py::none ());
1350+ }
13431351 break ;
13441352 case SQL_DOUBLE:
13451353 case SQL_FLOAT:
@@ -1443,13 +1451,16 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) {
14431451 rowSize += sizeof (SQLSMALLINT);
14441452 break ;
14451453 case SQL_REAL:
1446- case SQL_FLOAT:
1447- case SQL_DECIMAL:
14481454 rowSize += sizeof (SQLREAL);
14491455 break ;
14501456 case SQL_DOUBLE:
1457+ case SQL_FLOAT:
14511458 rowSize += sizeof (SQLDOUBLE);
14521459 break ;
1460+ case SQL_DECIMAL:
1461+ case SQL_NUMERIC:
1462+ rowSize += SQL_NUMERIC_SIZE;
1463+ break ;
14531464 case SQL_TIMESTAMP:
14541465 case SQL_TYPE_TIMESTAMP:
14551466 case SQL_DATETIME:
0 commit comments