Skip to content

Commit 9c1d92a

Browse files
author
subrata-ms
committed
resolving co-pilot review comment
1 parent 9ff1de0 commit 9c1d92a

File tree

3 files changed

+344
-20
lines changed

3 files changed

+344
-20
lines changed

mssql_python/pybind/ddbc_bindings.h

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -459,11 +459,9 @@ inline std::wstring Utf8ToWString(const std::string& str) {
459459
return result;
460460
#else
461461
// Optimized UTF-8 to UTF-32 conversion (wstring on Unix)
462-
if (str.empty())
463-
return {};
464462

465463
// Lambda to decode UTF-8 multi-byte sequences
466-
constexpr auto decodeUtf8 = [](const unsigned char* data, size_t& i, size_t len) -> wchar_t {
464+
auto decodeUtf8 = [](const unsigned char* data, size_t& i, size_t len) -> wchar_t {
467465
unsigned char byte = data[i];
468466

469467
// 1-byte sequence (ASCII): 0xxxxxxx
@@ -473,24 +471,58 @@ inline std::wstring Utf8ToWString(const std::string& str) {
473471
}
474472
// 2-byte sequence: 110xxxxx 10xxxxxx
475473
if ((byte & 0xE0) == 0xC0 && i + 1 < len) {
474+
// Validate continuation byte has correct bit pattern (10xxxxxx)
475+
if ((data[i + 1] & 0xC0) != 0x80) {
476+
++i;
477+
return 0xFFFD; // Invalid continuation byte
478+
}
476479
uint32_t cp = ((static_cast<uint32_t>(byte & 0x1F) << 6) | (data[i + 1] & 0x3F));
477-
i += 2;
478-
return static_cast<wchar_t>(cp);
480+
// Reject overlong encodings (must be >= 0x80)
481+
if (cp >= 0x80) {
482+
i += 2;
483+
return static_cast<wchar_t>(cp);
484+
}
485+
// Overlong encoding - invalid
486+
++i;
487+
return 0xFFFD;
479488
}
480489
// 3-byte sequence: 1110xxxx 10xxxxxx 10xxxxxx
481490
if ((byte & 0xF0) == 0xE0 && i + 2 < len) {
491+
// Validate continuation bytes have correct bit pattern (10xxxxxx)
492+
if ((data[i + 1] & 0xC0) != 0x80 || (data[i + 2] & 0xC0) != 0x80) {
493+
++i;
494+
return 0xFFFD; // Invalid continuation bytes
495+
}
482496
uint32_t cp = ((static_cast<uint32_t>(byte & 0x0F) << 12) |
483497
((data[i + 1] & 0x3F) << 6) | (data[i + 2] & 0x3F));
484-
i += 3;
485-
return static_cast<wchar_t>(cp);
498+
// Reject overlong encodings (must be >= 0x800) and surrogates (0xD800-0xDFFF)
499+
if (cp >= 0x800 && (cp < 0xD800 || cp > 0xDFFF)) {
500+
i += 3;
501+
return static_cast<wchar_t>(cp);
502+
}
503+
// Overlong encoding or surrogate - invalid
504+
++i;
505+
return 0xFFFD;
486506
}
487507
// 4-byte sequence: 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
488508
if ((byte & 0xF8) == 0xF0 && i + 3 < len) {
509+
// Validate continuation bytes have correct bit pattern (10xxxxxx)
510+
if ((data[i + 1] & 0xC0) != 0x80 || (data[i + 2] & 0xC0) != 0x80 ||
511+
(data[i + 3] & 0xC0) != 0x80) {
512+
++i;
513+
return 0xFFFD; // Invalid continuation bytes
514+
}
489515
uint32_t cp =
490516
((static_cast<uint32_t>(byte & 0x07) << 18) | ((data[i + 1] & 0x3F) << 12) |
491517
((data[i + 2] & 0x3F) << 6) | (data[i + 3] & 0x3F));
492-
i += 4;
493-
return static_cast<wchar_t>(cp);
518+
// Reject overlong encodings (must be >= 0x10000) and values above max Unicode
519+
if (cp >= 0x10000 && cp <= 0x10FFFF) {
520+
i += 4;
521+
return static_cast<wchar_t>(cp);
522+
}
523+
// Overlong encoding or out of range - invalid
524+
++i;
525+
return 0xFFFD;
494526
}
495527
// Invalid sequence - skip byte
496528
++i;
@@ -513,9 +545,9 @@ inline std::wstring Utf8ToWString(const std::string& str) {
513545
// Handle remaining multi-byte sequences
514546
while (i < len) {
515547
wchar_t wc = decodeUtf8(data, i, len);
516-
if (wc != 0xFFFD || data[i - 1] >= 0x80) { // Skip invalid sequences
517-
result.push_back(wc);
518-
}
548+
// Always push the decoded character (including 0xFFFD replacement characters)
549+
// This correctly handles both legitimate 0xFFFD in input and invalid sequences
550+
result.push_back(wc);
519551
}
520552

521553
return result;

mssql_python/pybind/unix_utils.cpp

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,17 @@
1414

1515
#if defined(__APPLE__) || defined(__linux__)
1616

17+
// Unicode constants for validation
18+
constexpr uint32_t kUnicodeReplacementChar = 0xFFFD;
19+
constexpr uint32_t kUnicodeMaxCodePoint = 0x10FFFF;
20+
1721
// Constants for character encoding
1822
const char* kOdbcEncoding = "utf-16-le"; // ODBC uses UTF-16LE for SQLWCHAR
1923
const size_t kUcsLength = 2; // SQLWCHAR is 2 bytes on all platforms
2024

2125
// Function to convert SQLWCHAR strings to std::wstring on macOS/Linux
22-
// Optimized version: direct conversion without intermediate buffer
26+
// Converts UTF-16 (SQLWCHAR) to UTF-32 (wstring on Unix)
27+
// Invalid surrogates (unpaired high/low) are replaced with U+FFFD
2328
std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS) {
2429
if (!sqlwStr) {
2530
return std::wstring();
@@ -73,19 +78,20 @@ std::wstring SQLWCHARToWString(const SQLWCHAR* sqlwStr, size_t length = SQL_NTS)
7378
continue;
7479
}
7580
}
76-
// Invalid surrogate - push as-is
77-
result.push_back(static_cast<wchar_t>(utf16Char));
81+
// Invalid surrogate - replace with Unicode replacement character
82+
result.push_back(static_cast<wchar_t>(kUnicodeReplacementChar));
7883
++i;
79-
} else { // Low surrogate without high - invalid but push as-is
80-
result.push_back(static_cast<wchar_t>(utf16Char));
84+
} else { // Low surrogate without high - invalid, replace with replacement character
85+
result.push_back(static_cast<wchar_t>(kUnicodeReplacementChar));
8186
++i;
8287
}
8388
}
8489
return result;
8590
}
8691

8792
// Function to convert std::wstring to SQLWCHAR array on macOS/Linux
88-
// Optimized version: streamlined conversion with better branch prediction
93+
// Converts UTF-32 (wstring on Unix) to UTF-16 (SQLWCHAR)
94+
// Invalid Unicode scalars (surrogates, values > 0x10FFFF) are replaced with U+FFFD
8995
std::vector<SQLWCHAR> WStringToSQLWCHAR(const std::wstring& str) {
9096
if (str.empty()) {
9197
return std::vector<SQLWCHAR>(1, 0); // Just null terminator
@@ -98,22 +104,35 @@ std::vector<SQLWCHAR> WStringToSQLWCHAR(const std::wstring& str) {
98104
vec.push_back(static_cast<SQLWCHAR>(0xDC00 | (cp & 0x3FF)));
99105
};
100106

107+
// Lambda to check if code point is a valid Unicode scalar value
108+
auto isValidUnicodeScalar = [](uint32_t cp) -> bool {
109+
// Exclude surrogate range (0xD800-0xDFFF) and values beyond max Unicode
110+
return cp <= kUnicodeMaxCodePoint && (cp < 0xD800 || cp > 0xDFFF);
111+
};
112+
101113
// Convert wstring (UTF-32) to UTF-16
102114
std::vector<SQLWCHAR> result;
103115
result.reserve(str.size() + 1); // Most chars are BMP, so reserve exact size
104116

105117
for (wchar_t wc : str) {
106118
uint32_t codePoint = static_cast<uint32_t>(wc);
107119

120+
// Validate code point first
121+
if (!isValidUnicodeScalar(codePoint)) {
122+
codePoint = kUnicodeReplacementChar;
123+
}
124+
108125
// Fast path: BMP character (most common - ~99% of strings)
126+
// After validation, codePoint cannot be in surrogate range (0xD800-0xDFFF)
109127
if (codePoint <= 0xFFFF) {
110128
result.push_back(static_cast<SQLWCHAR>(codePoint));
111129
}
112130
// Encode as surrogate pair for characters outside BMP
113-
else if (codePoint <= 0x10FFFF) {
131+
else if (codePoint <= kUnicodeMaxCodePoint) {
114132
encodeSurrogatePair(result, codePoint);
115133
}
116-
// Invalid code points silently skipped
134+
// Note: Invalid code points (surrogates and > 0x10FFFF) already
135+
// replaced with replacement character (0xFFFD) at validation above
117136
}
118137

119138
result.push_back(0); // Null terminator

0 commit comments

Comments
 (0)