From b9a1c61d2c0987cf93376f62fae55998462e1c68 Mon Sep 17 00:00:00 2001 From: Anton Zorin Date: Tue, 1 Jul 2025 14:07:55 +0200 Subject: [PATCH] add support of ddl and dml to SqlToSubstrait --- .../io/substrait/isthmus/SqlToSubstrait.java | 74 +++++++++-- .../isthmus/SubstraitRelVisitor.java | 41 +++++- .../isthmus/expression/DdlRelBuilder.java | 123 ++++++++++++++++++ .../substrait/isthmus/SqlToSubstraitTest.java | 29 +++++ .../sqltosubstrait/sqltosubstrait.sql | 6 + 5 files changed, 263 insertions(+), 10 deletions(-) create mode 100644 isthmus/src/main/java/io/substrait/isthmus/expression/DdlRelBuilder.java create mode 100644 isthmus/src/test/java/io/substrait/isthmus/SqlToSubstraitTest.java create mode 100644 isthmus/src/test/resources/sqltosubstrait/sqltosubstrait.sql diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 42b15a4b5..93ef6f0e8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,9 +1,12 @@ package io.substrait.isthmus; import com.google.common.annotations.VisibleForTesting; +import io.substrait.isthmus.expression.DdlRelBuilder; +import io.substrait.plan.ImmutablePlan; import io.substrait.plan.Plan.Version; import io.substrait.plan.PlanProtoConverter; import io.substrait.proto.Plan; +import java.util.LinkedList; import java.util.List; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; @@ -12,6 +15,7 @@ import org.apache.calcite.rel.RelRoot; import org.apache.calcite.schema.Schema; import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.sql.validate.SqlValidator; @@ -53,32 +57,84 @@ List sqlToRelNode(String sql, List tables) throws SqlParseExcep return sqlToRelNode(sql, validator, catalogReader); } + List sqlToPlanNodes(String sql, List tables) + throws SqlParseException { + Prepare.CatalogReader catalogReader = registerCreateTables(tables); + SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT); + return sqlToPlanNodes(sql, validator, catalogReader, io.substrait.plan.Plan.builder()); + } + private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) throws SqlParseException { var builder = io.substrait.plan.Plan.builder(); builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build()); // TODO: consider case in which one sql passes conversion while others don't - sqlToRelNode(sql, validator, catalogReader).stream() - .map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard)) - .forEach(root -> builder.addRoots(root)); - + sqlToPlanNodes(sql, validator, catalogReader, builder); PlanProtoConverter planToProto = new PlanProtoConverter(); return planToProto.toProto(builder.build()); } + private List sqlToPlanNodes( + String sql, + SqlValidator validator, + Prepare.CatalogReader catalogReader, + ImmutablePlan.Builder builder) + throws SqlParseException { + SqlParser parser = SqlParser.create(sql, parserConfig); + var parsedList = parser.parseStmtList(); + SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); + // IMPORTANT: parsedList gets filtered in the call below + List ddlRelRoots = ddlSqlToRootNodes(parsedList, converter); + + List sqlRelRoots = new LinkedList<>(); + + for (RelRoot relRoot : sqlNodesToRelNode(parsedList, converter)) { + io.substrait.plan.Plan.Root convert = + SubstraitRelVisitor.convert(relRoot, EXTENSION_COLLECTION, featureBoard); + sqlRelRoots.add(convert); + } + + ddlRelRoots.addAll(sqlRelRoots); + ddlRelRoots.forEach(builder::addRoots); + return ddlRelRoots; + } + + private List ddlSqlToRootNodes( + final SqlNodeList sqlNodeList, final SqlToRelConverter converter) throws SqlParseException { + + final DdlRelBuilder ddlRelBuilder = + new DdlRelBuilder( + converter, SqlToSubstrait::getBestExpRelRoot, EXTENSION_COLLECTION, featureBoard); + + List toRemove = new LinkedList<>(); + List retVal = new LinkedList<>(); + for (final SqlNode sqlNode : sqlNodeList) { + final io.substrait.plan.Plan.Root root = sqlNode.accept(ddlRelBuilder); + if (root != null) { + retVal.add(root); + toRemove.add(sqlNode); + } + } + sqlNodeList.removeAll(toRemove); + return retVal; + } + + private List sqlNodesToRelNode( + final SqlNodeList parsedList, final SqlToRelConverter converter) { + return parsedList.stream() + .map(parsed -> getBestExpRelRoot(converter, parsed)) + .collect(java.util.stream.Collectors.toList()); + } + private List sqlToRelNode( String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) throws SqlParseException { SqlParser parser = SqlParser.create(sql, parserConfig); var parsedList = parser.parseStmtList(); SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); - List roots = - parsedList.stream() - .map(parsed -> getBestExpRelRoot(converter, parsed)) - .collect(java.util.stream.Collectors.toList()); - return roots; + return sqlNodesToRelNode(parsedList, converter); } @VisibleForTesting diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index fdb3f8aec..e8ac26280 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -11,6 +11,7 @@ import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.plan.Plan; +import io.substrait.relation.AbstractWriteRel; import io.substrait.relation.Aggregate; import io.substrait.relation.Cross; import io.substrait.relation.EmptyScan; @@ -18,6 +19,7 @@ import io.substrait.relation.Filter; import io.substrait.relation.Join; import io.substrait.relation.NamedScan; +import io.substrait.relation.NamedWrite; import io.substrait.relation.Project; import io.substrait.relation.Rel; import io.substrait.relation.Set; @@ -352,7 +354,44 @@ public Rel visit(org.apache.calcite.rel.core.Exchange exchange) { @Override public Rel visit(org.apache.calcite.rel.core.TableModify modify) { - return super.visit(modify); + final Rel input = apply(modify.getInput()); + return switch (modify.getOperation()) { + case INSERT -> { + assert modify.getTable() != null; + yield NamedWrite.builder() + .input(input) + .tableSchema(typeConverter.toNamedStruct(modify.getRowType())) + .operation(AbstractWriteRel.WriteOp.INSERT) + .createMode(AbstractWriteRel.CreateMode.UNSPECIFIED) + .outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT) + .names(modify.getTable().getQualifiedName()) + .build(); + } + case UPDATE -> { + assert modify.getTable() != null; + yield NamedWrite.builder() + .input(input) + .tableSchema(typeConverter.toNamedStruct(modify.getRowType())) + .operation(AbstractWriteRel.WriteOp.UPDATE) + .createMode(AbstractWriteRel.CreateMode.UNSPECIFIED) + .outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT) + .names(modify.getTable().getQualifiedName()) + .build(); + } + + case DELETE -> { + assert modify.getTable() != null; + yield NamedWrite.builder() + .input(input) + .tableSchema(typeConverter.toNamedStruct(modify.getRowType())) + .operation(AbstractWriteRel.WriteOp.DELETE) + .createMode(AbstractWriteRel.CreateMode.UNSPECIFIED) + .outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT) + .names(modify.getTable().getQualifiedName()) + .build(); + } + default -> super.visit(modify); + }; } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/DdlRelBuilder.java b/isthmus/src/main/java/io/substrait/isthmus/expression/DdlRelBuilder.java new file mode 100644 index 000000000..9a53c3f7b --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/DdlRelBuilder.java @@ -0,0 +1,123 @@ +package io.substrait.isthmus.expression; + +import io.substrait.expression.Expression; +import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.FeatureBoard; +import io.substrait.isthmus.SubstraitRelVisitor; +import io.substrait.isthmus.TypeConverter; +import io.substrait.plan.Plan; +import io.substrait.relation.AbstractDdlRel; +import io.substrait.relation.AbstractWriteRel; +import io.substrait.relation.NamedDdl; +import io.substrait.relation.NamedWrite; +import io.substrait.type.NamedStruct; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiFunction; +import java.util.function.Function; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.ddl.SqlCreateTable; +import org.apache.calcite.sql.ddl.SqlCreateView; +import org.apache.calcite.sql.util.SqlBasicVisitor; +import org.apache.calcite.sql2rel.SqlToRelConverter; + +public class DdlRelBuilder extends SqlBasicVisitor { + protected final Map, Function> createHandlers = + new ConcurrentHashMap<>(); + + private final SqlToRelConverter converter; + private final BiFunction bestExpRelRootGetter; + private final SimpleExtension.ExtensionCollection extensionCollection; + private final FeatureBoard featureBoard; + + private Function findCreateHandler(final SqlCall call) { + Class currentClass = call.getClass(); + while (SqlCall.class.isAssignableFrom(currentClass)) { + final Function found = createHandlers.get(currentClass); + if (found != null) { + return found; + } + currentClass = currentClass.getSuperclass(); + } + return null; + } + + public DdlRelBuilder( + final SqlToRelConverter converter, + final BiFunction bestExpRelRootGetter, + final SimpleExtension.ExtensionCollection extensionCollection, + final FeatureBoard featureBoard) { + super(); + this.converter = converter; + this.bestExpRelRootGetter = bestExpRelRootGetter; + this.extensionCollection = extensionCollection; + this.featureBoard = featureBoard; + + createHandlers.put( + SqlCreateTable.class, sqlCall -> handleCreateTable((SqlCreateTable) sqlCall)); + createHandlers.put(SqlCreateView.class, sqlCall -> handleCreateView((SqlCreateView) sqlCall)); + } + + @Override + public Plan.Root visit(final SqlCall sqlCall) { + Function createHandler = findCreateHandler(sqlCall); + if (createHandler == null) { + return null; + } + + return createHandler.apply(sqlCall); + } + + private NamedStruct getSchema(final RelRoot queryRelRoot) { + final RelDataType rowType = queryRelRoot.rel.getRowType(); + + final TypeConverter typeConverter = TypeConverter.DEFAULT; + return typeConverter.toNamedStruct(rowType); + } + + public Plan.Root handleCreateTable(final SqlCreateTable sqlCreateTable) { + if (sqlCreateTable.query == null) { + throw new IllegalArgumentException("Only create table as select statements are supported"); + } + + final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateTable.query); + + NamedStruct schema = getSchema(queryRelRoot); + + var rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard); + NamedWrite namedWrite = + NamedWrite.builder() + .input(rel.getInput()) + .tableSchema(schema) + .operation(AbstractWriteRel.WriteOp.CTAS) + .createMode(AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS) + .outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT) + .names(sqlCreateTable.name.names) + .build(); + + return Plan.Root.builder().input(namedWrite).build(); + } + + Plan.Root handleCreateView(final SqlCreateView sqlCreateView) { + + final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateView.query); + var rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard); + final Expression.StructLiteral defaults = ExpressionCreator.struct(false); + + final NamedDdl namedDdl = + NamedDdl.builder() + .viewDefinition(rel.getInput()) + .tableSchema(getSchema(queryRelRoot)) + .tableDefaults(defaults) + .operation(AbstractDdlRel.DdlOp.CREATE) + .object(AbstractDdlRel.DdlObject.VIEW) + .names(sqlCreateView.name.names) + .build(); + + return Plan.Root.builder().input(namedDdl).build(); + } +} diff --git a/isthmus/src/test/java/io/substrait/isthmus/SqlToSubstraitTest.java b/isthmus/src/test/java/io/substrait/isthmus/SqlToSubstraitTest.java new file mode 100644 index 000000000..c7ced581e --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/SqlToSubstraitTest.java @@ -0,0 +1,29 @@ +package io.substrait.isthmus; + +import io.substrait.plan.Plan; +import java.io.IOException; +import java.util.List; +import org.apache.calcite.sql.parser.SqlParseException; +import org.junit.jupiter.api.Test; + +public class SqlToSubstraitTest extends PlanTestBase { + + @Test + void testDdlDml() throws SqlParseException, IOException { + final String sqlStatements = asString("sqltosubstrait/sqltosubstrait.sql"); + + SqlToSubstrait sql2subst = new SqlToSubstrait(); + final List relRoots = + sql2subst.sqlToPlanNodes( + sqlStatements, + List.of( + "create table src1 (intcol int, charcol varchar(10))", + "create table src2 (intcol int, charcol varchar(10))")); + var builder = io.substrait.plan.Plan.builder(); + for (final io.substrait.plan.Plan.Root planRoot : relRoots) { + builder.addRoots(planRoot); + } + final Plan plan = builder.build(); + assertPlanRoundtrip(plan); + } +} diff --git a/isthmus/src/test/resources/sqltosubstrait/sqltosubstrait.sql b/isthmus/src/test/resources/sqltosubstrait/sqltosubstrait.sql new file mode 100644 index 000000000..b0faad454 --- /dev/null +++ b/isthmus/src/test/resources/sqltosubstrait/sqltosubstrait.sql @@ -0,0 +1,6 @@ +delete from src1 where intcol=10; +update src1 set intcol=10 where charcol='a'; +insert into src1(intcol, charcol) values (1,'a'); +insert into src1 select * from src2; +create view dst1 as select * from src1; +create table dst1 as select * from src1; \ No newline at end of file