Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import java.util.stream.Collectors;

import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;

import org.apache.wayang.api.sql.calcite.converter.functions.JoinFlattenResult;
import org.apache.wayang.api.sql.calcite.converter.functions.MultiConditionJoinKeyExtractor;
import org.apache.wayang.api.sql.calcite.rel.WayangJoin;
import org.apache.wayang.api.sql.calcite.rel.WayangTableScan;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.basic.operators.JoinOperator;
Expand All @@ -48,26 +50,21 @@ public class WayangMultiConditionJoinVisitor extends WayangRelNodeVisitor<Wayang
*
* @param wayangRelConverter
*/
WayangMultiConditionJoinVisitor(final WayangRelConverter wayangRelConverter) {
public WayangMultiConditionJoinVisitor(final WayangRelConverter wayangRelConverter) {
super(wayangRelConverter);
}

@Override
Operator visit(WayangJoin wayangRelNode) {
public Operator visit(WayangJoin wayangRelNode) {
final Operator childOpLeft = wayangRelConverter.convert(wayangRelNode.getInput(0));
final Operator childOpRight = wayangRelConverter.convert(wayangRelNode.getInput(1));
final RexNode condition = ((Join) wayangRelNode).getCondition();
final RexCall call = (RexCall) condition;

//
final List<RexCall> subConditions = call.operands.stream()
.map(RexCall.class::cast)
.collect(Collectors.toList());

// calcite generates the RexInputRef indexes via looking at the union
// field list of the left and right input of a join.
// since the left input is always the first in this joined field list
// we can eagerly get the fields in the left input
final List<RexInputRef> leftTableInputRefs = subConditions.stream()
.map(sub -> sub.getOperands().stream()
.map(RexInputRef.class::cast)
Expand All @@ -79,9 +76,6 @@ Operator visit(WayangJoin wayangRelNode) {
.map(RexInputRef::getIndex)
.toArray(Integer[]::new);

// for the right table input refs, the indexes are offset by the amount of rows
// in the left
// input to the join
final List<RexInputRef> rightTableInputRefs = subConditions.stream()
.map(sub -> sub.getOperands().stream()
.map(RexInputRef.class::cast)
Expand All @@ -91,36 +85,40 @@ Operator visit(WayangJoin wayangRelNode) {

final Integer[] rightTableKeyIndexes = rightTableInputRefs.stream()
.map(RexInputRef::getIndex)
.map(key -> key - wayangRelNode.getLeft().getRowType().getFieldCount()) // apply offset
.map(key -> key - wayangRelNode.getLeft().getRowType().getFieldCount())
.toArray(Integer[]::new);

/*
final List<RelDataTypeField> leftFields = Arrays.stream(leftTableKeyIndexes)
.map(key -> wayangRelNode.getLeft().getRowType().getFieldList().get(key))
final List<RelDataTypeField> leftFields = leftTableInputRefs.stream()
.map(ref -> wayangRelNode.getLeft().getRowType().getFieldList().get(ref.getIndex()))
.collect(Collectors.toList());

final List<RelDataTypeField> rightFields = Arrays.stream(rightTableKeyIndexes)
.map(key -> wayangRelNode.getRight().getRowType().getFieldList().get(key))
final List<RelDataTypeField> rightFields = rightTableInputRefs.stream()
.map(ref -> wayangRelNode.getRight().getRowType().getFieldList().get(ref.getIndex() - wayangRelNode.getLeft().getRowType().getFieldCount()))
.collect(Collectors.toList());

final String joiningTableName = childOpLeft instanceof WayangTableScan ? childOpLeft.getName() : childOpRight.getName();
*/

// if join is joining the LHS of a join condition "JOIN left ON left = right"
// then we pick the first case, otherwise the 2nd "JOIN right ON left = right"
final JoinOperator<Record, Record, Record> join = this.getJoinOperator(
final String leftTableName = extractTableName(wayangRelNode.getLeft());
final String rightTableName = extractTableName(wayangRelNode.getRight());

final String leftFieldNames = leftFields.stream()
.map(RelDataTypeField::getName)
.collect(Collectors.joining(","));

final String rightFieldNames = rightFields.stream()
.map(RelDataTypeField::getName)
.collect(Collectors.joining(","));

final JoinOperator<Record, Record, Record> join = getJoinOperator(
leftTableKeyIndexes,
rightTableKeyIndexes,
wayangRelNode,
"",
"",
"",
"");
leftTableName,
leftFieldNames,
rightTableName,
rightFieldNames);

childOpLeft.connectTo(0, join, 0);
childOpRight.connectTo(0, join, 1);

// Join returns Tuple2 - map to a Record
final SerializableFunction<Tuple2<Record, Record>, Record> mp = new JoinFlattenResult();

final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
Expand All @@ -133,33 +131,34 @@ Operator visit(WayangJoin wayangRelNode) {
return mapOperator;
}

/**
* This method handles the {@link JoinOperator} creation
*
* @param wayangRelNode
* @param leftKeyIndex
* @param rightKeyIndex
* @return
*/
private String extractTableName(org.apache.calcite.rel.RelNode relNode) {
if (relNode instanceof WayangTableScan) {
return ((WayangTableScan) relNode).getTableName();
}
if (relNode.getInputs() != null && !relNode.getInputs().isEmpty()) {
return extractTableName(relNode.getInput(0));
}
return "UNKNOWN";
}

protected JoinOperator<Record, Record, Record> getJoinOperator(final Integer[] leftKeyIndexes,
final Integer[] rightKeyIndexes,
final WayangJoin wayangRelNode, final String leftTableName, final String leftFieldNames,
final String rightTableName, final String rightFieldNames) {
// TODO: needs withSqlImplementation() for sql support

if (wayangRelNode.getInputs().size() != 2)
throw new UnsupportedOperationException("Join had an unexpected amount of inputs, found: "
+ wayangRelNode.getInputs().size() + ", expected: 2");

final TransformationDescriptor<Record, Record> leftProjectionDescriptor = new TransformationDescriptor<Record, Record>(
new MultiConditionJoinKeyExtractor(leftKeyIndexes),
Record.class, Record.class);
// .withSqlImplementation(""," ")
Record.class, Record.class)
.withSqlImplementation(leftTableName, leftFieldNames);

final TransformationDescriptor<Record, Record> rightProjectionDescriptor = new TransformationDescriptor<Record, Record>(
new MultiConditionJoinKeyExtractor(rightKeyIndexes),
Record.class, Record.class);
// .withSqlImplementation(""," ")
Record.class, Record.class)
.withSqlImplementation(rightTableName, rightFieldNames);

final JoinOperator<Record, Record, Record> join = new JoinOperator<>(
leftProjectionDescriptor,
Expand Down
Loading
Loading