Skip to content

Commit 7056dca

Browse files
author
Theekshna Kotian
committed
Merged PR 5333: Fix truncation of varchar/binary at full capacity + Add related tests
Currently, if you add a string of size 10 to VARCHAR(10) column, & try to fetch it, you'll get a truncated string of size 9. This PR fixes this case & adds related tests. ---- #### AI description (iteration 1) #### PR Classification Bug fix #### PR Summary This pull request addresses the truncation issue of `VARCHAR` and `VARBINARY` types at full capacity and adds related tests to ensure proper functionality. - Added tests in `/tests/test_004_cursor.py` to verify `VARCHAR`, `NVARCHAR`, and `VARBINARY` columns can handle values at their full capacity. - Modified `/mssql_python/pybind/ddbc_bindings.cpp` to correctly handle buffer sizes and null-termination for `VARCHAR`, `NVARCHAR`, and `VARBINARY` types during data fetching and binding. <!-- GitOpsUserAgent=GitOps.Apps.Server.pullrequestcopilot --> Related work items: #33942
1 parent ce3083f commit 7056dca

File tree

3 files changed

+347
-89
lines changed

3 files changed

+347
-89
lines changed

mssql_python/cursor.py

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -148,68 +148,68 @@ def _parse_time(self, param):
148148
continue
149149
return None
150150

151-
def _parse_timestamptz(self, param):
152-
"""
153-
Attempt to parse a string as a timestamp with time zone (timestamptz).
154-
155-
Args:
156-
param: The string to parse.
157-
158-
Returns:
159-
A datetime.datetime object if parsing is successful, else None.
160-
"""
161-
formats = [
162-
"%Y-%m-%dT%H:%M:%S%z", # ISO 8601 datetime with timezone offset
163-
"%Y-%m-%d %H:%M:%S.%f%z", # Datetime with fractional seconds and timezone offset
164-
]
165-
for fmt in formats:
166-
try:
167-
return datetime.datetime.strptime(param, fmt)
168-
except ValueError:
169-
continue
170-
return None
171-
172-
def _parse_smalldatetime(self, param):
173-
"""
174-
Attempt to parse a string as a smalldatetime.
175-
176-
Args:
177-
param: The string to parse.
178-
179-
Returns:
180-
A datetime.datetime object if parsing is successful, else None.
181-
"""
182-
formats = [
183-
"%Y-%m-%d %H:%M:%S", # Standard datetime
184-
]
185-
for fmt in formats:
186-
try:
187-
return datetime.datetime.strptime(param, fmt)
188-
except ValueError:
189-
continue
190-
return None
191-
192-
def _parse_datetime2(self, param):
193-
"""
194-
Attempt to parse a string as a datetime2.
195-
196-
Args:
197-
param: The string to parse.
198-
199-
Returns:
200-
A datetime.datetime object if parsing is successful, else None.
201-
"""
202-
formats = [
203-
"%Y-%m-%d %H:%M:%S.%f", # Datetime with fractional seconds (up to 6 digits)
204-
]
205-
for fmt in formats:
206-
try:
207-
dt = datetime.datetime.strptime(param, fmt)
208-
if fmt == "%Y-%m-%d %H:%M:%S.%f" and len(param.split('.')[-1]) > 3:
209-
return dt
210-
except ValueError:
211-
continue
212-
return None
151+
# def _parse_timestamptz(self, param):
152+
# """
153+
# Attempt to parse a string as a timestamp with time zone (timestamptz).
154+
#
155+
# Args:
156+
# param: The string to parse.
157+
#
158+
# Returns:
159+
# A datetime.datetime object if parsing is successful, else None.
160+
# """
161+
# formats = [
162+
# "%Y-%m-%dT%H:%M:%S%z", # ISO 8601 datetime with timezone offset
163+
# "%Y-%m-%d %H:%M:%S.%f%z", # Datetime with fractional seconds and timezone offset
164+
# ]
165+
# for fmt in formats:
166+
# try:
167+
# return datetime.datetime.strptime(param, fmt)
168+
# except ValueError:
169+
# continue
170+
# return None
171+
172+
# def _parse_smalldatetime(self, param):
173+
# """
174+
# Attempt to parse a string as a smalldatetime.
175+
#
176+
# Args:
177+
# param: The string to parse.
178+
#
179+
# Returns:
180+
# A datetime.datetime object if parsing is successful, else None.
181+
# """
182+
# formats = [
183+
# "%Y-%m-%d %H:%M:%S", # Standard datetime
184+
# ]
185+
# for fmt in formats:
186+
# try:
187+
# return datetime.datetime.strptime(param, fmt)
188+
# except ValueError:
189+
# continue
190+
# return None
191+
192+
# def _parse_datetime2(self, param):
193+
# """
194+
# Attempt to parse a string as a datetime2.
195+
#
196+
# Args:
197+
# param: The string to parse.
198+
#
199+
# Returns:
200+
# A datetime.datetime object if parsing is successful, else None.
201+
# """
202+
# formats = [
203+
# "%Y-%m-%d %H:%M:%S.%f", # Datetime with fractional seconds (up to 6 digits)
204+
# ]
205+
# for fmt in formats:
206+
# try:
207+
# dt = datetime.datetime.strptime(param, fmt)
208+
# if fmt == "%Y-%m-%d %H:%M:%S.%f" and len(param.split('.')[-1]) > 3:
209+
# return dt
210+
# except ValueError:
211+
# continue
212+
# return None
213213

214214
def _get_numeric_data(self, param):
215215
"""
@@ -326,6 +326,7 @@ def _map_sql_type(self, param, parameters_list, i):
326326

327327
# String mapping logic here
328328
is_unicode = self._is_unicode_string(param)
329+
# TODO: revisit
329330
if len(param) > 4000: # Long strings
330331
if is_unicode:
331332
return odbc_sql_const.SQL_WLONGVARCHAR.value, odbc_sql_const.SQL_C_WCHAR.value, len(param), 0
@@ -348,8 +349,8 @@ def _map_sql_type(self, param, parameters_list, i):
348349
else:
349350
return odbc_sql_const.SQL_BINARY.value, odbc_sql_const.SQL_C_BINARY.value, len(param), 0
350351

351-
elif isinstance(param, uuid.UUID): # Handle uniqueidentifier
352-
return odbc_sql_const.SQL_GUID.value, odbc_sql_const.SQL_C_GUID.value, 36, 0
352+
# elif isinstance(param, uuid.UUID): # Handle uniqueidentifier
353+
# return odbc_sql_const.SQL_GUID.value, odbc_sql_const.SQL_C_GUID.value, 36, 0
353354

354355
elif isinstance(param, datetime.datetime):
355356
# Always keep datetime.datetime check before datetime.date check since datetime.datetime is a subclass of datetime (isinstance(datetime.datetime, datetime.date) returns True)

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 44 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,8 @@ SQLRETURN BindParameters(SQLHANDLE hStmt, const py::list& params,
420420
if (!py::isinstance<py::none>(param)) {
421421
ThrowStdException(MakeParamMismatchErrorStr(paramInfo.paramCType, paramIndex));
422422
}
423+
// TODO: This wont work for None values added to BINARY/VARBINARY columns. None values
424+
// of binary columns need to have C type = SQL_C_BINARY & SQL type = SQL_BINARY
423425
dataPtr = nullptr;
424426
strLenOrIndPtr = AllocateParamBuffer<SQLLEN>(paramBuffers);
425427
*strLenOrIndPtr = SQL_NULL_DATA;
@@ -918,19 +920,21 @@ SQLRETURN SQLGetData_wrap(intptr_t StatementHandle, SQLUSMALLINT colCount, py::l
918920
case SQL_LONGVARCHAR: {
919921
// TODO: revisit
920922
HandleZeroColumnSizeAtFetch(columnSize);
921-
std::vector<SQLCHAR> dataBuffer(columnSize + 1);
923+
uint64_t fetchBufferSize = columnSize + 1 /* null-termination */;
924+
std::vector<SQLCHAR> dataBuffer(fetchBufferSize);
922925
SQLLEN dataLen;
923926
// TODO: Handle the return code better
924-
ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size() - 1,
927+
ret = SQLGetData_ptr(hStmt, i, SQL_C_CHAR, dataBuffer.data(), dataBuffer.size(),
925928
&dataLen);
926929

927930
if (SQL_SUCCEEDED(ret)) {
928931
// TODO: Refactor these if's across other switches to avoid code duplication
929932
// columnSize is in chars, dataLen is in bytes
930933
if (dataLen > 0) {
931-
int numCharsInData = dataLen / sizeof(SQLCHAR);
934+
uint64_t numCharsInData = dataLen / sizeof(SQLCHAR);
935+
// NOTE: dataBuffer.size() includes null-terminator, dataLen doesn't. Hence use '<'.
932936
if (numCharsInData < dataBuffer.size()) {
933-
dataBuffer[numCharsInData] = '\0'; // Null-terminate
937+
// SQLGetData will null-terminate the data
934938
row.append(std::string(reinterpret_cast<char*>(dataBuffer.data())));
935939
} else {
936940
// In this case, buffer size is smaller, and data to be retrieved is longer
@@ -962,17 +966,18 @@ SQLRETURN SQLGetData_wrap(intptr_t StatementHandle, SQLUSMALLINT colCount, py::l
962966
case SQL_WLONGVARCHAR: {
963967
// TODO: revisit
964968
HandleZeroColumnSizeAtFetch(columnSize);
965-
std::vector<SQLWCHAR> dataBuffer(columnSize + 1);
969+
uint64_t fetchBufferSize = columnSize + 1 /* null-termination */;
970+
std::vector<SQLWCHAR> dataBuffer(fetchBufferSize);
966971
SQLLEN dataLen;
967972
ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, dataBuffer.data(),
968-
(dataBuffer.size() - 1) * sizeof(SQLWCHAR), &dataLen);
973+
dataBuffer.size() * sizeof(SQLWCHAR), &dataLen);
969974

970975
if (SQL_SUCCEEDED(ret)) {
971976
// TODO: Refactor these if's across other switches to avoid code duplication
972977
if (dataLen > 0) {
973-
int numCharsInData = dataLen / sizeof(SQLWCHAR);
978+
uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR);
974979
if (numCharsInData < dataBuffer.size()) {
975-
dataBuffer[numCharsInData] = L'\0'; // Null-terminate
980+
// SQLGetData will null-terminate the data
976981
row.append(std::wstring(dataBuffer.data()));
977982
} else {
978983
// In this case, buffer size is smaller, and data to be retrieved is longer
@@ -1273,24 +1278,37 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column
12731278
switch (dataType) {
12741279
case SQL_CHAR:
12751280
case SQL_VARCHAR:
1276-
case SQL_LONGVARCHAR:
1281+
case SQL_LONGVARCHAR: {
12771282
// TODO: handle variable length data correctly. This logic wont suffice
12781283
HandleZeroColumnSizeAtFetch(columnSize);
1279-
buffers.charBuffers[col - 1].resize(fetchSize * (columnSize + 1 /*null-terminator*/));
1284+
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
1285+
// TODO: For LONGVARCHAR/BINARY types, columnSize is returned as 2GB-1 by
1286+
// SQLDescribeCol. So fetchBufferSize = 2GB. fetchSize=1 if columnSize>1GB.
1287+
// So we'll allocate a vector of size 2GB. If a query fetches multiple (say N)
1288+
// LONG... columns, we will have allocated multiple (N) 2GB sized vectors. This
1289+
// will make driver very slow. And if the N is high enough, we could hit the OS
1290+
// limit for heap memory that we can allocate, & hence get a std::bad_alloc. The
1291+
// process could also be killed by OS for consuming too much memory.
1292+
// Hence this will be revisited in beta to not allocate 2GB+ memory,
1293+
// & use streaming instead
1294+
buffers.charBuffers[col - 1].resize(fetchSize * fetchBufferSize);
12801295
ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(),
1281-
(columnSize) * sizeof(SQLCHAR),
1296+
fetchBufferSize * sizeof(SQLCHAR),
12821297
buffers.indicators[col - 1].data());
12831298
break;
1299+
}
12841300
case SQL_WCHAR:
12851301
case SQL_WVARCHAR:
1286-
case SQL_WLONGVARCHAR:
1302+
case SQL_WLONGVARCHAR: {
12871303
// TODO: handle variable length data correctly. This logic wont suffice
12881304
HandleZeroColumnSizeAtFetch(columnSize);
1289-
buffers.wcharBuffers[col - 1].resize(fetchSize * (columnSize + 1 /*null-terminator*/));
1305+
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
1306+
buffers.wcharBuffers[col - 1].resize(fetchSize * fetchBufferSize);
12901307
ret = SQLBindCol_ptr(hStmt, col, SQL_C_WCHAR, buffers.wcharBuffers[col - 1].data(),
1291-
(columnSize) * sizeof(SQLWCHAR),
1308+
fetchBufferSize * sizeof(SQLWCHAR),
12921309
buffers.indicators[col - 1].data());
12931310
break;
1311+
}
12941312
case SQL_INTEGER:
12951313
buffers.intBuffers[col - 1].resize(fetchSize);
12961314
ret = SQLBindCol_ptr(hStmt, col, SQL_C_SLONG, buffers.intBuffers[col - 1].data(),
@@ -1439,12 +1457,13 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
14391457
// TODO: variable length data needs special handling, this logic wont suffice
14401458
SQLULEN columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
14411459
HandleZeroColumnSizeAtFetch(columnSize);
1442-
int numCharsInData = dataLen / sizeof(SQLCHAR);
1443-
if (numCharsInData <= columnSize) {
1444-
buffers.charBuffers[col - 1][(i * columnSize) + numCharsInData] =
1445-
'\0'; // Null-terminate
1460+
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
1461+
uint64_t numCharsInData = dataLen / sizeof(SQLCHAR);
1462+
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<'
1463+
if (numCharsInData < fetchBufferSize) {
1464+
// SQLFetch will nullterminate the data
14461465
row.append(std::string(
1447-
reinterpret_cast<char*>(&buffers.charBuffers[col - 1][i * columnSize]),
1466+
reinterpret_cast<char*>(&buffers.charBuffers[col - 1][i * fetchBufferSize]),
14481467
numCharsInData));
14491468
} else {
14501469
// In this case, buffer size is smaller, and data to be retrieved is longer
@@ -1463,13 +1482,13 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum
14631482
// TODO: variable length data needs special handling, this logic wont suffice
14641483
SQLULEN columnSize = columnMeta["ColumnSize"].cast<SQLULEN>();
14651484
HandleZeroColumnSizeAtFetch(columnSize);
1466-
int numCharsInData = dataLen / sizeof(SQLWCHAR);
1467-
if (numCharsInData <= columnSize) {
1468-
buffers.wcharBuffers[col - 1]
1469-
[(i * columnSize) + numCharsInData] =
1470-
L'\0'; // Null-terminate
1485+
uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/;
1486+
uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR);
1487+
// fetchBufferSize includes null-terminator, numCharsInData doesn't. Hence '<'
1488+
if (numCharsInData < fetchBufferSize) {
1489+
// SQLFetch will nullterminate the data
14711490
row.append(std::wstring(
1472-
reinterpret_cast<wchar_t*>(&buffers.wcharBuffers[col - 1][i * columnSize]),
1491+
reinterpret_cast<wchar_t*>(&buffers.wcharBuffers[col - 1][i * fetchBufferSize]),
14731492
numCharsInData));
14741493
} else {
14751494
// In this case, buffer size is smaller, and data to be retrieved is longer

0 commit comments

Comments
 (0)