Skip to content

Commit e2bb27a

Browse files
author
Theekshna Kotian
committed
Merged PR 5318: Use DDBC module path to find ODBC module path
This PR removes reliance on current working directory (CWD) when locating ODBC dll. User can run their apps from any directory. We cannot rely on CWD to find ODBC dll. Additionally, GetCurrentDirectory API is not safe to use from a shared library & in a multithreaded app (acc to it's [documentation](https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-getcurrentdirectory#remarks)) Instead, look for ODBC dll relative to the ddbc_bindings.pyd file path. The pyd will be installed somewhere inside the Python installation directory. The expectation is that we will find msodbcsql18.dll in the following directory - **<ddbc_bindings.pyd directory>\libs\win\msodbcsql18.dll** Related work items: #33826
1 parent 1d4a417 commit e2bb27a

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

mssql_python/cursor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,6 @@
1010
import uuid
1111
import os
1212
from mssql_python.exceptions import raise_exception
13-
14-
os.chdir(os.path.dirname(__file__))
15-
1613
from mssql_python import ddbc_bindings
1714

1815
# Setting up logging

mssql_python/pybind/ddbc_bindings.cpp

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
// Copyright (c) Microsoft Corporation.
22
// Licensed under the MIT license.
33

4+
// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be
5+
// taken up in beta release
6+
47
#include <pybind11/pybind11.h> // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions
58

69
#include <cstdint>
@@ -237,14 +240,33 @@ void ThrowStdException(const std::string& message) { throw std::runtime_error(me
237240
// TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit
238241
// linking to load this DLL. It will simplify the code a lot.
239242
void LoadDriverOrThrowException() {
240-
// Get the DLL directory to the current directory
241-
wchar_t currentDir[MAX_PATH];
242-
GetCurrentDirectoryW(MAX_PATH, currentDir);
243-
std::wstring dllDir = std::wstring(currentDir) + L"\\libs\\win\\msodbcsql18.dll";
243+
HMODULE hDdbcModule;
244+
wchar_t ddbcModulePath[MAX_PATH];
245+
// Get the path to DDBC module:
246+
// GetModuleHandleExW returns a handle to current shared library (ddbc_bindings.pyd) given a
247+
// function from the library (LoadDriverOrThrowException). GetModuleFileNameW takes in the
248+
// library handle (hDdbcModule) & returns the full path to this library (ddbcModulePath)
249+
if (GetModuleHandleExW(
250+
GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
251+
(LPWSTR)&LoadDriverOrThrowException, &hDdbcModule) &&
252+
GetModuleFileNameW(hDdbcModule, ddbcModulePath, MAX_PATH)) {
253+
// Look for last occurence of '\' in the path and set it to null
254+
wchar_t* lastBackSlash = wcsrchr(ddbcModulePath, L'\\');
255+
if (lastBackSlash == nullptr) {
256+
DEBUG_LOG("Invalid DDBC module path - %S", ddbcModulePath);
257+
ThrowStdException("Failed to load driver");
258+
}
259+
*lastBackSlash = 0;
260+
} else {
261+
DEBUG_LOG("Failed to get DDBC module path. Error code - %d", GetLastError());
262+
ThrowStdException("Failed to load driver");
263+
}
244264

245-
// Load the DLL from the specified path
265+
// Look for msodbcsql18.dll in a path relative to DDBC module
266+
std::wstring dllDir = std::wstring(ddbcModulePath) + L"\\libs\\win\\msodbcsql18.dll";
246267
HMODULE hModule = LoadLibraryW(dllDir.c_str());
247268
if (!hModule) {
269+
DEBUG_LOG("LoadLibraryW failed to load driver from - %S", dllDir.c_str());
248270
ThrowStdException("Failed to load driver");
249271
}
250272
DEBUG_LOG("Driver loaded successfully from - %S", dllDir.c_str());
@@ -294,6 +316,7 @@ void LoadDriverOrThrowException() {
294316
SQLFreeHandle_ptr && SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr;
295317

296318
if (!success) {
319+
DEBUG_LOG("Failed to load required function pointers from driver - %S", dllDir.c_str());
297320
ThrowStdException("Failed to load required function pointers from driver");
298321
}
299322
DEBUG_LOG("Sucessfully loaded function pointers from driver");

0 commit comments

Comments
 (0)