|
1 | 1 | // Copyright (c) Microsoft Corporation. |
2 | 2 | // Licensed under the MIT license. |
3 | 3 |
|
| 4 | +// INFO|TODO - Note that is file is Windows specific right now. Making it arch agnostic will be |
| 5 | +// taken up in beta release |
| 6 | + |
4 | 7 | #include <pybind11/pybind11.h> // pybind11.h must be the first include - https://pybind11.readthedocs.io/en/latest/basics.html#header-and-namespace-conventions |
5 | 8 |
|
6 | 9 | #include <cstdint> |
@@ -237,14 +240,33 @@ void ThrowStdException(const std::string& message) { throw std::runtime_error(me |
237 | 240 | // TODO: We don't need to do explicit linking using LoadLibrary. We can just use implicit |
238 | 241 | // linking to load this DLL. It will simplify the code a lot. |
239 | 242 | void LoadDriverOrThrowException() { |
240 | | - // Get the DLL directory to the current directory |
241 | | - wchar_t currentDir[MAX_PATH]; |
242 | | - GetCurrentDirectoryW(MAX_PATH, currentDir); |
243 | | - std::wstring dllDir = std::wstring(currentDir) + L"\\libs\\win\\msodbcsql18.dll"; |
| 243 | + HMODULE hDdbcModule; |
| 244 | + wchar_t ddbcModulePath[MAX_PATH]; |
| 245 | + // Get the path to DDBC module: |
| 246 | + // GetModuleHandleExW returns a handle to current shared library (ddbc_bindings.pyd) given a |
| 247 | + // function from the library (LoadDriverOrThrowException). GetModuleFileNameW takes in the |
| 248 | + // library handle (hDdbcModule) & returns the full path to this library (ddbcModulePath) |
| 249 | + if (GetModuleHandleExW( |
| 250 | + GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT, |
| 251 | + (LPWSTR)&LoadDriverOrThrowException, &hDdbcModule) && |
| 252 | + GetModuleFileNameW(hDdbcModule, ddbcModulePath, MAX_PATH)) { |
| 253 | + // Look for last occurence of '\' in the path and set it to null |
| 254 | + wchar_t* lastBackSlash = wcsrchr(ddbcModulePath, L'\\'); |
| 255 | + if (lastBackSlash == nullptr) { |
| 256 | + DEBUG_LOG("Invalid DDBC module path - %S", ddbcModulePath); |
| 257 | + ThrowStdException("Failed to load driver"); |
| 258 | + } |
| 259 | + *lastBackSlash = 0; |
| 260 | + } else { |
| 261 | + DEBUG_LOG("Failed to get DDBC module path. Error code - %d", GetLastError()); |
| 262 | + ThrowStdException("Failed to load driver"); |
| 263 | + } |
244 | 264 |
|
245 | | - // Load the DLL from the specified path |
| 265 | + // Look for msodbcsql18.dll in a path relative to DDBC module |
| 266 | + std::wstring dllDir = std::wstring(ddbcModulePath) + L"\\libs\\win\\msodbcsql18.dll"; |
246 | 267 | HMODULE hModule = LoadLibraryW(dllDir.c_str()); |
247 | 268 | if (!hModule) { |
| 269 | + DEBUG_LOG("LoadLibraryW failed to load driver from - %S", dllDir.c_str()); |
248 | 270 | ThrowStdException("Failed to load driver"); |
249 | 271 | } |
250 | 272 | DEBUG_LOG("Driver loaded successfully from - %S", dllDir.c_str()); |
@@ -294,6 +316,7 @@ void LoadDriverOrThrowException() { |
294 | 316 | SQLFreeHandle_ptr && SQLDisconnect_ptr && SQLFreeStmt_ptr && SQLGetDiagRec_ptr; |
295 | 317 |
|
296 | 318 | if (!success) { |
| 319 | + DEBUG_LOG("Failed to load required function pointers from driver - %S", dllDir.c_str()); |
297 | 320 | ThrowStdException("Failed to load required function pointers from driver"); |
298 | 321 | } |
299 | 322 | DEBUG_LOG("Sucessfully loaded function pointers from driver"); |
|
0 commit comments