diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index 5cd446b32..03d4fef2b 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -1,12 +1,9 @@ package io.substrait.isthmus.cli; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Message; import com.google.protobuf.TextFormat; import com.google.protobuf.util.JsonFormat; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.isthmus.FeatureBoard; -import io.substrait.isthmus.ImmutableFeatureBoard; +import io.substrait.isthmus.ConverterProvider; import io.substrait.isthmus.SqlExpressionToSubstrait; import io.substrait.isthmus.SqlToSubstrait; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; @@ -18,6 +15,7 @@ import java.util.concurrent.Callable; import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.sql.parser.SqlParser; import picocli.CommandLine; import picocli.CommandLine.Command; import picocli.CommandLine.Option; @@ -62,6 +60,15 @@ enum OutputFormat { description = "Calcite's casing policy for unquoted identifiers: ${COMPLETION-CANDIDATES}") private Casing unquotedCasing = Casing.TO_UPPER; + private ConverterProvider converterProvider() { + return new ConverterProvider() { + @Override + public SqlParser.Config getSqlParserConfig() { + return super.getSqlParserConfig().withUnquotedCasing(unquotedCasing); + } + }; + } + /** * Standard Java Main method invoked by the isthmus CLI command. * @@ -89,15 +96,13 @@ public static void main(String... args) { @Override public Integer call() throws Exception { - FeatureBoard featureBoard = buildFeatureBoard(); // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpressions != null) { - SqlExpressionToSubstrait converter = - new SqlExpressionToSubstrait(featureBoard, DefaultExtensionCatalog.DEFAULT_COLLECTION); + SqlExpressionToSubstrait converter = new SqlExpressionToSubstrait(converterProvider()); ExtendedExpression extendedExpression = converter.convert(sqlExpressions, createStatements); printMessage(extendedExpression); } else { // by default Isthmus image are parsing SQL Query - SqlToSubstrait converter = new SqlToSubstrait(featureBoard); + SqlToSubstrait converter = new SqlToSubstrait(converterProvider()); Prepare.CatalogReader catalog = SubstraitCreateStatementParser.processCreateStatementsToCatalog( createStatements.toArray(String[]::new)); @@ -116,9 +121,4 @@ private void printMessage(Message message) throws IOException { message.writeTo(System.out); } } - - @VisibleForTesting - FeatureBoard buildFeatureBoard() { - return ImmutableFeatureBoard.builder().unquotedCasing(unquotedCasing).build(); - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/ConverterProvider.java b/isthmus/src/main/java/io/substrait/isthmus/ConverterProvider.java new file mode 100644 index 000000000..ce55dbd3a --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/ConverterProvider.java @@ -0,0 +1,261 @@ +package io.substrait.isthmus; + +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitOperatorTable; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.CallConverters; +import io.substrait.isthmus.expression.ExpressionRexConverter; +import io.substrait.isthmus.expression.FieldSelectionConverter; +import io.substrait.isthmus.expression.RexExpressionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.SqlArrayValueConstructorCallConverter; +import io.substrait.isthmus.expression.SqlMapValueConstructorCallConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.relation.Rel; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; +import org.apache.calcite.sql.validate.SqlConformanceEnum; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.tools.Frameworks; +import org.apache.calcite.tools.RelBuilder; + +/** + * ConverterProvider provides a single-point of configuration for a number of conversions: {@code + * SQl <-> Calcite <-> Substrait} + * + *

It is consumed by all conversion classes as their primary source of configuration. + * + *

The no argument constructor {@link #ConverterProvider()} provides reasonable system defaults. + * + *

Other constructors allow for further customization of conversion behaviours. + * + *

More in-depth customization can be achieved by extending this class, as is done in {@link + * DynamicConverterProvider}. + */ +public class ConverterProvider { + + protected RelDataTypeFactory typeFactory; + + protected final SimpleExtension.ExtensionCollection extensions; + + protected ScalarFunctionConverter scalarFunctionConverter; + protected AggregateFunctionConverter aggregateFunctionConverter; + protected WindowFunctionConverter windowFunctionConverter; + + protected TypeConverter typeConverter; + + public ConverterProvider() { + this(DefaultExtensionCatalog.DEFAULT_COLLECTION, SubstraitTypeSystem.TYPE_FACTORY); + } + + public ConverterProvider(SimpleExtension.ExtensionCollection extensions) { + this(extensions, SubstraitTypeSystem.TYPE_FACTORY); + } + + public ConverterProvider( + SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { + this( + typeFactory, + extensions, + new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory), + new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), + new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), + TypeConverter.DEFAULT); + } + + public ConverterProvider( + RelDataTypeFactory typeFactory, + SimpleExtension.ExtensionCollection extensions, + ScalarFunctionConverter sfc, + AggregateFunctionConverter afc, + WindowFunctionConverter wfc, + TypeConverter tc) { + this.typeFactory = typeFactory; + this.extensions = extensions; + this.scalarFunctionConverter = sfc; + this.aggregateFunctionConverter = afc; + this.windowFunctionConverter = wfc; + this.typeConverter = tc; + } + + // SQL to Calcite Processing + + /** + * {@link SqlParser.Config} is a Calcite class which controls SQL parsing behaviour like + * identifier casing. + */ + public SqlParser.Config getSqlParserConfig() { + return SqlParser.Config.DEFAULT + .withUnquotedCasing(Casing.TO_UPPER) + .withParserFactory(SqlDdlParserImpl.FACTORY) + .withConformance(SqlConformanceEnum.LENIENT); + } + + /** + * {@link CalciteConnectionConfig} is a Calcite class which controls SQL processing behaviour like + * table name case-sensitivity. + */ + public CalciteConnectionConfig getCalciteConnectionConfig() { + return CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); + } + + /** + * {@link SqlToRelConverter.Config} is a Calcite class which controls SQL processing behaviour + * like field-trimming. + */ + public SqlToRelConverter.Config getSqlToRelConverterConfig() { + return SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false); + } + + /** + * {@link SqlOperatorTable} is a Calcite class which stores the {@link + * org.apache.calcite.sql.SqlOperator}s available and controls valid identifiers during SQL + * processing. + */ + public SqlOperatorTable getSqlOperatorTable() { + return SubstraitOperatorTable.INSTANCE; + } + + // Substrait to Calcite Processing + + /** + * {@link SubstraitToCalcite} is an Isthmus class for converting a Substrait {@link Rel} or {@link + * io.substrait.plan.Plan.Root} to a Calcite {@link org.apache.calcite.rel.RelNode} or {@link + * org.apache.calcite.rel.RelRoot} + */ + protected SubstraitToCalcite getSubstraitToCalcite() { + return new SubstraitToCalcite(this); + } + + /** + * {@link SubstraitToCalcite} is an Isthmus class for converting a Substrait {@link Rel} or {@link + * io.substrait.plan.Plan.Root} to a Calcite {@link org.apache.calcite.rel.RelNode} or {@link + * org.apache.calcite.rel.RelRoot} + * + * @param catalogReader a Calcite {@link Prepare.CatalogReader} used to construct a {@link + * RelBuilder} for use in creating Calcite {@link org.apache.calcite.rel.RelNode}s + */ + protected SubstraitToCalcite getSubstraitToCalcite(Prepare.CatalogReader catalogReader) { + return new SubstraitToCalcite(this, catalogReader); + } + + // Calcite to Substrait Processing + + /** + * A {@link SubstraitRelVisitor} converts Calcite {@link org.apache.calcite.rel.RelNode}s to + * Substrait {@link Rel}s + */ + public SubstraitRelVisitor getSubstraitRelVisitor() { + return new SubstraitRelVisitor(this); + } + + /** + * A {@link RexExpressionConverter} converts Calcite {@link org.apache.calcite.rex.RexNode}s to + * Substrait equivalents. + */ + public RexExpressionConverter getRexExpressionConverter(SubstraitRelVisitor srv) { + return new RexExpressionConverter( + srv, getCallConverters(), getWindowFunctionConverter(), getTypeConverter()); + } + + /** + * {@link CallConverter}s are used to convert Calcite {@link org.apache.calcite.rex.RexCall}s to + * Substrait equivalents. + */ + public List getCallConverters() { + ArrayList callConverters = new ArrayList<>(); + callConverters.add(new FieldSelectionConverter(typeConverter)); + callConverters.add(CallConverters.CASE); + callConverters.add(CallConverters.CAST.apply(typeConverter)); + callConverters.add(CallConverters.REINTERPRET.apply(typeConverter)); + callConverters.add(new SqlArrayValueConstructorCallConverter(typeConverter)); + callConverters.add(new SqlMapValueConstructorCallConverter()); + callConverters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); + callConverters.add(scalarFunctionConverter); + return callConverters; + } + + // Substrait To Calcite Processing + + /** + * When converting from Substrait to Calcite, Calcite needs to have a schema available. The + * default strategy uses a {@link SchemaCollector} to generate a {@link CalciteSchema} on the fly + * based on the leaf nodes of a Substrait plan. + * + *

Override to customize the schema generation behaviour + */ + public Function getSchemaResolver() { + SchemaCollector schemaCollector = new SchemaCollector(this); + return schemaCollector::toSchema; + } + + /** + * A {@link SubstraitRelNodeConverter} is used when converting from Substrait {@link Rel}s to + * Calcite {@link org.apache.calcite.rel.RelNode}s. + */ + public SubstraitRelNodeConverter getSubstraitRelNodeConverter(RelBuilder relBuilder) { + return new SubstraitRelNodeConverter(relBuilder, this); + } + + /** + * A {@link ExpressionRexConverter} converts Substrait {@link io.substrait.expression.Expression} + * to Calcite equivalents + */ + public ExpressionRexConverter getExpressionRexConverter( + SubstraitRelNodeConverter relNodeConverter) { + ExpressionRexConverter erc = + new ExpressionRexConverter( + getTypeFactory(), + getScalarFunctionConverter(), + getWindowFunctionConverter(), + getTypeConverter()); + erc.setRelNodeConverter(relNodeConverter); + return erc; + } + + /** + * A {@link RelBuilder} is a Calcite class used as a factory for creating {@link + * org.apache.calcite.rel.RelNode}s. + */ + public RelBuilder getRelBuilder(CalciteSchema schema) { + return RelBuilder.create(Frameworks.newConfigBuilder().defaultSchema(schema.plus()).build()); + } + + // Utility Getters + + public RelDataTypeFactory getTypeFactory() { + return typeFactory; + } + + public SimpleExtension.ExtensionCollection getExtensions() { + return extensions; + } + + public ScalarFunctionConverter getScalarFunctionConverter() { + return scalarFunctionConverter; + } + + public AggregateFunctionConverter getAggregateFunctionConverter() { + return aggregateFunctionConverter; + } + + public WindowFunctionConverter getWindowFunctionConverter() { + return windowFunctionConverter; + } + + public TypeConverter getTypeConverter() { + return typeConverter; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/DynamicConverterProvider.java b/isthmus/src/main/java/io/substrait/isthmus/DynamicConverterProvider.java new file mode 100644 index 000000000..0b6462d18 --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/DynamicConverterProvider.java @@ -0,0 +1,92 @@ +package io.substrait.isthmus; + +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.expression.FunctionMappings; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.util.SqlOperatorTables; + +public class DynamicConverterProvider extends ConverterProvider { + + public DynamicConverterProvider() { + this(DefaultExtensionCatalog.DEFAULT_COLLECTION, SubstraitTypeSystem.TYPE_FACTORY); + } + + public DynamicConverterProvider(SimpleExtension.ExtensionCollection extensions) { + this(extensions, SubstraitTypeSystem.TYPE_FACTORY); + } + + public DynamicConverterProvider( + SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { + super(extensions, typeFactory); + this.scalarFunctionConverter = createScalarFunctionConverter(); + } + + @Override + public List getCallConverters() { + List callConverters = super.getCallConverters(); + + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + List additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); + callConverters.add( + new ScalarFunctionConverter( + extensions.scalarFunctions(), additionalSignatures, typeFactory, typeConverter)); + return callConverters; + } + + @Override + public SqlOperatorTable getSqlOperatorTable() { + SqlOperatorTable operatorTable = super.getSqlOperatorTable(); + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + if (dynamicExtensionCollection.scalarFunctions().isEmpty() + && dynamicExtensionCollection.aggregateFunctions().isEmpty()) { + return operatorTable; + } + List generatedDynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + return SqlOperatorTables.chain(operatorTable, SqlOperatorTables.of(generatedDynamicOperators)); + } + + private ScalarFunctionConverter createScalarFunctionConverter() { + List additionalSignatures = Collections.emptyList(); + + java.util.Set knownFunctionNames = + FunctionMappings.SCALAR_SIGS.stream() + .map(FunctionMappings.Sig::name) + .collect(Collectors.toSet()); + + List dynamicFunctions = + extensions.scalarFunctions().stream() + .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase())) + .collect(Collectors.toList()); + + if (!dynamicFunctions.isEmpty()) { + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + SimpleExtension.ExtensionCollection.builder().scalarFunctions(dynamicFunctions).build(); + + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + + additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); + } + + return new ScalarFunctionConverter( + extensions.scalarFunctions(), additionalSignatures, typeFactory, typeConverter); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java b/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java deleted file mode 100644 index a54f24146..000000000 --- a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java +++ /dev/null @@ -1,35 +0,0 @@ -package io.substrait.isthmus; - -import org.apache.calcite.avatica.util.Casing; -import org.immutables.value.Value; - -/** - * A feature board is a collection of flags that are enabled or configurations that control the - * handling of a request to convert query [batch] to Substrait plans. - */ -@Value.Immutable -public abstract class FeatureBoard { - - /** - * @return Calcite's identifier casing policy for unquoted identifiers. - */ - @Value.Default - public Casing unquotedCasing() { - return Casing.TO_UPPER; - } - - /** - * Controls whether to support dynamic user-defined functions (UDFs) during SQL to Substrait plan - * conversion. - * - *

When enabled, custom functions defined in extension YAML files are available for use in SQL - * queries. These functions will be dynamically converted to SQL operators during plan conversion. - * This feature must be explicitly enabled by users and is disabled by default. - * - * @return true if dynamic UDFs should be supported; false otherwise (default) - */ - @Value.Default - public boolean allowDynamicUdfs() { - return false; - } -} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java index 99eaac1ab..09a3ebee9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java @@ -25,11 +25,17 @@ public class SchemaCollector { private final RelDataTypeFactory typeFactory; private final TypeConverter typeConverter; + @Deprecated public SchemaCollector(RelDataTypeFactory typeFactory, TypeConverter typeConverter) { this.typeFactory = typeFactory; this.typeConverter = typeConverter; } + public SchemaCollector(ConverterProvider converterProvider) { + this.typeFactory = converterProvider.getTypeFactory(); + this.typeConverter = converterProvider.getTypeConverter(); + } + /** * Returns a {@link CalciteSchema} containing all tables and schemas defined in {@link NamedScan}s * and {@link NamedWrite}s within the provided relation operation tree. diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index f667deab0..170dc4f94 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -1,7 +1,5 @@ package io.substrait.isthmus; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.plan.Contexts; @@ -14,12 +12,10 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; -import org.apache.calcite.sql.validate.SqlConformanceEnum; import org.apache.calcite.sql2rel.SqlToRelConverter; public class SqlConverterBase { - protected final SimpleExtension.ExtensionCollection extensionCollection; + protected final ConverterProvider converterProvider; public static final CalciteConnectionConfig CONNECTION_CONFIG = CalciteConnectionConfig.DEFAULT.set( @@ -32,15 +28,11 @@ public class SqlConverterBase { final SqlParser.Config parserConfig; - protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); - final FeatureBoard featureBoard; - - protected SqlConverterBase( - FeatureBoard features, SimpleExtension.ExtensionCollection extensionCollection) { - this.factory = SubstraitTypeSystem.TYPE_FACTORY; - this.config = - CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); - this.converterConfig = SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false); + protected SqlConverterBase(ConverterProvider converterProvider) { + this.converterProvider = converterProvider; + this.factory = converterProvider.getTypeFactory(); + this.config = converterProvider.getCalciteConnectionConfig(); + this.converterConfig = converterProvider.getSqlToRelConverterConfig(); VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.of("hello")); this.relOptCluster = RelOptCluster.create(planner, new RexBuilder(factory)); relOptCluster.setMetadataQuerySupplier( @@ -49,17 +41,6 @@ protected SqlConverterBase( new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE); return new RelMetadataQuery(handler); }); - featureBoard = features == null ? FEATURES_DEFAULT : features; - parserConfig = - SqlParser.Config.DEFAULT - .withUnquotedCasing(featureBoard.unquotedCasing()) - .withParserFactory(SqlDdlParserImpl.FACTORY) - .withConformance(SqlConformanceEnum.LENIENT); - - this.extensionCollection = extensionCollection; - } - - protected SqlConverterBase(FeatureBoard features) { - this(features, DefaultExtensionCatalog.DEFAULT_COLLECTION); + parserConfig = converterProvider.getSqlParserConfig(); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 3d45f8bde..b43b3753f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -3,11 +3,8 @@ import io.substrait.extendedexpression.ExtendedExpression; import io.substrait.extendedexpression.ExtendedExpressionProtoConverter; import io.substrait.extendedexpression.ImmutableExtendedExpression.Builder; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.calcite.SubstraitTable; import io.substrait.isthmus.expression.RexExpressionConverter; -import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; import io.substrait.isthmus.sql.SubstraitSqlValidator; import io.substrait.type.NamedStruct; @@ -35,15 +32,12 @@ public class SqlExpressionToSubstrait extends SqlConverterBase { protected final RexExpressionConverter rexConverter; public SqlExpressionToSubstrait() { - this(FEATURES_DEFAULT, DefaultExtensionCatalog.DEFAULT_COLLECTION); + this(new ConverterProvider()); } - public SqlExpressionToSubstrait( - FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { - super(features, extensions); - ScalarFunctionConverter scalarFunctionConverter = - new ScalarFunctionConverter(extensions.scalarFunctions(), factory); - this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); + public SqlExpressionToSubstrait(ConverterProvider converterProvider) { + super(converterProvider); + this.rexConverter = new RexExpressionConverter(converterProvider.getScalarFunctionConverter()); } private static final class Result { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index e60494244..fd1d8646f 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -1,71 +1,30 @@ package io.substrait.isthmus; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; -import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.ImmutablePlan.Builder; import io.substrait.plan.Plan; import io.substrait.plan.Plan.Version; -import io.substrait.plan.PlanProtoConverter; -import java.util.List; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.sql.SqlDialect; -import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.util.SqlOperatorTables; -/** Take a SQL statement and a set of table definitions and return a substrait plan. */ +/** + * Take a SQL statement and a set of table definitions and return a substrait plan. + * + *

Conversion behaviours can be customized using a {@link ConverterProvider} + */ public class SqlToSubstrait extends SqlConverterBase { private final SqlOperatorTable operatorTable; public SqlToSubstrait() { - this(DefaultExtensionCatalog.DEFAULT_COLLECTION, null); + this(new ConverterProvider()); } - public SqlToSubstrait(FeatureBoard features) { - this(DefaultExtensionCatalog.DEFAULT_COLLECTION, features); - } - - public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - super(features, extensions); - - if (featureBoard.allowDynamicUdfs()) { - SimpleExtension.ExtensionCollection dynamicExtensionCollection = - ExtensionUtils.getDynamicExtensions(extensions); - if (!dynamicExtensionCollection.scalarFunctions().isEmpty() - || !dynamicExtensionCollection.aggregateFunctions().isEmpty()) { - List generatedDynamicOperators = - SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, this.factory); - this.operatorTable = - SqlOperatorTables.chain( - SubstraitOperatorTable.INSTANCE, SqlOperatorTables.of(generatedDynamicOperators)); - return; - } - } - this.operatorTable = SubstraitOperatorTable.INSTANCE; - } - - /** - * Converts one or more SQL statements into a Substrait {@link io.substrait.proto.Plan}. - * - * @param sqlStatements a string containing one more SQL statements - * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in - * the SQL statements - * @return a Substrait proto {@link io.substrait.proto.Plan} - * @throws SqlParseException if there is an error while parsing the SQL statements string - * @deprecated use {@link #convert(String, org.apache.calcite.prepare.Prepare.CatalogReader)} - * instead to get a {@link Plan} and convert that to a {@link io.substrait.proto.Plan} using - * {@link PlanProtoConverter#toProto(Plan)} - */ - @Deprecated - public io.substrait.proto.Plan execute(String sqlStatements, Prepare.CatalogReader catalogReader) - throws SqlParseException { - PlanProtoConverter planToProto = new PlanProtoConverter(); - return planToProto.toProto( - convert(sqlStatements, catalogReader, SqlDialect.DatabaseProduct.CALCITE.getDialect())); + public SqlToSubstrait(ConverterProvider converterProvider) { + super(converterProvider); + this.operatorTable = converterProvider.getSqlOperatorTable(); } /** @@ -84,7 +43,7 @@ public Plan convert(final String sqlStatements, final Prepare.CatalogReader cata // TODO: consider case in which one sql passes conversion while others don't SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader, operatorTable).stream() - .map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard)) + .map(root -> SubstraitRelVisitor.convert(root, converterProvider)) .forEach(root -> builder.addRoots(root)); return builder.build(); @@ -112,7 +71,7 @@ public Plan convert( // TODO: consider case in which one sql passes conversion while others don't SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader, sqlParserConfig).stream() - .map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard)) + .map(root -> SubstraitRelVisitor.convert(root, converterProvider)) .forEach(root -> builder.addRoots(root)); return builder.build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index e1e6b6e21..f07a87a78 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -12,9 +12,7 @@ import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; -import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.ScalarFunctionConverter; -import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.relation.AbstractDdlRel; import io.substrait.relation.AbstractRelVisitor; import io.substrait.relation.AbstractUpdate; @@ -56,7 +54,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; -import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelTraitDef; import org.apache.calcite.prepare.Prepare; @@ -83,7 +80,6 @@ import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; @@ -104,138 +100,37 @@ public class SubstraitRelNodeConverter protected final RexBuilder rexBuilder; private final TypeConverter typeConverter; + /** Use {@link #SubstraitRelNodeConverter(RelBuilder, ConverterProvider)} instead */ + @Deprecated public SubstraitRelNodeConverter( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory, RelBuilder relBuilder) { - this(extensions, typeFactory, relBuilder, ImmutableFeatureBoard.builder().build()); + this(relBuilder, new ConverterProvider(extensions, typeFactory)); } - public SubstraitRelNodeConverter( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - FeatureBoard featureBoard) { - this( - typeFactory, - relBuilder, - createScalarFunctionConverter(extensions, typeFactory, featureBoard.allowDynamicUdfs()), - new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), - new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), - TypeConverter.DEFAULT); - } - - public SubstraitRelNodeConverter( - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter) { - this( - typeFactory, - relBuilder, - scalarFunctionConverter, - aggregateFunctionConverter, - windowFunctionConverter, - typeConverter, - new ExpressionRexConverter( - typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter)); - } - - public SubstraitRelNodeConverter( - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter, - ExpressionRexConverter expressionRexConverter) { - this.typeFactory = typeFactory; - this.typeConverter = typeConverter; + public SubstraitRelNodeConverter(RelBuilder relBuilder, ConverterProvider converterProvider) { + this.typeFactory = converterProvider.getTypeFactory(); + this.typeConverter = converterProvider.getTypeConverter(); this.relBuilder = relBuilder; this.rexBuilder = new RexBuilder(typeFactory); - this.scalarFunctionConverter = scalarFunctionConverter; - this.aggregateFunctionConverter = aggregateFunctionConverter; - this.expressionRexConverter = expressionRexConverter; - this.expressionRexConverter.setRelNodeConverter(this); - } - - private static ScalarFunctionConverter createScalarFunctionConverter( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - boolean allowDynamicUdfs) { - - List additionalSignatures; - - if (allowDynamicUdfs) { - java.util.Set knownFunctionNames = - FunctionMappings.SCALAR_SIGS.stream() - .map(FunctionMappings.Sig::name) - .collect(Collectors.toSet()); - - List dynamicFunctions = - extensions.scalarFunctions().stream() - .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase())) - .collect(Collectors.toList()); - - if (dynamicFunctions.isEmpty()) { - additionalSignatures = Collections.emptyList(); - } else { - SimpleExtension.ExtensionCollection dynamicExtensionCollection = - SimpleExtension.ExtensionCollection.builder().scalarFunctions(dynamicFunctions).build(); - - List dynamicOperators = - SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); - - additionalSignatures = - dynamicOperators.stream() - .map(op -> FunctionMappings.s(op, op.getName())) - .collect(Collectors.toList()); - } - } else { - additionalSignatures = Collections.emptyList(); - } - - return new ScalarFunctionConverter( - extensions.scalarFunctions(), additionalSignatures, typeFactory, TypeConverter.DEFAULT); + this.scalarFunctionConverter = converterProvider.getScalarFunctionConverter(); + this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter(); + this.expressionRexConverter = converterProvider.getExpressionRexConverter(this); } public static RelNode convert( - Rel relRoot, - RelOptCluster relOptCluster, - Prepare.CatalogReader catalogReader, - SqlParser.Config parserConfig, - SimpleExtension.ExtensionCollection extensions) { - return convert( - relRoot, - relOptCluster, - catalogReader, - parserConfig, - extensions, - ImmutableFeatureBoard.builder().build()); - } - - public static RelNode convert( - Rel relRoot, - RelOptCluster relOptCluster, - Prepare.CatalogReader catalogReader, - SqlParser.Config parserConfig, - SimpleExtension.ExtensionCollection extensions, - FeatureBoard featureBoard) { + Rel relRoot, Prepare.CatalogReader catalogReader, ConverterProvider converterProvider) { RelBuilder relBuilder = RelBuilder.create( Frameworks.newConfigBuilder() - .parserConfig(parserConfig) + .parserConfig(converterProvider.getSqlParserConfig()) .defaultSchema(catalogReader.getRootSchema().plus()) .traitDefs((List) null) .programs() .build()); - return relRoot.accept( - new SubstraitRelNodeConverter( - extensions, relOptCluster.getTypeFactory(), relBuilder, featureBoard), - Context.newContext()); + converterProvider.getSubstraitRelNodeConverter(relBuilder), Context.newContext()); } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 26cf37c0c..6bcc5b01a 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -8,12 +8,8 @@ import io.substrait.isthmus.calcite.rel.CreateTable; import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; -import io.substrait.isthmus.expression.CallConverters; -import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.LiteralConverter; import io.substrait.isthmus.expression.RexExpressionConverter; -import io.substrait.isthmus.expression.ScalarFunctionConverter; -import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.plan.Plan; import io.substrait.relation.AbstractDdlRel; import io.substrait.relation.AbstractWriteRel; @@ -58,88 +54,41 @@ import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; +/** + * SubstraitRelVisitor is used to convert Calcite {@link RelNode}s to Substrait {@link Rel}s. + * + *

Conversion behaviours can be customized by using a {@link ConverterProvider} and/or extending + * this class + */ @SuppressWarnings("UnstableApiUsage") @Value.Enclosing public class SubstraitRelVisitor extends RelNodeVisitor { - private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true); protected final RexExpressionConverter rexExpressionConverter; protected final AggregateFunctionConverter aggregateFunctionConverter; protected final TypeConverter typeConverter; - protected final FeatureBoard featureBoard; private Map fieldAccessDepthMap; + /** Use {@link SubstraitRelVisitor#SubstraitRelVisitor(ConverterProvider)} */ + @Deprecated public SubstraitRelVisitor( RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { - this(typeFactory, extensions, FEATURES_DEFAULT); + this(new ConverterProvider(extensions, typeFactory)); } - public SubstraitRelVisitor( - RelDataTypeFactory typeFactory, - SimpleExtension.ExtensionCollection extensions, - FeatureBoard features) { - - this.typeConverter = TypeConverter.DEFAULT; - ArrayList converters = new ArrayList<>(); - converters.addAll(CallConverters.defaults(typeConverter)); - - if (features.allowDynamicUdfs()) { - SimpleExtension.ExtensionCollection dynamicExtensionCollection = - ExtensionUtils.getDynamicExtensions(extensions); - List dynamicOperators = - SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); - - List additionalSignatures = - dynamicOperators.stream() - .map(op -> FunctionMappings.s(op, op.getName())) - .collect(Collectors.toList()); - converters.add( - new ScalarFunctionConverter( - extensions.scalarFunctions(), - additionalSignatures, - typeFactory, - TypeConverter.DEFAULT)); - } else { - converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)); - } - - converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); - this.aggregateFunctionConverter = - new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); - WindowFunctionConverter windowFunctionConverter = - new WindowFunctionConverter(extensions.windowFunctions(), typeFactory); - this.rexExpressionConverter = - new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); - this.featureBoard = features; - } - - public SubstraitRelVisitor( - RelDataTypeFactory typeFactory, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter, - FeatureBoard features) { - ArrayList converters = new ArrayList(); - converters.addAll(CallConverters.defaults(typeConverter)); - converters.add(scalarFunctionConverter); - converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); - this.aggregateFunctionConverter = aggregateFunctionConverter; - this.rexExpressionConverter = - new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); - this.typeConverter = typeConverter; - this.featureBoard = features; + public SubstraitRelVisitor(ConverterProvider converterProvider) { + this.typeConverter = converterProvider.getTypeConverter(); + this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter(); + this.rexExpressionConverter = converterProvider.getRexExpressionConverter(this); } protected Expression toExpression(RexNode node) { @@ -630,38 +579,32 @@ public List apply(List inputs) { } /** - * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} using default features. - * - *

This is a convenience method that delegates to {@link #convert(RelRoot, - * SimpleExtension.ExtensionCollection, FeatureBoard)} using {@link #FEATURES_DEFAULT}. + * Deprecated, use {@link #convert(RelRoot, ConverterProvider)} directly * * @param relRoot The Calcite RelRoot to convert. * @param extensions The extension collection to use for the conversion. * @return The resulting Substrait Plan.Root. */ + @Deprecated public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) { - return convert(relRoot, extensions, FEATURES_DEFAULT); + return convert(relRoot, new ConverterProvider(extensions)); } /** - * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} using a custom visitor. - * - *

This is the main conversion entry point for a complete plan. It applies the provided {@link - * SubstraitRelVisitor} to the final projected {@link RelNode} from the {@code relRoot}, and wraps - * the resulting {@link Rel} in a {@link Plan.Root}. + * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} * - *

This method also correctly extracts the final output field names, paying special attention - * to nested types (structs, maps) via the visitor's type converter, rather than using the names - * from {@code relRoot.validatedRowType} directly. + *

Converts the output of {@link RelRoot#project()} to a Substrait {@link Rel} and wraps it in + * a {@link Plan.Root}. Handles the extraction of final output field names, paying special + * attention to nested types (structs, maps) via the visitor's type converter, rather than using + * the names from {@link RelRoot#validatedRowType} directly. * - * @param relRoot The Calcite RelRoot to convert. This is expected to be a complete, optimized - * plan. - * @param visitor {@link SubstraitRelVisitor} or its subclass. This allows for custom visitor - * behavior. - * @return The resulting Substrait Plan.Root, containing the converted relational tree and the - * output names. + * @param relRoot The Calcite RelRoot to convert. This is expected to be a complete plan. + * @param converterProvider The {@link ConverterProvider} controlling conversion behaviours. + * @return The resulting Substrait {@link Plan.Root}, containing the converted relational tree and + * the output names. */ - public static Plan.Root convert(RelRoot relRoot, SubstraitRelVisitor visitor) { + public static Plan.Root convert(RelRoot relRoot, ConverterProvider converterProvider) { + SubstraitRelVisitor visitor = converterProvider.getSubstraitRelVisitor(); visitor.popFieldAccessDepthMap(relRoot.rel); Rel rel = visitor.apply(relRoot.project()); @@ -672,80 +615,31 @@ public static Plan.Root convert(RelRoot relRoot, SubstraitRelVisitor visitor) { } /** - * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} using the specified - * features. - * - *

This is a convenience method that delegates to {@link #convert(RelRoot, - * SubstraitRelVisitor)} using an instance of the {@link SubstraitRelVisitor} as the visitor. - * - * @param relRoot The Calcite RelRoot to convert. - * @param extensions The extension collection to use for the conversion. - * @param features The feature board specifying enabled Substrait features. - * @return The resulting Substrait Plan.Root. - */ - public static Plan.Root convert( - RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - return convert( - relRoot, - new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features)); - } - - /** - * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} using default features. + * Deprecated, use {@link #convert(RelNode, ConverterProvider)} directly * *

This method is suitable for converting a relational sub-tree, but it does not produce a * {@link Plan.Root}. For a complete plan conversion, use {@link #convert(RelRoot, * SimpleExtension.ExtensionCollection)}. * - *

This is a convenience method that delegates to {@link #convert(RelNode, - * SimpleExtension.ExtensionCollection, FeatureBoard)} using {@link #FEATURES_DEFAULT}. - * * @param relNode The Calcite RelNode (and its subtree) to convert. * @param extensions The extension collection to use for the conversion. * @return The resulting Substrait Rel. */ + @Deprecated public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) { - return convert(relNode, extensions, FEATURES_DEFAULT); + return convert(relNode, new ConverterProvider(extensions)); } /** - * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} using a custom visitor. - * - *

This is the main conversion entry point for a partial plan or a single node (and its - * children). It applies the provided {@link SubstraitRelVisitor} to the given {@code relNode}. + * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} * - *

This method does not wrap the result in a {@link Plan.Root} or extract output names. For - * that, use {@link #convert(RelRoot, SubstraitRelVisitor)}. - * - * @param relNode The Calcite RelNode (and its subtree) to convert. - * @param visitor {@link SubstraitRelVisitor} or its subclass. This allows for custom visitor - * behavior. + * @param relNode The Calcite RelNode to convert. + * @param converterProvider The {@link ConverterProvider} controlling conversion behaviours. * @return The resulting Substrait Rel. */ - public static Rel convert(RelNode relNode, SubstraitRelVisitor visitor) { + public static Rel convert(RelNode relNode, ConverterProvider converterProvider) { + SubstraitRelVisitor visitor = converterProvider.getSubstraitRelVisitor(); visitor.popFieldAccessDepthMap(relNode); return visitor.apply(relNode); } - - /** - * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} using the specified features. - * - *

This method is suitable for converting a relational sub-tree, but it does not produce a - * {@link Plan.Root}. For a complete plan conversion, use {@link #convert(RelRoot, - * SimpleExtension.ExtensionCollection, FeatureBoard)}. - * - *

This is a convenience method that delegates to {@link #convert(RelNode, - * SubstraitRelVisitor)} using an instance of the {@link SubstraitRelVisitor} as the visitor. - * - * @param relNode The Calcite RelNode (and its subtree) to convert. - * @param extensions The extension collection to use for the conversion. - * @param features The feature board specifying enabled Substrait features. - * @return The resulting Substrait Rel. - */ - public static Rel convert( - RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - return convert( - relNode, - new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features)); - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index a0c5132e4..9418ad6bd 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -1,6 +1,5 @@ package io.substrait.isthmus; -import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelNodeConverter.Context; import io.substrait.plan.Plan; import io.substrait.relation.Rel; @@ -16,116 +15,47 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.Pair; /** * Converts between Substrait {@link Rel}s and Calcite {@link RelNode}s. * - *

Can be extended to customize the {@link RelBuilder} and {@link SubstraitRelNodeConverter} used - * in the conversion. + *

Conversion behaviours can be customized using a {@link ConverterProvider} */ public class SubstraitToCalcite { - protected final SimpleExtension.ExtensionCollection extensions; protected final RelDataTypeFactory typeFactory; - protected final TypeConverter typeConverter; protected final Prepare.CatalogReader catalogReader; - protected final FeatureBoard featureBoard; + protected ConverterProvider converterProvider; - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { - this(extensions, typeFactory, TypeConverter.DEFAULT, null); - } - - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - Prepare.CatalogReader catalogReader) { - this(extensions, typeFactory, TypeConverter.DEFAULT, catalogReader); - } - - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { - this(extensions, typeFactory, typeConverter, null); + public SubstraitToCalcite(ConverterProvider converterProvider) { + this(converterProvider, null); } public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter, - Prepare.CatalogReader catalogReader) { - this( - extensions, - typeFactory, - typeConverter, - catalogReader, - ImmutableFeatureBoard.builder().build()); - } - - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter, - Prepare.CatalogReader catalogReader, - FeatureBoard featureBoard) { - this.extensions = extensions; - this.typeFactory = typeFactory; - this.typeConverter = typeConverter; + ConverterProvider converterProvider, Prepare.CatalogReader catalogReader) { + this.converterProvider = converterProvider; + this.typeFactory = converterProvider.getTypeFactory(); this.catalogReader = catalogReader; - this.featureBoard = featureBoard; - } - - /** - * Extracts a {@link CalciteSchema} from a {@link Rel} - * - *

Override this method to customize schema extraction. - */ - protected CalciteSchema toSchema(Rel rel) { - SchemaCollector schemaCollector = new SchemaCollector(typeFactory, typeConverter); - return schemaCollector.toSchema(rel); - } - - /** - * Creates a {@link RelBuilder} from the extracted {@link CalciteSchema} - * - *

Override this method to customize the {@link RelBuilder}. - */ - protected RelBuilder createRelBuilder(CalciteSchema schema) { - return RelBuilder.create(Frameworks.newConfigBuilder().defaultSchema(schema.plus()).build()); - } - - /** - * Creates a {@link SubstraitRelNodeConverter} from the {@link RelBuilder} - * - *

Override this method to customize the {@link SubstraitRelNodeConverter}. - */ - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder, featureBoard); } /** * Converts a Substrait {@link Rel} to a Calcite {@link RelNode} * - *

Generates a {@link CalciteSchema} based on the contents of the {@link Rel}, which will be - * used to construct a {@link RelBuilder} with the required schema information to build {@link - * RelNode}s, and a then a {@link SubstraitRelNodeConverter} to perform the actual conversion. - * * @param rel {@link Rel} to convert * @return {@link RelNode} */ public RelNode convert(Rel rel) { RelBuilder relBuilder; if (catalogReader != null) { - relBuilder = createRelBuilder(catalogReader.getRootSchema()); + relBuilder = converterProvider.getRelBuilder(catalogReader.getRootSchema()); } else { - CalciteSchema rootSchema = toSchema(rel); - relBuilder = createRelBuilder(rootSchema); + CalciteSchema rootSchema = converterProvider.getSchemaResolver().apply(rel); + relBuilder = converterProvider.getRelBuilder(rootSchema); } - SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); + SubstraitRelNodeConverter converter = + converterProvider.getSubstraitRelNodeConverter(relBuilder); return rel.accept(converter, Context.newContext()); } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index db86aeeee..2175fd871 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -11,23 +11,32 @@ import org.apache.calcite.rel.rel2sql.RelToSqlConverter; import org.apache.calcite.sql.SqlDialect; +/** + * SubstraitToSql assists with converting Substrait to SQL + * + *

Conversion behaviours can be customized using a {@link ConverterProvider} + */ public class SubstraitToSql extends SqlConverterBase { protected SubstraitToCalcite substraitToCalcite; public SubstraitToSql() { - super(FEATURES_DEFAULT); + this(new ConverterProvider()); } + /** Deprecated, use {@link #SubstraitToSql(ConverterProvider)} instead */ + @Deprecated public SubstraitToSql(SimpleExtension.ExtensionCollection extensions) { - super(FEATURES_DEFAULT, extensions); + this(new ConverterProvider(extensions)); + } - substraitToCalcite = new SubstraitToCalcite(extensions, factory); + public SubstraitToSql(ConverterProvider converterProvider) { + super(converterProvider); + substraitToCalcite = converterProvider.getSubstraitToCalcite(); } public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) { - return SubstraitRelNodeConverter.convert( - relRoot, relOptCluster, catalog, parserConfig, extensionCollection); + return SubstraitRelNodeConverter.convert(relRoot, catalog, converterProvider); } /** diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java index 8cf4958d8..65d24a0ed 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java @@ -15,7 +15,7 @@ public class SqlMapValueConstructorCallConverter implements CallConverter { - SqlMapValueConstructorCallConverter() {} + public SqlMapValueConstructorCallConverter() {} @Override public Optional convert( diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 9105316ac..3297e3c3d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -6,6 +6,7 @@ import com.google.protobuf.Any; import io.substrait.expression.Expression.UserDefinedLiteral; import io.substrait.expression.ExpressionCreator; +import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.FunctionMappings; @@ -14,7 +15,9 @@ import io.substrait.isthmus.utils.UserTypeFactory; import io.substrait.proto.Expression; import io.substrait.proto.Expression.Literal.Builder; +import io.substrait.relation.ProtoRelConverter; import io.substrait.relation.Rel; +import io.substrait.relation.RelProtoConverter; import io.substrait.type.Type; import io.substrait.type.TypeCreator; import java.io.IOException; @@ -22,7 +25,6 @@ import java.util.List; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunction; @@ -31,7 +33,6 @@ import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.tools.RelBuilder; import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; @@ -51,8 +52,8 @@ class CustomFunctionTest extends PlanTestBase { } // Load custom extension into an ExtensionCollection - static final SimpleExtension.ExtensionCollection extensionCollection = - SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); + static final SimpleExtension.ExtensionCollection CUSTOM_EXTENSIONS = + SimpleExtension.load(URN, FUNCTIONS_CUSTOM); // Create user-defined types static final String aTypeName = "a_type"; @@ -95,22 +96,7 @@ public RelDataType toCalcite(Type.UserDefined type) { static final RelDataType varcharArrayType = new SqlTypeFactoryImpl(RelDataTypeSystem.DEFAULT).createArrayType(varcharType, -1); - // Define additional mapping signatures for the custom scalar functions - final List additionalScalarSignatures = - List.of( - FunctionMappings.s(customScalarFn), - FunctionMappings.s(customScalarAnyFn), - FunctionMappings.s(customScalarAnyToAnyFn), - FunctionMappings.s(customScalarAny1Any1ToAny1Fn), - FunctionMappings.s(customScalarAny1Any2ToAny2Fn), - FunctionMappings.s(customScalarListAnyFn), - FunctionMappings.s(customScalarListAnyAndAnyFn), - FunctionMappings.s(customScalarListStringFn), - FunctionMappings.s(customScalarListStringAndAnyFn), - FunctionMappings.s(customScalarListStringAndAnyVariadic0Fn), - FunctionMappings.s(customScalarListStringAndAnyVariadic1Fn), - FunctionMappings.s(toBType)); - + // Define additional signatures for the custom scalar functions static final SqlFunction customScalarFn = new SqlFunction( "custom_scalar", @@ -189,6 +175,7 @@ public RelDataType toCalcite(Type.UserDefined type) { null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION); + static final SqlFunction customScalarListStringAndAnyVariadic0Fn = new SqlFunction( "custom_scalar_liststring_anyvariadic0_to_liststring", @@ -197,6 +184,7 @@ public RelDataType toCalcite(Type.UserDefined type) { null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION); + static final SqlFunction customScalarListStringAndAnyVariadic1Fn = new SqlFunction( "custom_scalar_liststring_anyvariadic1_to_liststring", @@ -215,9 +203,22 @@ public RelDataType toCalcite(Type.UserDefined type) { null, SqlFunctionCategory.USER_DEFINED_FUNCTION); - // Define additional mapping signatures for the custom aggregate functions - final List additionalAggregateSignatures = - List.of(FunctionMappings.s(customAggregateFn)); + static final List additionalScalarSignatures = + List.of( + FunctionMappings.s(customScalarFn), + FunctionMappings.s(customScalarAnyFn), + FunctionMappings.s(customScalarAnyToAnyFn), + FunctionMappings.s(customScalarAny1Any1ToAny1Fn), + FunctionMappings.s(customScalarAny1Any2ToAny2Fn), + FunctionMappings.s(customScalarListAnyFn), + FunctionMappings.s(customScalarListAnyAndAnyFn), + FunctionMappings.s(customScalarListStringFn), + FunctionMappings.s(customScalarListStringAndAnyFn), + FunctionMappings.s(customScalarListStringAndAnyVariadic0Fn), + FunctionMappings.s(customScalarListStringAndAnyVariadic1Fn), + FunctionMappings.s(toBType)); + + // Define additional signatures for the custom aggregate functions static final SqlAggFunction customAggregateFn = new SqlAggFunction( @@ -228,61 +229,41 @@ public RelDataType toCalcite(Type.UserDefined type) { null, SqlFunctionCategory.USER_DEFINED_FUNCTION) {}; - TypeConverter typeConverter = new TypeConverter(userTypeMapper); + static final List additionalAggregateSignatures = + List.of(FunctionMappings.s(customAggregateFn)); + + static TypeConverter typeConverter = new TypeConverter(userTypeMapper); // Create Function Converters that can handle the custom functions - ScalarFunctionConverter scalarFunctionConverter = + static ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter( - extensionCollection.scalarFunctions(), + CUSTOM_EXTENSIONS.scalarFunctions(), additionalScalarSignatures, - typeFactory, + SubstraitTypeSystem.TYPE_FACTORY, typeConverter); - AggregateFunctionConverter aggregateFunctionConverter = + static AggregateFunctionConverter aggregateFunctionConverter = new AggregateFunctionConverter( - extensionCollection.aggregateFunctions(), + CUSTOM_EXTENSIONS.aggregateFunctions(), additionalAggregateSignatures, - typeFactory, + SubstraitTypeSystem.TYPE_FACTORY, typeConverter); - WindowFunctionConverter windowFunctionConverter = - new WindowFunctionConverter(extensionCollection.windowFunctions(), typeFactory); - - final SubstraitToCalcite substraitToCalcite = - new CustomSubstraitToCalcite(extensionCollection, typeFactory, typeConverter); + static WindowFunctionConverter windowFunctionConverter = + new WindowFunctionConverter( + CUSTOM_EXTENSIONS.windowFunctions(), SubstraitTypeSystem.TYPE_FACTORY); // Create a SubstraitRelVisitor that uses the custom Function Converters - final SubstraitRelVisitor calciteToSubstrait = - new SubstraitRelVisitor( - typeFactory, - scalarFunctionConverter, - aggregateFunctionConverter, - windowFunctionConverter, - typeConverter, - ImmutableFeatureBoard.builder().build()); + final SubstraitRelVisitor calciteToSubstrait = new SubstraitRelVisitor(converterProvider); + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(converterProvider); CustomFunctionTest() { - super(extensionCollection); - } - - // Create a SubstraitToCalcite converter that has access to the custom Function Converters - class CustomSubstraitToCalcite extends SubstraitToCalcite { - - public CustomSubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { - super(extensions, typeFactory, typeConverter); - } - - @Override - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter( - typeFactory, - relBuilder, - scalarFunctionConverter, - aggregateFunctionConverter, - windowFunctionConverter, - typeConverter); - } + super( + new ConverterProvider( + SubstraitTypeSystem.TYPE_FACTORY, + CUSTOM_EXTENSIONS, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter)); } @Test @@ -601,6 +582,11 @@ void customTypesLiteralInFunctionsRoundtrip() { RelNode calciteRel = substraitToCalcite.convert(rel1); Rel rel2 = calciteToSubstrait.apply(calciteRel); assertEquals(rel1, rel2); + + ExtensionCollector extensionCollector = new ExtensionCollector(); + io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); + Rel rel3 = new ProtoRelConverter(extensionCollector, CUSTOM_EXTENSIONS).from(protoRel); + assertEquals(rel1, rel3); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index ca1e60c72..9d3c8135a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -7,7 +7,6 @@ import com.google.common.annotations.Beta; import com.google.common.io.Resources; import io.substrait.dsl.SubstraitBuilder; -import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.ExtensionCollector; import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; @@ -52,6 +51,8 @@ public class PlanTestBase { protected final SubstraitBuilder sb; protected final SubstraitToCalcite substraitToCalcite; + protected ConverterProvider converterProvider; + protected static final CalciteCatalogReader TPCH_CATALOG; static { @@ -69,13 +70,14 @@ public class PlanTestBase { PlanTestBase.schemaToCatalog("tpcds", TPCDS_SCHEMA); protected PlanTestBase() { - this(DefaultExtensionCatalog.DEFAULT_COLLECTION); + this(new ConverterProvider()); } - protected PlanTestBase(SimpleExtension.ExtensionCollection extensions) { - this.extensions = extensions; + protected PlanTestBase(ConverterProvider converterProvider) { + this.converterProvider = converterProvider; + this.extensions = converterProvider.getExtensions(); this.sb = new SubstraitBuilder(extensions); - this.substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + this.substraitToCalcite = new SubstraitToCalcite(converterProvider); } public static String asString(String resource) throws IOException { @@ -139,7 +141,8 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( // Return list of sql -> Substrait rel -> Calcite rel. SqlToSubstrait s2s = new SqlToSubstrait(); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(converterProvider, catalogReader); // 1. SQL -> Substrait Plan Plan plan1 = s2s.convert(query, catalogReader); @@ -151,7 +154,7 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( RelRoot relRoot2 = substraitToCalcite.convert(pojo1); // 4. Calcite RelNode -> Substrait Rel - Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions); + Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, converterProvider); assertEquals(pojo1, pojo2); return relRoot2; @@ -175,22 +178,15 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( * * @param query the SQL query to test * @param catalogReader the Calcite catalog with table definitions - * @param featureBoard optional FeatureBoard to control conversion behavior (e.g., dynamic UDFs). - * If null, a default FeatureBoard is used. */ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( - String query, Prepare.CatalogReader catalogReader, FeatureBoard featureBoard) - throws Exception { - // Use provided FeatureBoard, or create default if null - FeatureBoard features = - featureBoard != null ? featureBoard : ImmutableFeatureBoard.builder().build(); - + String query, Prepare.CatalogReader catalogReader) throws Exception { SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(extensions, typeFactory, TypeConverter.DEFAULT, null, features); - SqlToSubstrait s = new SqlToSubstrait(extensions, features); + new SubstraitToCalcite(converterProvider, catalogReader); + SqlToSubstrait sqlToSubstrait = new SqlToSubstrait(converterProvider); // 1. SQL -> Substrait Plan - Plan plan1 = s.convert(query, catalogReader); + Plan plan1 = sqlToSubstrait.convert(query, catalogReader); // 2. Substrait Plan -> Substrait Root (POJO 1) Plan.Root pojo1 = plan1.getRoots().get(0); @@ -199,7 +195,7 @@ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( RelRoot relRoot2 = substraitToCalcite.convert(pojo1); // 4. Calcite RelNode -> Substrait Root (POJO 2) - Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions, features); + Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, converterProvider); // Note: pojo1 and pojo2 may differ due to different optimization strategies applied by: // - SqlNode->RelRoot conversion during SQL->Substrait conversion @@ -210,23 +206,13 @@ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( RelRoot relRoot3 = substraitToCalcite.convert(pojo2); // 6. Calcite RelNode -> Substrait Root (POJO 3) - Plan.Root pojo3 = SubstraitRelVisitor.convert(relRoot3, extensions, features); + Plan.Root pojo3 = SubstraitRelVisitor.convert(relRoot3, converterProvider); // Verify that subsequent round trips are stable (pojo2 and pojo3 should be identical) assertEquals(pojo2, pojo3); return relRoot2; } - /** - * Convenience overload of {@link #assertSqlSubstraitRelRoundTripLoosePojoComparison(String, - * Prepare.CatalogReader, FeatureBoard)} with default FeatureBoard behavior (no dynamic UDFs). - */ - protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( - String query, Prepare.CatalogReader catalogReader) throws Exception { - return assertSqlSubstraitRelRoundTripLoosePojoComparison( - query, catalogReader, ImmutableFeatureBoard.builder().build()); - } - @Beta protected void assertFullRoundTrip(String query) throws SqlParseException { assertFullRoundTrip(query, TPCH_CATALOG); @@ -274,7 +260,7 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo // Substrait Root 2 -> Calcite 2 final SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(extensions, typeFactory, catalogReader); + new SubstraitToCalcite(converterProvider, catalogReader); RelRoot calcite2 = substraitToCalcite.convert(root2); // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to @@ -332,7 +318,7 @@ protected void assertFullRoundTripWithIdentityProjectionWorkaround( assertEquals(root0, root1); final SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(extensions, typeFactory, catalogReader); + new SubstraitToCalcite(converterProvider, catalogReader); // Substrait POJO 1 -> Calcite 1 RelRoot calcite1 = substraitToCalcite.convert(root1); @@ -380,7 +366,7 @@ protected void assertFullRoundTrip(Rel pojo1) { assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelNode calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + RelNode calcite = new SubstraitToCalcite(converterProvider).convert(pojo2); // Calcite -> Substrait POJO 3 io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, extensions); @@ -411,7 +397,7 @@ protected void assertFullRoundTrip(Plan.Root pojo1) { assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelRoot calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + RelRoot calcite = new SubstraitToCalcite(converterProvider).convert(pojo2); // Calcite -> Substrait POJO 3 io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, extensions); @@ -446,7 +432,7 @@ protected String toSql(Plan plan) { assertEquals(1, roots.size(), "number of roots"); Root root = roots.get(0); - RelRoot relRoot = new SubstraitToCalcite(extensions, typeFactory).convert(root); + RelRoot relRoot = new SubstraitToCalcite(converterProvider).convert(root); RelNode project = relRoot.project(true); return SubstraitSqlDialect.toSql(project).getSql(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index da8423c03..f71a81c1a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -73,7 +73,6 @@ public Optional visit(Cross cross, EmptyVisitationContext context) return super.visit(cross, context); } }; - ImmutableFeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); String query1 = "select\n" @@ -82,7 +81,7 @@ public Optional visit(Cross cross, EmptyVisitationContext context) + "from\n" + " \"customer\" c cross join\n" + " \"orders\" o"; - Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait(featureBoard)); + Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait()); plan1 .getRoots() .forEach( @@ -96,7 +95,7 @@ public Optional visit(Cross cross, EmptyVisitationContext context) + "from\n" + " \"customer\" c,\n" + " \"orders\" o"; - Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait(featureBoard)); + Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait()); plan2 .getRoots() .forEach( diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java index 0cd785379..5ec26b92b 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java @@ -81,7 +81,9 @@ void roundtrip(Rel pojo1) { Context.newContext()); // Calcite -> Substrait POJO 3 - Rel pojo3 = (new CustomSubstraitRelVisitor(typeFactory, extensions)).apply(calcite); + Rel pojo3 = + (new CustomSubstraitRelVisitor(new ConverterProvider(extensions, typeFactory))) + .apply(calcite); assertEquals(pojo1, pojo3); } @@ -246,9 +248,8 @@ public RelNode visit(ExtensionMulti extensionMulti, Context context) throws Runt /** Extends the standard {@link SubstraitRelVisitor} to handle the {@link ColumnAppenderRel} */ static class CustomSubstraitRelVisitor extends SubstraitRelVisitor { - public CustomSubstraitRelVisitor( - RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { - super(typeFactory, extensions); + public CustomSubstraitRelVisitor(ConverterProvider converterProvider) { + super(converterProvider); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java index 69b8be3b9..cbdbe2fa3 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java @@ -12,7 +12,7 @@ class UdfSqlSubstraitTest extends PlanTestBase { private static final String CUSTOM_FUNCTION_PATH = "/extensions/scalar_functions_custom.yaml"; UdfSqlSubstraitTest() { - super(loadExtensions(List.of(CUSTOM_FUNCTION_PATH))); + super(new DynamicConverterProvider(loadExtensions(List.of(CUSTOM_FUNCTION_PATH)))); } @Test @@ -22,16 +22,14 @@ void customUdfTest() throws Exception { SubstraitCreateStatementParser.processCreateStatementsToCatalog( "CREATE TABLE t(x VARCHAR NOT NULL)"); - FeatureBoard featureBoard = ImmutableFeatureBoard.builder().allowDynamicUdfs(true).build(); - assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT regexp_extract_custom(x, 'ab') from t", catalogReader, featureBoard); + "SELECT regexp_extract_custom(x, 'ab') from t", catalogReader); assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT format_text('UPPER', x) FROM t", catalogReader, featureBoard); + "SELECT format_text('UPPER', x) FROM t", catalogReader); assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT system_property_get(x) FROM t", catalogReader, featureBoard); + "SELECT system_property_get(x) FROM t", catalogReader); assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT safe_divide_custom(10,0) FROM t", catalogReader, featureBoard); + "SELECT safe_divide_custom(10,0) FROM t", catalogReader); } private static SimpleExtension.ExtensionCollection loadExtensions(