From 610e8a60c99e66da16b1509770dc722f7f9196ce Mon Sep 17 00:00:00 2001 From: Shriram Chandirasekaran Date: Wed, 22 Apr 2026 14:58:28 +0530 Subject: [PATCH] fix: qualify struct field access to avoid binder ambiguity post-join When a dimension SQL references a struct field like `stage.stage_id`, the aliaser previously skipped it because `column_names.length !== 1`. After joining tables that have a column with the same name as the struct root (e.g. `devusers.stage` alongside `issue.stage`), DuckDB's binder raises an ambiguous column reference error before the query can run. Make the aliaser schema-aware: introspect each table's physical columns via `DESCRIBE` in the browser and node wrappers, then rewrite a multi-part column ref as `..` only when the leading identifier is a known column on the current table. Cross-table references (`customers.id`), already-qualified refs, lambda-bound identifiers, and unknown multi-part refs are left untouched. Falls back to the legacy length-1 behavior when schema information is not supplied, preserving backwards compatibility. --- .../ensure-table-schema-alias.spec.ts | 4 +- .../ensure-table-schema-alias.ts | 1 + ...re-sql-expression-column-alias.fixtures.ts | 35 +++ ...ensure-sql-expression-column-alias.spec.ts | 296 ++++++++++++++++++ .../ensure-sql-expression-column-alias.ts | 162 +++++++--- .../ensure-table-schema-alias-sql.spec.ts | 178 ++++++++--- .../utils/ensure-table-schema-alias-sql.ts | 13 + .../ensure-table-schema-alias.spec.ts | 14 +- .../ensure-table-schema-alias.ts | 1 + 9 files changed, 625 insertions(+), 79 deletions(-) diff --git a/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts b/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts index d00ecaa5..e4368063 100644 --- a/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts +++ b/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts @@ -93,10 +93,10 @@ describe('ensureTableSchemasAlias', () => { expect(ensureColumnAliasBatch).toHaveBeenCalledWith({ items: [ - { + expect.objectContaining({ sql: 'SUM(order_amount)', tableName: 'orders', - }, + }), ], executeQuery: expect.any(Function), }); diff --git a/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.ts b/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.ts index ebc8e0e7..d404716f 100644 --- a/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.ts +++ b/meerkat-browser/src/ensure-table-schema-alias/ensure-table-schema-alias.ts @@ -33,6 +33,7 @@ export const ensureTableSchemasAlias = async ({ items: items.map((item) => ({ sql: item.sql, tableName: item.context.tableName, + knownTableNames: item.context.knownTableNames, })), executeQuery, }); diff --git a/meerkat-core/src/utils/__fixtures__/ensure-sql-expression-column-alias.fixtures.ts b/meerkat-core/src/utils/__fixtures__/ensure-sql-expression-column-alias.fixtures.ts index db452cc1..2f88685b 100644 --- a/meerkat-core/src/utils/__fixtures__/ensure-sql-expression-column-alias.fixtures.ts +++ b/meerkat-core/src/utils/__fixtures__/ensure-sql-expression-column-alias.fixtures.ts @@ -7,6 +7,7 @@ export interface EnsureColumnAliasScenario { expectedSql: string; shouldChange: boolean; notes?: string; + knownTableNames?: string[]; } export interface DeferredEnsureColumnAliasScenario @@ -159,6 +160,40 @@ export const ENSURE_COLUMN_ALIAS_SCENARIOS: EnsureColumnAliasScenario[] = [ expectedSql: 'customers.id', shouldChange: false, }, + { + description: 'struct field access on local column is qualified', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + inputSql: 'stage.stage_id', + expectedSql: 'issue.stage.stage_id', + shouldChange: true, + }, + { + description: + 'struct field access inside aggregate on local column is qualified', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + inputSql: 'COUNT(stage.stage_id)', + expectedSql: 'COUNT(issue.stage.stage_id)', + shouldChange: true, + }, + { + description: + 'multi-part ref where leading identifier is another table stays untouched', + tableName: 'orders', + knownTableNames: ['orders', 'customers'], + inputSql: 'customers.id', + expectedSql: 'customers.id', + shouldChange: false, + }, + { + description: 'already-qualified struct access is not double-qualified', + tableName: 'issue', + knownTableNames: ['issue'], + inputSql: 'issue.stage.stage_id', + expectedSql: 'issue.stage.stage_id', + shouldChange: false, + }, ]; export const DEFERRED_ENSURE_COLUMN_ALIAS_SCENARIOS: DeferredEnsureColumnAliasScenario[] = diff --git a/meerkat-core/src/utils/ensure-sql-expression-column-alias.spec.ts b/meerkat-core/src/utils/ensure-sql-expression-column-alias.spec.ts index de43e4ea..f6d5aa42 100644 --- a/meerkat-core/src/utils/ensure-sql-expression-column-alias.spec.ts +++ b/meerkat-core/src/utils/ensure-sql-expression-column-alias.spec.ts @@ -328,6 +328,67 @@ const expressionAstBySql: Record = { }), "'customer_id'": createStringConstant('customer_id'), '"Order ID"': createColumnRef('Order ID'), + 'stage.stage_id': createColumnRef(['stage', 'stage_id']), + 'issue.stage.stage_id': createColumnRef(['issue', 'stage', 'stage_id']), + 'COUNT(stage.stage_id)': createFunction({ + functionName: 'COUNT', + children: [createColumnRef(['stage', 'stage_id'])], + }), + 'foo.bar': createColumnRef(['foo', 'bar']), + missing_column: createColumnRef('missing_column'), + 'stage.stage_id + amount': createFunction({ + functionName: '+', + children: [ + createColumnRef(['stage', 'stage_id']), + createColumnRef('amount'), + ], + isOperator: true, + }), + 'stage.stage_id = devusers.id': createComparison({ + type: ExpressionType.COMPARE_EQUAL, + left: createColumnRef(['stage', 'stage_id']), + right: createColumnRef(['devusers', 'id']), + }), + 'CASE WHEN stage.stage_id = 1 THEN owner.id END': createCase({ + whenExpr: createComparison({ + type: ExpressionType.COMPARE_EQUAL, + left: createColumnRef(['stage', 'stage_id']), + right: { + class: ExpressionClass.CONSTANT, + type: ExpressionType.VALUE_CONSTANT, + alias: '', + value: { + type: { id: 'INTEGER', type_info: null }, + is_null: false, + value: 1, + }, + } as ParsedExpression, + }), + thenExpr: createColumnRef(['owner', 'id']), + elseExpr: { + class: ExpressionClass.CONSTANT, + type: ExpressionType.VALUE_CONSTANT, + alias: '', + value: { + type: { id: 'NULL', type_info: null }, + is_null: true, + }, + } as ParsedExpression, + }), + 'SUM(stage.stage_id)': createFunction({ + functionName: 'SUM', + children: [createColumnRef(['stage', 'stage_id'])], + }), + 'list_transform(stage.items, x -> x)': createFunction({ + functionName: 'list_transform', + children: [ + createColumnRef(['stage', 'items']), + createLambda({ + lhs: createColumnRef('x'), + expr: createColumnRef('x'), + }), + ], + }), "list_transform(priority_tags, x -> CASE WHEN x = 1 THEN 'P1' ELSE 'Unknown' END)": createFunction({ functionName: 'list_transform', @@ -567,6 +628,241 @@ describe('column refs with quotes and dots are not re-aliased', () => { }); }); +describe('schema-aware struct field aliasing', () => { + const runWithContext = async ({ + sql, + tableName, + knownTableNames, + }: { + sql: string; + tableName: string; + knownTableNames?: string[]; + }) => { + const [result] = await ensureColumnAliasBatch({ + items: [ + { + sql, + tableName, + knownTableNames: knownTableNames + ? new Set(knownTableNames) + : undefined, + }, + ], + executeQuery: dummyGetQueryOutput, + }); + if (!result) { + throw new Error('Missing alias result'); + } + return result.sql; + }; + + it('qualifies struct field access on a local column', async () => { + const result = await runWithContext({ + sql: 'stage.stage_id', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + }); + expect(result).toBe('issue.stage.stage_id'); + }); + + it('qualifies struct access inside an aggregate', async () => { + const result = await runWithContext({ + sql: 'COUNT(stage.stage_id)', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + }); + expect(result).toBe('COUNT(issue.stage.stage_id)'); + }); + + it('leaves cross-table references to a known table alias untouched', async () => { + const result = await runWithContext({ + sql: 'customers.id', + tableName: 'orders', + knownTableNames: ['orders', 'customers'], + }); + expect(result).toBe('customers.id'); + }); + + it('does not double-qualify already-qualified struct access', async () => { + const result = await runWithContext({ + sql: 'issue.stage.stage_id', + tableName: 'issue', + knownTableNames: ['issue'], + }); + expect(result).toBe('issue.stage.stage_id'); + }); + + it('falls back to legacy behavior when knownTableNames is omitted', async () => { + const result = await runWithContext({ + sql: 'customer_id', + tableName: 'orders', + }); + expect(result).toBe('orders.customer_id'); + }); +}); + +describe('schema-aware struct aliasing — extended coverage', () => { + const run = async ({ + sql, + tableName, + knownTableNames, + }: { + sql: string; + tableName: string; + knownTableNames?: string[]; + }) => { + const [result] = await ensureColumnAliasBatch({ + items: [ + { + sql, + tableName, + knownTableNames: knownTableNames + ? new Set(knownTableNames) + : undefined, + }, + ], + executeQuery: dummyGetQueryOutput, + }); + return result?.sql; + }; + + it('qualifies struct access mixed with bare columns in arithmetic', async () => { + const result = await run({ + sql: 'stage.stage_id + amount', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + }); + expect(result).toBe('issue.stage.stage_id + issue.amount'); + }); + + it('qualifies local struct but preserves cross-table ref in comparison', async () => { + const result = await run({ + sql: 'stage.stage_id = devusers.id', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + }); + expect(result).toBe('issue.stage.stage_id = devusers.id'); + }); + + it('qualifies struct access inside CASE branches', async () => { + const result = await run({ + sql: 'CASE WHEN stage.stage_id = 1 THEN owner.id END', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + }); + expect(result).toBe( + 'CASE WHEN issue.stage.stage_id = 1 THEN issue.owner.id END' + ); + }); + + it('qualifies struct access inside SUM aggregate', async () => { + const result = await run({ + sql: 'SUM(stage.stage_id)', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + }); + expect(result).toBe('SUM(issue.stage.stage_id)'); + }); + + it('qualifies struct access inside lambda function argument but not lambda-bound identifier', async () => { + const result = await run({ + sql: 'list_transform(stage.items, x -> x)', + tableName: 'issue', + knownTableNames: ['issue', 'devusers'], + }); + expect(result).toBe('LIST_TRANSFORM(issue.stage.items, x -> x)'); + }); + + it('treats ambiguous multi-part ref as cross-table when knownTableNames contains root', async () => { + const result = await run({ + sql: 'customers.id', + tableName: 'orders', + knownTableNames: ['orders', 'customers'], + }); + expect(result).toBe('customers.id'); + }); + + it('qualifies multi-part ref with unknown root as struct when schema batch is present', async () => { + const result = await run({ + sql: 'foo.bar', + tableName: 'issue', + knownTableNames: ['issue'], + }); + expect(result).toBe('issue.foo.bar'); + }); + + it('stays conservative on multi-part ref when knownTableNames is omitted', async () => { + const result = await run({ + sql: 'foo.bar', + tableName: 'issue', + }); + expect(result).toBe('foo.bar'); + }); + + it('does not double-qualify when root matches table even without knownTableNames', async () => { + const result = await run({ + sql: 'orders.order_amount', + tableName: 'orders', + }); + expect(result).toBe('orders.order_amount'); + }); +}); + +describe('ensureColumnAliasBatch — batched mixed tables', () => { + it('applies per-item knownTableNames correctly within a single batch', async () => { + const results = await ensureColumnAliasBatch({ + items: [ + { + sql: 'stage.stage_id', + tableName: 'issue', + knownTableNames: new Set(['issue', 'devusers']), + }, + { + sql: 'customers.id', + tableName: 'orders', + knownTableNames: new Set(['orders', 'customers']), + }, + { + sql: 'customer_id', + tableName: 'orders', + }, + ], + executeQuery: dummyGetQueryOutput, + }); + + expect(results.map((r) => r.sql)).toEqual([ + 'issue.stage.stage_id', + 'customers.id', + 'orders.customer_id', + ]); + expect(results.map((r) => r.didChange)).toEqual([true, false, true]); + }); + + it('preserves context per-item through batched aliasing', async () => { + const results = await ensureColumnAliasBatch({ + items: [ + { + sql: 'customer_id', + tableName: 'orders', + context: { memberName: 'a' }, + }, + { + sql: 'stage.stage_id', + tableName: 'issue', + knownTableNames: new Set(['issue']), + context: { memberName: 'b' }, + }, + ], + executeQuery: dummyGetQueryOutput, + }); + + expect(results.map((r) => r.context)).toEqual([ + { memberName: 'a' }, + { memberName: 'b' }, + ]); + }); +}); + describe.skip('single-item batch aliasing pending scenarios', () => { for (const scenario of ENSURE_COLUMN_ALIAS_SCENARIOS) { if (expressionAstBySql[scenario.inputSql]) { diff --git a/meerkat-core/src/utils/ensure-sql-expression-column-alias.ts b/meerkat-core/src/utils/ensure-sql-expression-column-alias.ts index 1b005079..c5ead212 100644 --- a/meerkat-core/src/utils/ensure-sql-expression-column-alias.ts +++ b/meerkat-core/src/utils/ensure-sql-expression-column-alias.ts @@ -23,6 +23,12 @@ import { getChildExpressions } from './get-child-expressions'; export interface EnsureColumnAliasBatchItem { sql: string; tableName: string; + /** + * Names of all tables in the current schema batch. Used to distinguish a + * cross-table reference (e.g. `customers.id`) from struct field access + * (e.g. `stage.stage_id`) on multi-part column references. + */ + knownTableNames?: Set; context?: TContext; } @@ -38,29 +44,56 @@ export interface EnsureColumnAliasBatchParams { const ALIASABLE_IDENTIFIER_REGEX = /^[A-Za-z_][A-Za-z0-9_]*$/; +const isAliasableIdentifier = (identifier: string | undefined): boolean => { + if (!identifier) { + return false; + } + if (/\s/.test(identifier)) { + return false; + } + return ALIASABLE_IDENTIFIER_REGEX.test(identifier); +}; + const shouldEnsureColumnRefAlias = ( columnNames: string[], - scopedIdentifiers: Set + scopedIdentifiers: Set, + tableName: string, + knownTableNames?: Set ): boolean => { - if (columnNames.length !== 1) { + if (columnNames.length === 0) { return false; } - const [columnName] = columnNames; + const [root] = columnNames; - if (!columnName) { + if (!isAliasableIdentifier(root)) { return false; } - if (/\s/.test(columnName)) { + if (scopedIdentifiers.has(root)) { return false; } - if (scopedIdentifiers.has(columnName)) { - return false; + if (columnNames.length === 1) { + return true; } - return ALIASABLE_IDENTIFIER_REGEX.test(columnName); + if (columnNames.length === 2) { + if (root === tableName) { + return false; + } + // Without a schema batch, a two-part ref is ambiguous between + // `table.column` and `struct.field`; stay conservative. + if (!knownTableNames) { + return false; + } + if (knownTableNames.has(root)) { + return false; + } + return true; + } + + return false; }; const getLambdaBoundIdentifiers = ( @@ -84,14 +117,16 @@ const getLambdaBoundIdentifiers = ( const ensureOrderByNodesAlias = ( orders: OrderByNode[], scopedIdentifiers: Set, - tableName?: string + tableName?: string, + knownTableNames?: Set ): boolean => { return orders.reduce( (changed, order) => ensureParsedExpressionAlias( order.expression, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) || changed, false ); @@ -124,7 +159,8 @@ const isLimitLikeModifier = ( const ensureQueryNodeAlias = ( node: QueryNode, tableName?: string, - scopedIdentifiers: Set = new Set() + scopedIdentifiers: Set = new Set(), + knownTableNames?: Set ): boolean => { if (isSelectNode(node)) { let changed = false; @@ -134,7 +170,8 @@ const ensureQueryNodeAlias = ( ensureParsedExpressionAlias( expression, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) || changed; }); changed = @@ -142,7 +179,8 @@ const ensureQueryNodeAlias = ( ? ensureParsedExpressionAlias( node.where_clause, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) : false) || changed; node.group_expressions.forEach((expression) => { @@ -150,7 +188,8 @@ const ensureQueryNodeAlias = ( ensureParsedExpressionAlias( expression, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) || changed; }); changed = @@ -158,7 +197,8 @@ const ensureQueryNodeAlias = ( ? ensureParsedExpressionAlias( node.having, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) : false) || changed; changed = @@ -166,7 +206,8 @@ const ensureQueryNodeAlias = ( ? ensureParsedExpressionAlias( node.qualify, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) : false) || changed; @@ -176,7 +217,8 @@ const ensureQueryNodeAlias = ( ensureOrderByNodesAlias( modifier.orders, scopedIdentifiers, - tableName + tableName, + knownTableNames ) || changed; } @@ -187,7 +229,8 @@ const ensureQueryNodeAlias = ( ensureParsedExpressionAlias( target, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) || changed) ); } @@ -198,7 +241,8 @@ const ensureQueryNodeAlias = ( ? ensureParsedExpressionAlias( modifier.limit, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) : false) || changed; changed = @@ -206,7 +250,8 @@ const ensureQueryNodeAlias = ( ? ensureParsedExpressionAlias( modifier.offset, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) : false) || changed; } @@ -218,27 +263,57 @@ const ensureQueryNodeAlias = ( if (node.type === QueryNodeType.SET_OPERATION_NODE) { let changed = false; changed = - ensureQueryNodeAlias(node.left, tableName, scopedIdentifiers) || changed; + ensureQueryNodeAlias( + node.left, + tableName, + scopedIdentifiers, + knownTableNames + ) || changed; changed = - ensureQueryNodeAlias(node.right, tableName, scopedIdentifiers) || changed; + ensureQueryNodeAlias( + node.right, + tableName, + scopedIdentifiers, + knownTableNames + ) || changed; return changed; } if (node.type === QueryNodeType.RECURSIVE_CTE_NODE) { let changed = false; changed = - ensureQueryNodeAlias(node.left, tableName, scopedIdentifiers) || changed; + ensureQueryNodeAlias( + node.left, + tableName, + scopedIdentifiers, + knownTableNames + ) || changed; changed = - ensureQueryNodeAlias(node.right, tableName, scopedIdentifiers) || changed; + ensureQueryNodeAlias( + node.right, + tableName, + scopedIdentifiers, + knownTableNames + ) || changed; return changed; } if (node.type === QueryNodeType.CTE_NODE) { let changed = false; changed = - ensureQueryNodeAlias(node.query, tableName, scopedIdentifiers) || changed; + ensureQueryNodeAlias( + node.query, + tableName, + scopedIdentifiers, + knownTableNames + ) || changed; changed = - ensureQueryNodeAlias(node.child, tableName, scopedIdentifiers) || changed; + ensureQueryNodeAlias( + node.child, + tableName, + scopedIdentifiers, + knownTableNames + ) || changed; return changed; } @@ -248,15 +323,23 @@ const ensureQueryNodeAlias = ( const ensureParsedExpressionAlias = ( node: ParsedExpression, tableName?: string, - scopedIdentifiers: Set = new Set() + scopedIdentifiers: Set = new Set(), + knownTableNames?: Set ): boolean => { if (!node || !tableName) { return false; } if (isColumnRefExpression(node)) { - if (shouldEnsureColumnRefAlias(node.column_names, scopedIdentifiers)) { - node.column_names = [tableName, node.column_names[0]]; + if ( + shouldEnsureColumnRefAlias( + node.column_names, + scopedIdentifiers, + tableName, + knownTableNames + ) + ) { + node.column_names = [tableName, ...node.column_names]; return true; } return false; @@ -272,7 +355,8 @@ const ensureParsedExpressionAlias = ( ? ensureParsedExpressionAlias( node.expr, tableName, - lambdaScopedIdentifiers + lambdaScopedIdentifiers, + knownTableNames ) : false; } @@ -281,14 +365,16 @@ const ensureParsedExpressionAlias = ( let changed = ensureQueryNodeAlias( node.subquery.node, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ); if (node.child) { changed = ensureParsedExpressionAlias( node.child, tableName, - scopedIdentifiers + scopedIdentifiers, + knownTableNames ) || changed; } return changed; @@ -296,8 +382,12 @@ const ensureParsedExpressionAlias = ( return getChildExpressions(node).reduce( (changed, child) => - ensureParsedExpressionAlias(child, tableName, scopedIdentifiers) || - changed, + ensureParsedExpressionAlias( + child, + tableName, + scopedIdentifiers, + knownTableNames + ) || changed, false ); }; @@ -336,7 +426,9 @@ export const ensureColumnAliasBatch = async ({ parsedExpressions.forEach((parsedExpression, index) => { const didChange = ensureParsedExpressionAlias( parsedExpression, - items[index].tableName + items[index].tableName, + new Set(), + items[index].knownTableNames ); if (!didChange) { return; diff --git a/meerkat-core/src/utils/ensure-table-schema-alias-sql.spec.ts b/meerkat-core/src/utils/ensure-table-schema-alias-sql.spec.ts index e0dd2c51..a08c083a 100644 --- a/meerkat-core/src/utils/ensure-table-schema-alias-sql.spec.ts +++ b/meerkat-core/src/utils/ensure-table-schema-alias-sql.spec.ts @@ -88,14 +88,14 @@ describe('ensureTableSchemaAliasSql', () => { expect(ensureExpressionAlias).toHaveBeenCalledTimes(2); expect(ensureExpressionAlias).toHaveBeenCalledWith({ items: expect.arrayContaining([ - { + expect.objectContaining({ sql: 'SUM(order_amount)', - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'gross_amount', memberType: 'measure', - }, - }, + }), + }), ]), }); }); @@ -220,38 +220,38 @@ describe('ensureTableSchemaAliasSql', () => { expect(ensureExpressionAlias).toHaveBeenCalledTimes(2); expect(ensureExpressionAlias).toHaveBeenNthCalledWith(1, { items: [ - { + expect.objectContaining({ sql: 'SUM(order_amount)', - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'gross_amount', memberType: 'measure', - }, - }, - { + }), + }), + expect.objectContaining({ sql: 'SUM(order_amount - discount_amount)', - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'net_amount', memberType: 'measure', - }, - }, - { + }), + }), + expect.objectContaining({ sql: 'customer_id', - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'customer_id', memberType: 'dimension', - }, - }, - { + }), + }), + expect.objectContaining({ sql: "DATE_TRUNC('month', created_at)", - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'order_month', memberType: 'dimension', - }, - }, + }), + }), ], }); expect(result[0].measures.map((measure) => measure.sql)).toEqual([ @@ -284,39 +284,145 @@ describe('ensureTableSchemaAliasSql', () => { expect(ensureExpressionAlias).toHaveBeenCalled(); expect(ensureExpressionAlias).toHaveBeenCalledWith({ items: [ - { + expect.objectContaining({ sql: 'SUM(order_amount)', - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'gross_amount', memberType: 'measure', - }, - }, - { + }), + }), + expect.objectContaining({ sql: 'SUM(order_amount - discount_amount)', - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'net_amount', memberType: 'measure', - }, - }, - { + }), + }), + expect.objectContaining({ sql: 'customer_id', - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'customer_id', memberType: 'dimension', - }, - }, - { + }), + }), + expect.objectContaining({ sql: "DATE_TRUNC('month', created_at)", - context: { + context: expect.objectContaining({ tableName: 'orders', memberName: 'order_month', memberType: 'dimension', - }, - }, + }), + }), ], }); }); + + it('threads knownTableNames through the expression aliaser', async () => { + const tableSchemas = createEnsureTableSchemaAliasSqlFixture(); + const ensureExpressionAlias = jest.fn(async ({ items }) => + items.map(({ sql }: { sql: string }) => sql) + ); + + await ensureTableSchemaAliasSql({ + tableSchemas, + ensureExpressionAlias, + }); + + const [firstCallArgs] = ensureExpressionAlias.mock.calls; + const [{ items: ordersItems }] = firstCallArgs; + expect(ordersItems[0].context.knownTableNames).toEqual( + new Set(['orders', 'customers']) + ); + }); + + it('passes the same knownTableNames to every table in the batch', async () => { + const tableSchemas = createEnsureTableSchemaAliasSqlFixture(); + const ensureExpressionAlias = jest.fn(async ({ items }) => + items.map(({ sql }: { sql: string }) => sql) + ); + + await ensureTableSchemaAliasSql({ + tableSchemas, + ensureExpressionAlias, + }); + + const expected = new Set(['orders', 'customers']); + for (const [call] of ensureExpressionAlias.mock.calls) { + for (const item of call.items) { + expect(item.context.knownTableNames).toEqual(expected); + } + } + }); + + it('uses a knownTableNames set derived from all provided schemas', async () => { + const schemas = [ + ...createEnsureTableSchemaAliasSqlFixture(), + { + name: 'audit', + sql: 'SELECT * FROM audit', + measures: [ + { name: 'total', sql: 'COUNT(id)', type: 'number' as const }, + ], + dimensions: [], + }, + ]; + const ensureExpressionAlias = jest.fn(async ({ items }) => + items.map(({ sql }: { sql: string }) => sql) + ); + + await ensureTableSchemaAliasSql({ + tableSchemas: schemas, + ensureExpressionAlias, + }); + + const expected = new Set(['orders', 'customers', 'audit']); + const lastCall = + ensureExpressionAlias.mock.calls[ + ensureExpressionAlias.mock.calls.length - 1 + ][0]; + expect(lastCall.items[0].context.knownTableNames).toEqual(expected); + }); + + it('does not mutate measures/dimensions when ensureExpressionAlias is a no-op', async () => { + const tableSchemas = createEnsureTableSchemaAliasSqlFixture(); + const before = JSON.parse(JSON.stringify(tableSchemas)); + + const result = await ensureTableSchemaAliasSql({ + tableSchemas, + ensureExpressionAlias: async ({ items }) => + items.map(({ sql }: { sql: string }) => sql), + }); + + expect(tableSchemas).toEqual(before); + expect(result[0].measures[0].sql).toBe( + tableSchemas[0].measures[0].sql + ); + expect(result[0]).not.toBe(tableSchemas[0]); + expect(result[0].measures).not.toBe(tableSchemas[0].measures); + }); + + it('handles schema with no measures or dimensions without calling aliaser', async () => { + const schemas = [ + { + name: 'empty', + sql: 'SELECT 1', + measures: [], + dimensions: [], + }, + ]; + const ensureExpressionAlias = jest.fn(async ({ items }) => + items.map(({ sql }: { sql: string }) => sql) + ); + + const result = await ensureTableSchemaAliasSql({ + tableSchemas: schemas, + ensureExpressionAlias, + }); + + expect(ensureExpressionAlias).not.toHaveBeenCalled(); + expect(result).toEqual(schemas); + }); }); diff --git a/meerkat-core/src/utils/ensure-table-schema-alias-sql.ts b/meerkat-core/src/utils/ensure-table-schema-alias-sql.ts index a7bf8067..b13e28bc 100644 --- a/meerkat-core/src/utils/ensure-table-schema-alias-sql.ts +++ b/meerkat-core/src/utils/ensure-table-schema-alias-sql.ts @@ -7,6 +7,12 @@ export interface EnsureAliasExpressionContext { tableName: string; memberName: string; memberType: MemberType; + /** + * Names of all tables in the current schema batch. Used to treat + * `otherTable.col` as an intentional cross-table reference rather than a + * struct access. + */ + knownTableNames?: Set; } export interface EnsureTableSchemaAliasSqlParams { @@ -29,11 +35,13 @@ const collectAliasableDescriptors = ({ members, memberType, tableName, + knownTableNames, descriptors, }: { members: AliasableMember[]; memberType: MemberType; tableName: string; + knownTableNames?: Set; descriptors: AliasableMemberDescriptor[]; }): void => { members.forEach((member) => { @@ -43,6 +51,7 @@ const collectAliasableDescriptors = ({ tableName, memberName: member.name, memberType, + knownTableNames, }, apply: (aliasedSql: string) => { member.sql = aliasedSql; @@ -93,6 +102,8 @@ export const ensureTableSchemaAliasSql = async ({ tableSchemas, ensureExpressionAlias, }: EnsureTableSchemaAliasSqlParams): Promise => { + const knownTableNames = new Set(tableSchemas.map((s) => s.name)); + return Promise.all( tableSchemas.map(async (tableSchema) => { const aliasedTableSchema: TableSchema = { @@ -108,12 +119,14 @@ export const ensureTableSchemaAliasSql = async ({ members: aliasedTableSchema.measures, memberType: 'measure', tableName: tableSchema.name, + knownTableNames, descriptors, }); collectAliasableDescriptors({ members: aliasedTableSchema.dimensions, memberType: 'dimension', tableName: tableSchema.name, + knownTableNames, descriptors, }); diff --git a/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts b/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts index e362006d..6dfc06dc 100644 --- a/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts +++ b/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.spec.ts @@ -89,16 +89,18 @@ describe('ensureTableSchemasAlias', () => { it('ensures alias for schemas using duckdbExec plumbing', async () => { const result = await ensureTableSchemasAlias(tableSchemas); - expect(ensureTableSchemaAliasSql).toHaveBeenCalledWith({ - tableSchemas, - ensureExpressionAlias: expect.any(Function), - }); + expect(ensureTableSchemaAliasSql).toHaveBeenCalledWith( + expect.objectContaining({ + tableSchemas, + ensureExpressionAlias: expect.any(Function), + }) + ); expect(ensureColumnAliasBatch).toHaveBeenCalledWith({ items: [ - { + expect.objectContaining({ sql: 'SUM(order_amount)', tableName: 'orders', - }, + }), ], executeQuery: expect.any(Function), }); diff --git a/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.ts b/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.ts index 7d13221e..6a44c7cb 100644 --- a/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.ts +++ b/meerkat-node/src/ensure-table-schema-alias/ensure-table-schema-alias.ts @@ -15,6 +15,7 @@ export const ensureTableSchemasAlias = async ( items: items.map((item) => ({ sql: item.sql, tableName: item.context.tableName, + knownTableNames: item.context.knownTableNames, })), executeQuery: (query) => duckdbExec[]>(query),