Skip to content
This repository was archived by the owner on Mar 13, 2020. It is now read-only.

Commit 644b04c

Browse files
committed
[SP-3] Add unit test for AWSLambdaDataSource, rename connection_string_prefix() to get_connection_string_prefix(), reformat file
1 parent 74544af commit 644b04c

File tree

4 files changed

+109
-45
lines changed

4 files changed

+109
-45
lines changed

rdl/data_sources/AWSLambdaDataSource.py

Lines changed: 71 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,101 +11,129 @@
1111

1212
class AWSLambdaDataSource(object):
1313
# 'aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;'
14-
CONNECTION_STRING_PREFIX = 'aws-lambda://'
15-
CONNECTION_STRING_GROUP_SEPARATOR = ';'
16-
CONNECTION_STRING_KEY_VALUE_SEPARATOR = '='
14+
CONNECTION_STRING_PREFIX = "aws-lambda://"
15+
CONNECTION_STRING_GROUP_SEPARATOR = ";"
16+
CONNECTION_STRING_KEY_VALUE_SEPARATOR = "="
1717

1818
def __init__(self, connection_string, logger=None):
1919
self.logger = logger or logging.getLogger(__name__)
2020
if not AWSLambdaDataSource.can_handle_connection_string(connection_string):
2121
raise ValueError(connection_string)
2222
self.connection_string = connection_string
23-
self.connection_data = dict(kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR) for kv in
24-
self.connection_string
25-
.lstrip(AWSLambdaDataSource.CONNECTION_STRING_PREFIX)
26-
.rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
27-
.split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR))
28-
self.aws_lambda_client = boto3.client('lambda')
23+
self.connection_data = dict(
24+
kv.split(AWSLambdaDataSource.CONNECTION_STRING_KEY_VALUE_SEPARATOR)
25+
for kv in self.connection_string.lstrip(
26+
AWSLambdaDataSource.CONNECTION_STRING_PREFIX
27+
)
28+
.rstrip(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
29+
.split(AWSLambdaDataSource.CONNECTION_STRING_GROUP_SEPARATOR)
30+
)
31+
self.aws_lambda_client = boto3.client("lambda")
2932

3033
@staticmethod
3134
def can_handle_connection_string(connection_string):
32-
return connection_string.startswith(AWSLambdaDataSource.CONNECTION_STRING_PREFIX)
35+
return connection_string.startswith(
36+
AWSLambdaDataSource.CONNECTION_STRING_PREFIX
37+
) and len(connection_string) != len(
38+
AWSLambdaDataSource.CONNECTION_STRING_PREFIX
39+
)
3340

3441
@staticmethod
35-
def connection_string_prefix():
42+
def get_connection_string_prefix():
3643
return AWSLambdaDataSource.CONNECTION_STRING_PREFIX
3744

3845
def get_table_info(self, table_config, last_known_sync_version):
39-
column_names, last_sync_version, sync_version, full_refresh_required, data_changed_since_last_sync = \
40-
self.__get_table_info(table_config, last_known_sync_version)
46+
column_names, last_sync_version, sync_version, full_refresh_required, data_changed_since_last_sync = self.__get_table_info(
47+
table_config, last_known_sync_version
48+
)
4149
columns_in_database = column_names
4250
change_tracking_info = ChangeTrackingInfo(
4351
last_sync_version=last_sync_version,
4452
sync_version=sync_version,
4553
force_full_load=full_refresh_required,
46-
data_changed_since_last_sync=data_changed_since_last_sync)
54+
data_changed_since_last_sync=data_changed_since_last_sync,
55+
)
4756
source_table_info = SourceTableInfo(columns_in_database, change_tracking_info)
4857
return source_table_info
4958

5059
@prevent_senstive_data_logging
51-
def get_table_data_frame(self, table_config, columns, batch_config, batch_tracker, batch_key_tracker,
52-
full_refresh, change_tracking_info):
60+
def get_table_data_frame(
61+
self,
62+
table_config,
63+
columns,
64+
batch_config,
65+
batch_tracker,
66+
batch_key_tracker,
67+
full_refresh,
68+
change_tracking_info,
69+
):
5370
self.logger.debug(f"Starting read data from lambda.. : \n{None}")
54-
column_names, data = self.__get_table_data(table_config, batch_config, change_tracking_info, full_refresh, columns, batch_key_tracker)
71+
column_names, data = self.__get_table_data(
72+
table_config,
73+
batch_config,
74+
change_tracking_info,
75+
full_refresh,
76+
columns,
77+
batch_key_tracker,
78+
)
5579
self.logger.debug(f"Finished read data from lambda.. : \n{None}")
5680
# should we log size of data extracted?
57-
data_frame = pandas.DataFrame(data=data, columns=column_names)
81+
data_frame = self.__get_data_frame(data, column_names)
5882
batch_tracker.extract_completed_successfully(len(data_frame))
5983
return data_frame
6084

6185
def __get_table_info(self, table_config, last_known_sync_version):
6286
pay_load = {
6387
"command": "GetTableInfo",
64-
"tenantId": self.connection_data['tenant'],
65-
"table": {
66-
"schema": table_config['schema'],
67-
"name": table_config['name']
68-
},
69-
"commandPayload": {
70-
"lastSyncVersion": last_known_sync_version,
71-
}
88+
"tenantId": self.connection_data["tenant"],
89+
"table": {"schema": table_config["schema"], "name": table_config["name"]},
90+
"commandPayload": {"lastSyncVersion": last_known_sync_version},
7291
}
7392

7493
result = self.__invoke_lambda(pay_load)
7594

76-
return result['ColumnNames'], result['Data']
77-
78-
def __get_table_data(self, table_config, batch_config, change_tracking_info, full_refresh, columns, batch_key_tracker):
95+
return result["ColumnNames"], result["Data"]
96+
97+
def __get_table_data(
98+
self,
99+
table_config,
100+
batch_config,
101+
change_tracking_info,
102+
full_refresh,
103+
columns,
104+
batch_key_tracker,
105+
):
79106
pay_load = {
80107
"command": "GetTableData",
81108
"tenantId": 543, # self.connection_string.tenant.split('_')[0] as int
82-
"table": {
83-
"schema": table_config['schema'],
84-
"name": table_config['name']
85-
},
109+
"table": {"schema": table_config["schema"], "name": table_config["name"]},
86110
"commandPayload": {
87111
"auditColumnNameForChangeVersion": Providers.AuditColumnsNames.CHANGE_VERSION,
88112
"auditColumnNameForDeletionFlag": Providers.AuditColumnsNames.IS_DELETED,
89-
"batchSize": batch_config['size'],
113+
"batchSize": batch_config["size"],
90114
"lastSyncVersion": change_tracking_info.last_sync_version,
91115
"fullRefresh": full_refresh,
92116
"columnNames": columns,
93-
"primaryKeyColumnNames": table_config['primary_keys'],
94-
"lastBatchPrimaryKeys": [{k: v} for k, v in batch_key_tracker.bookmarks.items()]
95-
}
117+
"primaryKeyColumnNames": table_config["primary_keys"],
118+
"lastBatchPrimaryKeys": [
119+
{k: v} for k, v in batch_key_tracker.bookmarks.items()
120+
],
121+
},
96122
}
97123

98124
result = self.__invoke_lambda(pay_load)
99125

100-
return result['ColumnNames'], result['Data']
126+
return result["ColumnNames"], result["Data"]
127+
128+
def __get_data_frame(self, data: [[]], column_names: []):
129+
return pandas.DataFrame(data=data, columns=column_names)
101130

102131
def __invoke_lambda(self, pay_load):
103132
lambda_response = self.aws_lambda_client.invoke(
104-
FunctionName=self.connection_data['function'],
105-
InvocationType='RequestResponse',
106-
LogType='None', # |'Tail', Set to Tail to include the execution log in the response
107-
Payload=json.dump(pay_load).encode()
133+
FunctionName=self.connection_data["function"],
134+
InvocationType="RequestResponse",
135+
LogType="None", # |'Tail', Set to Tail to include the execution log in the response
136+
Payload=json.dump(pay_load).encode(),
108137
)
109138
result = json.loads(lambda_response.Payload.read()) # .decode()
110139
return result
111-

rdl/data_sources/DataSourceFactory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,4 @@ def is_prefix_supported(self, connection_string):
2323
return False
2424

2525
def get_supported_source_prefixes(self):
26-
return list(map(lambda source: source.connection_string_prefix(), self.sources))
26+
return list(map(lambda source: source.get_connection_string_prefix(), self.sources))

rdl/data_sources/MsSqlDataSource.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def can_handle_connection_string(connection_string):
3737
return MsSqlDataSource.__connection_string_regex_match(connection_string) is not None
3838

3939
@staticmethod
40-
def connection_string_prefix():
40+
def get_connection_string_prefix():
4141
return 'mssql+pyodbc://'
4242

4343
def get_table_info(self, table_config, last_known_sync_version):
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import unittest
2+
3+
from rdl.data_sources.AWSLambdaDataSource import AWSLambdaDataSource
4+
5+
6+
class TestAWSLambdaDataSource(unittest.TestCase):
7+
data_source = None
8+
table_configs = []
9+
10+
@classmethod
11+
def setUpClass(cls):
12+
TestAWSLambdaDataSource.data_source = AWSLambdaDataSource(
13+
"aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;"
14+
)
15+
16+
@classmethod
17+
def tearDownClass(cls):
18+
TestAWSLambdaDataSource.data_source = None
19+
20+
def test_can_handle_valid_connection_string(self):
21+
self.assertTrue(
22+
self.data_source.can_handle_connection_string(
23+
"aws-lambda://tenant=543_dc2;function=123456789012:function:my-function;"
24+
)
25+
)
26+
27+
def test_can_handle_invalid_connection_string(self):
28+
self.assertFalse(
29+
self.data_source.can_handle_connection_string(
30+
"lambda-aws://tenant=543_dc2;function=123456789012:function:my-function;"
31+
)
32+
)
33+
34+
35+
if __name__ == "__main__":
36+
unittest.main()

0 commit comments

Comments
 (0)