|
6 | 6 |
|
7 | 7 |
|
8 | 8 | class BatchDataLoader(object): |
9 | | - def __init__(self, data_source, source_table_configuration, target_table_configuration, columns, data_load_tracker, batch_configuration, target_engine, logger=None): |
| 9 | + def __init__(self, data_source, source_table_configuration, target_schema, target_table, columns, data_load_tracker, batch_configuration, target_engine, logger=None): |
10 | 10 | self.logger = logger or logging.getLogger(__name__) |
11 | 11 | self.source_table_configuration = source_table_configuration |
12 | 12 | self.columns = columns |
13 | 13 | self.data_source = data_source |
14 | | - self.target_table_configuration = target_table_configuration |
| 14 | + self.target_schema = target_schema |
| 15 | + self.target_table = target_table |
15 | 16 | self.data_load_tracker = data_load_tracker |
16 | 17 | self.batch_configuration = batch_configuration |
17 | 18 | self.target_engine = target_engine |
18 | 19 |
|
19 | 20 | # Imports rows, returns True if >0 rows were found |
20 | | - def import_batch(self, previous_batch_key): |
| 21 | + def load_batch(self, previous_batch_key): |
21 | 22 | batch_tracker = self.data_load_tracker.start_batch() |
22 | 23 |
|
23 | 24 | self.logger.debug("ImportBatch Starting from previous_batch_key: {0}".format(previous_batch_key)) |
24 | 25 |
|
25 | 26 | data_frame = self.data_source.get_next_data_frame(self.source_table_configuration, self.columns, self.batch_configuration, batch_tracker, previous_batch_key) |
26 | 27 |
|
27 | | - if len(data_frame) == 0: |
| 28 | + if data_frame is None or len(data_frame) == 0: |
28 | 29 | self.logger.debug("There are no rows to import, returning -1") |
29 | 30 | batch_tracker.load_skipped_due_to_zero_rows() |
30 | 31 | return -1 |
31 | 32 |
|
32 | 33 | data_frame = self.attach_column_transformers(data_frame) |
33 | 34 |
|
34 | | - self.write_data_frame_to_table(data_frame, self.target_table_configuration, self.target_engine) |
| 35 | + self.write_data_frame_to_table(data_frame) |
35 | 36 | batch_tracker.load_completed_successfully() |
36 | 37 |
|
37 | 38 | last_key_returned = data_frame.iloc[-1][self.batch_configuration['source_unique_column']] |
38 | 39 |
|
39 | 40 | self.logger.info("Batch key {0} Completed. {1}".format(last_key_returned, batch_tracker.get_statistics())) |
40 | 41 | return last_key_returned |
41 | 42 |
|
42 | | - def write_data_frame_to_table(self, data_frame, table_configuration, target_engine): |
43 | | - destination_table = "{0}.{1}".format(table_configuration['schema'], table_configuration['name']) |
44 | | - self.logger.debug("Starting write to table {0}".format(destination_table)) |
| 43 | + def write_data_frame_to_table(self, data_frame): |
| 44 | + qualified_target_table = "{0}.{1}".format(self.target_schema, self.target_table) |
| 45 | + self.logger.debug("Starting write to table {0}".format(qualified_target_table)) |
45 | 46 | data = StringIO() |
46 | 47 | data_frame.to_csv(data, header=False, index=False, na_rep='') |
47 | 48 | data.seek(0) |
48 | | - raw = target_engine.raw_connection() |
| 49 | + raw = self.target_engine.raw_connection() |
49 | 50 | curs = raw.cursor() |
50 | 51 |
|
51 | | - #TODO: This is assuming that our destination schema column order matches the columns in the dataframe. This |
52 | | - #isn't always correct (especially in csv sources) - therefore, we should derive the column_array from the |
53 | | - #data frames' columns. |
54 | | - column_array = list(map(lambda cfg: cfg['destination']['name'], self.columns)) |
| 52 | + column_array = list(map(lambda source_colum_name: self.get_destination_column_name(source_colum_name), data_frame.columns)) |
| 53 | + column_list = ','.join(map(str, column_array)) |
55 | 54 |
|
56 | | - curs.copy_from(data, destination_table, sep=',', columns=column_array, null='') |
57 | | - self.logger.debug("Completed write to table {0}".format(destination_table)) |
| 55 | + sql = "COPY {0}({1}) FROM STDIN with csv".format(qualified_target_table, column_list) |
| 56 | + self.logger.debug("Writing to table using command {0}".format(sql)) |
| 57 | + curs.copy_expert(sql=sql, file=data) |
| 58 | + |
| 59 | + self.logger.debug("Completed write to table {0}".format(qualified_target_table)) |
58 | 60 |
|
59 | 61 | curs.connection.commit() |
60 | 62 | return |
61 | 63 |
|
| 64 | + def get_destination_column_name(self, source_column_name): |
| 65 | + for column in self.columns: |
| 66 | + if column['source_name'] == source_column_name: |
| 67 | + return column['destination']['name'] |
| 68 | + |
| 69 | + message = 'A source column with name {0} was not found in the column configuration'.format(source_column_name) |
| 70 | + raise ValueError(message) |
| 71 | + |
62 | 72 | def attach_column_transformers(self, data_frame): |
63 | 73 | self.logger.debug("Attaching column transformers") |
64 | 74 | for column in self.columns: |
|
0 commit comments