From d32b872b1ce54b34ea557837adfb1530d273a25d Mon Sep 17 00:00:00 2001 From: Niels Pardon Date: Sat, 28 Feb 2026 11:11:25 +0100 Subject: [PATCH] fix(isthmus): fix argument order for trim function mapping Signed-off-by: Niels Pardon --- .../expression/TrimFunctionMapper.java | 20 ++- .../expression/TrimFunctionMapperTest.java | 140 ++++++++++++++++++ 2 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 isthmus/src/test/java/io/substrait/isthmus/expression/TrimFunctionMapperTest.java diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java b/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java index 06adbc662..87350a6c8 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/TrimFunctionMapper.java @@ -23,12 +23,13 @@ /** * Custom mapping for the Calcite TRIM function to various Substrait functions. The first TRIM * operand indicates the Substrait function to which it should be mapped. The first operand is then - * omitted from the arguments supplied to the Substrait function. + * omitted from the arguments supplied to the Substrait function. The second and third operands are + * swapped in order. * * */ final class TrimFunctionMapper implements ScalarFunctionMapper { @@ -99,8 +100,11 @@ public Optional toSubstrait(final RexCall call) { } String name = trim.substraitName(); - List operands = - call.getOperands().stream().skip(1).collect(Collectors.toUnmodifiableList()); + List operands = call.getOperands().stream().skip(1).collect(Collectors.toList()); + + // Substrait expects (string, characters) while Calcite has (characters, string) + Collections.swap(operands, 0, 1); + return new SubstraitFunctionMapping(name, operands, functions); }); } @@ -129,7 +133,9 @@ public Optional> getExpressionArguments( .map(EnumArg::of) .map( trimTypeArg -> { - LinkedList args = new LinkedList<>(expression.arguments()); + LinkedList args = new LinkedList<>(expression.arguments()); + // Substrait expects (string, characters) while Calcite has (characters, string) + Collections.swap(args, 0, 1); args.addFirst(trimTypeArg); return args; }); diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/TrimFunctionMapperTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/TrimFunctionMapperTest.java new file mode 100644 index 000000000..a71a2ca85 --- /dev/null +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/TrimFunctionMapperTest.java @@ -0,0 +1,140 @@ +package io.substrait.isthmus.expression; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import io.substrait.expression.EnumArg; +import io.substrait.expression.Expression.ScalarFunctionInvocation; +import io.substrait.expression.Expression.StrLiteral; +import io.substrait.expression.FunctionArg; +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.isthmus.PlanTestBase; +import io.substrait.type.TypeCreator; +import java.util.List; +import java.util.Optional; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.rex.RexCall; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.fun.SqlTrimFunction; +import org.junit.jupiter.api.Test; + +class TrimFunctionMapperTest extends PlanTestBase { + final TrimFunctionMapper trimFunctionMapper = + new TrimFunctionMapper(DefaultExtensionCatalog.DEFAULT_COLLECTION.scalarFunctions()); + + final RexBuilder rexBuilder = builder.getRexBuilder(); + + @Test + void calciteTrimBothArgumentOrder() { + final RexNode characters = rexBuilder.makeLiteral(" "); + final RexNode input = rexBuilder.makeLiteral(" whitespace "); + final RexNode trimBothRexCall = + rexBuilder.makeCall( + SqlStdOperatorTable.TRIM, + List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.BOTH), characters, input)); + + Optional substraitCall = + trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall); + + assertEquals("trim", substraitCall.get().substraitName()); + // operands should be swapped now + assertEquals(input, substraitCall.get().operands().get(0)); + assertEquals(characters, substraitCall.get().operands().get(1)); + } + + @Test + void calciteTrimLeadingArgumentOrder() { + final RexNode characters = rexBuilder.makeLiteral(" "); + final RexNode input = rexBuilder.makeLiteral(" whitespace "); + final RexNode trimBothRexCall = + rexBuilder.makeCall( + SqlStdOperatorTable.TRIM, + List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.LEADING), characters, input)); + + Optional substraitCall = + trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall); + + assertEquals("ltrim", substraitCall.get().substraitName()); + // operands should be swapped now + assertEquals(input, substraitCall.get().operands().get(0)); + assertEquals(characters, substraitCall.get().operands().get(1)); + } + + @Test + void calciteTrimTrailingArgumentOrder() { + final RexNode characters = rexBuilder.makeLiteral(" "); + final RexNode input = rexBuilder.makeLiteral(" whitespace "); + final RexNode trimBothRexCall = + rexBuilder.makeCall( + SqlStdOperatorTable.TRIM, + List.of(rexBuilder.makeFlag(SqlTrimFunction.Flag.TRAILING), characters, input)); + + Optional substraitCall = + trimFunctionMapper.toSubstrait((RexCall) trimBothRexCall); + + assertEquals("rtrim", substraitCall.get().substraitName()); + // operands should be swapped now + assertEquals(input, substraitCall.get().operands().get(0)); + assertEquals(characters, substraitCall.get().operands().get(1)); + } + + @Test + void substraitTrimArgumentOrder() { + final StrLiteral characters = sb.str(" "); + final StrLiteral input = sb.str(" whitespace "); + ScalarFunctionInvocation trimFn = + sb.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_STRING, + "trim:str_str", + TypeCreator.REQUIRED.STRING, + input, + characters); + + Optional> arguments = trimFunctionMapper.getExpressionArguments(trimFn); + + assertEquals(EnumArg.of(SqlTrimFunction.Flag.BOTH.name()), arguments.get().get(0)); + // arguments should be swapped now + assertEquals(characters, arguments.get().get(1)); + assertEquals(input, arguments.get().get(2)); + } + + @Test + void substraitLtrimArgumentOrder() { + final StrLiteral characters = sb.str(" "); + final StrLiteral input = sb.str(" whitespace "); + ScalarFunctionInvocation trimFn = + sb.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_STRING, + "ltrim:str_str", + TypeCreator.REQUIRED.STRING, + input, + characters); + + Optional> arguments = trimFunctionMapper.getExpressionArguments(trimFn); + + assertEquals(EnumArg.of(SqlTrimFunction.Flag.LEADING.name()), arguments.get().get(0)); + // arguments should be swapped now + assertEquals(characters, arguments.get().get(1)); + assertEquals(input, arguments.get().get(2)); + } + + @Test + void substraitRtrimArgumentOrder() { + final StrLiteral characters = sb.str(" "); + final StrLiteral input = sb.str(" whitespace "); + ScalarFunctionInvocation trimFn = + sb.scalarFn( + DefaultExtensionCatalog.FUNCTIONS_STRING, + "rtrim:str_str", + TypeCreator.REQUIRED.STRING, + input, + characters); + + Optional> arguments = trimFunctionMapper.getExpressionArguments(trimFn); + + assertEquals(EnumArg.of(SqlTrimFunction.Flag.TRAILING.name()), arguments.get().get(0)); + // arguments should be swapped now + assertEquals(characters, arguments.get().get(1)); + assertEquals(input, arguments.get().get(2)); + } +}