Skip to content

Commit 4b89f3a

Browse files
bewithgauravTheekshna Kotian
authored andcommitted
Merged PR 5317: Fix for Numeric & float type + Common Python/C++ logger
Related work items: #33810
1 parent 6a5c592 commit 4b89f3a

File tree

12 files changed

+357
-239
lines changed

12 files changed

+357
-239
lines changed

main.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,42 @@
11
from mssql_python import connect
2+
from mssql_python import setup_logging
23
import os
4+
import decimal
5+
6+
setup_logging('stdout')
37

48
conn_str = os.getenv("DB_CONNECTION_STRING")
59
conn = connect(conn_str)
610

11+
# conn.autocommit = True
12+
713
cursor = conn.cursor()
814
cursor.execute("SELECT database_id, name from sys.databases;")
915
row = cursor.fetchmany(1)
1016
print(row)
1117

12-
cursor.execute("SELECT database_id, name from sys.databases;")
13-
row = cursor.fetchone()
14-
print(row)
18+
cursor.execute("DROP TABLE IF EXISTS main_single_column")
19+
20+
# cursor.execute("CREATE TABLE main_single_column (float_column FLOAT)")
21+
# cursor.executemany("INSERT INTO main_single_column (float_column) VALUES (?)", [[12.34], [1.234], [0.125], [0.0125], [0.00125], [23243243232.432432432], [0.247985732852735032750973209750]])
22+
# cursor.execute("SELECT * FROM main_single_column")
23+
# row = cursor.fetchall()
24+
# print(row)
25+
# print(len(row))
26+
27+
import time
28+
cursor.execute("CREATE TABLE main_single_column (decimal_column NUMERIC(10, 4))")
29+
# time.sleep(45)
30+
31+
cursor.execute("INSERT INTO main_single_column (decimal_column) VALUES (?)", [decimal.Decimal(123.45).quantize(decimal.Decimal('0.00'))])
32+
cursor.execute("SELECT * FROM main_single_column")
33+
row = cursor.fetchone()[0]
1534

16-
cursor.execute("SELECT database_id, name from sys.databases;")
17-
row = cursor.fetchall()
1835
print(row)
36+
print(row.val)
37+
print(row.precision)
38+
print(row.scale)
39+
print(row.sign)
1940

2041
cursor.close()
2142
conn.close()

mssql_python/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@
2727

2828
# Cursor Objects
2929
from .cursor import Cursor
30+
31+
from .logging_config import setup_logging, get_logger

mssql_python/connection.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
import ctypes
22
from mssql_python.cursor import Cursor
3-
from mssql_python.logging_config import setup_logging, ENABLE_LOGGING
3+
from mssql_python.logging_config import get_logger, ENABLE_LOGGING
44
from mssql_python.constants import ConstantsODBC as odbc_sql_const
55
from mssql_python.helpers import add_driver_to_connection_str, check_error
6-
import logging
76
import os
8-
9-
# Change the current working directory to the directory of the script to import ddbc_bindings
10-
os.chdir(os.path.dirname(__file__))
11-
127
from mssql_python import ddbc_bindings
138

14-
# Setting up logging
15-
setup_logging()
9+
logger = get_logger()
1610

1711
class Connection:
1812
"""
@@ -116,15 +110,15 @@ def _connect_to_db(self) -> None:
116110
InterfaceError: If there is an error related to the database interface.
117111
"""
118112
if ENABLE_LOGGING:
119-
logging.info("Connecting to the database")
113+
logger.info("Connecting to the database")
120114
ret = ddbc_bindings.DDBCSQLDriverConnect(
121115
self.hdbc.value, # Connection handle
122116
0, # Window handle
123117
self.connection_str # Connection string
124118
)
125119
check_error(odbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret)
126120
if ENABLE_LOGGING:
127-
logging.info("Connection established successfully.")
121+
logger.info("Connection established successfully.")
128122

129123
@property
130124
def autocommit(self) -> bool:
@@ -159,7 +153,7 @@ def setautocommit(self, value: bool) -> None:
159153
check_error(odbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret)
160154
self._autocommit = value
161155
if ENABLE_LOGGING:
162-
logging.info("Autocommit mode set to %s.", value)
156+
logger.info("Autocommit mode set to %s.", value)
163157

164158
def cursor(self) -> Cursor:
165159
"""
@@ -198,7 +192,7 @@ def commit(self) -> None:
198192
)
199193
check_error(odbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret)
200194
if ENABLE_LOGGING:
201-
logging.info("Transaction committed successfully.")
195+
logger.info("Transaction committed successfully.")
202196

203197
def rollback(self) -> None:
204198
"""
@@ -219,7 +213,7 @@ def rollback(self) -> None:
219213
)
220214
check_error(odbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret)
221215
if ENABLE_LOGGING:
222-
logging.info("Transaction rolled back successfully.")
216+
logger.info("Transaction rolled back successfully.")
223217

224218
def close(self) -> None:
225219
"""
@@ -243,4 +237,4 @@ def close(self) -> None:
243237
check_error(odbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc.value, ret)
244238

245239
if ENABLE_LOGGING:
246-
logging.info("Connection closed successfully.")
240+
logger.info("Connection closed successfully.")

mssql_python/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,3 +105,4 @@ class ConstantsODBC(Enum):
105105
SQL_PARAM_INPUT_OUTPUT = 3
106106
SQL_C_WCHAR = -8
107107
SQL_NULLABLE = 1
108+
SQL_MAX_NUMERIC_LEN = 16

mssql_python/cursor.py

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,17 @@
11
import ctypes
2-
import logging
32
import decimal, uuid
43
from typing import List, Union
5-
from mssql_python.logging_config import setup_logging, ENABLE_LOGGING
64
from mssql_python.constants import ConstantsODBC as odbc_sql_const
75
from mssql_python.helpers import check_error
6+
from mssql_python.logging_config import get_logger, ENABLE_LOGGING
87
import datetime
98
import decimal
109
import uuid
1110
import os
1211
from mssql_python.exceptions import raise_exception
1312
from mssql_python import ddbc_bindings
1413

15-
# Setting up logging
16-
setup_logging()
14+
logger = get_logger()
1715

1816
class Cursor:
1917
"""
@@ -60,7 +58,7 @@ def __init__(self, connection) -> None:
6058
# Is a list instead of a bool coz bools in Python are immutable.
6159
# Hence, we can't pass around bools by reference & modify them.
6260
# Therefore, it must be a list with exactly one bool element.
63-
61+
6462
def _is_unicode_string(self, param):
6563
"""
6664
Check if a string contains non-ASCII characters.
@@ -207,28 +205,67 @@ def _get_numeric_data(self, param):
207205
Get the data for a numeric parameter.
208206
209207
Args:
210-
param: The numeric parameter.
208+
param: The numeric parameter.
211209
212210
Returns:
213-
A tuple containing the numeric data.
214-
"""
211+
A NumericData struct containing the numeric data.
212+
"""
213+
decimal_as_tuple = param.as_tuple()
214+
num_digits = len(decimal_as_tuple.digits)
215+
exponent = decimal_as_tuple.exponent
216+
217+
# Calculate the SQL precision & scale
218+
# precision = no. of significant digits
219+
# scale = no. digits after decimal point
220+
if exponent >= 0:
221+
# digits=314, exp=2 ---> '31400' --> precision=5, scale=0
222+
precision = num_digits + exponent
223+
scale = 0
224+
elif (-1 * exponent) <= num_digits:
225+
# digits=3140, exp=-3 ---> '3.140' --> precision=4, scale=3
226+
precision = num_digits
227+
scale = exponent * -1
228+
else:
229+
# digits=3140, exp=-5 ---> '0.03140' --> precision=5, scale=5
230+
# TODO: double check the precision calculation here with SQL documentation
231+
precision = exponent * -1
232+
scale = exponent * -1
233+
234+
# TODO: Revisit this check, do we want this restriction?
235+
if precision > 15:
236+
raise ValueError("Precision of the numeric value is too high - " + str(param) +
237+
". Should be less than or equal to 15")
215238
NumericData = ddbc_bindings.NumericData
216239
numeric_data = NumericData()
217-
numeric_data.precision = len(param.as_tuple().digits)
218-
numeric_data.scale = param.as_tuple().exponent * -1
219-
numeric_data.sign = param.as_tuple().sign
220-
numeric_data.val = str(param)
221-
240+
numeric_data.scale = scale
241+
numeric_data.precision = precision
242+
numeric_data.sign = 1 if decimal_as_tuple.sign == 0 else 0
243+
# strip decimal point from param & convert the significant digits to integer
244+
# Ex: 12.34 ---> 1234
245+
val = str(param)
246+
if '.' in val:
247+
val = val.replace('.', '')
248+
val = val.replace('-', '')
249+
val = int(val)
250+
numeric_data.val = val
222251
return numeric_data
223252

224253
def _map_sql_type(self, param, parameters_list, i):
225-
"""Map a Python data type to the corresponding SQL type,C type,Columnsize and Decimal digits."""
254+
"""
255+
Map a Python data type to the corresponding SQL type,C type,Columnsize and Decimal digits.
256+
Takes:
257+
- param: The parameter to map.
258+
- parameters_list: The list of parameters to bind.
259+
- i: The index of the parameter in the list.
260+
Returns:
261+
- A tuple containing the SQL type, C type, column size, and decimal digits.
262+
"""
226263
if param is None:
227-
return odbc_sql_const.SQL_NULL_DATA.value, odbc_sql_const.SQL_C_DEFAULT.value, 1, 0
228-
264+
return odbc_sql_const.SQL_NULL_DATA.value, odbc_sql_const.SQL_C_DEFAULT.value, 1, 0
265+
229266
elif isinstance(param, bool):
230267
return odbc_sql_const.SQL_BIT.value, odbc_sql_const.SQL_C_BIT.value, 1, 0
231-
268+
232269
elif isinstance(param, int):
233270
if 0 <= param <= 255:
234271
return odbc_sql_const.SQL_TINYINT.value, odbc_sql_const.SQL_C_TINYINT.value, 3, 0
@@ -238,22 +275,20 @@ def _map_sql_type(self, param, parameters_list, i):
238275
return odbc_sql_const.SQL_INTEGER.value, odbc_sql_const.SQL_C_LONG.value, 10, 0
239276
else:
240277
return odbc_sql_const.SQL_BIGINT.value, odbc_sql_const.SQL_C_SBIGINT.value, 19, 0
241-
278+
242279
elif isinstance(param, float):
243-
if -3.4028235E+38 <= param <= 3.4028235E+38:
244-
return odbc_sql_const.SQL_REAL.value, odbc_sql_const.SQL_C_FLOAT.value, 7, 0
245-
else:
246-
return odbc_sql_const.SQL_FLOAT.value, odbc_sql_const.SQL_C_DOUBLE.value, 15, 0
247-
280+
return odbc_sql_const.SQL_DOUBLE.value, odbc_sql_const.SQL_C_DOUBLE.value, 15, 0
281+
248282
elif isinstance(param, decimal.Decimal):
283+
# TODO: Support for other numeric types (smallmoney, money etc.)
249284
# if param.as_tuple().exponent == -4: # Scale is 4
250285
# if -214748.3648 <= param <= 214748.3647:
251286
# return odbc_sql_const.SQL_SMALLMONEY.value, odbc_sql_const.SQL_C_NUMERIC.value, 10, 4
252287
# elif -922337203685477.5808 <= param <= 922337203685477.5807:
253288
# return odbc_sql_const.SQL_MONEY.value, odbc_sql_const.SQL_C_NUMERIC.value, 19, 4
254289
parameters_list[i] = self._get_numeric_data(param) # Replace the parameter with the dictionary
255-
return odbc_sql_const.SQL_DECIMAL.value, odbc_sql_const.SQL_C_NUMERIC.value, len(param.as_tuple().digits), param.as_tuple().exponent * -1
256-
290+
return odbc_sql_const.SQL_NUMERIC.value, odbc_sql_const.SQL_C_NUMERIC.value, parameters_list[i].precision, parameters_list[i].scale
291+
257292
elif isinstance(param, str):
258293
# Check for Well-Known Text (WKT) format for geography/geometry
259294
if param.startswith("POINT") or param.startswith("LINESTRING") or param.startswith("POLYGON"):
@@ -269,7 +304,7 @@ def _map_sql_type(self, param, parameters_list, i):
269304
elif self._parse_time(param):
270305
parameters_list[i] = self._parse_time(param)
271306
return odbc_sql_const.SQL_TIME.value, odbc_sql_const.SQL_C_TYPE_TIME.value, 8, 0
272-
# Unsupported types
307+
# TODO: Support for other types (Timestampoffset etc.)
273308
# elif self._parse_timestamptz(param):
274309
# return odbc_sql_const.SQL_TIMESTAMPOFFSET.value, odbc_sql_const.SQL_C_TYPE_TIMESTAMP.value, 34, 7
275310
# elif self._parse_smalldatetime(param):
@@ -466,7 +501,7 @@ def execute(self, operation: str, *parameters, use_prepare: bool = True, reset_c
466501

467502
ParamInfo = ddbc_bindings.ParamInfo
468503
parameters_type = []
469-
504+
470505
# Flatten parameters if a single tuple or list is passed
471506
if len(parameters) == 1 and isinstance(parameters[0], (tuple, list)):
472507
parameters = parameters[0]
@@ -487,19 +522,19 @@ def execute(self, operation: str, *parameters, use_prepare: bool = True, reset_c
487522
'''
488523
Execute SQL Statement - (SQLExecute)
489524
'''
525+
# TODO - Need to evaluate encrypted logs for query parameters
490526
if ENABLE_LOGGING:
491-
# TODO - Need to evaluate encrypted logs for query parameters
492-
logging.debug("Executing query: %s", operation)
527+
logger.debug("Executing query: %s", operation)
493528
for i, param in enumerate(parameters):
494-
logging.debug(
529+
logger.debug(
495530
"Parameter number: %s, Parameter: %s, Param Python Type: %s, ParamInfo: %s, %s, %s, %s, %s",
496531
i+1,
497532
param,
498533
str(type(param)),
499-
parameters_type[i].paramSQLType,
500-
parameters_type[i].paramCType,
501-
parameters_type[i].columnSize,
502-
parameters_type[i].decimalDigits,
534+
parameters_type[i].paramSQLType,
535+
parameters_type[i].paramCType,
536+
parameters_type[i].columnSize,
537+
parameters_type[i].decimalDigits,
503538
parameters_type[i].inputOutputType
504539
)
505540

@@ -538,7 +573,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
538573
# Converting the parameters to a list
539574
parameters = list(parameters)
540575
if ENABLE_LOGGING:
541-
logging.info("Executing query with parameters: %s", parameters)
576+
logger.info("Executing query with parameters: %s", parameters)
542577
# Prepare the statement only during first execution. From second time
543578
# onwards, skip preparing and directly execute. This helps avoid
544579
# unnecessary 'prepare' network calls.
@@ -558,7 +593,7 @@ def executemany(self, operation: str, seq_of_parameters: list) -> None:
558593
self.rowcount = total_rowcount
559594
except Exception as e:
560595
if ENABLE_LOGGING:
561-
logging.info("Executing query with parameters: %s", parameters)
596+
logger.info("Executing query with parameters: %s", parameters)
562597
# Prepare the statement only during first execution. From second time
563598
# onwards, skip preparing and directly execute. This helps avoid
564599
# unnecessary 'prepare' network calls.

mssql_python/db_connection.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,6 @@
1-
import logging
2-
from mssql_python.logging_config import setup_logging
31
from mssql_python.exceptions import DatabaseError, InterfaceError
42
from mssql_python.connection import Connection
53

6-
# Setting up logging
7-
setup_logging()
8-
94
def connect(connection_str: str) -> Connection:
105
"""
116
Constructor for creating a connection to the database.

mssql_python/exceptions.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from mssql_python.logging_config import setup_logging, ENABLE_LOGGING
2-
import logging
1+
from mssql_python.logging_config import get_logger, ENABLE_LOGGING
32

4-
setup_logging()
3+
logger = get_logger()
54

65
class Exception(Exception):
76
"""
@@ -231,7 +230,7 @@ def truncate_error_message(error_message: str) -> str:
231230
return string_first + string_third
232231
except Exception as e:
233232
if ENABLE_LOGGING:
234-
logging.error(f"Error while truncating error message: {e}")
233+
logger.error(f"Error while truncating error message: {e}")
235234
return error_message
236235

237236
def raise_exception(sqlstate: str, ddbc_error: str) -> None:
@@ -250,6 +249,6 @@ def raise_exception(sqlstate: str, ddbc_error: str) -> None:
250249
exception_class = sqlstate_to_exception(sqlstate, ddbc_error)
251250
if exception_class:
252251
if ENABLE_LOGGING:
253-
logging.error(exception_class)
252+
logger.error(exception_class)
254253
raise exception_class
255-
raise DatabaseError(driver_error="An error occurred with SQLSTATE code", ddbc_error=f"Unknown DDBC error: {sqlstate}")
254+
raise DatabaseError(driver_error="An error occurred with SQLSTATE code", ddbc_error=f"Unknown DDBC error: {sqlstate}")

mssql_python/helpers.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from mssql_python.constants import ConstantsODBC
22
from mssql_python import ddbc_bindings
33
from mssql_python.exceptions import raise_exception
4-
from mssql_python.logging_config import setup_logging, ENABLE_LOGGING
5-
import logging
4+
from mssql_python.logging_config import get_logger, ENABLE_LOGGING
65

7-
setup_logging()
6+
logger = get_logger()
87

98
def add_driver_to_connection_str(connection_str):
109
"""
@@ -62,7 +61,7 @@ def check_error(handle_type, handle, ret):
6261
if ret < 0:
6362
error_info = ddbc_bindings.DDBCSQLCheckError(handle_type, handle, ret)
6463
if ENABLE_LOGGING:
65-
logging.error(f"Error: {error_info.ddbcErrorMsg}")
64+
logger.error(f"Error: {error_info.ddbcErrorMsg}")
6665
raise_exception(error_info.sqlState, error_info.ddbcErrorMsg)
6766

6867
def add_driver_name_to_app_parameter(connection_string):

0 commit comments

Comments
 (0)