From 74f3e5f7598acf0985f749891ad7448ad1a3b3f0 Mon Sep 17 00:00:00 2001 From: Anton Zorin Date: Tue, 8 Jul 2025 10:25:12 +0200 Subject: [PATCH] feat(core,isthmus): add dml support to SqlToSubstrait --- .../relation/RelCopyOnWriteVisitor.java | 40 +++++- .../io/substrait/isthmus/SqlToSubstrait.java | 6 +- .../isthmus/SubstraitRelNodeConverter.java | 121 ++++++++++++++++++ .../isthmus/SubstraitRelVisitor.java | 76 ++++++++++- .../substrait/isthmus/SubstraitToCalcite.java | 61 +++++++-- .../substrait/isthmus/DmlRoundtripTest.java | 33 +++++ .../io/substrait/isthmus/PlanTestBase.java | 5 +- 7 files changed, 327 insertions(+), 15 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java diff --git a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java index f5a3e8081..14a3b9fda 100644 --- a/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java +++ b/core/src/main/java/io/substrait/relation/RelCopyOnWriteVisitor.java @@ -215,7 +215,15 @@ public Optional visit(Expand expand, EmptyVisitationContext context) throws @Override public Optional visit(NamedWrite write, EmptyVisitationContext context) throws E { - throw new UnsupportedOperationException(); + + Optional input = write.getInput().accept(this, context); + + if (allEmpty(input)) { + return Optional.empty(); + } + + return Optional.of( + NamedWrite.builder().from(write).input(input.orElse(write.getInput())).build()); } @Override @@ -233,9 +241,37 @@ public Optional visit(ExtensionDdl ddl, EmptyVisitationContext context) thr throw new UnsupportedOperationException(); } + protected Optional visitTransformExpression( + NamedUpdate.TransformExpression transform, EmptyVisitationContext context) throws E { + return transform + .getTransformation() + .accept(getExpressionCopyOnWriteVisitor(), context) + .map( + expr -> + NamedUpdate.TransformExpression.builder() + .from(transform) + .transformation(expr) + .build()); + } + @Override public Optional visit(NamedUpdate update, EmptyVisitationContext context) throws E { - throw new UnsupportedOperationException(); + Optional condition = + update.getCondition().accept(getExpressionCopyOnWriteVisitor(), context); + + Optional> transformations = + transformList(update.getTransformations(), context, this::visitTransformExpression); + + if (allEmpty(condition, transformations)) { + return Optional.empty(); + } + + return Optional.of( + NamedUpdate.builder() + .from(update) + .condition(condition.orElse(update.getCondition())) + .transformations(transformations.orElse(update.getTransformations())) + .build()); } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 928df2c9d..dacbd55c7 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -10,6 +10,7 @@ import org.apache.calcite.plan.hep.HepProgram; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.rules.CoreRules; import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; @@ -33,7 +34,6 @@ public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlP return executeInner(sql, validator, catalogReader); } - // Package protected for testing List sqlToRelNode(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { SqlValidator validator = new SubstraitSqlValidator(catalogReader); @@ -86,7 +86,9 @@ SqlToRelConverter createSqlToRelConverter( static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) { RelRoot root = converter.convertQuery(parsed, true, true); { - var program = HepProgram.builder().build(); + // RelBuilder seems to implicitly use the rule below, + // need to add to avoid discrepancies in assertFullRoundTrip + var program = HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build(); HepPlanner hepPlanner = new HepPlanner(program); hepPlanner.setRoot(root.rel); root = root.withRel(hepPlanner.findBestExp()); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index fd7673100..bbbfc76db 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -11,6 +11,7 @@ import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.relation.AbstractRelVisitor; +import io.substrait.relation.AbstractUpdate; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; @@ -20,10 +21,14 @@ import io.substrait.relation.Join.JoinType; import io.substrait.relation.LocalFiles; import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedUpdate; +import io.substrait.relation.NamedWrite; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Set; import io.substrait.relation.Sort; +import io.substrait.relation.VirtualTableScan; +import io.substrait.type.NamedStruct; import io.substrait.util.VisitationContext; import java.util.ArrayList; import java.util.Collection; @@ -34,6 +39,7 @@ import java.util.stream.IntStream; import java.util.stream.Stream; import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelTraitDef; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelCollation; @@ -42,12 +48,15 @@ import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.logical.LogicalTableModify; import org.apache.calcite.rel.logical.LogicalValues; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexInputRef; +import org.apache.calcite.rex.RexLiteral; import org.apache.calcite.rex.RexNode; import org.apache.calcite.rex.RexSlot; import org.apache.calcite.sql.SqlAggFunction; @@ -473,6 +482,118 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField, Co return new RelFieldCollation(fieldIndex, fieldDirection, nullDirection); } + @Override + public RelNode visit(NamedUpdate update, Context context) { + relBuilder.scan(update.getNames()); + RexNode condition = update.getCondition().accept(expressionRexConverter, context); + relBuilder.filter(condition); + RelNode inputForModify = relBuilder.build(); + + NamedStruct tableSchema = update.getTableSchema(); + List fieldNames = tableSchema.names(); + + List updateColumnList = new ArrayList<>(); + List sourceExpressionList = new ArrayList<>(); + + for (AbstractUpdate.TransformExpression transform : update.getTransformations()) { + + updateColumnList.add(fieldNames.get(transform.getColumnTarget())); + sourceExpressionList.add( + transform.getTransformation().accept(expressionRexConverter, context)); + } + + assert relBuilder.getRelOptSchema() != null; + final RelOptTable table = relBuilder.getRelOptSchema().getTableForMember(update.getNames()); + + if (table == null) { + throw new IllegalStateException("Table not found in Calcite catalog: " + update.getNames()); + } + final Prepare.CatalogReader catalogReader = (Prepare.CatalogReader) table.getRelOptSchema(); + + assert catalogReader != null; + return LogicalTableModify.create( + table, + catalogReader, + inputForModify, + TableModify.Operation.UPDATE, + updateColumnList, + sourceExpressionList, + false); + } + + @Override + public RelNode visit(VirtualTableScan virtualTableScan, Context context) { + + final RelDataType typeInfoOnly = + typeConverter.toCalcite(typeFactory, virtualTableScan.getInitialSchema().struct()); + + final List correctFieldNames = virtualTableScan.getInitialSchema().names(); + + final List fieldTypes = + typeInfoOnly.getFieldList().stream() + .map(RelDataTypeField::getType) + .collect(Collectors.toList()); + + final RelDataType rowTypeWithNames = + typeFactory.createStructType(fieldTypes, correctFieldNames); + + final List> tuples = new ArrayList<>(); + for (final Expression.StructLiteral row : virtualTableScan.getRows()) { + final List rexRow = new ArrayList<>(); + for (final Expression.Literal literal : row.fields()) { + final RexNode rexNode = literal.accept(expressionRexConverter, context); + if (rexNode instanceof RexLiteral) { + final RexLiteral rexLiteral = (RexLiteral) rexNode; + rexRow.add(rexLiteral); + } else { + throw new UnsupportedOperationException( + "VirtualTableScan only supports literal values, found: " + + rexNode.getClass().getName()); + } + } + tuples.add(ImmutableList.copyOf(rexRow)); + } + + return LogicalValues.create( + relBuilder.getCluster(), rowTypeWithNames, ImmutableList.copyOf(tuples)); + } + + @Override + public RelNode visit(NamedWrite write, Context context) { + RelNode input = write.getInput().accept(this, context); + assert relBuilder.getRelOptSchema() != null; + final RelOptTable table = relBuilder.getRelOptSchema().getTableForMember(write.getNames()); + + if (table == null) { + throw new IllegalStateException("Table not found in Calcite catalog: " + write.getNames()); + } + + TableModify.Operation operation; + switch (write.getOperation()) { + case INSERT: + operation = TableModify.Operation.INSERT; + break; + case DELETE: + operation = TableModify.Operation.DELETE; + break; + default: + throw new UnsupportedOperationException( + "Write operation '" + + write.getOperation() + + "' is not supported by the NamedWrite visitor. " + + "Check if a more specific relation type (e.g., NamedUpdate) should be used."); + } + + return LogicalTableModify.create( + table, + (Prepare.CatalogReader) relBuilder.getRelOptSchema(), + input, + operation, + null, + null, + false); + } + @Override public RelNode visitFallback(Rel rel, Context context) throws RuntimeException { throw new UnsupportedOperationException( diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index c8ff01ee3..220b9fb2e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -11,6 +11,7 @@ import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.plan.Plan; +import io.substrait.relation.AbstractWriteRel; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; @@ -18,6 +19,8 @@ import io.substrait.relation.Filter; import io.substrait.relation.Join; import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedUpdate; +import io.substrait.relation.NamedWrite; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Set; @@ -37,6 +40,7 @@ import org.apache.calcite.rel.RelRoot; import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.JoinRelType; +import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; @@ -374,8 +378,76 @@ public Rel visit(org.apache.calcite.rel.core.Exchange exchange) { } @Override - public Rel visit(org.apache.calcite.rel.core.TableModify modify) { - return super.visit(modify); + public Rel visit(TableModify modify) { + switch (modify.getOperation()) { + case INSERT: + case DELETE: + { + final Rel input = apply(modify.getInput()); + final AbstractWriteRel.WriteOp op = + modify.getOperation() == TableModify.Operation.INSERT + ? AbstractWriteRel.WriteOp.INSERT + : AbstractWriteRel.WriteOp.DELETE; + + assert modify.getTable() != null; + return NamedWrite.builder() + .input(input) + .tableSchema(typeConverter.toNamedStruct(modify.getTable().getRowType())) + .operation(op) + .createMode(AbstractWriteRel.CreateMode.UNSPECIFIED) + .outputMode(AbstractWriteRel.OutputMode.MODIFIED_RECORDS) + .names(modify.getTable().getQualifiedName()) + .build(); + } + + case UPDATE: + { + assert modify.getTable() != null; + + Expression condition; + if (modify.getInput() instanceof org.apache.calcite.rel.core.Filter) { + org.apache.calcite.rel.core.Filter filter = + (org.apache.calcite.rel.core.Filter) modify.getInput(); + condition = toExpression(filter.getCondition()); + } else { + condition = Expression.BoolLiteral.builder().nullable(false).value(true).build(); + } + + List updateColumnNames = modify.getUpdateColumnList(); + List sourceExpressions = modify.getSourceExpressionList(); + List allTableColumnNames = modify.getTable().getRowType().getFieldNames(); + List transformations = new ArrayList<>(); + + for (int i = 0; i < updateColumnNames.size(); i++) { + String colName = updateColumnNames.get(i); + RexNode rexExpr = sourceExpressions.get(i); + + int columnIndex = allTableColumnNames.indexOf(colName); + if (columnIndex == -1) { + throw new IllegalStateException( + "Updated column '" + colName + "' not found in table schema."); + } + + Expression substraitExpr = toExpression(rexExpr); + + transformations.add( + NamedUpdate.TransformExpression.builder() + .columnTarget(columnIndex) + .transformation(substraitExpr) + .build()); + } + + return NamedUpdate.builder() + .tableSchema(typeConverter.toNamedStruct(modify.getTable().getRowType())) + .names(modify.getTable().getQualifiedName()) + .condition(condition) + .transformations(transformations) + .build(); + } + + default: + return super.visit(modify); + } } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index 375f3fb9d..d584c4e46 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -14,8 +14,10 @@ import java.util.Map; import java.util.Optional; import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; @@ -35,19 +37,36 @@ public class SubstraitToCalcite { protected final SimpleExtension.ExtensionCollection extensions; protected final RelDataTypeFactory typeFactory; protected final TypeConverter typeConverter; + protected final Prepare.CatalogReader catalogReader; public SubstraitToCalcite( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { - this(extensions, typeFactory, TypeConverter.DEFAULT); + this(extensions, typeFactory, TypeConverter.DEFAULT, null); + } + + public SubstraitToCalcite( + SimpleExtension.ExtensionCollection extensions, + RelDataTypeFactory typeFactory, + Prepare.CatalogReader catalogReader) { + this(extensions, typeFactory, TypeConverter.DEFAULT, catalogReader); } public SubstraitToCalcite( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory, TypeConverter typeConverter) { + this(extensions, typeFactory, typeConverter, null); + } + + public SubstraitToCalcite( + SimpleExtension.ExtensionCollection extensions, + RelDataTypeFactory typeFactory, + TypeConverter typeConverter, + Prepare.CatalogReader catalogReader) { this.extensions = extensions; this.typeFactory = typeFactory; this.typeConverter = typeConverter; + this.catalogReader = catalogReader; } /** @@ -89,9 +108,15 @@ protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder r * @return {@link RelNode} */ public RelNode convert(Rel rel) { - CalciteSchema rootSchema = toSchema(rel); - RelBuilder relBuilder = createRelBuilder(rootSchema); + RelBuilder relBuilder; + if (catalogReader != null) { + relBuilder = createRelBuilder(catalogReader.getRootSchema()); + } else { + CalciteSchema rootSchema = toSchema(rel); + relBuilder = createRelBuilder(rootSchema); + } SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); + return rel.accept(converter, Context.newContext()); } @@ -110,13 +135,33 @@ public RelNode convert(Rel rel) { * @return {@link RelRoot} */ public RelRoot convert(Plan.Root root) { - RelNode input = convert(root.getInput()); - RelDataType inputRowType = input.getRowType(); + RelNode convertedNode = convert(root.getInput()); + + if (convertedNode instanceof TableModify) { + final TableModify tableModify = (TableModify) convertedNode; + final RelDataType tableRowType = tableModify.getTable().getRowType(); + final SqlKind kind; + + switch (tableModify.getOperation()) { + case INSERT: + kind = SqlKind.INSERT; + break; + case UPDATE: + kind = SqlKind.UPDATE; + break; + case DELETE: + kind = SqlKind.DELETE; + break; + default: + throw new IllegalArgumentException( + "Unsupported table modify operation: " + tableModify.getOperation()); + } + return RelRoot.of(tableModify, tableRowType, kind); + } + RelDataType inputRowType = convertedNode.getRowType(); RelDataType newRowType = renameFields(inputRowType, root.getNames(), 0).right; - RelRoot calciteRoot = RelRoot.of(input, newRowType, SqlKind.SELECT); - - return calciteRoot; + return RelRoot.of(convertedNode, newRowType, SqlKind.SELECT); } /** diff --git a/isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java new file mode 100644 index 000000000..e1408f658 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java @@ -0,0 +1,33 @@ +package io.substrait.isthmus; + +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +public class DmlRoundtripTest extends PlanTestBase { + + final Prepare.CatalogReader catalogReader = + SubstraitCreateStatementParser.processCreateStatementsToCatalog( + "create table src1 (intcol int, charcol varchar(10))", + "create table src2 (intcol int, charcol varchar(10))"); + + public DmlRoundtripTest() throws SqlParseException {} + + @Test + void testDelete() throws SqlParseException { + assertFullRoundTrip("delete from src1 where intcol=10", catalogReader); + } + + @Test + void testUpdate() throws SqlParseException { + assertFullRoundTrip("update src1 set intcol=10 where charcol='a'", catalogReader); + } + + @Test + void testInsert() throws SqlParseException { + assertFullRoundTrip("insert into src1 (intcol, charcol) values (1,'a'); ", catalogReader); + assertFullRoundTrip( + "insert into src1 (intcol, charcol) select intcol,charcol from src2;", catalogReader); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 644123675..f85d0b67b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -196,7 +196,10 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite 2 - RelRoot calcite2 = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + final SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(extensions, typeFactory, catalogReader); + + RelRoot calcite2 = substraitToCalcite.convert(pojo2); // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to // do so assertNotNull(calcite2);