Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/main/java/io/github/melin/sqlflow/analyzer/Analysis.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import io.github.melin.sqlflow.tree.window.WindowFrame;
import io.github.melin.sqlflow.type.Type;
import com.google.common.collect.*;
import io.github.melin.sqlflow.tree.*;
import io.github.melin.sqlflow.tree.expression.*;

import javax.annotation.Nullable;
import javax.annotation.concurrent.Immutable;
Expand Down Expand Up @@ -431,10 +429,17 @@ public int hashCode() {
public static class SourceColumn {
private final QualifiedObjectName tableName;
private final String columnName;
private List<FunctionContent> functions;

public SourceColumn(QualifiedObjectName tableName, String columnName) {
public SourceColumn(QualifiedObjectName tableName, String columnName, FunctionContent function) {
this.tableName = requireNonNull(tableName, "tableName is null");
this.columnName = requireNonNull(columnName, "columnName is null");
if (function != null) {
if (functions == null) {
functions = new ArrayList<>();
}
this.functions.add(function);
}
}

public QualifiedObjectName getTableName() {
Expand All @@ -445,6 +450,19 @@ public String getColumnName() {
return columnName;
}

public List<FunctionContent> getFunctions(){
return this.functions;
}

public void addFunction(FunctionContent function) {
if (function != null) {
if (functions == null) {
functions = new ArrayList<>();
}
this.functions.add(function);
}
}

@Override
public int hashCode() {
return Objects.hash(tableName, columnName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ public static ExpressionAnalysis analyzeExpression(
analyzer.getSourceFields().forEach(field -> {
if (field.getOriginTable().isPresent() && field.getOriginColumnName().isPresent()) {
Analysis.SourceColumn sourceColumn = new Analysis.SourceColumn(field.getOriginTable().get(),
field.getOriginColumnName().get());
field.getOriginColumnName().get(), null);
analysis.addOriginField(sourceColumn, field.getLocation());
}
});
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package io.github.melin.sqlflow.analyzer;

import io.github.melin.sqlflow.tree.NodeLocation;

import java.util.ArrayList;
import java.util.List;

import static java.util.Objects.requireNonNull;

public class FunctionContent {
private NodeLocation location;
private String name;
private List<String> argumentTextList;

public FunctionContent(String name, List<String> argumentTextList) {
this.argumentTextList = argumentTextList;
this.name = name;
}

public FunctionContent(String name) {
this.name = name;
}

public static FunctionContent newUnqualified(String name) {
requireNonNull(name, "name is null");
return new FunctionContent(name);
}

public static FunctionContent newUnqualified(String name, List<String> arguments) {
requireNonNull(name, "name is null");
return new FunctionContent(name, arguments);
}

public void addArgument(String argumentText) {
if (argumentTextList == null) {
argumentTextList = new ArrayList<>();
}
argumentTextList.add(argumentText);
}

public List<String> getArgumentTextList() {
return argumentTextList;
}

public NodeLocation getLocation() {
return location;
}

public void setLocation(NodeLocation location) {
this.location = location;
}

public String getName() {
return name;
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static com.google.common.collect.ImmutableList.toImmutableList;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.Iterables.getLast;
import static java.lang.Math.exp;
import static java.lang.Math.toIntExact;
import static java.util.Collections.emptyList;
import static java.util.Locale.ENGLISH;
Expand Down Expand Up @@ -268,7 +269,7 @@ private List<Field> analyzeTableOutputFields(Table table, QualifiedObjectName ta
for (String column : schemaTable.getColumns()) {
Field field = Field.newQualified(table.getName(), Optional.of(column), Optional.of(tableName), Optional.of(column), false);
fields.add(field);
analysis.addSourceColumns(field, ImmutableSet.of(new Analysis.SourceColumn(tableName, column)));
analysis.addSourceColumns(field, ImmutableSet.of(new Analysis.SourceColumn(tableName, column, null)));
}
return fields.build();
}
Expand Down Expand Up @@ -345,7 +346,7 @@ private Scope createScopeForView(Table table, QualifiedObjectName name, Optional
Optional.of(column.getName()), Optional.of(name),
Optional.of(column.getName()), false)).collect(toImmutableList());

outputFields.forEach(field -> analysis.addSourceColumns(field, ImmutableSet.of(new Analysis.SourceColumn(name, field.getName().get()))));
outputFields.forEach(field -> analysis.addSourceColumns(field, ImmutableSet.of(new Analysis.SourceColumn(name, field.getName().get(), null))));
return createAndAssignScope(table, scope, outputFields);
}

Expand Down Expand Up @@ -1389,11 +1390,77 @@ private Scope computeAndAssignOutputScope(QuerySpecification node, Optional<Scop
Optional<QualifiedObjectName> originTable = Optional.empty();
Optional<String> originColumn = Optional.empty();
QualifiedName name = null;

FunctionContent function = null; //主要记录Function名称和location以解析Function内容
if (expression instanceof Identifier) {
name = QualifiedName.of(((Identifier) expression).getValue());
} else if (expression instanceof DereferenceExpression) {
name = DereferenceExpression.getQualifiedName((DereferenceExpression) expression);
} else if(expression instanceof FunctionCall) {
final FunctionCall functionCall = (FunctionCall) expression;
final QualifiedName qualifiedName = functionCall.getName();
if (qualifiedName.getOriginalParts() != null && qualifiedName.getOriginalParts().size() > 0) {
function = new FunctionContent(qualifiedName.getOriginalParts().get(0).getValue());
function.setLocation(expression.getLocation().get());
}
final NodeLocation functionLocation = function.getLocation();
functionLocation.setStopIndex(-1); //默认-1,有alias、windows-partitionBy orderBy时候才更新获取到完整的Function位,否则就按照Function括号位置获取
if(functionCall.getWindow().isPresent()){
final Window window = functionCall.getWindow().get();
if (window instanceof WindowSpecification) {
final WindowSpecification windowSpecification = (WindowSpecification) window;
final List<Expression> partitionBy = windowSpecification.getPartitionBy();
for(Expression express : partitionBy) {
if (express.getLocation().isPresent()) {
if (express.getLocation().get().getStopIndex() > functionLocation.getStopIndex()) {
functionLocation.setStopIndex(express.getLocation().get().getStopIndex() + 1);
}
}
}

if (windowSpecification.getOrderBy().isPresent()) {
final OrderBy orderBy = windowSpecification.getOrderBy().get();
if(orderBy.getLocation().isPresent()) {
if (orderBy.getLocation().get().getStopIndex() > functionLocation.getStopIndex()) {
functionLocation.setStopIndex(orderBy.getLocation().get().getStopIndex() + 1);
}
}
}
}
}
if (column.getAlias().isPresent()) {
final Optional<NodeLocation> aliasLocationOption = column.getAlias().get().getLocation();
if (aliasLocationOption.isPresent()) {
final NodeLocation aliasLocation = aliasLocationOption.get();
if (aliasLocation.getStopIndex() > functionLocation.getStopIndex()) {
functionLocation.setStopIndex(aliasLocation.getStopIndex());
}
}
}
} else if(expression instanceof SearchedCaseExpression) {
final SearchedCaseExpression caseExpression = (SearchedCaseExpression) expression;
function = new FunctionContent("case");
function.setLocation(caseExpression.getLocation().get());
//获取case子元素里面最大的stopIndex
final NodeLocation caseLocation = function.getLocation();
final List<? extends Node> children = caseExpression.getChildren();
for(Node child : children) {
final Optional<NodeLocation> childLocation = child.getLocation();
if (childLocation.isPresent()) {
final NodeLocation nodeLocation = childLocation.get();
if (nodeLocation.getStopIndex() > caseLocation.getStopIndex()) {
caseLocation.setStopIndex(nodeLocation.getStopIndex());
}
}
}
if (column.getAlias().isPresent()) {
final Optional<NodeLocation> aliasLocationOption = column.getAlias().get().getLocation();
if (aliasLocationOption.isPresent()) {
final NodeLocation aliasLocation = aliasLocationOption.get();
if (aliasLocation.getStopIndex() > caseLocation.getStopIndex()) {
caseLocation.setStopIndex(aliasLocation.getStopIndex());
}
}
}
}

if (name != null) {
Expand All @@ -1418,8 +1485,15 @@ private Scope computeAndAssignOutputScope(QuerySpecification node, Optional<Scop
// fix join 子查询是union 语句
Set<Analysis.SourceColumn> sourceColumns = analysis.getExpressionSourceColumns(expression);
if (sourceColumns.isEmpty() && originTable.isPresent()) {
analysis.addSourceColumns(newField, ImmutableSet.of(new Analysis.SourceColumn(originTable.get(), originColumn.get())));
analysis.addSourceColumns(newField, ImmutableSet.of(new Analysis.SourceColumn(originTable.get(), originColumn.get(), function)));
} else if (!sourceColumns.isEmpty()) {
if (function != null) {
FunctionContent finalFunctionName = function;
sourceColumns = sourceColumns.stream().map(s -> {
s.addFunction(finalFunctionName);
return s;
}).collect(Collectors.toSet());
}
analysis.addSourceColumns(newField, sourceColumns);
}
outputFields.add(newField);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2081,7 +2081,7 @@ public static NodeLocation getLocation(ParserRuleContext parserRuleContext) {
public static NodeLocation getLocation(Token token) {
requireNonNull(token, "token is null");
return new NodeLocation(token.getLine(), token.getCharPositionInLine() + 1,
token.getStartIndex(), token.getTokenIndex());
token.getStartIndex(), token.getStopIndex());
}

private static ParsingException parseError(String message, ParserRuleContext context) {
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/io/github/melin/sqlflow/tree/NodeLocation.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public final class NodeLocation {
private final int line;
private final int column;
private final int startIndex;
private final int stopIndex;
private int stopIndex;

public NodeLocation(int line, int column, int startIndex, int stopIndex) {
checkArgument(line >= 1, "line must be at least one, got: %s", line);
Expand Down Expand Up @@ -39,6 +39,10 @@ public int getStopIndex() {
return stopIndex;
}

public void setStopIndex(int stopIndex) {
this.stopIndex = stopIndex;
}

@Override
public String toString() {
return "(" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ public void testInsertInto() throws Exception {

//System.out.println(MapperUtils.toJSONString(analysis.getTarget().get()));

assertLineage(analysis, new OutputColumn("NAME", ImmutableSet.of(
new Analysis.SourceColumn(QualifiedObjectName.valueOf("default.test"), "COL1"),
new Analysis.SourceColumn(QualifiedObjectName.valueOf("default.test"), "COL2")
)), new OutputColumn("row_num", ImmutableSet.of(
new Analysis.SourceColumn(QualifiedObjectName.valueOf("default.test"), "row_num")
)));
// assertLineage(analysis, new OutputColumn("NAME", ImmutableSet.of(
// new Analysis.SourceColumn(QualifiedObjectName.valueOf("default.test"), "COL1"),
// new Analysis.SourceColumn(QualifiedObjectName.valueOf("default.test"), "COL2")
// )), new OutputColumn("row_num", ImmutableSet.of(
// new Analysis.SourceColumn(QualifiedObjectName.valueOf("default.test"), "row_num")
// )));
}
}