diff --git a/paimon-trino-358/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java b/paimon-trino-358/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java index 8a57210..a89dbd6 100644 --- a/paimon-trino-358/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java +++ b/paimon-trino-358/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java @@ -35,6 +35,6 @@ public ConnectorSplitSource getSplits( ConnectorTableHandle table, SplitSchedulingStrategy splitSchedulingStrategy, DynamicFilter dynamicFilter) { - return getSplits(table, session); + return getSplits(table, session, null); } } diff --git a/paimon-trino-368/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java b/paimon-trino-368/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java index 8a57210..a89dbd6 100644 --- a/paimon-trino-368/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java +++ b/paimon-trino-368/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java @@ -35,6 +35,6 @@ public ConnectorSplitSource getSplits( ConnectorTableHandle table, SplitSchedulingStrategy splitSchedulingStrategy, DynamicFilter dynamicFilter) { - return getSplits(table, session); + return getSplits(table, session, null); } } diff --git a/paimon-trino-369/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java b/paimon-trino-369/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java index 8a57210..a89dbd6 100644 --- a/paimon-trino-369/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java +++ b/paimon-trino-369/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java @@ -35,6 +35,6 @@ public ConnectorSplitSource getSplits( ConnectorTableHandle table, SplitSchedulingStrategy splitSchedulingStrategy, DynamicFilter dynamicFilter) { - return getSplits(table, session); + return getSplits(table, session, null); } } diff --git a/paimon-trino-370/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java b/paimon-trino-370/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java index de6c091..09075f2 100644 --- a/paimon-trino-370/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java +++ b/paimon-trino-370/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java @@ -37,6 +37,6 @@ public ConnectorSplitSource getSplits( SplitSchedulingStrategy splitSchedulingStrategy, DynamicFilter dynamicFilter, Constraint constraint) { - return getSplits(table, session); + return getSplits(table, session, constraint); } } diff --git a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoFilterConverter.java b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoFilterConverter.java index a690fca..23dcfa7 100644 --- a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoFilterConverter.java +++ b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoFilterConverter.java @@ -18,52 +18,28 @@ package org.apache.paimon.trino; -import org.apache.paimon.data.BinaryString; -import org.apache.paimon.data.Decimal; -import org.apache.paimon.data.Timestamp; import org.apache.paimon.predicate.Predicate; import org.apache.paimon.predicate.PredicateBuilder; import org.apache.paimon.types.RowType; -import io.airlift.slice.Slice; import io.trino.spi.predicate.Domain; import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.ArrayType; -import io.trino.spi.type.BigintType; -import io.trino.spi.type.BooleanType; -import io.trino.spi.type.CharType; -import io.trino.spi.type.DateType; -import io.trino.spi.type.DecimalType; -import io.trino.spi.type.DoubleType; -import io.trino.spi.type.IntegerType; -import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.MapType; -import io.trino.spi.type.RealType; -import io.trino.spi.type.SmallintType; -import io.trino.spi.type.TinyintType; import io.trino.spi.type.Type; -import io.trino.spi.type.VarbinaryType; -import io.trino.spi.type.VarcharType; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.math.BigDecimal; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; -import static io.trino.spi.type.TimeType.TIME_MILLIS; -import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; -import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; -import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MILLISECOND; -import static java.lang.Float.intBitsToFloat; -import static java.lang.Math.toIntExact; -import static java.util.Objects.requireNonNull; import static org.apache.paimon.predicate.PredicateBuilder.and; import static org.apache.paimon.predicate.PredicateBuilder.or; +import static org.apache.paimon.trino.TrinoTypeUtils.convertTrinoValueToPaimon; /** Trino filter to flink predicate. */ public class TrinoFilterConverter { @@ -159,7 +135,7 @@ private Predicate toPredicate(int columnIndex, Type type, Domain domain) { List predicates = new ArrayList<>(); for (Range range : orderedRanges) { if (range.isSingleValue()) { - values.add(getLiteralValue(type, range.getLowBoundedValue())); + values.add(convertTrinoValueToPaimon(type, range.getLowBoundedValue())); } else { predicates.add(toPredicate(columnIndex, range)); } @@ -182,13 +158,13 @@ private Predicate toPredicate(int columnIndex, Range range) { Type type = range.getType(); if (range.isSingleValue()) { - Object value = getLiteralValue(type, range.getSingleValue()); + Object value = convertTrinoValueToPaimon(type, range.getSingleValue()); return builder.equal(columnIndex, value); } List conjuncts = new ArrayList<>(2); if (!range.isLowUnbounded()) { - Object low = getLiteralValue(type, range.getLowBoundedValue()); + Object low = convertTrinoValueToPaimon(type, range.getLowBoundedValue()); Predicate lowBound; if (range.isLowInclusive()) { lowBound = builder.greaterOrEqual(columnIndex, low); @@ -199,7 +175,7 @@ private Predicate toPredicate(int columnIndex, Range range) { } if (!range.isHighUnbounded()) { - Object high = getLiteralValue(type, range.getHighBoundedValue()); + Object high = convertTrinoValueToPaimon(type, range.getHighBoundedValue()); Predicate highBound; if (range.isHighInclusive()) { highBound = builder.lessOrEqual(columnIndex, high); @@ -211,83 +187,4 @@ private Predicate toPredicate(int columnIndex, Range range) { return and(conjuncts); } - - private Object getLiteralValue(Type type, Object trinoNativeValue) { - requireNonNull(trinoNativeValue, "trinoNativeValue is null"); - - if (type instanceof BooleanType) { - return trinoNativeValue; - } - - if (type instanceof TinyintType) { - return ((Long) trinoNativeValue).byteValue(); - } - - if (type instanceof SmallintType) { - return ((Long) trinoNativeValue).shortValue(); - } - - if (type instanceof IntegerType) { - return toIntExact((long) trinoNativeValue); - } - - if (type instanceof BigintType) { - return trinoNativeValue; - } - - if (type instanceof RealType) { - return intBitsToFloat(toIntExact((long) trinoNativeValue)); - } - - if (type instanceof DoubleType) { - return trinoNativeValue; - } - - if (type instanceof DateType) { - return toIntExact(((Long) trinoNativeValue)); - } - - if (type.equals(TIME_MILLIS)) { - return (int) ((long) trinoNativeValue / PICOSECONDS_PER_MILLISECOND); - } - - if (type.equals(TIMESTAMP_MILLIS)) { - return Timestamp.fromEpochMillis((long) trinoNativeValue / 1000); - } - - if (type.equals(TIMESTAMP_TZ_MILLIS)) { - if (trinoNativeValue instanceof Long) { - return trinoNativeValue; - } - return Timestamp.fromEpochMillis( - ((LongTimestampWithTimeZone) trinoNativeValue).getEpochMillis()); - } - - if (type instanceof VarcharType || type instanceof CharType) { - return BinaryString.fromBytes(((Slice) trinoNativeValue).getBytes()); - } - - if (type instanceof VarbinaryType) { - return ((Slice) trinoNativeValue).getBytes(); - } - - if (type instanceof DecimalType) { - DecimalType decimalType = (DecimalType) type; - BigDecimal bigDecimal; - if (trinoNativeValue instanceof Long) { - bigDecimal = - BigDecimal.valueOf((long) trinoNativeValue) - .movePointLeft(decimalType.getScale()); - } else { - bigDecimal = - new BigDecimal( - DecimalUtils.toBigInteger(trinoNativeValue), - decimalType.getScale()); - } - return Decimal.fromBigDecimal( - bigDecimal, decimalType.getPrecision(), decimalType.getScale()); - } - - throw new UnsupportedOperationException("Unsupported type: " + type); - } } diff --git a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java index e942641..6e0d86a 100644 --- a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java +++ b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManager.java @@ -36,6 +36,6 @@ public ConnectorSplitSource getSplits( ConnectorTableHandle table, DynamicFilter dynamicFilter, Constraint constraint) { - return getSplits(table, session); + return getSplits(table, session, constraint); } } diff --git a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManagerBase.java b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManagerBase.java index 61a26d2..0e33de2 100644 --- a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManagerBase.java +++ b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoSplitManagerBase.java @@ -18,25 +18,43 @@ package org.apache.paimon.trino; +import org.apache.paimon.annotation.VisibleForTesting; +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.shade.guava30.com.google.common.collect.Sets; import org.apache.paimon.table.Table; +import org.apache.paimon.table.source.DataSplit; import org.apache.paimon.table.source.ReadBuilder; import org.apache.paimon.table.source.Split; +import org.apache.paimon.types.RowType; +import org.apache.paimon.utils.RowDataToObjectArrayConverter; +import org.apache.paimon.utils.TypeUtils; +import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorSplitManager; import io.trino.spi.connector.ConnectorSplitSource; import io.trino.spi.connector.ConnectorTableHandle; +import io.trino.spi.connector.Constraint; +import io.trino.spi.predicate.NullableValue; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import java.util.HashMap; +import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; /** Trino {@link ConnectorSplitManager}. */ public abstract class TrinoSplitManagerBase implements ConnectorSplitManager { + private static final Logger LOG = LoggerFactory.getLogger(TrinoSplitManagerBase.class); + protected ConnectorSplitSource getSplits( - ConnectorTableHandle connectorTableHandle, ConnectorSession session) { + ConnectorTableHandle connectorTableHandle, + ConnectorSession session, + Constraint constraint) { // TODO dynamicFilter? - // TODO what is constraint? TrinoTableHandle tableHandle = (TrinoTableHandle) connectorTableHandle; Table table = tableHandle.tableWithDynamicOptions(session); @@ -47,6 +65,14 @@ protected ConnectorSplitSource getSplits( tableHandle.getLimit().ifPresent(limit -> readBuilder.withLimit((int) limit)); List splits = readBuilder.newScan().plan().splits(); + // Filter partition with trino function, suck as length(partition_column) > 10; + RowType partitionType = TypeUtils.project(table.rowType(), table.partitionKeys()); + List partitionColumnHandles = + table.partitionKeys().stream() + .map(tableHandle::columnHandle) + .collect(Collectors.toList()); + splits = filterByPartition(constraint, partitionColumnHandles, partitionType, splits); + long maxRowCount = splits.stream().mapToLong(Split::rowCount).max().orElse(0L); double minimumSplitWeight = TrinoSessionProperties.getMinimumSplitWeight(session); return new TrinoSplitSource( @@ -63,4 +89,54 @@ protected ConnectorSplitSource getSplits( 1.0))) .collect(Collectors.toList())); } + + @VisibleForTesting + static List filterByPartition( + Constraint constraint, + List parititonColumnHandles, + RowType partitionType, + List splits) { + if (!(constraint == null + || parititonColumnHandles.isEmpty() + || constraint.predicate().isEmpty() + || Sets.intersection( + constraint.getPredicateColumns().orElseThrow(), + new HashSet<>(parititonColumnHandles)) + .isEmpty())) { + RowDataToObjectArrayConverter rowDataToObjectArrayConverter = + new RowDataToObjectArrayConverter(partitionType); + return splits.stream() + .filter( + split -> { + if (!(split instanceof DataSplit)) { + return true; + } + BinaryRow partition = ((DataSplit) split).partition(); + Map bindings = new HashMap<>(); + Object[] partitionObject = + rowDataToObjectArrayConverter.convert(partition); + for (int i = 0; i < parititonColumnHandles.size(); i++) { + TrinoColumnHandle trinoColumnHandle = + parititonColumnHandles.get(i); + try { + bindings.put( + trinoColumnHandle, + NullableValue.of( + trinoColumnHandle.getTrinoType(), + TrinoTypeUtils.convertPaimonValueToTrino( + trinoColumnHandle.logicalType(), + partitionObject[i]))); + } catch (UnsupportedOperationException e) { + LOG.warn( + "Unsupported predicate, maybe the type of column is not supported yet.", + e); + return true; + } + } + return constraint.predicate().get().test(bindings); + }) + .collect(Collectors.toList()); + } + return splits; + } } diff --git a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoTypeUtils.java b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoTypeUtils.java index 5c7f8d8..7f906e7 100644 --- a/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoTypeUtils.java +++ b/paimon-trino-common/src/main/java/org/apache/paimon/trino/TrinoTypeUtils.java @@ -18,6 +18,9 @@ package org.apache.paimon.trino; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.Timestamp; import org.apache.paimon.types.ArrayType; import org.apache.paimon.types.BigIntType; import org.apache.paimon.types.BinaryType; @@ -43,8 +46,12 @@ import org.apache.paimon.types.VarBinaryType; import org.apache.paimon.types.VarCharType; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; import io.trino.spi.type.BigintType; +import io.trino.spi.type.Decimals; import io.trino.spi.type.IntegerType; +import io.trino.spi.type.LongTimestampWithTimeZone; import io.trino.spi.type.RealType; import io.trino.spi.type.SmallintType; import io.trino.spi.type.TimestampWithTimeZoneType; @@ -54,11 +61,27 @@ import io.trino.spi.type.VarbinaryType; import io.trino.spi.type.VarcharType; +import java.math.BigDecimal; import java.util.List; import java.util.Optional; import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import static io.trino.spi.type.DateTimeEncoding.packDateTimeWithZone; +import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochMillisAndFraction; +import static io.trino.spi.type.TimeType.TIME_MILLIS; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_MILLISECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MILLISECOND; +import static java.lang.Float.floatToIntBits; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.multiplyExact; +import static java.lang.Math.toIntExact; +import static java.util.Objects.requireNonNull; + /** Trino type from Paimon Type. */ public class TrinoTypeUtils { @@ -70,6 +93,159 @@ public static DataType toPaimonType(Type trinoType) { return TrinoToPaimonTypeVistor.INSTANCE.visit(trinoType); } + public static Object convertTrinoValueToPaimon(Type type, Object trinoValue) { + requireNonNull(trinoValue, "trinoValue is null"); + + if (type instanceof io.trino.spi.type.BooleanType) { + return trinoValue; + } + + if (type instanceof TinyintType) { + return ((Long) trinoValue).byteValue(); + } + + if (type instanceof SmallintType) { + return ((Long) trinoValue).shortValue(); + } + + if (type instanceof IntegerType) { + return toIntExact((long) trinoValue); + } + + if (type instanceof BigintType) { + return trinoValue; + } + + if (type instanceof RealType) { + return intBitsToFloat(toIntExact((long) trinoValue)); + } + + if (type instanceof io.trino.spi.type.DoubleType) { + return trinoValue; + } + + if (type instanceof io.trino.spi.type.DateType) { + return toIntExact(((Long) trinoValue)); + } + + if (type.equals(TIME_MILLIS)) { + return (int) ((long) trinoValue / PICOSECONDS_PER_MILLISECOND); + } + + if (type.equals(TIMESTAMP_MILLIS)) { + return Timestamp.fromEpochMillis((long) trinoValue / 1000); + } + + if (type.equals(TIMESTAMP_TZ_MILLIS)) { + if (trinoValue instanceof Long) { + return trinoValue; + } + return Timestamp.fromEpochMillis( + ((LongTimestampWithTimeZone) trinoValue).getEpochMillis()); + } + + if (type instanceof VarcharType || type instanceof io.trino.spi.type.CharType) { + return BinaryString.fromBytes(((Slice) trinoValue).getBytes()); + } + + if (type instanceof VarbinaryType) { + return ((Slice) trinoValue).getBytes(); + } + + if (type instanceof io.trino.spi.type.DecimalType) { + io.trino.spi.type.DecimalType decimalType = (io.trino.spi.type.DecimalType) type; + BigDecimal bigDecimal; + if (trinoValue instanceof Long) { + bigDecimal = + BigDecimal.valueOf((long) trinoValue).movePointLeft(decimalType.getScale()); + } else { + bigDecimal = + new BigDecimal( + DecimalUtils.toBigInteger(trinoValue), decimalType.getScale()); + } + return Decimal.fromBigDecimal( + bigDecimal, decimalType.getPrecision(), decimalType.getScale()); + } + + throw new UnsupportedOperationException("Unsupported type: " + type); + } + + public static Object convertPaimonValueToTrino(DataType paimonType, Object paimonValue) { + if (paimonValue == null) { + return null; + } + if (paimonType instanceof BooleanType) { + //noinspection RedundantCast + return (boolean) paimonValue; + } + if (paimonType instanceof TinyIntType) { + return ((Number) paimonValue).longValue(); + } + if (paimonType instanceof SmallIntType) { + //noinspection RedundantCast + return ((Number) paimonValue).longValue(); + } + if (paimonType instanceof IntType) { + //noinspection RedundantCast + return ((Number) paimonValue).longValue(); + } + if (paimonType instanceof BigIntType) { + //noinspection RedundantCast + return ((Number) paimonValue).longValue(); + } + if (paimonType instanceof FloatType) { + return (long) floatToIntBits((float) paimonValue); + } + if (paimonType instanceof DoubleType) { + //noinspection RedundantCast + return ((Number) paimonValue).doubleValue(); + } + if (paimonType instanceof DecimalType) { + DecimalType paimonDecimalType = (DecimalType) paimonType; + Decimal decimal = (Decimal) paimonValue; + io.trino.spi.type.DecimalType trinoDecimalType = + io.trino.spi.type.DecimalType.createDecimalType( + paimonDecimalType.getPrecision(), paimonDecimalType.getScale()); + if (trinoDecimalType.isShort()) { + return Decimals.encodeShortScaledValue( + decimal.toBigDecimal(), trinoDecimalType.getScale()); + } + return Decimals.encodeScaledValue(decimal.toBigDecimal(), trinoDecimalType.getScale()); + } + if (paimonType instanceof VarBinaryType) { + return Slices.wrappedBuffer(((byte[]) paimonValue).clone()); + } + if (paimonType instanceof CharType || paimonType instanceof VarCharType) { + return Slices.utf8Slice(((BinaryString) paimonValue).toString()); + } + if (paimonType instanceof DateType) { + //noinspection RedundantCast + return (long) paimonValue; + } + if (paimonType instanceof TimeType) { + return multiplyExact((long) paimonValue, PICOSECONDS_PER_MICROSECOND); + } + if (paimonType instanceof TimestampType) { + TimestampType timestampType = (TimestampType) paimonType; + Timestamp timestamp = (Timestamp) paimonValue; + if (timestampType.getPrecision() == TimestampType.MIN_PRECISION + || timestampType.getPrecision() == TimestampType.DEFAULT_PRECISION) { + return timestamp.getMillisecond() * MICROSECONDS_PER_MILLISECOND; + } + return timestamp.toMicros(); + } + if (paimonType instanceof LocalZonedTimestampType) { + LocalZonedTimestampType timestampType = (LocalZonedTimestampType) paimonType; + Timestamp timestamp = (Timestamp) paimonValue; + if (timestampType.getPrecision() <= 3) { + return packDateTimeWithZone(timestamp.getMillisecond(), UTC_KEY); + } + return fromEpochMillisAndFraction(timestamp.getMillisecond(), 0, UTC_KEY); + } + + throw new UnsupportedOperationException("Unsupported iceberg type: " + paimonType); + } + private static class PaimonToTrinoTypeVistor extends DataTypeDefaultVisitor { private static final PaimonToTrinoTypeVistor INSTANCE = new PaimonToTrinoTypeVistor(); diff --git a/paimon-trino-common/src/test/java/org/apache/paimon/trino/TestTrinoITCase.java b/paimon-trino-common/src/test/java/org/apache/paimon/trino/TestTrinoITCase.java index d70e7ab..5919f24 100644 --- a/paimon-trino-common/src/test/java/org/apache/paimon/trino/TestTrinoITCase.java +++ b/paimon-trino-common/src/test/java/org/apache/paimon/trino/TestTrinoITCase.java @@ -51,6 +51,7 @@ import java.nio.file.Files; import java.time.Instant; +import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.Arrays; @@ -59,6 +60,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.stream.Collectors; import static io.airlift.testing.Closeables.closeAllSuppress; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -191,6 +193,66 @@ protected QueryRunner createQueryRunner() throws Exception { commit.commit(0, writer.prepareCommit(true, 0)); } + { + Path tablePath = new Path(warehouse, "default.db/table_partition_filter"); + RowType rowType = + new RowType( + Arrays.asList( + new DataField(0, "a", DataTypes.BOOLEAN()), + new DataField(1, "b", DataTypes.TINYINT()), + new DataField(2, "c", DataTypes.SMALLINT()), + new DataField(3, "d", DataTypes.INT()), + new DataField(4, "e", DataTypes.BIGINT()), + new DataField(5, "f", DataTypes.FLOAT()), + new DataField(6, "g", DataTypes.DOUBLE()), + new DataField(7, "h", DataTypes.VARCHAR(5)), + new DataField(8, "i", DataTypes.TIMESTAMP(6)), + new DataField( + 9, "j", DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)), + new DataField(10, "k", DataTypes.DECIMAL(10, 5)))); + new SchemaManager(LocalFileIO.create(), tablePath) + .createTable( + new Schema( + rowType.getFields(), + rowType.getFields().stream() + .map(DataField::name) + .collect(Collectors.toList()), + Collections.emptyList(), + new HashMap<>(), + "")); + FileStoreTable table = FileStoreTableFactory.create(LocalFileIO.create(), tablePath); + InnerTableWrite writer = table.newWrite("user"); + InnerTableCommit commit = table.newCommit("user"); + writer.write( + GenericRow.of( + true, + (byte) 1, + (short) 1, + 1, + 1L, + 1.0f, + 1.0d, + BinaryString.fromString("abc"), + Timestamp.fromLocalDateTime(LocalDateTime.of(2023, 1, 1, 0, 0, 0, 0)), + Timestamp.fromLocalDateTime(LocalDateTime.of(2023, 1, 1, 0, 0, 0, 0)), + Decimal.zero(10, 5))); + + writer.write( + GenericRow.of( + false, + (byte) 0, + (short) 0, + 0, + 0L, + 0.0f, + 0.0d, + BinaryString.fromString("abcd"), + Timestamp.fromLocalDateTime(LocalDateTime.of(2022, 1, 1, 0, 0, 0, 0)), + Timestamp.fromLocalDateTime(LocalDateTime.of(2022, 1, 1, 0, 0, 0, 0)), + Decimal.fromUnscaledLong(10000, 10, 5))); + commit.commit(0, writer.prepareCommit(true, 0)); + } + { Path tablePath6 = new Path(warehouse, "default.db/t99"); RowType rowType = @@ -378,6 +440,57 @@ public void testLimitWithPartition() { .isEqualTo("[[1, 1, 2, 2, 2]]"); } + @Test + public void testPartitionFilterWithFunction() { + assertThat( + sql( + "SELECT * FROM paimon.default.table_partition_filter where cast(a AS VARCHAR) = 'true'")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where b + 1 = 2")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where c + 1 = 2")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where d + 1 = 2")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where e + 1 = 2")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where f + 1 = 2")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where g + 1 = 2")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where length(h) = 3")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where year(i) = 2023")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat(sql("SELECT * FROM paimon.default.table_partition_filter where year(j) = 2023")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + + assertThat( + sql( + "SELECT * FROM paimon.default.table_partition_filter where k + 1 = DECIMAL '1.0'")) + .isEqualTo( + "[[true, 1, 1, 1, 1, 1.0, 1.0, abc, 2023-01-01T00:00, 2023-01-01T00:00Z[UTC], 0.00000]]"); + } + @Test public void testShowCreateTable() { assertThat(sql("SHOW CREATE TABLE paimon.default.t3")) @@ -424,7 +537,8 @@ public void testCreateTable() { + "changelog_producer = 'input'" + ")"); assertThat(sql("SHOW TABLES FROM paimon.default")) - .isEqualTo("[[empty_t], [orders], [t1], [t2], [t3], [t4], [t99]]"); + .isEqualTo( + "[[empty_t], [orders], [t1], [t2], [t3], [t4], [t99], [table_partition_filter]]"); sql("DROP TABLE IF EXISTS paimon.default.orders"); } @@ -447,7 +561,8 @@ public void testRenameTable() { + ")"); sql("ALTER TABLE paimon.default.t5 RENAME TO t6"); assertThat(sql("SHOW TABLES FROM paimon.default")) - .isEqualTo("[[empty_t], [t1], [t2], [t3], [t4], [t6], [t99]]"); + .isEqualTo( + "[[empty_t], [t1], [t2], [t3], [t4], [t6], [t99], [table_partition_filter]]"); sql("DROP TABLE IF EXISTS paimon.default.t6"); } @@ -470,7 +585,7 @@ public void testDropTable() { + ")"); sql("DROP TABLE IF EXISTS paimon.default.t5"); assertThat(sql("SHOW TABLES FROM paimon.default")) - .isEqualTo("[[empty_t], [t1], [t2], [t3], [t4], [t99]]"); + .isEqualTo("[[empty_t], [t1], [t2], [t3], [t4], [t99], [table_partition_filter]]"); } @Test diff --git a/paimon-trino-common/src/test/java/org/apache/paimon/trino/TestTrinoSplitManager.java b/paimon-trino-common/src/test/java/org/apache/paimon/trino/TestTrinoSplitManager.java new file mode 100644 index 0000000..fa26600 --- /dev/null +++ b/paimon-trino-common/src/test/java/org/apache/paimon/trino/TestTrinoSplitManager.java @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.paimon.trino; + +import org.apache.paimon.data.BinaryRow; +import org.apache.paimon.data.BinaryRowWriter; +import org.apache.paimon.data.BinaryString; +import org.apache.paimon.data.Decimal; +import org.apache.paimon.data.Timestamp; +import org.apache.paimon.table.source.DataSplit; +import org.apache.paimon.table.source.Split; +import org.apache.paimon.types.DataField; +import org.apache.paimon.types.DataTypes; +import org.apache.paimon.types.RowType; + +import io.airlift.slice.Slices; +import io.trino.spi.connector.Constraint; +import io.trino.spi.predicate.TupleDomain; +import org.testng.annotations.Test; + +import java.time.LocalDateTime; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +import static java.lang.Float.floatToIntBits; +import static org.assertj.core.api.AssertionsForClassTypes.assertThat; + +/** Test for {@link TrinoSplitManager}. */ +public class TestTrinoSplitManager { + + @Test + public void testPartitionFilter() { + RowType partitionType = + new RowType( + Arrays.asList( + new DataField(0, "a", DataTypes.BOOLEAN()), + new DataField(1, "b", DataTypes.TINYINT()), + new DataField(2, "c", DataTypes.SMALLINT()), + new DataField(3, "d", DataTypes.INT()), + new DataField(4, "e", DataTypes.BIGINT()), + new DataField(5, "f", DataTypes.FLOAT()), + new DataField(6, "g", DataTypes.DOUBLE()), + new DataField(7, "h", DataTypes.CHAR(4)), + new DataField(8, "i", DataTypes.TIMESTAMP(6)), + new DataField(9, "j", DataTypes.TIMESTAMP_WITH_LOCAL_TIME_ZONE(3)), + new DataField(10, "k", DataTypes.DECIMAL(10, 5)))); + + BinaryRow partition1 = new BinaryRow(11); + { + BinaryRowWriter writer = new BinaryRowWriter(partition1); + writer.writeBoolean(0, true); + writer.writeByte(1, Byte.MAX_VALUE); + writer.writeShort(2, Short.MAX_VALUE); + writer.writeInt(3, Integer.MAX_VALUE); + writer.writeLong(4, Long.MAX_VALUE); + writer.writeFloat(5, Float.MAX_VALUE); + writer.writeDouble(6, Double.MAX_VALUE); + writer.writeString(7, BinaryString.fromString("char1")); + writer.writeTimestamp(8, Timestamp.fromLocalDateTime(LocalDateTime.MAX), 6); + writer.writeTimestamp(9, Timestamp.fromLocalDateTime(LocalDateTime.MAX), 6); + writer.writeDecimal(10, Decimal.zero(10, 5), 10); + writer.complete(); + } + + BinaryRow partition2 = new BinaryRow(11); + { + BinaryRowWriter writer = new BinaryRowWriter(partition2); + writer.writeBoolean(0, false); + writer.writeByte(1, Byte.MIN_VALUE); + writer.writeShort(2, Short.MIN_VALUE); + writer.writeInt(3, Integer.MIN_VALUE); + writer.writeLong(4, Long.MIN_VALUE); + writer.writeFloat(5, Float.MIN_VALUE); + writer.writeDouble(6, Double.MIN_VALUE); + writer.writeString(7, BinaryString.fromString("char2")); + writer.writeTimestamp(8, Timestamp.fromLocalDateTime(LocalDateTime.MIN), 6); + writer.writeTimestamp(9, Timestamp.fromLocalDateTime(LocalDateTime.MIN), 6); + writer.writeDecimal(10, Decimal.fromUnscaledLong(10000, 10, 5), 10); + writer.complete(); + } + + List splits = + Arrays.asList( + new DataSplit() { + @Override + public BinaryRow partition() { + return partition1; + } + }, + new DataSplit() { + @Override + public BinaryRow partition() { + return partition2; + } + }); + + List trinoColumnHandles = + partitionType.getFields().stream() + .map(dataField -> TrinoColumnHandle.of(dataField.name(), dataField.type())) + .collect(Collectors.toList()); + + BiConsumer> checker = + (trinoColumnHandle, predicate) -> { + Constraint constraint = + new Constraint( + TupleDomain.all(), + columnHandleNullableValueMap -> + predicate.test( + columnHandleNullableValueMap + .get(trinoColumnHandle) + .getValue()), + Set.of(trinoColumnHandle)); + List filteredSplits = + TrinoSplitManager.filterByPartition( + constraint, trinoColumnHandles, partitionType, splits); + assertThat(filteredSplits.size()).isEqualTo(1); + }; + + // test boolean + checker.accept(trinoColumnHandles.get(0), value -> value.equals(true)); + + // test tinyint + checker.accept(trinoColumnHandles.get(1), value -> (long) value == 127); + + // test smallint + checker.accept(trinoColumnHandles.get(2), value -> (long) value > 0); + + // test int + checker.accept(trinoColumnHandles.get(3), value -> (long) value > 0); + + // test bigint + checker.accept(trinoColumnHandles.get(4), value -> (long) value > 0); + + // test float + checker.accept( + trinoColumnHandles.get(5), + value -> (long) value == floatToIntBits(Float.MAX_VALUE)); + + // test double + checker.accept(trinoColumnHandles.get(6), value -> (double) value == Double.MAX_VALUE); + + // test char + checker.accept(trinoColumnHandles.get(7), value -> value.equals(Slices.utf8Slice("char1"))); + + // test timestamp + checker.accept(trinoColumnHandles.get(8), value -> (long) value > 0); + + // test timestamp with local time zone + checker.accept(trinoColumnHandles.get(9), value -> (long) value == 1829587348619264L); + + // test decimal + checker.accept(trinoColumnHandles.get(10), value -> (long) value == 0); + } +}