Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 65 additions & 9 deletions isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package io.substrait.isthmus;

import com.google.common.annotations.VisibleForTesting;
import io.substrait.isthmus.expression.DdlRelBuilder;
import io.substrait.plan.ImmutablePlan;
import io.substrait.plan.Plan.Version;
import io.substrait.plan.PlanProtoConverter;
import io.substrait.proto.Plan;
import java.util.LinkedList;
import java.util.List;
import org.apache.calcite.plan.hep.HepPlanner;
import org.apache.calcite.plan.hep.HepProgram;
Expand All @@ -12,6 +15,7 @@
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.schema.Schema;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.SqlNodeList;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
import org.apache.calcite.sql.validate.SqlValidator;
Expand Down Expand Up @@ -53,32 +57,84 @@ List<RelRoot> sqlToRelNode(String sql, List<String> tables) throws SqlParseExcep
return sqlToRelNode(sql, validator, catalogReader);
}

List<io.substrait.plan.Plan.Root> sqlToPlanNodes(String sql, List<String> tables)
throws SqlParseException {
Prepare.CatalogReader catalogReader = registerCreateTables(tables);
SqlValidator validator = Validator.create(factory, catalogReader, SqlValidator.Config.DEFAULT);
return sqlToPlanNodes(sql, validator, catalogReader, io.substrait.plan.Plan.builder());
}

private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
throws SqlParseException {
var builder = io.substrait.plan.Plan.builder();
builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build());

// TODO: consider case in which one sql passes conversion while others don't
sqlToRelNode(sql, validator, catalogReader).stream()
.map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard))
.forEach(root -> builder.addRoots(root));

sqlToPlanNodes(sql, validator, catalogReader, builder);
PlanProtoConverter planToProto = new PlanProtoConverter();

return planToProto.toProto(builder.build());
}

private List<io.substrait.plan.Plan.Root> sqlToPlanNodes(
String sql,
SqlValidator validator,
Prepare.CatalogReader catalogReader,
ImmutablePlan.Builder builder)
throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
var parsedList = parser.parseStmtList();
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
// IMPORTANT: parsedList gets filtered in the call below
List<io.substrait.plan.Plan.Root> ddlRelRoots = ddlSqlToRootNodes(parsedList, converter);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to process DDL statements separately? Would it be possible to just parse all statements, and then just use the conversion code in SubstraitRelVisitor?


List<io.substrait.plan.Plan.Root> sqlRelRoots = new LinkedList<>();

for (RelRoot relRoot : sqlNodesToRelNode(parsedList, converter)) {
io.substrait.plan.Plan.Root convert =
SubstraitRelVisitor.convert(relRoot, EXTENSION_COLLECTION, featureBoard);
sqlRelRoots.add(convert);
}

ddlRelRoots.addAll(sqlRelRoots);
ddlRelRoots.forEach(builder::addRoots);
return ddlRelRoots;
}

private List<io.substrait.plan.Plan.Root> ddlSqlToRootNodes(
final SqlNodeList sqlNodeList, final SqlToRelConverter converter) throws SqlParseException {

final DdlRelBuilder ddlRelBuilder =
new DdlRelBuilder(
converter, SqlToSubstrait::getBestExpRelRoot, EXTENSION_COLLECTION, featureBoard);

List<SqlNode> toRemove = new LinkedList<>();
List<io.substrait.plan.Plan.Root> retVal = new LinkedList<>();
for (final SqlNode sqlNode : sqlNodeList) {
final io.substrait.plan.Plan.Root root = sqlNode.accept(ddlRelBuilder);
if (root != null) {
retVal.add(root);
toRemove.add(sqlNode);
}
}
sqlNodeList.removeAll(toRemove);
return retVal;
}

private List<RelRoot> sqlNodesToRelNode(
final SqlNodeList parsedList, final SqlToRelConverter converter) {
return parsedList.stream()
.map(parsed -> getBestExpRelRoot(converter, parsed))
.collect(java.util.stream.Collectors.toList());
}

private List<RelRoot> sqlToRelNode(
String sql, SqlValidator validator, Prepare.CatalogReader catalogReader)
throws SqlParseException {
SqlParser parser = SqlParser.create(sql, parserConfig);
var parsedList = parser.parseStmtList();
SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader);
List<RelRoot> roots =
parsedList.stream()
.map(parsed -> getBestExpRelRoot(converter, parsed))
.collect(java.util.stream.Collectors.toList());
return roots;
return sqlNodesToRelNode(parsedList, converter);
}

@VisibleForTesting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractWriteRel;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NamedWrite;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
Expand Down Expand Up @@ -352,7 +354,44 @@ public Rel visit(org.apache.calcite.rel.core.Exchange exchange) {

@Override
public Rel visit(org.apache.calcite.rel.core.TableModify modify) {
return super.visit(modify);
final Rel input = apply(modify.getInput());
return switch (modify.getOperation()) {
case INSERT -> {
assert modify.getTable() != null;
yield NamedWrite.builder()
.input(input)
.tableSchema(typeConverter.toNamedStruct(modify.getRowType()))
.operation(AbstractWriteRel.WriteOp.INSERT)
.createMode(AbstractWriteRel.CreateMode.UNSPECIFIED)
.outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT)
.names(modify.getTable().getQualifiedName())
.build();
}
case UPDATE -> {
assert modify.getTable() != null;
yield NamedWrite.builder()
.input(input)
.tableSchema(typeConverter.toNamedStruct(modify.getRowType()))
.operation(AbstractWriteRel.WriteOp.UPDATE)
.createMode(AbstractWriteRel.CreateMode.UNSPECIFIED)
.outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT)
.names(modify.getTable().getQualifiedName())
.build();
}

case DELETE -> {
assert modify.getTable() != null;
yield NamedWrite.builder()
.input(input)
.tableSchema(typeConverter.toNamedStruct(modify.getRowType()))
.operation(AbstractWriteRel.WriteOp.DELETE)
.createMode(AbstractWriteRel.CreateMode.UNSPECIFIED)
.outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT)
.names(modify.getTable().getQualifiedName())
.build();
}
default -> super.visit(modify);
};
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package io.substrait.isthmus.expression;

import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.extension.SimpleExtension;
import io.substrait.isthmus.FeatureBoard;
import io.substrait.isthmus.SubstraitRelVisitor;
import io.substrait.isthmus.TypeConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractDdlRel;
import io.substrait.relation.AbstractWriteRel;
import io.substrait.relation.NamedDdl;
import io.substrait.relation.NamedWrite;
import io.substrait.type.NamedStruct;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import java.util.function.Function;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.sql.SqlCall;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.ddl.SqlCreateTable;
import org.apache.calcite.sql.ddl.SqlCreateView;
import org.apache.calcite.sql.util.SqlBasicVisitor;
import org.apache.calcite.sql2rel.SqlToRelConverter;

public class DdlRelBuilder extends SqlBasicVisitor<Plan.Root> {
protected final Map<Class<? extends SqlCall>, Function<SqlCall, Plan.Root>> createHandlers =
new ConcurrentHashMap<>();

private final SqlToRelConverter converter;
private final BiFunction<SqlToRelConverter, SqlNode, RelRoot> bestExpRelRootGetter;
private final SimpleExtension.ExtensionCollection extensionCollection;
private final FeatureBoard featureBoard;

private Function<SqlCall, Plan.Root> findCreateHandler(final SqlCall call) {
Class<?> currentClass = call.getClass();
while (SqlCall.class.isAssignableFrom(currentClass)) {
final Function<SqlCall, Plan.Root> found = createHandlers.get(currentClass);
if (found != null) {
return found;
}
currentClass = currentClass.getSuperclass();
}
return null;
}

public DdlRelBuilder(
final SqlToRelConverter converter,
final BiFunction<SqlToRelConverter, SqlNode, RelRoot> bestExpRelRootGetter,
final SimpleExtension.ExtensionCollection extensionCollection,
final FeatureBoard featureBoard) {
super();
this.converter = converter;
this.bestExpRelRootGetter = bestExpRelRootGetter;
this.extensionCollection = extensionCollection;
this.featureBoard = featureBoard;

createHandlers.put(
SqlCreateTable.class, sqlCall -> handleCreateTable((SqlCreateTable) sqlCall));
createHandlers.put(SqlCreateView.class, sqlCall -> handleCreateView((SqlCreateView) sqlCall));
}

@Override
public Plan.Root visit(final SqlCall sqlCall) {
Function<SqlCall, Plan.Root> createHandler = findCreateHandler(sqlCall);
if (createHandler == null) {
return null;
}

return createHandler.apply(sqlCall);
}

private NamedStruct getSchema(final RelRoot queryRelRoot) {
final RelDataType rowType = queryRelRoot.rel.getRowType();

final TypeConverter typeConverter = TypeConverter.DEFAULT;
return typeConverter.toNamedStruct(rowType);
}

public Plan.Root handleCreateTable(final SqlCreateTable sqlCreateTable) {
if (sqlCreateTable.query == null) {
throw new IllegalArgumentException("Only create table as select statements are supported");
}

final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateTable.query);

NamedStruct schema = getSchema(queryRelRoot);

var rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard);
NamedWrite namedWrite =
NamedWrite.builder()
.input(rel.getInput())
.tableSchema(schema)
.operation(AbstractWriteRel.WriteOp.CTAS)
.createMode(AbstractWriteRel.CreateMode.REPLACE_IF_EXISTS)
.outputMode(AbstractWriteRel.OutputMode.NO_OUTPUT)
.names(sqlCreateTable.name.names)
.build();

return Plan.Root.builder().input(namedWrite).build();
}

Plan.Root handleCreateView(final SqlCreateView sqlCreateView) {

final RelRoot queryRelRoot = bestExpRelRootGetter.apply(converter, sqlCreateView.query);
var rel = SubstraitRelVisitor.convert(queryRelRoot, extensionCollection, featureBoard);
final Expression.StructLiteral defaults = ExpressionCreator.struct(false);

final NamedDdl namedDdl =
NamedDdl.builder()
.viewDefinition(rel.getInput())
.tableSchema(getSchema(queryRelRoot))
.tableDefaults(defaults)
.operation(AbstractDdlRel.DdlOp.CREATE)
.object(AbstractDdlRel.DdlObject.VIEW)
.names(sqlCreateView.name.names)
.build();

return Plan.Root.builder().input(namedDdl).build();
}
}
29 changes: 29 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/SqlToSubstraitTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package io.substrait.isthmus;

import io.substrait.plan.Plan;
import java.io.IOException;
import java.util.List;
import org.apache.calcite.sql.parser.SqlParseException;
import org.junit.jupiter.api.Test;

public class SqlToSubstraitTest extends PlanTestBase {

@Test
void testDdlDml() throws SqlParseException, IOException {
final String sqlStatements = asString("sqltosubstrait/sqltosubstrait.sql");

SqlToSubstrait sql2subst = new SqlToSubstrait();
final List<io.substrait.plan.Plan.Root> relRoots =
sql2subst.sqlToPlanNodes(
sqlStatements,
List.of(
"create table src1 (intcol int, charcol varchar(10))",
"create table src2 (intcol int, charcol varchar(10))"));
var builder = io.substrait.plan.Plan.builder();
for (final io.substrait.plan.Plan.Root planRoot : relRoots) {
builder.addRoots(planRoot);
}
final Plan plan = builder.build();
assertPlanRoundtrip(plan);
}
}
6 changes: 6 additions & 0 deletions isthmus/src/test/resources/sqltosubstrait/sqltosubstrait.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
delete from src1 where intcol=10;
update src1 set intcol=10 where charcol='a';
insert into src1(intcol, charcol) values (1,'a');
insert into src1 select * from src2;
create view dst1 as select * from src1;
create table dst1 as select * from src1;