Skip to content

Commit a0e43e5

Browse files
authored
Merge pull request #45 from microsoft/jahnvithakkar/access_token
FEAT: Access Token Login
2 parents d2b82b5 + a10cb00 commit a0e43e5

File tree

6 files changed

+177
-25
lines changed

6 files changed

+177
-25
lines changed

mssql_python/connection.py

Lines changed: 101 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class Connection:
3333
close() -> None:
3434
"""
3535

36-
def __init__(self, connection_str: str, autocommit: bool = False, **kwargs) -> None:
36+
def __init__(self, connection_str: str = "", autocommit: bool = False, attrs_before: dict = None, **kwargs) -> None:
3737
"""
3838
Initialize the connection object with the specified connection string and parameters.
3939
@@ -58,11 +58,12 @@ def __init__(self, connection_str: str, autocommit: bool = False, **kwargs) -> N
5858
self.connection_str = self._construct_connection_string(
5959
connection_str, **kwargs
6060
)
61+
self._attrs_before = attrs_before
62+
self._autocommit = autocommit # Initialize _autocommit before calling _initializer
6163
self._initializer()
62-
self._autocommit = autocommit
6364
self.setautocommit(autocommit)
6465

65-
def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
66+
def _construct_connection_string(self, connection_str: str = "", **kwargs) -> str:
6667
"""
6768
Construct the connection string by concatenating the connection string
6869
with key/value pairs from kwargs.
@@ -76,13 +77,14 @@ def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
7677
"""
7778
# Add the driver attribute to the connection string
7879
conn_str = add_driver_to_connection_str(connection_str)
80+
7981
# Add additional key-value pairs to the connection string
8082
for key, value in kwargs.items():
81-
if key.lower() == "host":
83+
if key.lower() == "host" or key.lower() == "server":
8284
key = "Server"
83-
elif key.lower() == "user":
85+
elif key.lower() == "user" or key.lower() == "uid":
8486
key = "Uid"
85-
elif key.lower() == "password":
87+
elif key.lower() == "password" or key.lower() == "pwd":
8688
key = "Pwd"
8789
elif key.lower() == "database":
8890
key = "Database"
@@ -93,6 +95,11 @@ def _construct_connection_string(self, connection_str: str, **kwargs) -> str:
9395
else:
9496
continue
9597
conn_str += f"{key}={value};"
98+
print(f"Connection string after adding driver: {conn_str}")
99+
100+
if ENABLE_LOGGING:
101+
logger.info("Final connection string: %s", conn_str)
102+
96103
return conn_str
97104

98105
def _is_closed(self) -> bool:
@@ -103,7 +110,7 @@ def _is_closed(self) -> bool:
103110
bool: True if the connection is closed, False otherwise.
104111
"""
105112
return self.hdbc is None
106-
113+
107114
def _initializer(self) -> None:
108115
"""
109116
Initialize the environment and connection handles.
@@ -115,9 +122,79 @@ def _initializer(self) -> None:
115122
self._allocate_environment_handle()
116123
self._set_environment_attributes()
117124
self._allocate_connection_handle()
118-
self._set_connection_attributes()
125+
if self._attrs_before != {}:
126+
self._apply_attrs_before() # Apply pre-connection attributes
127+
if self._autocommit:
128+
self._set_connection_attributes(
129+
ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value,
130+
ddbc_sql_const.SQL_AUTOCOMMIT_ON.value,
131+
)
119132
self._connect_to_db()
120133

134+
def _apply_attrs_before(self):
135+
"""
136+
Apply specific pre-connection attributes.
137+
Currently, this method only processes an attribute with key 1256 (e.g., SQL_COPT_SS_ACCESS_TOKEN)
138+
if present in `self._attrs_before`. Other attributes are ignored.
139+
140+
Returns:
141+
bool: True.
142+
"""
143+
144+
if ENABLE_LOGGING:
145+
logger.info("Attempting to apply pre-connection attributes (attrs_before): %s", self._attrs_before)
146+
147+
if not isinstance(self._attrs_before, dict):
148+
if self._attrs_before is not None and ENABLE_LOGGING:
149+
logger.warning(
150+
f"_attrs_before is of type {type(self._attrs_before).__name__}, "
151+
f"expected dict. Skipping attribute application."
152+
)
153+
elif self._attrs_before is None and ENABLE_LOGGING:
154+
logger.debug("_attrs_before is None. No pre-connection attributes to apply.")
155+
return True # Exit if _attrs_before is not a dictionary or is None
156+
157+
for key, value in self._attrs_before.items():
158+
ikey = None
159+
if isinstance(key, int):
160+
ikey = key
161+
elif isinstance(key, str) and key.isdigit():
162+
try:
163+
ikey = int(key)
164+
except ValueError:
165+
if ENABLE_LOGGING:
166+
logger.debug(
167+
f"Skipping attribute with key '{key}' in attrs_before: "
168+
f"could not convert string to int."
169+
)
170+
continue # Skip if string key is not a valid integer
171+
else:
172+
if ENABLE_LOGGING:
173+
logger.debug(
174+
f"Skipping attribute with key '{key}' in attrs_before due to "
175+
f"unsupported key type: {type(key).__name__}. Expected int or string representation of an int."
176+
)
177+
continue # Skip keys that are not int or string representation of an int
178+
179+
if ikey == ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value:
180+
if ENABLE_LOGGING:
181+
logger.info(
182+
f"Found attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value}. Attempting to set it."
183+
)
184+
self._set_connection_attributes(ikey, value)
185+
if ENABLE_LOGGING:
186+
logger.info(
187+
f"Call to set attribute {ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value} with value '{value}' completed."
188+
)
189+
# If you expect only one such key, you could add 'break' here.
190+
else:
191+
if ENABLE_LOGGING:
192+
logger.debug(
193+
f"Ignoring attribute with key '{key}' (resolved to {ikey}) in attrs_before "
194+
f"as it is not the target attribute ({ddbc_sql_const.SQL_COPT_SS_ACCESS_TOKEN.value})."
195+
)
196+
return True
197+
121198
def _allocate_environment_handle(self):
122199
"""
123200
Allocate the environment handle.
@@ -152,18 +229,25 @@ def _allocate_connection_handle(self):
152229
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, handle, ret)
153230
self.hdbc = handle
154231

155-
def _set_connection_attributes(self):
232+
def _set_connection_attributes(self, ikey: int, ivalue: any) -> None:
156233
"""
157234
Set the connection attributes before connecting.
235+
236+
Args:
237+
ikey (int): The attribute key to set.
238+
ivalue (Any): The value to set for the attribute. Can be bytes, bytearray, int, or unicode.
239+
vallen (int): The length of the value.
240+
241+
Raises:
242+
DatabaseError: If there is an error while setting the connection attribute.
158243
"""
159-
if self.autocommit:
160-
ret = ddbc_bindings.DDBCSQLSetConnectAttr(
161-
self.hdbc, # Using the wrapper class
162-
ddbc_sql_const.SQL_ATTR_AUTOCOMMIT.value,
163-
ddbc_sql_const.SQL_AUTOCOMMIT_ON.value,
164-
0
165-
)
166-
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)
244+
245+
ret = ddbc_bindings.DDBCSQLSetConnectAttr(
246+
self.hdbc, # Connection handle
247+
ikey, # Attribute
248+
ivalue, # Value
249+
)
250+
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)
167251

168252
def _connect_to_db(self) -> None:
169253
"""
@@ -224,7 +308,6 @@ def autocommit(self, value: bool) -> None:
224308
if value
225309
else ddbc_sql_const.SQL_AUTOCOMMIT_OFF.value
226310
), # Value
227-
0, # String length
228311
)
229312
check_error(ddbc_sql_const.SQL_HANDLE_DBC.value, self.hdbc, ret)
230313
self._autocommit = value

mssql_python/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,5 @@ class ConstantsDDBC(Enum):
116116
SQL_C_WCHAR = -8
117117
SQL_NULLABLE = 1
118118
SQL_MAX_NUMERIC_LEN = 16
119+
SQL_IS_POINTER = -4
120+
SQL_COPT_SS_ACCESS_TOKEN = 1256

mssql_python/db_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mssql_python.connection import Connection
77

88

9-
def connect(connection_str: str, autocommit: bool = True, **kwargs) -> Connection:
9+
def connect(connection_str: str = "", autocommit: bool = True, attrs_before: dict = None, **kwargs) -> Connection:
1010
"""
1111
Constructor for creating a connection to the database.
1212
@@ -34,5 +34,5 @@ def connect(connection_str: str, autocommit: bool = True, **kwargs) -> Connectio
3434
be used to perform database operations such as executing queries, committing
3535
transactions, and closing the connection.
3636
"""
37-
conn = Connection(connection_str, autocommit=autocommit, **kwargs)
37+
conn = Connection(connection_str, autocommit=autocommit, attrs_before=attrs_before, **kwargs)
3838
return conn

mssql_python/helpers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def add_driver_to_connection_str(connection_str):
4646
# Insert the driver attribute at the beginning of the connection string
4747
final_connection_attributes.insert(0, driver_name)
4848
connection_str = ";".join(final_connection_attributes)
49+
print(f"Connection string after adding driver: {connection_str}")
4950
except Exception as e:
5051
raise Exception(
5152
"Invalid connection string, Please follow the format: "

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -692,18 +692,58 @@ SQLRETURN SQLSetEnvAttr_wrap(SqlHandlePtr EnvHandle, SQLINTEGER Attribute, intpt
692692
}
693693

694694
// Wrap SQLSetConnectAttr
695-
SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute, intptr_t ValuePtr,
696-
SQLINTEGER StringLength) {
695+
SQLRETURN SQLSetConnectAttr_wrap(SqlHandlePtr ConnectionHandle, SQLINTEGER Attribute,
696+
py::object ValuePtr) {
697697
LOG("Set SQL Connection Attribute");
698698
if (!SQLSetConnectAttr_ptr) {
699699
LoadDriverOrThrowException();
700700
}
701701

702-
// TODO: Does ValuePtr need to be converted from Python to C++ object?
703-
SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, reinterpret_cast<SQLPOINTER>(ValuePtr), StringLength);
702+
// Print the type of ValuePtr and attribute value - helpful for debugging
703+
LOG("Type of ValuePtr: {}, Attribute: {}", py::type::of(ValuePtr).attr("__name__").cast<std::string>(), Attribute);
704+
705+
SQLPOINTER value = 0;
706+
SQLINTEGER length = 0;
707+
708+
if (py::isinstance<py::int_>(ValuePtr)) {
709+
// Handle integer values
710+
int intValue = ValuePtr.cast<int>();
711+
value = reinterpret_cast<SQLPOINTER>(intValue);
712+
length = SQL_IS_INTEGER; // Integer values don't require a length
713+
// } else if (py::isinstance<py::str>(ValuePtr)) {
714+
// // Handle Unicode string values
715+
// static std::wstring unicodeValueBuffer;
716+
// unicodeValueBuffer = ValuePtr.cast<std::wstring>();
717+
// value = const_cast<SQLWCHAR*>(unicodeValueBuffer.c_str());
718+
// length = SQL_NTS; // Indicates null-terminated string
719+
} else if (py::isinstance<py::bytes>(ValuePtr) || py::isinstance<py::bytearray>(ValuePtr)) {
720+
// Handle byte or bytearray values (like access tokens)
721+
// Store in static buffer to ensure memory remains valid during connection
722+
static std::vector<std::string> bytesBuffers;
723+
bytesBuffers.push_back(ValuePtr.cast<std::string>());
724+
value = const_cast<char*>(bytesBuffers.back().c_str());
725+
length = SQL_IS_POINTER; // Indicates we're passing a pointer (required for token)
726+
// } else if (py::isinstance<py::list>(ValuePtr) || py::isinstance<py::tuple>(ValuePtr)) {
727+
// // Handle list or tuple values
728+
// LOG("ValuePtr is a sequence (list or tuple)");
729+
// for (py::handle item : ValuePtr) {
730+
// LOG("Processing item in sequence");
731+
// SQLRETURN ret = SQLSetConnectAttr_wrap(ConnectionHandle, Attribute, py::reinterpret_borrow<py::object>(item));
732+
// if (!SQL_SUCCEEDED(ret)) {
733+
// LOG("Failed to set attribute for item in sequence");
734+
// return ret;
735+
// }
736+
// }
737+
} else {
738+
LOG("Unsupported ValuePtr type");
739+
return SQL_ERROR;
740+
}
741+
742+
SQLRETURN ret = SQLSetConnectAttr_ptr(ConnectionHandle->get(), Attribute, value, length);
704743
if (!SQL_SUCCEEDED(ret)) {
705744
LOG("Failed to set Connection attribute");
706745
}
746+
LOG("Set Connection attribute successfully");
707747
return ret;
708748
}
709749

tests/test_003_connection.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,33 @@ def test_connection(db_connection):
3333

3434
def test_construct_connection_string(db_connection):
3535
# Check if the connection string is constructed correctly with kwargs
36-
conn_str = db_connection._construct_connection_string("",host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes")
36+
conn_str = db_connection._construct_connection_string(host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes")
37+
assert "Server=localhost;" in conn_str, "Connection string should contain 'Server=localhost;'"
38+
assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'"
39+
assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'"
40+
assert "Database=mydb;" in conn_str, "Connection string should contain 'Database=mydb;'"
41+
assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'"
42+
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
43+
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
44+
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
45+
assert "Driver={ODBC Driver 18 for SQL Server};;APP=MSSQL-Python;Server=localhost;Uid=me;Pwd=mypwd;Database=mydb;Encrypt=yes;TrustServerCertificate=yes;" == conn_str, "Connection string is incorrect"
46+
47+
def test_connection_string_with_attrs_before(db_connection):
48+
# Check if the connection string is constructed correctly with attrs_before
49+
conn_str = db_connection._construct_connection_string(host="localhost", user="me", password="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes", attrs_before={1256: "token"})
50+
assert "Server=localhost;" in conn_str, "Connection string should contain 'Server=localhost;'"
51+
assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'"
52+
assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'"
53+
assert "Database=mydb;" in conn_str, "Connection string should contain 'Database=mydb;'"
54+
assert "Encrypt=yes;" in conn_str, "Connection string should contain 'Encrypt=yes;'"
55+
assert "TrustServerCertificate=yes;" in conn_str, "Connection string should contain 'TrustServerCertificate=yes;'"
56+
assert "APP=MSSQL-Python" in conn_str, "Connection string should contain 'APP=MSSQL-Python'"
57+
assert "Driver={ODBC Driver 18 for SQL Server}" in conn_str, "Connection string should contain 'Driver={ODBC Driver 18 for SQL Server}'"
58+
assert "{1256: token}" not in conn_str, "Connection string should not contain '{1256: token}'"
59+
60+
def test_connection_string_with_odbc_param(db_connection):
61+
# Check if the connection string is constructed correctly with ODBC parameters
62+
conn_str = db_connection._construct_connection_string(server="localhost", uid="me", pwd="mypwd", database="mydb", encrypt="yes", trust_server_certificate="yes")
3763
assert "Server=localhost;" in conn_str, "Connection string should contain 'Server=localhost;'"
3864
assert "Uid=me;" in conn_str, "Connection string should contain 'Uid=me;'"
3965
assert "Pwd=mypwd;" in conn_str, "Connection string should contain 'Pwd=mypwd;'"

0 commit comments

Comments
 (0)