diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java b/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java new file mode 100644 index 000000000..0a1f475c7 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlKindFromRel.java @@ -0,0 +1,240 @@ +package io.substrait.isthmus; + +import io.substrait.relation.Aggregate; +import io.substrait.relation.ConsistentPartitionWindow; +import io.substrait.relation.Cross; +import io.substrait.relation.EmptyScan; +import io.substrait.relation.Expand; +import io.substrait.relation.ExtensionDdl; +import io.substrait.relation.ExtensionLeaf; +import io.substrait.relation.ExtensionMulti; +import io.substrait.relation.ExtensionSingle; +import io.substrait.relation.ExtensionTable; +import io.substrait.relation.ExtensionWrite; +import io.substrait.relation.Fetch; +import io.substrait.relation.Filter; +import io.substrait.relation.Join; +import io.substrait.relation.LocalFiles; +import io.substrait.relation.NamedDdl; +import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedUpdate; +import io.substrait.relation.NamedWrite; +import io.substrait.relation.Project; +import io.substrait.relation.RelVisitor; +import io.substrait.relation.Set; +import io.substrait.relation.Sort; +import io.substrait.relation.VirtualTableScan; +import io.substrait.relation.physical.HashJoin; +import io.substrait.relation.physical.MergeJoin; +import io.substrait.relation.physical.NestedLoopJoin; +import io.substrait.util.EmptyVisitationContext; +import org.apache.calcite.sql.SqlKind; + +/** + * A visitor to infer the general SqlKind from the root of a Substrait Rel tree. Note: This infers + * the general operation type, as the original SQL syntax is not preserved in the Substrait plan. + */ +public class SqlKindFromRel + implements RelVisitor { + + // Most common query operations map to SELECT. + private static final SqlKind QUERY_KIND = SqlKind.SELECT; + + @Override + public SqlKind visit(Aggregate aggregate, EmptyVisitationContext context) + throws RuntimeException { + + return QUERY_KIND; + } + + @Override + public SqlKind visit(EmptyScan emptyScan, EmptyVisitationContext context) + throws RuntimeException { + // An empty scan is typically the result of a query that returns no rows. + return QUERY_KIND; + } + + @Override + public SqlKind visit(Fetch fetch, EmptyVisitationContext context) throws RuntimeException { + return QUERY_KIND; + } + + @Override + public SqlKind visit(Filter filter, EmptyVisitationContext context) throws RuntimeException { + return QUERY_KIND; + } + + @Override + public SqlKind visit(Join join, EmptyVisitationContext context) throws RuntimeException { + return SqlKind.JOIN; + } + + @Override + public SqlKind visit(Set set, EmptyVisitationContext context) throws RuntimeException { + switch (set.getSetOp()) { + case UNION_ALL: + case UNION_DISTINCT: + return SqlKind.UNION; + case INTERSECTION_PRIMARY: + case INTERSECTION_MULTISET: + case INTERSECTION_MULTISET_ALL: + return SqlKind.INTERSECT; + case MINUS_PRIMARY: + case MINUS_PRIMARY_ALL: + case MINUS_MULTISET: + return SqlKind.EXCEPT; + case UNKNOWN: + default: + return SqlKind.OTHER; + } + } + + @Override + public SqlKind visit(NamedScan namedScan, EmptyVisitationContext context) + throws RuntimeException { + return QUERY_KIND; + } + + @Override + public SqlKind visit(LocalFiles localFiles, EmptyVisitationContext context) + throws RuntimeException { + return QUERY_KIND; + } + + @Override + public SqlKind visit(Project project, EmptyVisitationContext context) throws RuntimeException { + return QUERY_KIND; + } + + @Override + public SqlKind visit(Expand expand, EmptyVisitationContext context) throws RuntimeException { + return QUERY_KIND; + } + + @Override + public SqlKind visit(Sort sort, EmptyVisitationContext context) throws RuntimeException { + return SqlKind.ORDER_BY; + } + + @Override + public SqlKind visit(Cross cross, EmptyVisitationContext context) throws RuntimeException { + return SqlKind.JOIN; + } + + @Override + public SqlKind visit(VirtualTableScan virtualTableScan, EmptyVisitationContext context) + throws RuntimeException { + // A virtual table scan corresponds to a VALUES clause. + return SqlKind.VALUES; + } + + @Override + public SqlKind visit(ExtensionLeaf extensionLeaf, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER; + } + + @Override + public SqlKind visit(ExtensionSingle extensionSingle, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER; + } + + @Override + public SqlKind visit(ExtensionMulti extensionMulti, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER; + } + + @Override + public SqlKind visit(ExtensionTable extensionTable, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER; + } + + @Override + public SqlKind visit(HashJoin hashJoin, EmptyVisitationContext context) throws RuntimeException { + return SqlKind.JOIN; + } + + @Override + public SqlKind visit(MergeJoin mergeJoin, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.JOIN; + } + + @Override + public SqlKind visit(NestedLoopJoin nestedLoopJoin, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.JOIN; + } + + @Override + public SqlKind visit( + ConsistentPartitionWindow consistentPartitionWindow, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OVER; + } + + @Override + public SqlKind visit(NamedWrite write, EmptyVisitationContext context) throws RuntimeException { + switch (write.getOperation()) { + case INSERT: + return SqlKind.INSERT; + case DELETE: + return SqlKind.DELETE; + case UPDATE: + return SqlKind.UPDATE; + case CTAS: + return SqlKind.CREATE_TABLE; + default: + return SqlKind.OTHER; + } + } + + @Override + public SqlKind visit(ExtensionWrite write, EmptyVisitationContext context) + throws RuntimeException { + return SqlKind.OTHER_DDL; + } + + @Override + public SqlKind visit(NamedDdl ddl, EmptyVisitationContext context) throws RuntimeException { + switch (ddl.getOperation()) { + case CREATE: + case CREATE_OR_REPLACE: + if (ddl.getObject() == NamedDdl.DdlObject.TABLE) { + return SqlKind.CREATE_TABLE; + } else if (ddl.getObject() == NamedDdl.DdlObject.VIEW) { + return SqlKind.CREATE_VIEW; + } + break; + case DROP: + case DROP_IF_EXIST: + if (ddl.getObject() == NamedDdl.DdlObject.TABLE) { + return SqlKind.DROP_TABLE; + } else if (ddl.getObject() == NamedDdl.DdlObject.VIEW) { + return SqlKind.DROP_VIEW; + } + break; + case ALTER: + if (ddl.getObject() == NamedDdl.DdlObject.TABLE) { + return SqlKind.ALTER_TABLE; + } else if (ddl.getObject() == NamedDdl.DdlObject.VIEW) { + return SqlKind.ALTER_VIEW; + } + break; + } + return SqlKind.OTHER_DDL; + } + + @Override + public SqlKind visit(ExtensionDdl ddl, EmptyVisitationContext context) throws RuntimeException { + return SqlKind.OTHER_DDL; + } + + @Override + public SqlKind visit(NamedUpdate update, EmptyVisitationContext context) throws RuntimeException { + return SqlKind.UPDATE; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 801110c5f..aa8281fe3 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -10,12 +10,16 @@ import io.substrait.expression.Expression.SortDirection; import io.substrait.expression.FunctionArg; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.rel.CreateTable; +import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.relation.AbstractDdlRel; import io.substrait.relation.AbstractRelVisitor; import io.substrait.relation.AbstractUpdate; +import io.substrait.relation.AbstractWriteRel; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; @@ -24,6 +28,7 @@ import io.substrait.relation.Join; import io.substrait.relation.Join.JoinType; import io.substrait.relation.LocalFiles; +import io.substrait.relation.NamedDdl; import io.substrait.relation.NamedScan; import io.substrait.relation.NamedUpdate; import io.substrait.relation.NamedWrite; @@ -548,8 +553,29 @@ public RelNode visit(NamedUpdate update, Context context) { } @Override - public RelNode visit(VirtualTableScan virtualTableScan, Context context) { + public RelNode visit(NamedDdl namedDdl, Context context) { + if (namedDdl.getOperation() != AbstractDdlRel.DdlOp.CREATE + || namedDdl.getObject() != AbstractDdlRel.DdlObject.VIEW) { + throw new UnsupportedOperationException( + String.format( + "Can only handle NamedDdl with (%s, %s), given (%s, %s)", + AbstractDdlRel.DdlOp.CREATE, + AbstractDdlRel.DdlObject.VIEW, + namedDdl.getOperation(), + namedDdl.getObject())); + } + + if (namedDdl.getViewDefinition().isEmpty()) { + throw new IllegalArgumentException("NamedDdl view definition must be set"); + } + Rel viewDefinition = namedDdl.getViewDefinition().get(); + RelNode relNode = viewDefinition.accept(this, context); + return new CreateView(namedDdl.getNames(), relNode); + } + + @Override + public RelNode visit(VirtualTableScan virtualTableScan, Context context) { final RelDataType typeInfoOnly = typeConverter.toCalcite(typeFactory, virtualTableScan.getInitialSchema().struct()); @@ -584,15 +610,29 @@ public RelNode visit(VirtualTableScan virtualTableScan, Context context) { relBuilder.getCluster(), rowTypeWithNames, ImmutableList.copyOf(tuples)); } + private RelNode handleCreateTableAs(NamedWrite namedWrite, Context context) { + if (namedWrite.getCreateMode() != AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS + || namedWrite.getOutputMode() != AbstractWriteRel.OutputMode.NO_OUTPUT) { + throw new UnsupportedOperationException( + String.format( + "Can only handle CTAS NamedWrite with (%s, %s), given (%s, %s)", + AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS, + AbstractWriteRel.OutputMode.NO_OUTPUT, + namedWrite.getCreateMode(), + namedWrite.getOutputMode())); + } + + Rel input = namedWrite.getInput(); + RelNode relNode = input.accept(this, context); + return new CreateTable(namedWrite.getNames(), relNode); + } + @Override public RelNode visit(NamedWrite write, Context context) { RelNode input = write.getInput().accept(this, context); assert relBuilder.getRelOptSchema() != null; - final RelOptTable table = relBuilder.getRelOptSchema().getTableForMember(write.getNames()); - - if (table == null) { - throw new IllegalStateException("Table not found in Calcite catalog: " + write.getNames()); - } + final RelOptTable targetTable = + relBuilder.getRelOptSchema().getTableForMember(write.getNames()); TableModify.Operation operation; switch (write.getOperation()) { @@ -602,16 +642,20 @@ public RelNode visit(NamedWrite write, Context context) { case DELETE: operation = TableModify.Operation.DELETE; break; + case CTAS: + return handleCreateTableAs(write, context); default: throw new UnsupportedOperationException( - "Write operation '" - + write.getOperation() - + "' is not supported by the NamedWrite visitor. " - + "Check if a more specific relation type (e.g., NamedUpdate) should be used."); + String.format( + "NamedWrite with WriteOp %s cannot be converted to a Calcite RelNode. Consider using a more specific Rel (e.g NamedUpdate)", + write.getOperation())); } + // checked by validation + assert targetTable != null; + return LogicalTableModify.create( - table, + targetTable, (Prepare.CatalogReader) relBuilder.getRelOptSchema(), input, operation, diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index cc41e84a0..79449ee9e 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -6,6 +6,8 @@ import io.substrait.expression.ExpressionCreator; import io.substrait.expression.FieldReference; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.rel.CreateTable; +import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.CallConverters; import io.substrait.isthmus.expression.LiteralConverter; @@ -13,6 +15,7 @@ import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.plan.Plan; +import io.substrait.relation.AbstractDdlRel; import io.substrait.relation.AbstractWriteRel; import io.substrait.relation.Aggregate; import io.substrait.relation.Aggregate.Grouping; @@ -25,6 +28,7 @@ import io.substrait.relation.ImmutableMeasure.Builder; import io.substrait.relation.Join; import io.substrait.relation.Join.JoinType; +import io.substrait.relation.NamedDdl; import io.substrait.relation.NamedScan; import io.substrait.relation.NamedUpdate; import io.substrait.relation.NamedWrite; @@ -49,6 +53,7 @@ import org.apache.calcite.rel.core.AggregateCall; import org.apache.calcite.rel.core.JoinRelType; import org.apache.calcite.rel.core.TableModify; +import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; @@ -445,8 +450,50 @@ public Rel visit(TableModify modify) { } } + private NamedStruct getSchema(final RelNode queryRelRoot) { + final RelDataType rowType = queryRelRoot.getRowType(); + return typeConverter.toNamedStruct(rowType); + } + + public Rel handleCreateTable(CreateTable createTable) { + RelNode input = createTable.getInput(); + Rel inputRel = apply(input); + NamedStruct schema = getSchema(input); + return NamedWrite.builder() + .input(inputRel) + .tableSchema(schema) + .operation(AbstractWriteRel.WriteOp.CTAS) + .createMode(AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS) + .outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT) + .names(createTable.getTableName()) + .build(); + } + + public Rel handleCreateView(CreateView createView) { + RelNode input = createView.getInput(); + Rel inputRel = apply(input); + + final Expression.StructLiteral defaults = ExpressionCreator.struct(false); + + return NamedDdl.builder() + .viewDefinition(inputRel) + .tableSchema(getSchema(input)) + .tableDefaults(defaults) + .operation(AbstractDdlRel.DdlOp.CREATE) + .object(AbstractDdlRel.DdlObject.VIEW) + .names(createView.getViewName()) + .build(); + } + @Override public Rel visitOther(RelNode other) { + if (other instanceof CreateTable) { + return handleCreateTable((CreateTable) other); + + } else if (other instanceof CreateView) { + return handleCreateView((CreateView) other); + } + throw new UnsupportedOperationException("Unable to handle node: " + other); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index f0d3e2137..8dcfbf9e0 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -158,10 +158,11 @@ public RelRoot convert(Plan.Root root) { } return RelRoot.of(tableModify, tableRowType, kind); } - + SqlKindFromRel sqlKindFromRel = new SqlKindFromRel(); + SqlKind kind = root.getInput().accept(sqlKindFromRel, EmptyVisitationContext.INSTANCE); RelDataType inputRowType = convertedNode.getRowType(); RelDataType newRowType = renameFields(inputRowType, root.getNames(), 0).right; - return RelRoot.of(convertedNode, newRowType, SqlKind.SELECT); + return RelRoot.of(convertedNode, newRowType, kind); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateTable.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateTable.java new file mode 100644 index 000000000..66a030b8b --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateTable.java @@ -0,0 +1,43 @@ +package io.substrait.isthmus.calcite.rel; + +import java.util.List; +import org.apache.calcite.rel.AbstractRelNode; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.type.RelDataType; + +public class CreateTable extends AbstractRelNode { + + private final List tableName; + private final RelNode input; + + public CreateTable(List tableName, RelNode input) { + super(input.getCluster(), input.getTraitSet()); + + this.tableName = tableName; + this.input = input; + } + + @Override + protected RelDataType deriveRowType() { + return input.getRowType(); + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + return super.explainTerms(pw).input("input", getInput()).item("tableName", getTableName()); + } + + @Override + public List getInputs() { + return List.of(input); + } + + public List getTableName() { + return tableName; + } + + public RelNode getInput() { + return input; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateView.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateView.java new file mode 100644 index 000000000..ef1e228cb --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/CreateView.java @@ -0,0 +1,41 @@ +package io.substrait.isthmus.calcite.rel; + +import java.util.List; +import org.apache.calcite.rel.AbstractRelNode; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelWriter; +import org.apache.calcite.rel.type.RelDataType; + +public class CreateView extends AbstractRelNode { + private final List viewName; + private final RelNode input; + + public CreateView(List viewName, RelNode input) { + super(input.getCluster(), input.getTraitSet()); + this.viewName = viewName; + this.input = input; + } + + @Override + protected RelDataType deriveRowType() { + return input.getRowType(); + } + + @Override + public RelWriter explainTerms(RelWriter pw) { + return super.explainTerms(pw).input("input", getInput()).item("viewName", getViewName()); + } + + @Override + public List getInputs() { + return List.of(input); + } + + public List getViewName() { + return viewName; + } + + public RelNode getInput() { + return input; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/DdlSqlToRelConverter.java b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/DdlSqlToRelConverter.java new file mode 100644 index 000000000..6a237b366 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/calcite/rel/DdlSqlToRelConverter.java @@ -0,0 +1,65 @@ +package io.substrait.isthmus.calcite.rel; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.ddl.SqlCreateTable; +import org.apache.calcite.sql.ddl.SqlCreateView; +import org.apache.calcite.sql.util.SqlBasicVisitor; +import org.apache.calcite.sql2rel.SqlToRelConverter; + +public class DdlSqlToRelConverter extends SqlBasicVisitor { + + protected final Map, Function> ddlHandlers = + new ConcurrentHashMap<>(); + private final SqlToRelConverter converter; + + private Function findDdlHandler(final SqlCall call) { + Class currentClass = call.getClass(); + while (SqlCall.class.isAssignableFrom(currentClass)) { + final Function found = ddlHandlers.get(currentClass); + if (found != null) { + return found; + } + currentClass = currentClass.getSuperclass(); + } + return null; + } + + public DdlSqlToRelConverter(SqlToRelConverter converter) { + this.converter = converter; + + ddlHandlers.put(SqlCreateTable.class, sqlCall -> handleCreateTable((SqlCreateTable) sqlCall)); + ddlHandlers.put(SqlCreateView.class, sqlCall -> handleCreateView((SqlCreateView) sqlCall)); + } + + @Override + public RelRoot visit(SqlCall sqlCall) { + Function ddlHandler = findDdlHandler(sqlCall); + if (ddlHandler != null) { + return ddlHandler.apply(sqlCall); + } + return handleNonDdl(sqlCall); + } + + protected RelRoot handleNonDdl(final SqlNode sqlNode) { + return converter.convertQuery(sqlNode, true, true); + } + + protected RelRoot handleCreateTable(final SqlCreateTable sqlCreateTable) { + if (sqlCreateTable.query == null) { + throw new IllegalArgumentException("Only create table as select statements are supported"); + } + final RelNode input = converter.convertQuery(sqlCreateTable.query, true, true).rel; + return RelRoot.of(new CreateTable(sqlCreateTable.name.names, input), sqlCreateTable.getKind()); + } + + protected RelRoot handleCreateView(final SqlCreateView sqlCreateView) { + final RelNode input = converter.convertQuery(sqlCreateView.query, true, true).rel; + return RelRoot.of(new CreateTable(sqlCreateView.name.names, input), sqlCreateView.getKind()); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java index 4376a3eba..0f7891b5b 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlStatementParser.java @@ -5,6 +5,7 @@ import org.apache.calcite.sql.SqlNode; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; import org.apache.calcite.sql.validate.SqlConformanceEnum; /** @@ -18,7 +19,8 @@ public class SubstraitSqlStatementParser { // TODO: switch to Casing.UNCHANGED .withUnquotedCasing(Casing.TO_UPPER) // use LENIENT conformance to allow for parsing a wide variety of dialects - .withConformance(SqlConformanceEnum.LENIENT); + .withConformance(SqlConformanceEnum.LENIENT) + .withParserFactory(SqlDdlParserImpl.FACTORY); /** * Parse one or more SQL statements to a list of {@link SqlNode}s. diff --git a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java index 499f33f1d..ed34c57e2 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/sql/SubstraitSqlToCalcite.java @@ -1,6 +1,7 @@ package io.substrait.isthmus.sql; import io.substrait.isthmus.SubstraitTypeSystem; +import io.substrait.isthmus.calcite.rel.DdlSqlToRelConverter; import java.util.List; import java.util.stream.Collectors; import org.apache.calcite.jdbc.JavaTypeFactoryImpl; @@ -132,10 +133,17 @@ static List convert( boolean needsValidation = true; // query is the root of the tree boolean top = true; + DdlSqlToRelConverter ddlSqlToRelConverter = new DdlSqlToRelConverter(converter); return sqlNodes.stream() .map( - sqlNode -> - removeRedundantProjects(converter.convertQuery(sqlNode, needsValidation, top))) + sqlNode -> { + RelRoot relRoot = sqlNode.accept(ddlSqlToRelConverter); + if (relRoot == null) { + relRoot = + removeRedundantProjects(converter.convertQuery(sqlNode, needsValidation, top)); + } + return relRoot; + }) .collect(Collectors.toList()); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/DdlRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/DdlRoundtripTest.java new file mode 100644 index 000000000..1f14d9459 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/DdlRoundtripTest.java @@ -0,0 +1,29 @@ +package io.substrait.isthmus; + +import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +class DdlRoundtripTest extends PlanTestBase { + final Prepare.CatalogReader catalogReader = + SubstraitCreateStatementParser.processCreateStatementsToCatalog( + "create table src1 (intcol int, charcol varchar(10))", + "create table src2 (intcol int, charcol varchar(10))"); + + public DdlRoundtripTest() throws SqlParseException { + super(); + } + + @Test + void testCreateTable() throws Exception { + String sql = "create table dst1 as select * from src1"; + assertFullRoundTripWithIdentityProjectionWorkaround(sql, catalogReader); + } + + @Test + void testCreateView() throws Exception { + String sql = "create view dst1 as select * from src1"; + assertFullRoundTripWithIdentityProjectionWorkaround(sql, catalogReader); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java index e1408f658..71db7d28f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/DmlRoundtripTest.java @@ -16,18 +16,21 @@ public DmlRoundtripTest() throws SqlParseException {} @Test void testDelete() throws SqlParseException { - assertFullRoundTrip("delete from src1 where intcol=10", catalogReader); + assertFullRoundTripWithIdentityProjectionWorkaround( + "delete from src1 where intcol=10", catalogReader); } @Test void testUpdate() throws SqlParseException { - assertFullRoundTrip("update src1 set intcol=10 where charcol='a'", catalogReader); + assertFullRoundTripWithIdentityProjectionWorkaround( + "update src1 set intcol=10 where charcol='a'", catalogReader); } @Test void testInsert() throws SqlParseException { - assertFullRoundTrip("insert into src1 (intcol, charcol) values (1,'a'); ", catalogReader); - assertFullRoundTrip( + assertFullRoundTripWithIdentityProjectionWorkaround( + "insert into src1 (intcol, charcol) values (1,'a'); ", catalogReader); + assertFullRoundTripWithIdentityProjectionWorkaround( "insert into src1 (intcol, charcol) select intcol,charcol from src2;", catalogReader); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index d462ea05d..d4671736d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -27,6 +27,7 @@ void preserveNamesFromSql() throws Exception { org.apache.calcite.rel.RelRoot calciteRelRoot1 = SubstraitSqlToCalcite.convertQuery(query, catalogReader); + assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames()); io.substrait.plan.Plan.Root substraitRelRoot = diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 56173f1ea..aba04f047 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -208,6 +208,75 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo assertEquals(root1, root3); } + /** + * Verifies that the given query can be converted from its Calcite representation to Substrait + * proto and back. Due to the various ways in which Calcite plans are produced, some plans contain + * identity projections and others do not. Fully removing this behaviour is quite tricky. As a + * workaround, this method prepares the plan such that identity projections are removed. + * + *

In the long-term, we should work to remove this test method. + * + *

Preparation: + * SQL -> Calcite 0 -> Substrait POJO 0 -> Substrait Proto 0 -> Substrait POJO 1 -> Calcite 1 + * this code also checks that: Main cycle: + * + *

    + *
  • Substrait POJO 0 == Substrait POJO 1 + *
+ * + * Calcite 1 -> Substrait POJO 2 -> Substrait Proto 2 -> Substrait POJO 3 -> Calcite 2 -> + * Substrait POJO 4 + * + *
    + *
  • Substrait POJO 2 == Substrait POJO 4 + *
+ */ + protected void assertFullRoundTripWithIdentityProjectionWorkaround( + String sqlQuery, Prepare.CatalogReader catalogReader) throws SqlParseException { + ExtensionCollector extensionCollector = new ExtensionCollector(); + + // Preparation + // SQL -> Calcite 0 + RelRoot calcite0 = SubstraitSqlToCalcite.convertQuery(sqlQuery, catalogReader); + + // Calcite 0 -> Substrait POJO 0 + Plan.Root root0 = SubstraitRelVisitor.convert(calcite0, extensions); + + // Substrait POJO 0 -> Substrait Proto 0 + io.substrait.proto.RelRoot proto0 = new RelProtoConverter(extensionCollector).toProto(root0); + + // Substrait Proto -> Substrait POJO 1 + Plan.Root root1 = new ProtoRelConverter(extensionCollector, extensions).from(proto0); + + // Verify that POJOs are the same + assertEquals(root0, root1); + + final SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(extensions, typeFactory, catalogReader); + + // Substrait POJO 1 -> Calcite 1 + RelRoot calcite1 = substraitToCalcite.convert(root1); + + // End Preparation + + // Calcite 1 -> Substrait POJO 2 + Plan.Root root2 = SubstraitRelVisitor.convert(calcite1, extensions); + + // Substrait POJO 2 -> Substrait Proto 1 + io.substrait.proto.RelRoot proto1 = new RelProtoConverter(extensionCollector).toProto(root2); + + // Substrait Proto1 -> Substrait POJO 3 + Plan.Root root3 = new ProtoRelConverter(extensionCollector, extensions).from(proto1); + + // Substrait POJO 3 -> Calcite 2 + RelRoot calcite2 = substraitToCalcite.convert(root3); + // Calcite 2 -> Substrait POJO 4 + Plan.Root root4 = SubstraitRelVisitor.convert(calcite2, extensions); + + // Verify that POJOs are the same + assertEquals(root2, root4); + } + /** * Verifies that the given POJO can be converted: *