Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ public ConnectorSplitSource getSplits(
ConnectorTableHandle table,
SplitSchedulingStrategy splitSchedulingStrategy,
DynamicFilter dynamicFilter) {
return getSplits(table, session);
return getSplits(table, session, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ public ConnectorSplitSource getSplits(
ConnectorTableHandle table,
SplitSchedulingStrategy splitSchedulingStrategy,
DynamicFilter dynamicFilter) {
return getSplits(table, session);
return getSplits(table, session, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,6 @@ public ConnectorSplitSource getSplits(
ConnectorTableHandle table,
SplitSchedulingStrategy splitSchedulingStrategy,
DynamicFilter dynamicFilter) {
return getSplits(table, session);
return getSplits(table, session, null);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ public ConnectorSplitSource getSplits(
SplitSchedulingStrategy splitSchedulingStrategy,
DynamicFilter dynamicFilter,
Constraint constraint) {
return getSplits(table, session);
return getSplits(table, session, constraint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,52 +18,28 @@

package org.apache.paimon.trino;

import org.apache.paimon.data.BinaryString;
import org.apache.paimon.data.Decimal;
import org.apache.paimon.data.Timestamp;
import org.apache.paimon.predicate.Predicate;
import org.apache.paimon.predicate.PredicateBuilder;
import org.apache.paimon.types.RowType;

import io.airlift.slice.Slice;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.Range;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.ArrayType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.CharType;
import io.trino.spi.type.DateType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.LongTimestampWithTimeZone;
import io.trino.spi.type.MapType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarbinaryType;
import io.trino.spi.type.VarcharType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;

import static io.trino.spi.type.TimeType.TIME_MILLIS;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MILLISECOND;
import static java.lang.Float.intBitsToFloat;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
import static org.apache.paimon.predicate.PredicateBuilder.and;
import static org.apache.paimon.predicate.PredicateBuilder.or;
import static org.apache.paimon.trino.TrinoTypeUtils.convertTrinoValueToPaimon;

/** Trino filter to flink predicate. */
public class TrinoFilterConverter {
Expand Down Expand Up @@ -159,7 +135,7 @@ private Predicate toPredicate(int columnIndex, Type type, Domain domain) {
List<Predicate> predicates = new ArrayList<>();
for (Range range : orderedRanges) {
if (range.isSingleValue()) {
values.add(getLiteralValue(type, range.getLowBoundedValue()));
values.add(convertTrinoValueToPaimon(type, range.getLowBoundedValue()));
} else {
predicates.add(toPredicate(columnIndex, range));
}
Expand All @@ -182,13 +158,13 @@ private Predicate toPredicate(int columnIndex, Range range) {
Type type = range.getType();

if (range.isSingleValue()) {
Object value = getLiteralValue(type, range.getSingleValue());
Object value = convertTrinoValueToPaimon(type, range.getSingleValue());
return builder.equal(columnIndex, value);
}

List<Predicate> conjuncts = new ArrayList<>(2);
if (!range.isLowUnbounded()) {
Object low = getLiteralValue(type, range.getLowBoundedValue());
Object low = convertTrinoValueToPaimon(type, range.getLowBoundedValue());
Predicate lowBound;
if (range.isLowInclusive()) {
lowBound = builder.greaterOrEqual(columnIndex, low);
Expand All @@ -199,7 +175,7 @@ private Predicate toPredicate(int columnIndex, Range range) {
}

if (!range.isHighUnbounded()) {
Object high = getLiteralValue(type, range.getHighBoundedValue());
Object high = convertTrinoValueToPaimon(type, range.getHighBoundedValue());
Predicate highBound;
if (range.isHighInclusive()) {
highBound = builder.lessOrEqual(columnIndex, high);
Expand All @@ -211,83 +187,4 @@ private Predicate toPredicate(int columnIndex, Range range) {

return and(conjuncts);
}

private Object getLiteralValue(Type type, Object trinoNativeValue) {
requireNonNull(trinoNativeValue, "trinoNativeValue is null");

if (type instanceof BooleanType) {
return trinoNativeValue;
}

if (type instanceof TinyintType) {
return ((Long) trinoNativeValue).byteValue();
}

if (type instanceof SmallintType) {
return ((Long) trinoNativeValue).shortValue();
}

if (type instanceof IntegerType) {
return toIntExact((long) trinoNativeValue);
}

if (type instanceof BigintType) {
return trinoNativeValue;
}

if (type instanceof RealType) {
return intBitsToFloat(toIntExact((long) trinoNativeValue));
}

if (type instanceof DoubleType) {
return trinoNativeValue;
}

if (type instanceof DateType) {
return toIntExact(((Long) trinoNativeValue));
}

if (type.equals(TIME_MILLIS)) {
return (int) ((long) trinoNativeValue / PICOSECONDS_PER_MILLISECOND);
}

if (type.equals(TIMESTAMP_MILLIS)) {
return Timestamp.fromEpochMillis((long) trinoNativeValue / 1000);
}

if (type.equals(TIMESTAMP_TZ_MILLIS)) {
if (trinoNativeValue instanceof Long) {
return trinoNativeValue;
}
return Timestamp.fromEpochMillis(
((LongTimestampWithTimeZone) trinoNativeValue).getEpochMillis());
}

if (type instanceof VarcharType || type instanceof CharType) {
return BinaryString.fromBytes(((Slice) trinoNativeValue).getBytes());
}

if (type instanceof VarbinaryType) {
return ((Slice) trinoNativeValue).getBytes();
}

if (type instanceof DecimalType) {
DecimalType decimalType = (DecimalType) type;
BigDecimal bigDecimal;
if (trinoNativeValue instanceof Long) {
bigDecimal =
BigDecimal.valueOf((long) trinoNativeValue)
.movePointLeft(decimalType.getScale());
} else {
bigDecimal =
new BigDecimal(
DecimalUtils.toBigInteger(trinoNativeValue),
decimalType.getScale());
}
return Decimal.fromBigDecimal(
bigDecimal, decimalType.getPrecision(), decimalType.getScale());
}

throw new UnsupportedOperationException("Unsupported type: " + type);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@ public ConnectorSplitSource getSplits(
ConnectorTableHandle table,
DynamicFilter dynamicFilter,
Constraint constraint) {
return getSplits(table, session);
return getSplits(table, session, constraint);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,43 @@

package org.apache.paimon.trino;

import org.apache.paimon.annotation.VisibleForTesting;
import org.apache.paimon.data.BinaryRow;
import org.apache.paimon.shade.guava30.com.google.common.collect.Sets;
import org.apache.paimon.table.Table;
import org.apache.paimon.table.source.DataSplit;
import org.apache.paimon.table.source.ReadBuilder;
import org.apache.paimon.table.source.Split;
import org.apache.paimon.types.RowType;
import org.apache.paimon.utils.RowDataToObjectArrayConverter;
import org.apache.paimon.utils.TypeUtils;

import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.connector.ConnectorSplitManager;
import io.trino.spi.connector.ConnectorSplitSource;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.Constraint;
import io.trino.spi.predicate.NullableValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/** Trino {@link ConnectorSplitManager}. */
public abstract class TrinoSplitManagerBase implements ConnectorSplitManager {

private static final Logger LOG = LoggerFactory.getLogger(TrinoSplitManagerBase.class);

protected ConnectorSplitSource getSplits(
ConnectorTableHandle connectorTableHandle, ConnectorSession session) {
ConnectorTableHandle connectorTableHandle,
ConnectorSession session,
Constraint constraint) {
// TODO dynamicFilter?
// TODO what is constraint?

TrinoTableHandle tableHandle = (TrinoTableHandle) connectorTableHandle;
Table table = tableHandle.tableWithDynamicOptions(session);
Expand All @@ -47,6 +65,14 @@ protected ConnectorSplitSource getSplits(
tableHandle.getLimit().ifPresent(limit -> readBuilder.withLimit((int) limit));
List<Split> splits = readBuilder.newScan().plan().splits();

// Filter partition with trino function, suck as length(partition_column) > 10;
RowType partitionType = TypeUtils.project(table.rowType(), table.partitionKeys());
List<TrinoColumnHandle> partitionColumnHandles =
table.partitionKeys().stream()
.map(tableHandle::columnHandle)
.collect(Collectors.toList());
splits = filterByPartition(constraint, partitionColumnHandles, partitionType, splits);

long maxRowCount = splits.stream().mapToLong(Split::rowCount).max().orElse(0L);
double minimumSplitWeight = TrinoSessionProperties.getMinimumSplitWeight(session);
return new TrinoSplitSource(
Expand All @@ -63,4 +89,54 @@ protected ConnectorSplitSource getSplits(
1.0)))
.collect(Collectors.toList()));
}

@VisibleForTesting
static List<Split> filterByPartition(
Constraint constraint,
List<TrinoColumnHandle> parititonColumnHandles,
RowType partitionType,
List<Split> splits) {
if (!(constraint == null
|| parititonColumnHandles.isEmpty()
|| constraint.predicate().isEmpty()
|| Sets.intersection(
constraint.getPredicateColumns().orElseThrow(),
new HashSet<>(parititonColumnHandles))
.isEmpty())) {
RowDataToObjectArrayConverter rowDataToObjectArrayConverter =
new RowDataToObjectArrayConverter(partitionType);
return splits.stream()
.filter(
split -> {
if (!(split instanceof DataSplit)) {
return true;
}
BinaryRow partition = ((DataSplit) split).partition();
Map<ColumnHandle, NullableValue> bindings = new HashMap<>();
Object[] partitionObject =
rowDataToObjectArrayConverter.convert(partition);
for (int i = 0; i < parititonColumnHandles.size(); i++) {
TrinoColumnHandle trinoColumnHandle =
parititonColumnHandles.get(i);
try {
bindings.put(
trinoColumnHandle,
NullableValue.of(
trinoColumnHandle.getTrinoType(),
TrinoTypeUtils.convertPaimonValueToTrino(
trinoColumnHandle.logicalType(),
partitionObject[i])));
} catch (UnsupportedOperationException e) {
LOG.warn(
"Unsupported predicate, maybe the type of column is not supported yet.",
e);
return true;
}
}
return constraint.predicate().get().test(bindings);
})
.collect(Collectors.toList());
}
return splits;
}
}
Loading