Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions core/src/main/java/io/substrait/expression/Expression.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ interface Literal extends Expression {
default boolean nullable() {
return false;
}

/**
* Returns a copy of this literal with the specified nullability.
*
* <p>This method is implemented by all concrete Literal classes via Immutables code generation.
*/
Literal withNullable(boolean nullable);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oooh, that's handy.

I think we can use this to replace the

private static final class NullableSettingTypeVisitor
which uses a whole visitor just to change the nullability 😅

Would be worth quick follow-up after this PR.

}

interface Nested extends Expression {
Expand All @@ -52,6 +59,14 @@ public boolean nullable() {
return true;
}

@Override
public NullLiteral withNullable(boolean nullable) {
if (!nullable) {
throw new IllegalArgumentException("NullLiteral cannot be made non-nullable");
}
return this;
}

@Value.Check
protected void check() {
if (!type().nullable()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ public List<CallConverter> getCallConverters() {
ArrayList<CallConverter> callConverters = new ArrayList<>();
callConverters.add(new FieldSelectionConverter(typeConverter));
callConverters.add(CallConverters.CASE);
callConverters.add(CallConverters.ROW);
callConverters.add(CallConverters.CAST.apply(typeConverter));
callConverters.add(CallConverters.REINTERPRET.apply(typeConverter));
callConverters.add(new SqlArrayValueConstructorCallConverter(typeConverter));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
import io.substrait.expression.Expression;
import io.substrait.expression.ExpressionCreator;
import io.substrait.isthmus.CallConverter;
import io.substrait.isthmus.SubstraitRelNodeConverter;
import io.substrait.isthmus.TypeConverter;
import io.substrait.type.Type;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
Expand Down Expand Up @@ -44,16 +47,18 @@ public class CallConverters {
* {@link SqlKind#REINTERPRET} is utilized by Isthmus to represent and store {@link
* Expression.UserDefinedLiteral}s within Calcite.
*
* <p>When converting from Substrait to Calcite, the {@link
* Expression.UserDefinedAnyLiteral#value()} is stored within a {@link
* org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link org.apache.calcite.rex.RexLiteral} and
* then re-interpreted to have the correct type.
* <p>When converting from Substrait to Calcite, the user-defined literal value is stored either
* as a {@link org.apache.calcite.sql.type.SqlTypeName#BINARY} {@link
* org.apache.calcite.rex.RexLiteral} (for ANY-encoded values) or a {@link SqlKind#ROW} (for
* struct-encoded values) and then re-interpreted to have the correct user-defined type.
*
* <p>See {@link ExpressionRexConverter#visit(Expression.UserDefinedAnyLiteral,
* SubstraitRelNodeConverter.Context)} and {@link
* ExpressionRexConverter#visit(Expression.UserDefinedStructLiteral,
* SubstraitRelNodeConverter.Context)} for this conversion.
*
* <p>When converting from Calcite to Substrait, this call converter extracts the {@link
* Expression.UserDefinedAnyLiteral} that was stored.
* <p>When converting from Calcite to Substrait, this call converter extracts the stored {@link
* Expression.UserDefinedLiteral}.
*/
public static Function<TypeConverter, SimpleCallConverter> REINTERPRET =
typeConverter ->
Expand All @@ -64,8 +69,7 @@ public class CallConverters {
Expression operand = visitor.apply(call.getOperands().get(0));
Type type = typeConverter.toSubstrait(call.getType());

// For now, we only support handling of SqlKind.REINTEPRETET for the case of stored
// user-defined literals
// Calcite encoded Expression.UserDefinedAnyLiteral
if (operand instanceof Expression.FixedBinaryLiteral
&& type instanceof Type.UserDefined) {
Expression.FixedBinaryLiteral literal = (Expression.FixedBinaryLiteral) operand;
Expand All @@ -87,19 +91,64 @@ public class CallConverters {
throw new IllegalStateException("Failed to parse UserDefinedAnyLiteral value", e);
}
}
// Calcite encoded Expression.UserDefinedStructLiteral
else if (operand instanceof Expression.StructLiteral
&& type instanceof Type.UserDefined) {
Expression.StructLiteral structLiteral = (Expression.StructLiteral) operand;
Type.UserDefined t = (Type.UserDefined) type;

return Expression.UserDefinedStructLiteral.builder()
.nullable(t.nullable())
.urn(t.urn())
.name(t.name())
.addAllTypeParameters(t.typeParameters())
.addAllFields(structLiteral.fields())
.build();
}
return null;
};

// public static SimpleCallConverter OrAnd(FunctionConverter c) {
// return (call, visitor) -> {
// if (call.getKind() != SqlKind.AND && call.getKind() != SqlKind.OR) {
// return null;
// }
//
//
// return null;
// };
// }
/**
* Converts Calcite ROW constructors into Substrait {@link Expression.StructLiteral}s.
*
* <p>ROW values are always concrete (never null themselves) - if a value is actually null, use
* NullLiteral instead of StructLiteral. Therefore, the resulting StructLiteral always has
* nullable=false. The ROW's type may be nullable (for regular structs) or non-nullable (for UDT
* struct encoding), but the value itself is always concrete.
*
* <p>Each literal's nullability is set to match its field type's nullability.
*/
public static SimpleCallConverter ROW =
(call, visitor) -> {
if (call.getKind() != SqlKind.ROW) {
return null;
}

List<Expression> operands =
call.getOperands().stream().map(visitor).collect(Collectors.toList());
if (!operands.stream().allMatch(expr -> expr instanceof Expression.Literal)) {
throw new IllegalArgumentException("ROW operands must be literals.");
}

// ROW types are never nullable (struct literals are always concrete values).
// Field nullability comes from individual field types, so match literal nullability
// to field type nullability.
List<RelDataTypeField> fieldTypes = call.getType().getFieldList();
List<Expression.Literal> literals =
java.util.stream.IntStream.range(0, operands.size())
.mapToObj(
i -> {
Expression.Literal lit = (Expression.Literal) operands.get(i);
boolean fieldIsNullable = fieldTypes.get(i).getType().isNullable();
return lit.withNullable(fieldIsNullable);
})
.collect(Collectors.toList());

// Struct literals are always concrete values (never null).
// For UDT struct literals, struct-level nullability is in the REINTERPRET target type.
return ExpressionCreator.struct(false, literals);
};

/** */
public static SimpleCallConverter CASE =
(call, visitor) -> {
Expand All @@ -112,7 +161,7 @@ public class CallConverters {
assert call.getOperands().size() % 2 == 1;

List<Expression> caseArgs =
call.getOperands().stream().map(visitor).collect(java.util.stream.Collectors.toList());
call.getOperands().stream().map(visitor).collect(Collectors.toList());

int last = caseArgs.size() - 1;
// for if/else, process in reverse to maintain query order
Expand Down Expand Up @@ -150,6 +199,7 @@ public static List<CallConverter> defaults(TypeConverter typeConverter) {
return ImmutableList.of(
new FieldSelectionConverter(typeConverter),
CallConverters.CASE,
CallConverters.ROW,
CallConverters.CAST.apply(typeConverter),
CallConverters.REINTERPRET.apply(typeConverter),
new SqlArrayValueConstructorCallConverter(typeConverter),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,12 @@ public RexNode visit(Expression.UserDefinedAnyLiteral expr, Context context)
@Override
public RexNode visit(Expression.UserDefinedStructLiteral expr, Context context)
throws RuntimeException {
throw new UnsupportedOperationException(
"UserDefinedStructLiteral representation is not yet supported in Isthmus");
// UserDefinedStructLiteral: Struct is just the ENCODING/REPRESENTATION of a UDT value.
// The ROW is never nullable (it's just encoding). UDT nullability is carried by the
// REINTERPRET target type: REINTERPRET(ROW(...), udt{nullable=true/false}).
RelDataType type = typeConverter.toCalcite(typeFactory, expr.getType());
RexNode structValue = toStructEncoding(expr.fields(), context);
return rexBuilder.makeReinterpretCast(type, structValue, rexBuilder.makeLiteral(false));
}

@Override
Expand Down Expand Up @@ -320,6 +324,14 @@ public RexNode visit(Expression.DecimalLiteral expr, Context context) throws Run
return rexBuilder.makeLiteral(decimal, typeConverter.toCalcite(typeFactory, expr.getType()));
}

@Override
public RexNode visit(Expression.StructLiteral expr, Context context) throws RuntimeException {
List<RexNode> fieldNodes =
expr.fields().stream().map(f -> f.accept(this, context)).collect(Collectors.toList());
RelDataType structType = typeConverter.toCalcite(typeFactory, expr.getType());
return rexBuilder.makeCall(structType, SqlStdOperatorTable.ROW, fieldNodes);
}

@Override
public RexNode visit(Expression.ListLiteral expr, Context context) throws RuntimeException {
List<RexNode> args =
Expand Down Expand Up @@ -723,4 +735,35 @@ public RexNode visit(SetPredicate expr, Context context) throws RuntimeException
"Cannot handle SetPredicate when PredicateOp is %s.", expr.predicateOp().name()));
}
}

/**
* Helper method to create a Calcite ROW expression for encoding UDT struct literals.
*
* <p>Used specifically for {@link Expression.UserDefinedStructLiteral} where the struct is just
* the encoding representation of the UDT value. The ROW is never nullable because it's just the
* encoding - nullability is carried by the REINTERPRET target UDT type.
*
* <p>For regular {@link Expression.StructLiteral}, use the struct's own type via {@code
* expr.getType()} instead.
*/
private RexNode toStructEncoding(List<? extends Expression.Literal> fields, Context context) {
List<RexNode> fieldNodes =
fields.stream().map(f -> f.accept(this, context)).collect(Collectors.toList());

// Note: Field names ("field0", "field1", etc.) are dummy values required by Calcite's ROW
// type. These names are discarded during roundtrip conversion back to Substrait, as Substrait
// struct literals are position-based and only the field values are preserved.
//
// The ROW type is never nullable because it's just encoding for the UDT. Field nullability
// comes from individual field types.
RelDataTypeFactory.Builder rowBuilder = typeFactory.builder();
IntStream.range(0, fields.size())
.forEach(
i -> {
RelDataType fieldType = typeConverter.toCalcite(typeFactory, fields.get(i).getType());
rowBuilder.add("field" + i, fieldType);
});

return rexBuilder.makeCall(rowBuilder.build(), SqlStdOperatorTable.ROW, fieldNodes);
}
}
69 changes: 69 additions & 0 deletions isthmus/src/test/java/io/substrait/isthmus/CalciteLiteralTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.charset.StandardCharsets;
import java.time.LocalDate;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.TimeUnit;
import org.apache.calcite.rex.RexLiteral;
Expand Down Expand Up @@ -388,6 +389,74 @@ void tStruct() {
false));
}

@Test
void tStructRoundtripNullableFields() {
// Test regular struct with nullable fields roundtrips correctly
Expression.StructLiteral struct =
ExpressionCreator.struct(
false, ExpressionCreator.i32(true, 4), ExpressionCreator.i32(true, -1));

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tStructRoundtripMixedFieldNullability() {
// Test regular struct with mixed field nullability roundtrips correctly
Expression.StructLiteral struct =
ExpressionCreator.struct(
false, ExpressionCreator.i32(true, 4), ExpressionCreator.i32(false, -1));

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tStructRoundtripWithNullFieldValues() {
// Test struct with actual NULL field values roundtrips correctly
Expression.NullLiteral nullField =
ExpressionCreator.typedNull(io.substrait.type.Type.I32.builder().nullable(true).build());

Expression.StructLiteral struct =
ExpressionCreator.struct(false, nullField, ExpressionCreator.i32(false, 100));

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tStructRoundtripNested() {
// Test nested regular structs roundtrip correctly
Expression.StructLiteral innerStruct =
ExpressionCreator.struct(
false, ExpressionCreator.i32(false, 1), ExpressionCreator.i32(false, 2));

Expression.StructLiteral outerStruct =
ExpressionCreator.struct(false, innerStruct, ExpressionCreator.i32(false, 3));

RexNode rex = outerStruct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(outerStruct, roundtrip);
}

@Test
void tStructRoundtripEmpty() {
// Test empty struct roundtrips correctly
Expression.StructLiteral struct = ExpressionCreator.struct(false, Collections.emptyList());

RexNode rex = struct.accept(expressionRexConverter, Context.newContext());
Expression roundtrip = rex.accept(rexExpressionConverter);

assertEquals(struct, roundtrip);
}

@Test
void tFixedBinary() {
byte[] val = "my test".getBytes(StandardCharsets.UTF_8);
Expand Down
Loading