diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java
index 06adbc662..87350a6c8 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java
@@ -23,12 +23,13 @@
/**
* Custom mapping for the Calcite TRIM function to various Substrait functions. The first TRIM
* operand indicates the Substrait function to which it should be mapped. The first operand is then
- * omitted from the arguments supplied to the Substrait function.
+ * omitted from the arguments supplied to the Substrait function. The second and third operands are
+ * swapped in order.
*
*
- * - TRIM('BOTH', characters, string) -> trim(characters, string)
- *
- TRIM('LEADING', characters, string) -> ltrim(characters, string)
- *
- TRIM('TRAILING', .characters, string) -> rtrim(characters, string)
+ *
- TRIM('BOTH', characters, string) -> trim(string, characters)
+ *
- TRIM('LEADING', characters, string) -> ltrim(string, characters)
+ *
- TRIM('TRAILING', .characters, string) -> rtrim(string, characters)
*
*/
final class TrimFunctionMapper implements ScalarFunctionMapper {
@@ -99,8 +100,11 @@ public Optional toSubstrait(final RexCall call) {
}
String name = trim.substraitName();
- List operands =
- call.getOperands().stream().skip(1).collect(Collectors.toUnmodifiableList());
+ List operands = call.getOperands().stream().skip(1).collect(Collectors.toList());
+
+ // Substrait expects (string, characters) while Calcite has (characters, string)
+ Collections.swap(operands, 0, 1);
+
return new SubstraitFunctionMapping(name, operands, functions);
});
}
@@ -129,7 +133,9 @@ public Optional> getExpressionArguments(
.map(EnumArg::of)
.map(
trimTypeArg -> {
- LinkedList args = new LinkedList<>(expression.arguments());
+ LinkedList args = new LinkedList<>(expression.arguments());
+ // Substrait expects (string, characters) while Calcite has (characters, string)
+ Collections.swap(args, 0, 1);
args.addFirst(trimTypeArg);
return args;
});
diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/TrimFunctionMapperTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/TrimFunctionMapperTest.java
new file mode 100644
index 000000000..a71a2ca85
--- /dev/null
+++ b/isthmus/src/test/java/io/substrait/isthmus/expression/TrimFunctionMapperTest.java
@@ -0,0 +1,140 @@
+package io.substrait.isthmus.expression;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+import io.substrait.expression.EnumArg;
+import io.substrait.expression.Expression.ScalarFunctionInvocation;
+import io.substrait.expression.Expression.StrLiteral;
+import io.substrait.expression.FunctionArg;
+import io.substrait.extension.DefaultExtensionCatalog;
+import io.substrait.isthmus.PlanTestBase;
+import io.substrait.type.TypeCreator;
+import java.util.List;
+import java.util.Optional;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexCall;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.fun.SqlTrimFunction;
+import org.junit.jupiter.api.Test;
+
+class TrimFunctionMapperTest extends PlanTestBase {
+ final TrimFunctionMapper trimFunctionMapper =
+ new TrimFunctionMapper(DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions());
+
+ final RexBuilder rexBuilder = builder.getRexBuilder();
+
+ @Test
+ void calciteTrimBothArgumentOrder() {
+ final RexNode characters = rexBuilder.makeLiteral(" ");
+ final RexNode input = rexBuilder.makeLiteral(" whitespace ");
+ final RexNode trimBothRexCall =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.TRIM,
+ List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.BOTH), characters, input));
+
+ Optional substraitCall =
+ trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall);
+
+ assertEquals("trim", substraitCall.get().substraitName());
+ // operands should be swapped now
+ assertEquals(input, substraitCall.get().operands().get(0));
+ assertEquals(characters, substraitCall.get().operands().get(1));
+ }
+
+ @Test
+ void calciteTrimLeadingArgumentOrder() {
+ final RexNode characters = rexBuilder.makeLiteral(" ");
+ final RexNode input = rexBuilder.makeLiteral(" whitespace ");
+ final RexNode trimBothRexCall =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.TRIM,
+ List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.LEADING), characters, input));
+
+ Optional substraitCall =
+ trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall);
+
+ assertEquals("ltrim", substraitCall.get().substraitName());
+ // operands should be swapped now
+ assertEquals(input, substraitCall.get().operands().get(0));
+ assertEquals(characters, substraitCall.get().operands().get(1));
+ }
+
+ @Test
+ void calciteTrimTrailingArgumentOrder() {
+ final RexNode characters = rexBuilder.makeLiteral(" ");
+ final RexNode input = rexBuilder.makeLiteral(" whitespace ");
+ final RexNode trimBothRexCall =
+ rexBuilder.makeCall(
+ SqlStdOperatorTable.TRIM,
+ List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.TRAILING), characters, input));
+
+ Optional substraitCall =
+ trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall);
+
+ assertEquals("rtrim", substraitCall.get().substraitName());
+ // operands should be swapped now
+ assertEquals(input, substraitCall.get().operands().get(0));
+ assertEquals(characters, substraitCall.get().operands().get(1));
+ }
+
+ @Test
+ void substraitTrimArgumentOrder() {
+ final StrLiteral characters = sb.str(" ");
+ final StrLiteral input = sb.str(" whitespace ");
+ ScalarFunctionInvocation trimFn =
+ sb.scalarFn(
+ DefaultExtensionCatalog.FUNCTIONS_STRING,
+ "trim:str_str",
+ TypeCreator.REQUIRED.STRING,
+ input,
+ characters);
+
+ Optional> arguments = trimFunctionMapper.getExpressionArguments(trimFn);
+
+ assertEquals(EnumArg.of(SqlTrimFunction.Flag.BOTH.name()), arguments.get().get(0));
+ // arguments should be swapped now
+ assertEquals(characters, arguments.get().get(1));
+ assertEquals(input, arguments.get().get(2));
+ }
+
+ @Test
+ void substraitLtrimArgumentOrder() {
+ final StrLiteral characters = sb.str(" ");
+ final StrLiteral input = sb.str(" whitespace ");
+ ScalarFunctionInvocation trimFn =
+ sb.scalarFn(
+ DefaultExtensionCatalog.FUNCTIONS_STRING,
+ "ltrim:str_str",
+ TypeCreator.REQUIRED.STRING,
+ input,
+ characters);
+
+ Optional> arguments = trimFunctionMapper.getExpressionArguments(trimFn);
+
+ assertEquals(EnumArg.of(SqlTrimFunction.Flag.LEADING.name()), arguments.get().get(0));
+ // arguments should be swapped now
+ assertEquals(characters, arguments.get().get(1));
+ assertEquals(input, arguments.get().get(2));
+ }
+
+ @Test
+ void substraitRtrimArgumentOrder() {
+ final StrLiteral characters = sb.str(" ");
+ final StrLiteral input = sb.str(" whitespace ");
+ ScalarFunctionInvocation trimFn =
+ sb.scalarFn(
+ DefaultExtensionCatalog.FUNCTIONS_STRING,
+ "rtrim:str_str",
+ TypeCreator.REQUIRED.STRING,
+ input,
+ characters);
+
+ Optional> arguments = trimFunctionMapper.getExpressionArguments(trimFn);
+
+ assertEquals(EnumArg.of(SqlTrimFunction.Flag.TRAILING.name()), arguments.get().get(0));
+ // arguments should be swapped now
+ assertEquals(characters, arguments.get().get(1));
+ assertEquals(input, arguments.get().get(2));
+ }
+}