diff --git a/src/main/java/org/apache/paimon/trino/FixedBucketTableShuffleFunction.java b/src/main/java/org/apache/paimon/trino/FixedBucketTableShuffleFunction.java index 839ad0c..20856bc 100644 --- a/src/main/java/org/apache/paimon/trino/FixedBucketTableShuffleFunction.java +++ b/src/main/java/org/apache/paimon/trino/FixedBucketTableShuffleFunction.java @@ -44,18 +44,34 @@ public class FixedBucketTableShuffleFunction implements BucketFunction { private final int bucketCount; private final boolean isRowId; private final ThreadLocal projectionContext; + private final TableSchema schema; + private final List bucketKeys; // 🔧 改为通用的 bucketKeys public FixedBucketTableShuffleFunction( List partitionChannelTypes, TrinoPartitioningHandle partitioningHandle, int workerCount) { - TableSchema schema = partitioningHandle.getOriginalSchema(); - this.projectionContext = - ThreadLocal.withInitial( - () -> - CodeGenUtils.newProjection( - schema.logicalPrimaryKeysType(), schema.primaryKeys())); + this.schema = partitioningHandle.getOriginalSchema(); + + // 🔧 关键修改:根据是否分区表选择不同的 keys + List partitionKeys = schema.partitionKeys(); + if (!partitionKeys.isEmpty()) { + // 分区表:使用 partition keys + this.bucketKeys = partitionKeys; + this.projectionContext = + ThreadLocal.withInitial( + () -> + CodeGenUtils.newProjection( + schema.logicalPartitionType(), bucketKeys)); + } else { + // 非分区表:使用 primary keys + this.bucketKeys = schema.primaryKeys(); + this.projectionContext = + ThreadLocal.withInitial( + () -> CodeGenUtils.newProjection(schema.logicalRowType(), bucketKeys)); + } + this.bucketCount = new CoreOptions(schema.options()).bucket(); this.workerCount = workerCount; this.isRowId = @@ -65,23 +81,59 @@ public FixedBucketTableShuffleFunction( @Override public int getBucket(Page page, int position) { + Page processedPage = page; + + // 处理 RowBlock 的情况 if (isRowId) { RowBlock rowBlock = (RowBlock) page.getBlock(0); try { Method method = RowBlock.class.getDeclaredMethod("getRawFieldBlocks"); method.setAccessible(true); - page = new Page(rowBlock.getPositionCount(), (Block[]) method.invoke(rowBlock)); - } catch (NoSuchMethodException e) { - throw new RuntimeException(e); - } catch (InvocationTargetException e) { - throw new RuntimeException(e); - } catch (IllegalAccessException e) { - throw new RuntimeException(e); + Block[] rawBlocks = (Block[]) method.invoke(rowBlock); + processedPage = new Page(rowBlock.getPositionCount(), rawBlocks); + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + throw new RuntimeException("Failed to extract raw field blocks from RowBlock", e); } } - TrinoRow trinoRow = new TrinoRow(page.getSingleValuePage(position), RowKind.INSERT); - BinaryRow pk = projectionContext.get().apply(trinoRow); + // 🔧 修改验证逻辑:验证 bucketKeys 数量 + int expectedBlockCount = bucketKeys.size(); + int actualBlockCount = processedPage.getChannelCount(); + + if (actualBlockCount != expectedBlockCount) { + throw new IllegalStateException( + String.format( + "Page block count mismatch: expected %d (bucket keys), but got %d. " + + "Bucket keys: %s, Partition keys: %s, Primary keys: %s, Schema fields: %s", + expectedBlockCount, + actualBlockCount, + bucketKeys, + schema.partitionKeys(), + schema.primaryKeys(), + schema.fieldNames())); + } + + // 使用 processedPage 创建 TrinoRow + TrinoRow trinoRow = + new TrinoRow(processedPage.getSingleValuePage(position), RowKind.INSERT); + + // 🔧 修改错误信息:显示 bucketKeys 相关信息 + BinaryRow pk; + try { + pk = projectionContext.get().apply(trinoRow); + } catch (IndexOutOfBoundsException e) { + throw new RuntimeException( + String.format( + "Failed to extract bucket keys from row. " + + "Row field count: %d, Bucket keys: %s, " + + "Page block count: %d, Position: %d", + trinoRow.getFieldCount(), + bucketKeys, + processedPage.getChannelCount(), + position), + e); + } + int bucket = KeyAndBucketExtractor.bucket( KeyAndBucketExtractor.bucketKeyHashCode(pk), bucketCount);