diff --git a/mssql_python/connection.py b/mssql_python/connection.py index e459e00a..0ad5232d 100644 --- a/mssql_python/connection.py +++ b/mssql_python/connection.py @@ -1284,21 +1284,22 @@ def getinfo(self, info_type: int) -> Union[str, int, bool, None]: # Make sure we use the correct amount of data based on length actual_data = data[:length] - # Now decode the string data - try: - return actual_data.decode("utf-8").rstrip("\0") - except UnicodeDecodeError: + # SQLGetInfoW returns UTF-16LE encoded strings (wide-character ODBC API) + # Try UTF-16LE first (expected), then UTF-8 as fallback + for encoding in ("utf-16-le", "utf-8"): try: - return actual_data.decode("latin1").rstrip("\0") - except Exception as e: - logger.debug( - "error", - "Failed to decode string in getinfo: %s. " - "Returning None to avoid silent corruption.", - e, - ) - # Explicitly return None to signal decoding failure - return None + return actual_data.decode(encoding).rstrip("\0") + except UnicodeDecodeError: + continue + + # All decodings failed + logger.debug( + "error", + "Failed to decode string in getinfo (info_type=%d) with supported encodings. " + "Returning None to avoid silent corruption.", + info_type, + ) + return None else: # If it's not bytes, return as is return data diff --git a/mssql_python/pybind/build.bat b/mssql_python/pybind/build.bat index 90241c05..f264f686 100644 --- a/mssql_python/pybind/build.bat +++ b/mssql_python/pybind/build.bat @@ -157,22 +157,6 @@ if exist "%OUTPUT_DIR%\%PYD_NAME%" ( echo [WARNING] PDB file !PDB_NAME! not found in output directory. ) - setlocal enabledelayedexpansion - for %%I in ("%SOURCE_DIR%..") do ( - set PARENT_DIR=%%~fI - ) - echo [DIAGNOSTIC] Parent is: !PARENT_DIR! - - set VCREDIST_DLL_PATH=!PARENT_DIR!\libs\windows\!ARCH!\vcredist\msvcp140.dll - echo [DIAGNOSTIC] Looking for msvcp140.dll at "!VCREDIST_DLL_PATH!" - - if exist "!VCREDIST_DLL_PATH!" ( - copy /Y "!VCREDIST_DLL_PATH!" "%SOURCE_DIR%\.." - echo [SUCCESS] Copied msvcp140.dll from !VCREDIST_DLL_PATH! to "%SOURCE_DIR%\.." - ) else ( - echo [ERROR] Could not find msvcp140.dll at "!VCREDIST_DLL_PATH!" - exit /b 1 - ) ) else ( echo [ERROR] Could not find built .pyd file: %PYD_NAME% REM Exit with an error code here if the .pyd file is not found diff --git a/tests/test_003_connection.py b/tests/test_003_connection.py index 97edff2a..8106e7a4 100644 --- a/tests/test_003_connection.py +++ b/tests/test_003_connection.py @@ -2187,6 +2187,105 @@ def test_getinfo_basic_driver_info(db_connection): pytest.fail(f"getinfo failed for basic driver info: {e}") +def test_getinfo_string_encoding_utf16(db_connection): + """Test that string values from getinfo are properly decoded from UTF-16.""" + + # Test string info types that should not contain null bytes + string_info_types = [ + ("SQL_DRIVER_VER", sql_const.SQL_DRIVER_VER.value), + ("SQL_DRIVER_NAME", sql_const.SQL_DRIVER_NAME.value), + ("SQL_DRIVER_ODBC_VER", sql_const.SQL_DRIVER_ODBC_VER.value), + ("SQL_SERVER_NAME", sql_const.SQL_SERVER_NAME.value), + ] + + for name, info_type in string_info_types: + result = db_connection.getinfo(info_type) + + if result is not None: + # Verify it's a string + assert isinstance(result, str), f"{name}: Expected str, got {type(result).__name__}" + + # Verify no null bytes (indicates UTF-16 decoded as UTF-8 bug) + assert ( + "\x00" not in result + ), f"{name} contains null bytes, likely UTF-16/UTF-8 encoding mismatch: {repr(result)}" + + # Verify it's not empty (optional, but good sanity check) + assert len(result) > 0, f"{name} returned empty string" + + +def test_getinfo_string_decoding_utf8_fallback(db_connection): + """Test that getinfo falls back to UTF-8 when UTF-16LE decoding fails. + + This test verifies the fallback path in the encoding loop where + UTF-16LE fails but UTF-8 succeeds. + """ + from unittest.mock import patch + + # UTF-8 encoded "Hello" - this is valid UTF-8 but NOT valid UTF-16LE + # (odd number of bytes would fail UTF-16LE decode) + utf8_data = "Hello".encode("utf-8") # b'Hello' - 5 bytes, odd length + + mock_result = {"data": utf8_data, "length": len(utf8_data)} + + # Use a string-type info_type (SQL_DRIVER_NAME = 6 is in string_type_constants) + info_type = sql_const.SQL_DRIVER_NAME.value + + with patch.object(db_connection._conn, "get_info", return_value=mock_result): + result = db_connection.getinfo(info_type) + + assert result == "Hello", f"Expected 'Hello', got {repr(result)}" + assert isinstance(result, str), f"Expected str, got {type(result).__name__}" + + +def test_getinfo_string_decoding_all_fail_returns_none(db_connection): + """Test that getinfo returns None when all decoding attempts fail. + + This test verifies that when both UTF-16LE and UTF-8 decoding fail, + the method returns None to avoid silent data corruption. + """ + from unittest.mock import patch + + # Invalid byte sequence that cannot be decoded as UTF-16LE or UTF-8 + # 0xFF 0xFE is a BOM, but followed by invalid continuation bytes for UTF-8 + # and odd length makes it invalid UTF-16LE + invalid_data = bytes([0x80, 0x81, 0x82]) # Invalid for both encodings + + mock_result = {"data": invalid_data, "length": len(invalid_data)} + + # Use a string-type info_type (SQL_DRIVER_NAME = 6 is in string_type_constants) + info_type = sql_const.SQL_DRIVER_NAME.value + + with patch.object(db_connection._conn, "get_info", return_value=mock_result): + result = db_connection.getinfo(info_type) + + # Should return None when all decoding fails + assert result is None, f"Expected None for invalid encoding, got {repr(result)}" + + +def test_getinfo_string_encoding_utf16_primary(db_connection): + """Test that getinfo correctly decodes valid UTF-16LE data. + + This test verifies the primary (expected) encoding path where + UTF-16LE decoding succeeds on first try. + """ + from unittest.mock import patch + + # Valid UTF-16LE encoded "Test" with null terminator + utf16_data = "Test".encode("utf-16-le") + b"\x00\x00" + + mock_result = {"data": utf16_data, "length": len(utf16_data)} + + # Use a string-type info_type + info_type = sql_const.SQL_DRIVER_NAME.value + + with patch.object(db_connection._conn, "get_info", return_value=mock_result): + result = db_connection.getinfo(info_type) + + assert result == "Test", f"Expected 'Test', got {repr(result)}" + assert "\x00" not in result, f"Result contains null bytes: {repr(result)}" + + def test_getinfo_sql_support(db_connection): """Test SQL support and conformance info types."""