Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,34 @@ public class FixedBucketTableShuffleFunction implements BucketFunction {
private final int bucketCount;
private final boolean isRowId;
private final ThreadLocal<Projection> projectionContext;
private final TableSchema schema;
private final List<String> bucketKeys; // 🔧 改为通用的 bucketKeys

public FixedBucketTableShuffleFunction(
List<Type> partitionChannelTypes,
TrinoPartitioningHandle partitioningHandle,
int workerCount) {

TableSchema schema = partitioningHandle.getOriginalSchema();
this.projectionContext =
ThreadLocal.withInitial(
() ->
CodeGenUtils.newProjection(
schema.logicalPrimaryKeysType(), schema.primaryKeys()));
this.schema = partitioningHandle.getOriginalSchema();

// 🔧 关键修改:根据是否分区表选择不同的 keys
List<String> partitionKeys = schema.partitionKeys();
if (!partitionKeys.isEmpty()) {
// 分区表:使用 partition keys
this.bucketKeys = partitionKeys;
this.projectionContext =
ThreadLocal.withInitial(
() ->
CodeGenUtils.newProjection(
schema.logicalPartitionType(), bucketKeys));
} else {
// 非分区表:使用 primary keys
this.bucketKeys = schema.primaryKeys();
this.projectionContext =
ThreadLocal.withInitial(
() -> CodeGenUtils.newProjection(schema.logicalRowType(), bucketKeys));
}

this.bucketCount = new CoreOptions(schema.options()).bucket();
this.workerCount = workerCount;
this.isRowId =
Expand All @@ -65,23 +81,59 @@ public FixedBucketTableShuffleFunction(

@Override
public int getBucket(Page page, int position) {
Page processedPage = page;

// 处理 RowBlock 的情况
if (isRowId) {
RowBlock rowBlock = (RowBlock) page.getBlock(0);
try {
Method method = RowBlock.class.getDeclaredMethod("getRawFieldBlocks");
method.setAccessible(true);
page = new Page(rowBlock.getPositionCount(), (Block[]) method.invoke(rowBlock));
} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
} catch (InvocationTargetException e) {
throw new RuntimeException(e);
} catch (IllegalAccessException e) {
throw new RuntimeException(e);
Block[] rawBlocks = (Block[]) method.invoke(rowBlock);
processedPage = new Page(rowBlock.getPositionCount(), rawBlocks);
} catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) {
throw new RuntimeException("Failed to extract raw field blocks from RowBlock", e);
}
}

TrinoRow trinoRow = new TrinoRow(page.getSingleValuePage(position), RowKind.INSERT);
BinaryRow pk = projectionContext.get().apply(trinoRow);
// 🔧 修改验证逻辑:验证 bucketKeys 数量
int expectedBlockCount = bucketKeys.size();
int actualBlockCount = processedPage.getChannelCount();

if (actualBlockCount != expectedBlockCount) {
throw new IllegalStateException(
String.format(
"Page block count mismatch: expected %d (bucket keys), but got %d. "
+ "Bucket keys: %s, Partition keys: %s, Primary keys: %s, Schema fields: %s",
expectedBlockCount,
actualBlockCount,
bucketKeys,
schema.partitionKeys(),
schema.primaryKeys(),
schema.fieldNames()));
}

// 使用 processedPage 创建 TrinoRow
TrinoRow trinoRow =
new TrinoRow(processedPage.getSingleValuePage(position), RowKind.INSERT);

// 🔧 修改错误信息:显示 bucketKeys 相关信息
BinaryRow pk;
try {
pk = projectionContext.get().apply(trinoRow);
} catch (IndexOutOfBoundsException e) {
throw new RuntimeException(
String.format(
"Failed to extract bucket keys from row. "
+ "Row field count: %d, Bucket keys: %s, "
+ "Page block count: %d, Position: %d",
trinoRow.getFieldCount(),
bucketKeys,
processedPage.getChannelCount(),
position),
e);
}

int bucket =
KeyAndBucketExtractor.bucket(
KeyAndBucketExtractor.bucketKeyHashCode(pk), bucketCount);
Expand Down