2525import io .cdap .plugin .db .TransactionIsolationLevel ;
2626import io .cdap .plugin .util .DBUtils ;
2727import org .apache .hadoop .conf .Configuration ;
28+ import org .apache .hadoop .mapreduce .JobContext ;
29+ import org .apache .hadoop .mapreduce .OutputCommitter ;
2830import org .apache .hadoop .mapreduce .RecordWriter ;
2931import org .apache .hadoop .mapreduce .TaskAttemptContext ;
3032import org .apache .hadoop .mapreduce .lib .db .DBConfiguration ;
4345import java .sql .Statement ;
4446import java .util .Map ;
4547import java .util .Properties ;
48+ import java .util .concurrent .ConcurrentHashMap ;
4649
4750import static io .cdap .plugin .db .ConnectionConfigAccessor .OPERATION_NAME ;
4851import static io .cdap .plugin .db .ConnectionConfigAccessor .RELATION_TABLE_KEY ;
5659public class ETLDBOutputFormat <K extends DBWritable , V > extends DBOutputFormat <K , V > {
5760 // Batch size before submitting a batch to the SQL engine. If set to 0, no batches will be submitted until commit.
5861 public static final String COMMIT_BATCH_SIZE = "io.cdap.plugin.db.output.commit.batch.size" ;
62+ public static final String STAGE_NAME = "io.cdap.plugin.db.output.stage_name" ;
5963 public static final int DEFAULT_COMMIT_BATCH_SIZE = 1000 ;
6064 private static final Character ESCAPE_CHAR = '"' ;
6165
66+ // Format for connection map's key will be "taskAttemptId_stageName"
67+ private static final String CONNECTION_MAP_KEY_FORMAT = "%s_%s" ;
68+
69+ // CONNECTION_MAP will be used to store connections with "taskAttemptId_stageName" as key and
70+ // connection object as value. Making it static to be accessed from multiple task attempts within same executor.
71+ private static final Map <String , Connection > CONNECTION_MAP = new ConcurrentHashMap <>();
6272 private static final Logger LOG = LoggerFactory .getLogger (ETLDBOutputFormat .class );
6373
6474 private Configuration conf ;
6575 private Driver driver ;
6676 private JDBCDriverShim driverShim ;
6777
78+ @ Override
79+ public OutputCommitter getOutputCommitter (TaskAttemptContext context )
80+ throws IOException , InterruptedException {
81+ return new OutputCommitter () {
82+ @ Override
83+ public void setupJob (JobContext jobContext ) throws IOException {
84+ // do nothing
85+ }
86+
87+ @ Override
88+ public void setupTask (TaskAttemptContext taskContext ) throws IOException {
89+ // do nothing
90+ }
91+
92+ @ Override
93+ public boolean needsTaskCommit (TaskAttemptContext taskContext ) throws IOException {
94+ return true ;
95+ }
96+
97+ @ Override
98+ public void commitTask (TaskAttemptContext taskContext ) throws IOException {
99+ conf = context .getConfiguration ();
100+ String stageName = conf .get (STAGE_NAME );
101+ String connectionId = getConnectionMapKeyFormat (context .getTaskAttemptID ().toString (), stageName );
102+ Connection connection ;
103+ if ((connection = CONNECTION_MAP .remove (connectionId )) != null ) {
104+ try {
105+ connection .commit ();
106+ } catch (SQLException e ) {
107+ try {
108+ connection .rollback ();
109+ } catch (SQLException ex ) {
110+ LOG .warn (StringUtils .stringifyException (ex ));
111+ }
112+ throw new IOException (e );
113+ } finally {
114+ try {
115+ connection .close ();
116+ LOG .debug ("Connection Closed after committing the task with taskAttemptId {}" , connectionId );
117+ } catch (SQLException ex ) {
118+ LOG .warn (StringUtils .stringifyException (ex ));
119+ }
120+ }
121+ }
122+ }
123+
124+ @ Override
125+ public void abortTask (TaskAttemptContext taskContext ) throws IOException {
126+ conf = context .getConfiguration ();
127+ String stageName = conf .get (STAGE_NAME );
128+ String connectionId = getConnectionMapKeyFormat (context .getTaskAttemptID ().toString (), stageName );
129+ Connection connection ;
130+ if ((connection = CONNECTION_MAP .remove (connectionId )) != null ) {
131+ try {
132+ connection .rollback ();
133+ } catch (SQLException e ) {
134+ throw new IOException (e );
135+ } finally {
136+ try {
137+ connection .close ();
138+ LOG .debug ("Connection Closed after rollback the task with taskAttemptId {}" , connectionId );
139+ } catch (SQLException ex ) {
140+ LOG .warn (StringUtils .stringifyException (ex ));
141+ }
142+ }
143+ }
144+ }
145+ };
146+ }
147+
68148 @ Override
69149 public RecordWriter <K , V > getRecordWriter (TaskAttemptContext context ) throws IOException {
70150 conf = context .getConfiguration ();
@@ -81,6 +161,11 @@ public RecordWriter<K, V> getRecordWriter(TaskAttemptContext context) throws IOE
81161
82162 try {
83163 Connection connection = getConnection (conf );
164+ String stageName = conf .get (STAGE_NAME );
165+ // If using multiple sinks, task attemptID can be same in that case, appending stage in the end for uniqueness.
166+ String connectionId = getConnectionMapKeyFormat (context .getTaskAttemptID ().toString (), stageName );
167+ CONNECTION_MAP .put (connectionId , connection );
168+ LOG .debug ("Connection Added to the map with connectionId : {}" , connectionId );
84169 PreparedStatement statement = connection .prepareStatement (constructQueryOnOperation (tableName , fieldNames ,
85170 operationName , listKeys ));
86171 return new DBRecordWriter (connection , statement ) {
@@ -98,23 +183,15 @@ public void close(TaskAttemptContext context) throws IOException {
98183 if (!emptyData ) {
99184 getStatement ().executeBatch ();
100185 }
101- getConnection ().commit ();
102186 } catch (SQLException e ) {
103- try {
104- getConnection ().rollback ();
105- } catch (SQLException ex ) {
106- LOG .warn (StringUtils .stringifyException (ex ));
107- }
108187 throw new IOException (e );
109188 } finally {
110189 try {
111190 getStatement ().close ();
112- getConnection ().close ();
113191 } catch (SQLException ex ) {
114192 throw new IOException (ex );
115193 }
116194 }
117-
118195 try {
119196 DriverManager .deregisterDriver (driverShim );
120197 } catch (SQLException e ) {
@@ -298,4 +375,8 @@ public String constructUpdateQuery(String table, String[] fieldNames, String[] l
298375 return query .toString ();
299376 }
300377 }
378+
379+ private String getConnectionMapKeyFormat (String taskAttemptId , String stageName ) {
380+ return String .format (CONNECTION_MAP_KEY_FORMAT , taskAttemptId , stageName );
381+ }
301382}
0 commit comments