Skip to content

Commit ab15ef9

Browse files
authored
Merge branch 'main' into subrata-ms/DepricatedFixLinux
2 parents be4b70e + eb95d2e commit ab15ef9

File tree

7 files changed

+9413
-5774
lines changed

7 files changed

+9413
-5774
lines changed

mssql_python/connection.py

Lines changed: 122 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,57 @@
5454
INFO_TYPE_STRING_THRESHOLD: int = 10000
5555

5656
# UTF-16 encoding variants that should use SQL_WCHAR by default
57-
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16", "utf-16le", "utf-16be"])
57+
# Note: "utf-16" with BOM is NOT included as it's problematic for SQL_WCHAR
58+
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16le", "utf-16be"])
59+
60+
61+
def _validate_utf16_wchar_compatibility(
62+
encoding: str, wchar_type: int, context: str = "SQL_WCHAR"
63+
) -> None:
64+
"""
65+
Validates UTF-16 encoding compatibility with SQL_WCHAR.
66+
67+
Centralizes the validation logic to eliminate duplication across setencoding/setdecoding.
68+
69+
Args:
70+
encoding: The encoding string (already normalized to lowercase)
71+
wchar_type: The SQL_WCHAR constant value to check against
72+
context: Context string for error messages ('SQL_WCHAR', 'SQL_WCHAR ctype', etc.)
73+
74+
Raises:
75+
ProgrammingError: If encoding is incompatible with SQL_WCHAR
76+
"""
77+
if encoding == "utf-16":
78+
# UTF-16 with BOM is rejected due to byte order ambiguity
79+
logger.warning("utf-16 with BOM rejected for %s", context)
80+
raise ProgrammingError(
81+
driver_error="UTF-16 with Byte Order Mark not supported for SQL_WCHAR",
82+
ddbc_error=(
83+
"Cannot use 'utf-16' encoding with SQL_WCHAR due to Byte Order Mark ambiguity. "
84+
"Use 'utf-16le' or 'utf-16be' instead for explicit byte order."
85+
),
86+
)
87+
elif encoding not in UTF16_ENCODINGS:
88+
# Non-UTF-16 encodings are not supported with SQL_WCHAR
89+
logger.warning(
90+
"Non-UTF-16 encoding %s attempted with %s", sanitize_user_input(encoding), context
91+
)
92+
93+
# Generate context-appropriate error messages
94+
if "ctype" in context:
95+
driver_error = f"SQL_WCHAR ctype only supports UTF-16 encodings"
96+
ddbc_context = "SQL_WCHAR ctype"
97+
else:
98+
driver_error = f"SQL_WCHAR only supports UTF-16 encodings"
99+
ddbc_context = "SQL_WCHAR"
100+
101+
raise ProgrammingError(
102+
driver_error=driver_error,
103+
ddbc_error=(
104+
f"Cannot use encoding '{encoding}' with {ddbc_context}. "
105+
f"SQL_WCHAR requires UTF-16 encodings (utf-16le, utf-16be)"
106+
),
107+
)
58108

59109

60110
def _validate_encoding(encoding: str) -> bool:
@@ -70,7 +120,21 @@ def _validate_encoding(encoding: str) -> bool:
70120
Note:
71121
Uses LRU cache to avoid repeated expensive codecs.lookup() calls.
72122
Cache size is limited to 128 entries which should cover most use cases.
123+
Also validates that encoding name only contains safe characters.
73124
"""
125+
# Basic security checks - prevent obvious attacks
126+
if not encoding or not isinstance(encoding, str):
127+
return False
128+
129+
# Check length limit (prevent DOS)
130+
if len(encoding) > 100:
131+
return False
132+
133+
# Prevent null bytes and control characters that could cause issues
134+
if "\x00" in encoding or any(ord(c) < 32 and c not in "\t\n\r" for c in encoding):
135+
return False
136+
137+
# Then check if it's a valid Python codec
74138
try:
75139
codecs.lookup(encoding)
76140
return True
@@ -227,6 +291,15 @@ def __init__(
227291
self._output_converters = {}
228292
self._converters_lock = threading.Lock()
229293

294+
# Initialize encoding/decoding settings lock for thread safety
295+
# This lock protects both _encoding_settings and _decoding_settings dictionaries
296+
# from concurrent modification. We use a simple Lock (not RLock) because:
297+
# - Write operations (setencoding/setdecoding) replace the entire dict atomically
298+
# - Read operations (getencoding/getdecoding) return a copy, so they're safe
299+
# - No recursive locking is needed in our usage pattern
300+
# This is more performant than RLock for the multiple-readers-single-writer pattern
301+
self._encoding_lock = threading.Lock()
302+
230303
# Initialize search escape character
231304
self._searchescape = None
232305

@@ -416,8 +489,7 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
416489
# Validate encoding using cached validation for better performance
417490
if not _validate_encoding(encoding):
418491
# Log the sanitized encoding for security
419-
logger.debug(
420-
"warning",
492+
logger.warning(
421493
"Invalid encoding attempted: %s",
422494
sanitize_user_input(str(encoding)),
423495
)
@@ -430,6 +502,10 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
430502
encoding = encoding.casefold()
431503
logger.debug("setencoding: Encoding normalized to %s", encoding)
432504

505+
# Early validation if ctype is already specified as SQL_WCHAR
506+
if ctype == ConstantsDDBC.SQL_WCHAR.value:
507+
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR")
508+
433509
# Set default ctype based on encoding if not provided
434510
if ctype is None:
435511
if encoding in UTF16_ENCODINGS:
@@ -443,8 +519,7 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
443519
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
444520
if ctype not in valid_ctypes:
445521
# Log the sanitized ctype for security
446-
logger.debug(
447-
"warning",
522+
logger.warning(
448523
"Invalid ctype attempted: %s",
449524
sanitize_user_input(str(ctype)),
450525
)
@@ -456,20 +531,24 @@ def setencoding(self, encoding: Optional[str] = None, ctype: Optional[int] = Non
456531
),
457532
)
458533

459-
# Store the encoding settings
460-
self._encoding_settings = {"encoding": encoding, "ctype": ctype}
534+
# Final validation: SQL_WCHAR ctype only supports UTF-16 encodings (without BOM)
535+
if ctype == ConstantsDDBC.SQL_WCHAR.value:
536+
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR")
537+
538+
# Store the encoding settings (thread-safe with lock)
539+
with self._encoding_lock:
540+
self._encoding_settings = {"encoding": encoding, "ctype": ctype}
461541

462542
# Log with sanitized values for security
463-
logger.debug(
464-
"info",
543+
logger.info(
465544
"Text encoding set to %s with ctype %s",
466545
sanitize_user_input(encoding),
467546
sanitize_user_input(str(ctype)),
468547
)
469548

470549
def getencoding(self) -> Dict[str, Union[str, int]]:
471550
"""
472-
Gets the current text encoding settings.
551+
Gets the current text encoding settings (thread-safe).
473552
474553
Returns:
475554
dict: A dictionary containing 'encoding' and 'ctype' keys.
@@ -481,14 +560,20 @@ def getencoding(self) -> Dict[str, Union[str, int]]:
481560
settings = cnxn.getencoding()
482561
print(f"Current encoding: {settings['encoding']}")
483562
print(f"Current ctype: {settings['ctype']}")
563+
564+
Note:
565+
This method is thread-safe and can be called from multiple threads concurrently.
566+
Returns a copy of the settings to prevent external modification.
484567
"""
485568
if self._closed:
486569
raise InterfaceError(
487570
driver_error="Connection is closed",
488571
ddbc_error="Connection is closed",
489572
)
490573

491-
return self._encoding_settings.copy()
574+
# Thread-safe read with lock to prevent race conditions
575+
with self._encoding_lock:
576+
return self._encoding_settings.copy()
492577

493578
def setdecoding(
494579
self, sqltype: int, encoding: Optional[str] = None, ctype: Optional[int] = None
@@ -539,8 +624,7 @@ def setdecoding(
539624
SQL_WMETADATA,
540625
]
541626
if sqltype not in valid_sqltypes:
542-
logger.debug(
543-
"warning",
627+
logger.warning(
544628
"Invalid sqltype attempted: %s",
545629
sanitize_user_input(str(sqltype)),
546630
)
@@ -562,8 +646,7 @@ def setdecoding(
562646

563647
# Validate encoding using cached validation for better performance
564648
if not _validate_encoding(encoding):
565-
logger.debug(
566-
"warning",
649+
logger.warning(
567650
"Invalid encoding attempted: %s",
568651
sanitize_user_input(str(encoding)),
569652
)
@@ -575,6 +658,13 @@ def setdecoding(
575658
# Normalize encoding to lowercase for consistency
576659
encoding = encoding.lower()
577660

661+
# Validate SQL_WCHAR encoding compatibility
662+
if sqltype == ConstantsDDBC.SQL_WCHAR.value:
663+
_validate_utf16_wchar_compatibility(encoding, sqltype, "SQL_WCHAR sqltype")
664+
665+
# SQL_WMETADATA can use any valid encoding (UTF-8, UTF-16, etc.)
666+
# No restriction needed here - let users configure as needed
667+
578668
# Set default ctype based on encoding if not provided
579669
if ctype is None:
580670
if encoding in UTF16_ENCODINGS:
@@ -585,8 +675,7 @@ def setdecoding(
585675
# Validate ctype
586676
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
587677
if ctype not in valid_ctypes:
588-
logger.debug(
589-
"warning",
678+
logger.warning(
590679
"Invalid ctype attempted: %s",
591680
sanitize_user_input(str(ctype)),
592681
)
@@ -598,8 +687,13 @@ def setdecoding(
598687
),
599688
)
600689

601-
# Store the decoding settings for the specified sqltype
602-
self._decoding_settings[sqltype] = {"encoding": encoding, "ctype": ctype}
690+
# Validate SQL_WCHAR ctype encoding compatibility
691+
if ctype == ConstantsDDBC.SQL_WCHAR.value:
692+
_validate_utf16_wchar_compatibility(encoding, ctype, "SQL_WCHAR ctype")
693+
694+
# Store the decoding settings for the specified sqltype (thread-safe with lock)
695+
with self._encoding_lock:
696+
self._decoding_settings[sqltype] = {"encoding": encoding, "ctype": ctype}
603697

604698
# Log with sanitized values for security
605699
sqltype_name = {
@@ -608,8 +702,7 @@ def setdecoding(
608702
SQL_WMETADATA: "SQL_WMETADATA",
609703
}.get(sqltype, str(sqltype))
610704

611-
logger.debug(
612-
"info",
705+
logger.info(
613706
"Text decoding set for %s to %s with ctype %s",
614707
sqltype_name,
615708
sanitize_user_input(encoding),
@@ -618,7 +711,7 @@ def setdecoding(
618711

619712
def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
620713
"""
621-
Gets the current text decoding settings for the specified SQL type.
714+
Gets the current text decoding settings for the specified SQL type (thread-safe).
622715
623716
Args:
624717
sqltype (int): The SQL type to get settings for: SQL_CHAR, SQL_WCHAR, or SQL_WMETADATA.
@@ -634,6 +727,10 @@ def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
634727
settings = cnxn.getdecoding(mssql_python.SQL_CHAR)
635728
print(f"SQL_CHAR encoding: {settings['encoding']}")
636729
print(f"SQL_CHAR ctype: {settings['ctype']}")
730+
731+
Note:
732+
This method is thread-safe and can be called from multiple threads concurrently.
733+
Returns a copy of the settings to prevent external modification.
637734
"""
638735
if self._closed:
639736
raise InterfaceError(
@@ -657,7 +754,9 @@ def getdecoding(self, sqltype: int) -> Dict[str, Union[str, int]]:
657754
),
658755
)
659756

660-
return self._decoding_settings[sqltype].copy()
757+
# Thread-safe read with lock to prevent race conditions
758+
with self._encoding_lock:
759+
return self._decoding_settings[sqltype].copy()
661760

662761
def set_attr(self, attribute: int, value: Union[int, str, bytes, bytearray]) -> None:
663762
"""

0 commit comments

Comments
 (0)