11import io
22import logging
33import pandas
4+ import pyodbc
5+ import re
6+
7+ import sqlalchemy .exc
48from sqlalchemy import create_engine
59from sqlalchemy import MetaData
610from sqlalchemy .schema import Table
11+ from sqlalchemy .sql import text
12+
713from rdl .ColumnTypeResolver import ColumnTypeResolver
814from rdl .data_sources .ChangeTrackingInfo import ChangeTrackingInfo
9- from sqlalchemy .sql import text
1015from rdl .shared import Constants
1116
1217
1318class MsSqlDataSource (object ):
1419 SOURCE_TABLE_ALIAS = 'src'
1520 CHANGE_TABLE_ALIAS = 'chg'
21+ MSSQL_STRING_REGEX = r"mssql\+pyodbc://" \
22+ r"(?:(?P<username>[^@/?&:]+)?:(?P<password>[^@/?&:]+)?@)?" \
23+ r"(?P<server>[^@/?&:]*)/(?P<database>[^@/?&:]*)" \
24+ r"\?driver=(?P<driver>[^@/?&:]*)" \
25+ r"(?:&failover=(?P<failover>[^@/?&:]*))?"
1626
1727 def __init__ (self , connection_string , logger = None ):
1828 self .logger = logger or logging .getLogger (__name__ )
1929 self .connection_string = connection_string
20- self .database_engine = create_engine (connection_string )
30+ self .database_engine = create_engine (connection_string , creator = self . create_connection_with_failover )
2131 self .column_type_resolver = ColumnTypeResolver ()
2232
2333 @staticmethod
2434 def can_handle_connection_string (connection_string ):
25- return connection_string .startswith (MsSqlDataSource .connection_string_prefix ())
35+ return MsSqlDataSource .connection_string_regex_match (connection_string ) is not None
36+
37+ @staticmethod
38+ def connection_string_regex_match (connection_string ):
39+ return re .match (MsSqlDataSource .MSSQL_STRING_REGEX , connection_string )
2640
2741 @staticmethod
2842 def connection_string_prefix ():
@@ -37,6 +51,36 @@ def prefix_column(column_name, full_refresh, primary_key_column_names):
3751 else :
3852 return f"{ MsSqlDataSource .SOURCE_TABLE_ALIAS } .{ column_name } "
3953
54+ def create_connection_with_failover (self ):
55+ conn_string_data = MsSqlDataSource .connection_string_regex_match (self .connection_string )
56+ server = conn_string_data .group ('server' )
57+ failover = conn_string_data .group ('failover' )
58+ database = conn_string_data .group ('database' )
59+ driver = "{" + conn_string_data .group ('driver' ).replace ('+' , ' ' )+ "}"
60+ dsn = f'DRIVER={ driver } ;DATABASE={ database } ;'
61+
62+ username = conn_string_data .group ('username' )
63+ password = conn_string_data .group ('password' )
64+
65+ login_cred = "Trusted_Connection=yes;"
66+ if username is not None and password is not None :
67+ login_cred = f'UID={ username } ;PWD={ password } ;'
68+
69+ dsn += login_cred
70+ self .logger .info (
71+ 'Parsed Connection Details: ' +
72+ f'''FAILOVER={ failover }
73+ SERVER={ server }
74+ DRIVER={ driver }
75+ DATABASE={ database } ''' )
76+ try :
77+ return pyodbc .connect (dsn , server = server )
78+ except (sqlalchemy .exc .OperationalError , pyodbc .OperationalError ) as e :
79+ if e .args [0 ] == "08001" and failover is not None :
80+ self .logger .warning (f'Using Failover Server: { failover } ' )
81+ return pyodbc .connect (dsn , server = failover )
82+ raise e
83+
4084 def build_select_statement (self , table_config , columns , batch_config , batch_key_tracker , full_refresh ,
4185 change_tracking_info ):
4286 column_array = list (
0 commit comments