Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@
/**
* Custom mapping for the Calcite TRIM function to various Substrait functions. The first TRIM
* operand indicates the Substrait function to which it should be mapped. The first operand is then
* omitted from the arguments supplied to the Substrait function.
* omitted from the arguments supplied to the Substrait function. The second and third operands are
* swapped in order.
*
* <ul>
* <li>TRIM('BOTH', characters, string) -> trim(characters, string)
* <li>TRIM('LEADING', characters, string) -> ltrim(characters, string)
* <li>TRIM('TRAILING', .characters, string) -> rtrim(characters, string)
* <li>TRIM('BOTH', characters, string) -> trim(string, characters)
* <li>TRIM('LEADING', characters, string) -> ltrim(string, characters)
* <li>TRIM('TRAILING', .characters, string) -> rtrim(string, characters)
* </ul>
*/
final class TrimFunctionMapper implements ScalarFunctionMapper {
Expand Down Expand Up @@ -99,8 +100,11 @@ public Optional<SubstraitFunctionMapping> toSubstrait(final RexCall call) {
}

String name = trim.substraitName();
List<RexNode> operands =
call.getOperands().stream().skip(1).collect(Collectors.toUnmodifiableList());
List<RexNode> operands = call.getOperands().stream().skip(1).collect(Collectors.toList());

// Substrait expects (string, characters) while Calcite has (characters, string)
Collections.swap(operands, 0, 1);

return new SubstraitFunctionMapping(name, operands, functions);
});
}
Expand Down Expand Up @@ -129,7 +133,9 @@ public Optional<List<FunctionArg>> getExpressionArguments(
.map(EnumArg::of)
.map(
trimTypeArg -> {
LinkedList args = new LinkedList<>(expression.arguments());
LinkedList<FunctionArg> args = new LinkedList<>(expression.arguments());
// Substrait expects (string, characters) while Calcite has (characters, string)
Collections.swap(args, 0, 1);
args.addFirst(trimTypeArg);
return args;
});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package io.substrait.isthmus.expression;

import static org.junit.jupiter.api.Assertions.assertEquals;

import io.substrait.expression.EnumArg;
import io.substrait.expression.Expression.ScalarFunctionInvocation;
import io.substrait.expression.Expression.StrLiteral;
import io.substrait.expression.FunctionArg;
import io.substrait.extension.DefaultExtensionCatalog;
import io.substrait.isthmus.PlanTestBase;
import io.substrait.type.TypeCreator;
import java.util.List;
import java.util.Optional;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.fun.SqlTrimFunction;
import org.junit.jupiter.api.Test;

class TrimFunctionMapperTest extends PlanTestBase {
final TrimFunctionMapper trimFunctionMapper =
new TrimFunctionMapper(DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions());

final RexBuilder rexBuilder = builder.getRexBuilder();

@Test
void calciteTrimBothArgumentOrder() {
final RexNode characters = rexBuilder.makeLiteral(" ");
final RexNode input = rexBuilder.makeLiteral(" whitespace ");
final RexNode trimBothRexCall =
rexBuilder.makeCall(
SqlStdOperatorTable.TRIM,
List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.BOTH), characters, input));

Optional<SubstraitFunctionMapping> substraitCall =
trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall);

assertEquals("trim", substraitCall.get().substraitName());
// operands should be swapped now
assertEquals(input, substraitCall.get().operands().get(0));
assertEquals(characters, substraitCall.get().operands().get(1));
}

@Test
void calciteTrimLeadingArgumentOrder() {
final RexNode characters = rexBuilder.makeLiteral(" ");
final RexNode input = rexBuilder.makeLiteral(" whitespace ");
final RexNode trimBothRexCall =
rexBuilder.makeCall(
SqlStdOperatorTable.TRIM,
List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.LEADING), characters, input));

Optional<SubstraitFunctionMapping> substraitCall =
trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall);

assertEquals("ltrim", substraitCall.get().substraitName());
// operands should be swapped now
assertEquals(input, substraitCall.get().operands().get(0));
assertEquals(characters, substraitCall.get().operands().get(1));
}

@Test
void calciteTrimTrailingArgumentOrder() {
final RexNode characters = rexBuilder.makeLiteral(" ");
final RexNode input = rexBuilder.makeLiteral(" whitespace ");
final RexNode trimBothRexCall =
rexBuilder.makeCall(
SqlStdOperatorTable.TRIM,
List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.TRAILING), characters, input));

Optional<SubstraitFunctionMapping> substraitCall =
trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall);

assertEquals("rtrim", substraitCall.get().substraitName());
// operands should be swapped now
assertEquals(input, substraitCall.get().operands().get(0));
assertEquals(characters, substraitCall.get().operands().get(1));
}

@Test
void substraitTrimArgumentOrder() {
final StrLiteral characters = sb.str(" ");
final StrLiteral input = sb.str(" whitespace ");
ScalarFunctionInvocation trimFn =
sb.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_STRING,
"trim:str_str",
TypeCreator.REQUIRED.STRING,
input,
characters);

Optional<List<FunctionArg>> arguments = trimFunctionMapper.getExpressionArguments(trimFn);

assertEquals(EnumArg.of(SqlTrimFunction.Flag.BOTH.name()), arguments.get().get(0));
// arguments should be swapped now
assertEquals(characters, arguments.get().get(1));
assertEquals(input, arguments.get().get(2));
}

@Test
void substraitLtrimArgumentOrder() {
final StrLiteral characters = sb.str(" ");
final StrLiteral input = sb.str(" whitespace ");
ScalarFunctionInvocation trimFn =
sb.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_STRING,
"ltrim:str_str",
TypeCreator.REQUIRED.STRING,
input,
characters);

Optional<List<FunctionArg>> arguments = trimFunctionMapper.getExpressionArguments(trimFn);

assertEquals(EnumArg.of(SqlTrimFunction.Flag.LEADING.name()), arguments.get().get(0));
// arguments should be swapped now
assertEquals(characters, arguments.get().get(1));
assertEquals(input, arguments.get().get(2));
}

@Test
void substraitRtrimArgumentOrder() {
final StrLiteral characters = sb.str(" ");
final StrLiteral input = sb.str(" whitespace ");
ScalarFunctionInvocation trimFn =
sb.scalarFn(
DefaultExtensionCatalog.FUNCTIONS_STRING,
"rtrim:str_str",
TypeCreator.REQUIRED.STRING,
input,
characters);

Optional<List<FunctionArg>> arguments = trimFunctionMapper.getExpressionArguments(trimFn);

assertEquals(EnumArg.of(SqlTrimFunction.Flag.TRAILING.name()), arguments.get().get(0));
// arguments should be swapped now
assertEquals(characters, arguments.get().get(1));
assertEquals(input, arguments.get().get(2));
}
}