Skip to content

Commit e08956d

Browse files
committed
support merge mode
1 parent 77b1e56 commit e08956d

File tree

4 files changed

+260
-2
lines changed

4 files changed

+260
-2
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@ Snowflake output plugin for Embulk loads records to Snowflake.
2020
- **retry_limit**: max retry count for database operations (integer, default: 12). When intermediate table to create already created by another process, this plugin will retry with another table name to avoid collision.
2121
- **retry_wait**: initial retry wait time in milliseconds (integer, default: 1000 (1 second))
2222
- **max_retry_wait**: upper limit of retry wait, which will be doubled at every retry (integer, default: 1800000 (30 minutes))
23-
- **mode**: "insert", "insert_direct", "truncate_insert" or "replace". See below. (string, required)
23+
- **mode**: "insert", "insert_direct", "truncate_insert", "replace" or "merge". See below. (string, required)
24+
- **merge_keys**: key column names for merging records in merge mode (string array, required in merge mode if table doesn't have primary key)
25+
- **merge_rule**: list of column assignments for updating existing records used in merge mode, for example `"foo" = T."foo" + S."foo"` (`T` means target table and `S` means source table). (string array, default: always overwrites with new values)
2426
- **batch_size**: size of a single batch insert (integer, default: 16777216)
2527
- **default_timezone**: If input column type (embulk type) is timestamp, this plugin needs to format the timestamp into a SQL string. This default_timezone option is used to control the timezone. You can overwrite timezone for each columns using column_options option. (string, default: `UTC`)
2628
- **column_options**: advanced: a key-value pairs where key is a column name and value is options for the column.
@@ -49,6 +51,10 @@ Snowflake output plugin for Embulk loads records to Snowflake.
4951
* Behavior: This mode writes rows to an intermediate table first. If all those tasks run correctly, drops the target table and alters the name of the intermediate table into the target table name.
5052
* Transactional: Yes.
5153
* Resumable: No.
54+
* **merge**:
55+
* Behavior: This mode writes rows to some intermediate tables first. If all those tasks run correctly, runs MERGE INTO ... WHEN MATCHED THEN UPDATE ... WHEN NOT MATCHED THEN INSERT ... query. Namely, if merge keys of a record in the intermediate tables already exist in the target table, the target record is updated by the intermediate record, otherwise the intermediate record is inserted. If the target table doesn't exist, it is created automatically.
56+
* Transactional: Yes.
57+
* Resumable: No.
5258

5359
## Build
5460

src/main/java/org/embulk/output/SnowflakeOutputPlugin.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,12 @@ protected Features getFeatures(PluginTask task) {
6868
.setMaxTableNameLength(127)
6969
.setSupportedModes(
7070
new HashSet<>(
71-
Arrays.asList(Mode.INSERT, Mode.INSERT_DIRECT, Mode.TRUNCATE_INSERT, Mode.REPLACE)))
71+
Arrays.asList(
72+
Mode.INSERT,
73+
Mode.INSERT_DIRECT,
74+
Mode.TRUNCATE_INSERT,
75+
Mode.REPLACE,
76+
Mode.MERGE)))
7277
.setIgnoreMergeKeys(false);
7378
}
7479

src/main/java/org/embulk/output/snowflake/SnowflakeOutputConnection.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
import java.sql.Connection;
55
import java.sql.SQLException;
66
import java.sql.Statement;
7+
import java.util.List;
78
import net.snowflake.client.jdbc.SnowflakeConnection;
89
import org.embulk.output.jdbc.JdbcColumn;
910
import org.embulk.output.jdbc.JdbcOutputConnection;
11+
import org.embulk.output.jdbc.JdbcSchema;
12+
import org.embulk.output.jdbc.MergeConfig;
1013
import org.embulk.output.jdbc.TableIdentifier;
1114

1215
public class SnowflakeOutputConnection extends JdbcOutputConnection {
@@ -75,6 +78,71 @@ protected String buildColumnTypeName(JdbcColumn c) {
7578
}
7679
}
7780

81+
@Override
82+
protected String buildCollectMergeSql(
83+
List<TableIdentifier> fromTables,
84+
JdbcSchema schema,
85+
TableIdentifier toTable,
86+
MergeConfig mergeConfig)
87+
throws SQLException {
88+
StringBuilder sb = new StringBuilder();
89+
90+
sb.append("MERGE INTO ");
91+
sb.append(quoteTableIdentifier(toTable));
92+
sb.append(" AS T ");
93+
sb.append(" USING (");
94+
for (int i = 0; i < fromTables.size(); i++) {
95+
if (i != 0) {
96+
sb.append(" UNION ALL ");
97+
}
98+
sb.append(" SELECT ");
99+
sb.append(buildColumns(schema, ""));
100+
sb.append(" FROM ");
101+
sb.append(quoteTableIdentifier(fromTables.get(i)));
102+
}
103+
sb.append(") AS S");
104+
sb.append(" ON (");
105+
for (int i = 0; i < mergeConfig.getMergeKeys().size(); i++) {
106+
if (i != 0) {
107+
sb.append(" AND ");
108+
}
109+
String mergeKey = quoteIdentifierString(mergeConfig.getMergeKeys().get(i));
110+
sb.append("T.");
111+
sb.append(mergeKey);
112+
sb.append(" = S.");
113+
sb.append(mergeKey);
114+
}
115+
sb.append(")");
116+
sb.append(" WHEN MATCHED THEN");
117+
sb.append(" UPDATE SET ");
118+
if (mergeConfig.getMergeRule().isPresent()) {
119+
for (int i = 0; i < mergeConfig.getMergeRule().get().size(); i++) {
120+
if (i != 0) {
121+
sb.append(", ");
122+
}
123+
sb.append(mergeConfig.getMergeRule().get().get(i));
124+
}
125+
} else {
126+
for (int i = 0; i < schema.getCount(); i++) {
127+
if (i != 0) {
128+
sb.append(", ");
129+
}
130+
String column = quoteIdentifierString(schema.getColumnName(i));
131+
sb.append(column);
132+
sb.append(" = S.");
133+
sb.append(column);
134+
}
135+
}
136+
sb.append(" WHEN NOT MATCHED THEN");
137+
sb.append(" INSERT (");
138+
sb.append(buildColumns(schema, ""));
139+
sb.append(") VALUES (");
140+
sb.append(buildColumns(schema, "S."));
141+
sb.append(");");
142+
143+
return sb.toString();
144+
}
145+
78146
protected String buildCreateStageSQL(StageIdentifier stageIdentifier) {
79147
StringBuilder sb = new StringBuilder();
80148
sb.append("CREATE STAGE IF NOT EXISTS ");
@@ -137,4 +205,16 @@ protected String quoteInternalStoragePath(
137205
sb.append(".csv.gz");
138206
return sb.toString();
139207
}
208+
209+
private String buildColumns(JdbcSchema schema, String prefix) {
210+
StringBuilder sb = new StringBuilder();
211+
for (int i = 0; i < schema.getCount(); i++) {
212+
if (i != 0) {
213+
sb.append(", ");
214+
}
215+
sb.append(prefix);
216+
sb.append(quoteIdentifierString(schema.getColumnName(i)));
217+
}
218+
return sb.toString();
219+
}
140220
}

src/test/java/org/embulk/output/snowflake/TestSnowflakeOutputPlugin.java

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import java.sql.SQLException;
1414
import java.sql.Statement;
1515
import java.util.ArrayList;
16+
import java.util.Arrays;
1617
import java.util.List;
1718
import java.util.Optional;
1819
import java.util.Properties;
@@ -434,6 +435,7 @@ public void testExecuteMultiQuery() throws IOException {
434435
String targetTableFullName =
435436
String.format(
436437
"\"%s\".\"%s\".\"%s\"", TEST_SNOWFLAKE_DB, TEST_SNOWFLAKE_SCHEMA, targetTableName);
438+
System.out.println(targetTableFullName);
437439
runQuery(
438440
String.format("create table %s (c0 FLOAT, c1 STRING)", targetTableFullName),
439441
foreachResult(rs_ -> {}));
@@ -519,6 +521,171 @@ public void testRuntimeInsertStringTable() throws IOException {
519521
}));
520522
}
521523

524+
@Test
525+
public void testRuntimeMergeTable() throws IOException {
526+
File in = testFolder.newFile(SnowflakeUtils.randomString(8) + ".csv");
527+
List<String> lines =
528+
Stream.of("c0:double,c1:double,c2:string", "1,1,aaa", "1,2,bbb", "1,3,ccc")
529+
.collect(Collectors.toList());
530+
Files.write(in.toPath(), lines);
531+
532+
String targetTableName = generateTemporaryTableName();
533+
String targetTableFullName =
534+
String.format(
535+
"\"%s\".\"%s\".\"%s\"", TEST_SNOWFLAKE_DB, TEST_SNOWFLAKE_SCHEMA, targetTableName);
536+
runQuery(
537+
String.format(
538+
"create table %s (\"c0\" NUMBER, \"c1\" NUMBER, \"c2\" STRING)", targetTableFullName),
539+
foreachResult(rs_ -> {}));
540+
runQuery(
541+
String.format("insert into %s values (1, 1, 'ddd'), (1, 4, 'eee');", targetTableFullName),
542+
foreachResult(rs_ -> {}));
543+
544+
final ConfigSource config =
545+
CONFIG_MAPPER_FACTORY
546+
.newConfigSource()
547+
.set("type", "snowflake")
548+
.set("user", TEST_SNOWFLAKE_USER)
549+
.set("password", TEST_SNOWFLAKE_PASSWORD)
550+
.set("host", TEST_SNOWFLAKE_HOST)
551+
.set("database", TEST_SNOWFLAKE_DB)
552+
.set("warehouse", TEST_SNOWFLAKE_WAREHOUSE)
553+
.set("schema", TEST_SNOWFLAKE_SCHEMA)
554+
.set("mode", "merge")
555+
.set("merge_keys", new ArrayList<String>(Arrays.asList("c0", "c1")))
556+
.set("table", targetTableName);
557+
embulk.runOutput(config, in.toPath());
558+
559+
runQuery(
560+
String.format("select count(*) from %s;", targetTableFullName),
561+
foreachResult(
562+
rs -> {
563+
assertEquals(4, rs.getInt(1));
564+
}));
565+
List<String> results = new ArrayList();
566+
runQuery(
567+
"select \"c0\",\"c1\",\"c2\" from " + targetTableFullName + " order by 1, 2",
568+
foreachResult(
569+
rs -> {
570+
results.add(rs.getString(1) + "," + rs.getString(2) + "," + rs.getString(3));
571+
}));
572+
List<String> expected =
573+
Stream.of("1,1,aaa", "1,2,bbb", "1,3,ccc", "1,4,eee").collect(Collectors.toList());
574+
for (int i = 0; i < results.size(); i++) {
575+
assertEquals(expected.get(i), results.get(i));
576+
}
577+
}
578+
579+
@Test
580+
public void testRuntimeMergeTableWithMergeRule() throws IOException {
581+
File in = testFolder.newFile(SnowflakeUtils.randomString(8) + ".csv");
582+
List<String> lines =
583+
Stream.of("c0:double,c1:string", "0.0,aaa", "0.1,bbb", "1.2,ccc")
584+
.collect(Collectors.toList());
585+
Files.write(in.toPath(), lines);
586+
587+
String targetTableName = generateTemporaryTableName();
588+
String targetTableFullName =
589+
String.format(
590+
"\"%s\".\"%s\".\"%s\"", TEST_SNOWFLAKE_DB, TEST_SNOWFLAKE_SCHEMA, targetTableName);
591+
runQuery(
592+
String.format("create table %s (\"c0\" FLOAT, \"c1\" STRING)", targetTableFullName),
593+
foreachResult(rs_ -> {}));
594+
runQuery(
595+
String.format("insert into %s values (0.0, 'ddd'), (1.3, 'eee');", targetTableFullName),
596+
foreachResult(rs_ -> {}));
597+
598+
final ConfigSource config =
599+
CONFIG_MAPPER_FACTORY
600+
.newConfigSource()
601+
.set("type", "snowflake")
602+
.set("user", TEST_SNOWFLAKE_USER)
603+
.set("password", TEST_SNOWFLAKE_PASSWORD)
604+
.set("host", TEST_SNOWFLAKE_HOST)
605+
.set("database", TEST_SNOWFLAKE_DB)
606+
.set("warehouse", TEST_SNOWFLAKE_WAREHOUSE)
607+
.set("schema", TEST_SNOWFLAKE_SCHEMA)
608+
.set("mode", "merge")
609+
.set("merge_keys", new ArrayList<String>(Arrays.asList("c0")))
610+
.set(
611+
"merge_rule", new ArrayList<String>(Arrays.asList("\"c1\" = T.\"c1\" || S.\"c1\"")))
612+
.set("table", targetTableName);
613+
embulk.runOutput(config, in.toPath());
614+
615+
runQuery(
616+
String.format("select count(*) from %s;", targetTableFullName),
617+
foreachResult(
618+
rs -> {
619+
assertEquals(4, rs.getInt(1));
620+
}));
621+
List<String> results = new ArrayList();
622+
runQuery(
623+
"select \"c0\",\"c1\" from " + targetTableFullName + " order by 1",
624+
foreachResult(
625+
rs -> {
626+
results.add(rs.getString(1) + "," + rs.getString(2));
627+
}));
628+
List<String> expected =
629+
Stream.of("0.0,dddaaa", "0.1,bbb", "1.2,ccc", "1.3,eee").collect(Collectors.toList());
630+
for (int i = 0; i < results.size(); i++) {
631+
assertEquals(expected.get(i), results.get(i));
632+
}
633+
}
634+
635+
@Test
636+
public void testRuntimeMergeTableWithoutMergeKey() throws IOException {
637+
File in = testFolder.newFile(SnowflakeUtils.randomString(8) + ".csv");
638+
List<String> lines =
639+
Stream.of("c0:double,c1:string", "0.0,aaa", "0.1,bbb", "1.2,ccc")
640+
.collect(Collectors.toList());
641+
Files.write(in.toPath(), lines);
642+
643+
String targetTableName = generateTemporaryTableName();
644+
String targetTableFullName =
645+
String.format(
646+
"\"%s\".\"%s\".\"%s\"", TEST_SNOWFLAKE_DB, TEST_SNOWFLAKE_SCHEMA, targetTableName);
647+
runQuery(
648+
String.format(
649+
"create table %s (\"c0\" FLOAT PRIMARY KEY, \"c1\" STRING)", targetTableFullName),
650+
foreachResult(rs_ -> {}));
651+
runQuery(
652+
String.format("insert into %s values (0.0, 'ddd'), (1.3, 'eee');", targetTableFullName),
653+
foreachResult(rs_ -> {}));
654+
655+
final ConfigSource config =
656+
CONFIG_MAPPER_FACTORY
657+
.newConfigSource()
658+
.set("type", "snowflake")
659+
.set("user", TEST_SNOWFLAKE_USER)
660+
.set("password", TEST_SNOWFLAKE_PASSWORD)
661+
.set("host", TEST_SNOWFLAKE_HOST)
662+
.set("database", TEST_SNOWFLAKE_DB)
663+
.set("warehouse", TEST_SNOWFLAKE_WAREHOUSE)
664+
.set("schema", TEST_SNOWFLAKE_SCHEMA)
665+
.set("mode", "merge")
666+
.set("table", targetTableName);
667+
embulk.runOutput(config, in.toPath());
668+
669+
runQuery(
670+
String.format("select count(*) from %s;", targetTableFullName),
671+
foreachResult(
672+
rs -> {
673+
assertEquals(4, rs.getInt(1));
674+
}));
675+
List<String> results = new ArrayList();
676+
runQuery(
677+
"select \"c0\",\"c1\" from " + targetTableFullName + " order by 1",
678+
foreachResult(
679+
rs -> {
680+
results.add(rs.getString(1) + "," + rs.getString(2));
681+
}));
682+
List<String> expected =
683+
Stream.of("0.0,aaa", "0.1,bbb", "1.2,ccc", "1.3,eee").collect(Collectors.toList());
684+
for (int i = 0; i < results.size(); i++) {
685+
assertEquals(expected.get(i), results.get(i));
686+
}
687+
}
688+
522689
@Ignore(
523690
"This test takes so long time because it needs to create more than 1000 tables, so ignored...")
524691
@Test(expected = Test.None.class /* no exception expected */)

0 commit comments

Comments
 (0)