Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,15 @@ public Optional<Rel> visit(Expand expand, EmptyVisitationContext context) throws

@Override
public Optional<Rel> visit(NamedWrite write, EmptyVisitationContext context) throws E {
throw new UnsupportedOperationException();

Optional<Rel> input = write.getInput().accept(this, context);

if (allEmpty(input)) {
return Optional.empty();
}

return Optional.of(
NamedWrite.builder().from(write).input(input.orElse(write.getInput())).build());
}

@Override
Expand All @@ -233,9 +241,37 @@ public Optional<Rel> visit(ExtensionDdl ddl, EmptyVisitationContext context) thr
throw new UnsupportedOperationException();
}

protected Optional<NamedUpdate.TransformExpression> visitTransformExpression(
NamedUpdate.TransformExpression transform, EmptyVisitationContext context) throws E {
return transform
.getTransformation()
.accept(getExpressionCopyOnWriteVisitor(), context)
.map(
expr ->
NamedUpdate.TransformExpression.builder()
.from(transform)
.transformation(expr)
.build());
}

@Override
public Optional<Rel> visit(NamedUpdate update, EmptyVisitationContext context) throws E {
throw new UnsupportedOperationException();
Optional<Expression> condition =
update.getCondition().accept(getExpressionCopyOnWriteVisitor(), context);

Optional<List<AbstractUpdate.TransformExpression>> transformations =
transformList(update.getTransformations(), context, this::visitTransformExpression);

if (allEmpty(condition, transformations)) {
return Optional.empty();
}

return Optional.of(
NamedUpdate.builder()
.from(update)
.condition(condition.orElse(update.getCondition()))
.transformations(transformations.orElse(update.getTransformations()))
.build());
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.calcite.plan.hep.HepProgram;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.rules.CoreRules;
import org.apache.calcite.sql.SqlNode;
import org.apache.calcite.sql.parser.SqlParseException;
import org.apache.calcite.sql.parser.SqlParser;
Expand All @@ -33,7 +34,6 @@ public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlP
return executeInner(sql, validator, catalogReader);
}

// Package protected for testing
List<RelRoot> sqlToRelNode(String sql, Prepare.CatalogReader catalogReader)
throws SqlParseException {
SqlValidator validator = new SubstraitSqlValidator(catalogReader);
Expand Down Expand Up @@ -86,7 +86,9 @@ SqlToRelConverter createSqlToRelConverter(
static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) {
RelRoot root = converter.convertQuery(parsed, true, true);
{
var program = HepProgram.builder().build();
// RelBuilder seems to implicitly use the rule below,
// need to add to avoid discrepancies in assertFullRoundTrip
var program = HepProgram.builder().addRuleInstance(CoreRules.PROJECT_REMOVE).build();
HepPlanner hepPlanner = new HepPlanner(program);
hepPlanner.setRoot(root.rel);
root = root.withRel(hepPlanner.findBestExp());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.relation.AbstractRelVisitor;
import io.substrait.relation.AbstractUpdate;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
Expand All @@ -20,10 +21,14 @@
import io.substrait.relation.Join.JoinType;
import io.substrait.relation.LocalFiles;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NamedUpdate;
import io.substrait.relation.NamedWrite;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
import io.substrait.relation.Sort;
import io.substrait.relation.VirtualTableScan;
import io.substrait.type.NamedStruct;
import io.substrait.util.VisitationContext;
import java.util.ArrayList;
import java.util.Collection;
Expand All @@ -34,6 +39,7 @@
import java.util.stream.IntStream;
import java.util.stream.Stream;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptTable;
import org.apache.calcite.plan.RelTraitDef;
import org.apache.calcite.prepare.Prepare;
import org.apache.calcite.rel.RelCollation;
Expand All @@ -42,12 +48,15 @@
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.logical.LogicalTableModify;
import org.apache.calcite.rel.logical.LogicalValues;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexSlot;
import org.apache.calcite.sql.SqlAggFunction;
Expand Down Expand Up @@ -473,6 +482,118 @@ private RelFieldCollation toRelFieldCollation(Expression.SortField sortField, Co
return new RelFieldCollation(fieldIndex, fieldDirection, nullDirection);
}

@Override
public RelNode visit(NamedUpdate update, Context context) {
relBuilder.scan(update.getNames());
RexNode condition = update.getCondition().accept(expressionRexConverter, context);
relBuilder.filter(condition);
RelNode inputForModify = relBuilder.build();

NamedStruct tableSchema = update.getTableSchema();
List<String> fieldNames = tableSchema.names();

List<String> updateColumnList = new ArrayList<>();
List<RexNode> sourceExpressionList = new ArrayList<>();

for (AbstractUpdate.TransformExpression transform : update.getTransformations()) {

updateColumnList.add(fieldNames.get(transform.getColumnTarget()));
sourceExpressionList.add(
transform.getTransformation().accept(expressionRexConverter, context));
}

assert relBuilder.getRelOptSchema() != null;
final RelOptTable table = relBuilder.getRelOptSchema().getTableForMember(update.getNames());

if (table == null) {
throw new IllegalStateException("Table not found in Calcite catalog: " + update.getNames());
}
final Prepare.CatalogReader catalogReader = (Prepare.CatalogReader) table.getRelOptSchema();

assert catalogReader != null;
return LogicalTableModify.create(
table,
catalogReader,
inputForModify,
TableModify.Operation.UPDATE,
updateColumnList,
sourceExpressionList,
false);
}

@Override
public RelNode visit(VirtualTableScan virtualTableScan, Context context) {

final RelDataType typeInfoOnly =
typeConverter.toCalcite(typeFactory, virtualTableScan.getInitialSchema().struct());

final List<String> correctFieldNames = virtualTableScan.getInitialSchema().names();

final List<RelDataType> fieldTypes =
typeInfoOnly.getFieldList().stream()
.map(RelDataTypeField::getType)
.collect(Collectors.toList());

final RelDataType rowTypeWithNames =
typeFactory.createStructType(fieldTypes, correctFieldNames);

final List<ImmutableList<RexLiteral>> tuples = new ArrayList<>();
for (final Expression.StructLiteral row : virtualTableScan.getRows()) {
final List<RexLiteral> rexRow = new ArrayList<>();
for (final Expression.Literal literal : row.fields()) {
final RexNode rexNode = literal.accept(expressionRexConverter, context);
if (rexNode instanceof RexLiteral) {
final RexLiteral rexLiteral = (RexLiteral) rexNode;
rexRow.add(rexLiteral);
} else {
throw new UnsupportedOperationException(
"VirtualTableScan only supports literal values, found: "
+ rexNode.getClass().getName());
}
}
tuples.add(ImmutableList.copyOf(rexRow));
}

return LogicalValues.create(
relBuilder.getCluster(), rowTypeWithNames, ImmutableList.copyOf(tuples));
}

@Override
public RelNode visit(NamedWrite write, Context context) {
RelNode input = write.getInput().accept(this, context);
assert relBuilder.getRelOptSchema() != null;
final RelOptTable table = relBuilder.getRelOptSchema().getTableForMember(write.getNames());

if (table == null) {
throw new IllegalStateException("Table not found in Calcite catalog: " + write.getNames());
}

TableModify.Operation operation;
switch (write.getOperation()) {
case INSERT:
operation = TableModify.Operation.INSERT;
break;
case DELETE:
operation = TableModify.Operation.DELETE;
break;
default:
throw new UnsupportedOperationException(
"Write operation '"
+ write.getOperation()
+ "' is not supported by the NamedWrite visitor. "
+ "Check if a more specific relation type (e.g., NamedUpdate) should be used.");
}

return LogicalTableModify.create(
table,
(Prepare.CatalogReader) relBuilder.getRelOptSchema(),
input,
operation,
null,
null,
false);
}

@Override
public RelNode visitFallback(Rel rel, Context context) throws RuntimeException {
throw new UnsupportedOperationException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import io.substrait.isthmus.expression.ScalarFunctionConverter;
import io.substrait.isthmus.expression.WindowFunctionConverter;
import io.substrait.plan.Plan;
import io.substrait.relation.AbstractWriteRel;
import io.substrait.relation.Aggregate;
import io.substrait.relation.Cross;
import io.substrait.relation.EmptyScan;
import io.substrait.relation.Fetch;
import io.substrait.relation.Filter;
import io.substrait.relation.Join;
import io.substrait.relation.NamedScan;
import io.substrait.relation.NamedUpdate;
import io.substrait.relation.NamedWrite;
import io.substrait.relation.Project;
import io.substrait.relation.Rel;
import io.substrait.relation.Set;
Expand All @@ -37,6 +40,7 @@
import org.apache.calcite.rel.RelRoot;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.core.TableModify;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexFieldAccess;
Expand Down Expand Up @@ -374,8 +378,76 @@ public Rel visit(org.apache.calcite.rel.core.Exchange exchange) {
}

@Override
public Rel visit(org.apache.calcite.rel.core.TableModify modify) {
return super.visit(modify);
public Rel visit(TableModify modify) {
switch (modify.getOperation()) {
case INSERT:
case DELETE:
{
final Rel input = apply(modify.getInput());
final AbstractWriteRel.WriteOp op =
modify.getOperation() == TableModify.Operation.INSERT
? AbstractWriteRel.WriteOp.INSERT
: AbstractWriteRel.WriteOp.DELETE;

assert modify.getTable() != null;
return NamedWrite.builder()
.input(input)
.tableSchema(typeConverter.toNamedStruct(modify.getTable().getRowType()))
.operation(op)
.createMode(AbstractWriteRel.CreateMode.UNSPECIFIED)
.outputMode(AbstractWriteRel.OutputMode.MODIFIED_RECORDS)
.names(modify.getTable().getQualifiedName())
.build();
}

case UPDATE:
{
assert modify.getTable() != null;

Expression condition;
if (modify.getInput() instanceof org.apache.calcite.rel.core.Filter) {
org.apache.calcite.rel.core.Filter filter =
(org.apache.calcite.rel.core.Filter) modify.getInput();
condition = toExpression(filter.getCondition());
} else {
condition = Expression.BoolLiteral.builder().nullable(false).value(true).build();
}

List<String> updateColumnNames = modify.getUpdateColumnList();
List<RexNode> sourceExpressions = modify.getSourceExpressionList();
List<String> allTableColumnNames = modify.getTable().getRowType().getFieldNames();
List<NamedUpdate.TransformExpression> transformations = new ArrayList<>();

for (int i = 0; i < updateColumnNames.size(); i++) {
String colName = updateColumnNames.get(i);
RexNode rexExpr = sourceExpressions.get(i);

int columnIndex = allTableColumnNames.indexOf(colName);
if (columnIndex == -1) {
throw new IllegalStateException(
"Updated column '" + colName + "' not found in table schema.");
}

Expression substraitExpr = toExpression(rexExpr);

transformations.add(
NamedUpdate.TransformExpression.builder()
.columnTarget(columnIndex)
.transformation(substraitExpr)
.build());
}

return NamedUpdate.builder()
.tableSchema(typeConverter.toNamedStruct(modify.getTable().getRowType()))
.names(modify.getTable().getQualifiedName())
.condition(condition)
.transformations(transformations)
.build();
}

default:
return super.visit(modify);
}
}

@Override
Expand Down
Loading
Loading