diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index c8726c387..58a7e123f 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -10,6 +10,7 @@ import io.substrait.isthmus.SqlExpressionToSubstrait; import io.substrait.isthmus.SqlToSubstrait; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; +import io.substrait.plan.PlanProtoConverter; import io.substrait.proto.ExtendedExpression; import io.substrait.proto.Plan; import java.io.IOException; @@ -94,7 +95,7 @@ public Integer call() throws Exception { Prepare.CatalogReader catalog = SubstraitCreateStatementParser.processCreateStatementsToCatalog( createStatements.toArray(String[]::new)); - Plan plan = converter.execute(sql, catalog); + Plan plan = new PlanProtoConverter().toProto(converter.convert(sql, catalog)); printMessage(plan); } return 0; diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index 8a5e19bf8..e901397d1 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -3,9 +3,9 @@ import com.google.common.annotations.VisibleForTesting; import io.substrait.isthmus.sql.SubstraitSqlValidator; import io.substrait.plan.ImmutablePlan.Builder; +import io.substrait.plan.Plan; import io.substrait.plan.Plan.Version; import io.substrait.plan.PlanProtoConverter; -import io.substrait.proto.Plan; import java.util.List; import org.apache.calcite.plan.hep.HepPlanner; import org.apache.calcite.plan.hep.HepProgram; @@ -31,35 +31,51 @@ public SqlToSubstrait(FeatureBoard features) { super(features); } - public Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlValidator validator = new SubstraitSqlValidator(catalogReader); - return executeInner(sql, validator, catalogReader); - } - - List sqlToRelNode(String sql, Prepare.CatalogReader catalogReader) + /** + * Converts a SQL statements string into a Substrait proto {@link io.substrait.proto.Plan}. + * + * @param sql the SQL statements string containing one more SQL statements + * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in + * the SQL statements string + * @return the Substrait proto {@link io.substrait.proto.Plan} + * @throws SqlParseException if there is an error while parsing the SQL statements string + * @deprecated use {@link #convert(String, org.apache.calcite.prepare.Prepare.CatalogReader)} + * instead to get a {@link Plan} and convert that to a {@link io.substrait.proto.Plan} using + * {@link PlanProtoConverter#toProto(Plan)} + */ + @Deprecated + public io.substrait.proto.Plan execute(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { - SqlValidator validator = new SubstraitSqlValidator(catalogReader); - return sqlToRelNode(sql, validator, catalogReader); + PlanProtoConverter planToProto = new PlanProtoConverter(); + + return planToProto.toProto(convert(sql, catalogReader)); } - private Plan executeInner(String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) - throws SqlParseException { + /** + * Converts a SQL statements string into a Substrait {@link Plan}. + * + * @param sql the SQL statements string containing one more SQL statements + * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in + * the SQL statements string + * @return the Substrait {@link Plan} + * @throws SqlParseException if there is an error while parsing the SQL statements string + */ + public Plan convert(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { Builder builder = io.substrait.plan.Plan.builder(); builder.version(Version.builder().from(Version.DEFAULT_VERSION).producer("isthmus").build()); // TODO: consider case in which one sql passes conversion while others don't - sqlToRelNode(sql, validator, catalogReader).stream() + sqlToRelNode(sql, catalogReader).stream() .map(root -> SubstraitRelVisitor.convert(root, EXTENSION_COLLECTION, featureBoard)) .forEach(root -> builder.addRoots(root)); - PlanProtoConverter planToProto = new PlanProtoConverter(); - - return planToProto.toProto(builder.build()); + return builder.build(); } - private List sqlToRelNode( - String sql, SqlValidator validator, Prepare.CatalogReader catalogReader) + @VisibleForTesting + List sqlToRelNode(String sql, Prepare.CatalogReader catalogReader) throws SqlParseException { + SqlValidator validator = new SubstraitSqlValidator(catalogReader); SqlParser parser = SqlParser.create(sql, parserConfig); SqlNodeList parsedList = parser.parseStmtList(); SqlToRelConverter converter = createSqlToRelConverter(validator, catalogReader); @@ -70,8 +86,7 @@ private List sqlToRelNode( return roots; } - @VisibleForTesting - SqlToRelConverter createSqlToRelConverter( + protected SqlToRelConverter createSqlToRelConverter( SqlValidator validator, Prepare.CatalogReader catalogReader) { SqlToRelConverter converter = new SqlToRelConverter( @@ -84,8 +99,7 @@ SqlToRelConverter createSqlToRelConverter( return converter; } - @VisibleForTesting - static RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) { + protected RelRoot getBestExpRelRoot(SqlToRelConverter converter, SqlNode parsed) { RelRoot root = converter.convertQuery(parsed, true, true); { // RelBuilder seems to implicitly use the rule below, diff --git a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java index 8fb704977..7804c2b46 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ApplyJoinPlanTest.java @@ -62,7 +62,7 @@ public void lateralJoinQuery() throws SqlParseException { SqlToSubstrait sE2E = new SqlToSubstrait(); Assertions.assertThrows( UnsupportedOperationException.class, - () -> sE2E.execute(sql, TPCDS_CATALOG), + () -> sE2E.convert(sql, TPCDS_CATALOG), "Lateral join is not supported"); } @@ -83,7 +83,7 @@ public void outerApplyQuery() throws SqlParseException { // TODO validate end to end conversion Assertions.assertThrows( UnsupportedOperationException.class, - () -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG), + () -> new SqlToSubstrait().convert(sql, TPCDS_CATALOG), "APPLY is not supported"); } @@ -123,7 +123,7 @@ public void nestedApplyJoinQuery() throws SqlParseException { // TODO validate end to end conversion Assertions.assertThrows( UnsupportedOperationException.class, - () -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG), + () -> new SqlToSubstrait().convert(sql, TPCDS_CATALOG), "APPLY is not supported"); } @@ -138,7 +138,7 @@ public void crossApplyQuery() throws SqlParseException { // TODO validate end to end conversion Assertions.assertThrows( UnsupportedOperationException.class, - () -> new SqlToSubstrait().execute(sql, TPCDS_CATALOG), + () -> new SqlToSubstrait().convert(sql, TPCDS_CATALOG), "APPLY is not supported"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java b/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java index 7c43502d9..ff00123bb 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/KeyConstraintsTest.java @@ -14,6 +14,6 @@ public void tpcds(int query) throws Exception { String values = asString("keyconstraints_schema.sql"); Prepare.CatalogReader catalog = SubstraitCreateStatementParser.processCreateStatementsToCatalog(values); - s.execute(asString(String.format("tpcds/queries/%02d.sql", query)), catalog); + s.convert(asString(String.format("tpcds/queries/%02d.sql", query)), catalog); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index ee89f608c..6eb89cbdf 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -25,14 +25,10 @@ void preserveNamesFromSql() throws Exception { String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; List expectedNames = List.of("a", "B"); - List calciteRelRoots = s.sqlToRelNode(query, catalogReader); - assertEquals(1, calciteRelRoots.size()); + Plan plan = s.convert(query, catalogReader); + assertEquals(1, plan.getRoots().size()); - org.apache.calcite.rel.RelRoot calciteRelRoot1 = calciteRelRoots.get(0); - assertEquals(expectedNames, calciteRelRoot1.validatedRowType.getFieldNames()); - - io.substrait.plan.Plan.Root substraitRelRoot = - SubstraitRelVisitor.convert(calciteRelRoot1, EXTENSION_COLLECTION); + io.substrait.plan.Plan.Root substraitRelRoot = plan.getRoots().get(0); assertEquals(expectedNames, substraitRelRoot.getNames()); org.apache.calcite.rel.RelRoot calciteRelRoot2 = substraitToCalcite.convert(substraitRelRoot); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java index e645f29cb..05aa5944f 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedStructQueryTest.java @@ -60,7 +60,7 @@ private void test(Table table, String query, String expectedExpressionText) final Schema schema = new SubstraitSchema(Map.of("my_table", table)); final CalciteCatalogReader catalog = schemaToCatalog("nested", schema); final SqlToSubstrait sqlToSubstrait = new SqlToSubstrait(); - Plan plan = sqlToSubstrait.execute(query, catalog); + Plan plan = toProto(sqlToSubstrait.convert(query, catalog)); Expression obtainedExpression = plan.getRelations(0).getRoot().getInput().getProject().getExpressions(0); Expression expectedExpression = TextFormat.parse(expectedExpressionText, Expression.class); diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index 27ae7eea9..8bac5d854 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -87,24 +87,26 @@ protected Plan assertProtoPlanRoundrip(String query, SqlToSubstrait s, String cr protected Plan assertProtoPlanRoundrip( String query, SqlToSubstrait s, Prepare.CatalogReader catalogReader) throws SqlParseException { - io.substrait.proto.Plan protoPlan1 = s.execute(query, catalogReader); - Plan plan = new ProtoPlanConverter(extensions).from(protoPlan1); - io.substrait.proto.Plan protoPlan2 = new PlanProtoConverter().toProto(plan); + Plan plan1 = s.convert(query, catalogReader); + io.substrait.proto.Plan protoPlan1 = toProto(plan1); + + Plan plan2 = new ProtoPlanConverter(extensions).from(protoPlan1); + io.substrait.proto.Plan protoPlan2 = toProto(plan2); assertEquals(protoPlan1, protoPlan2); - List rootRels = s.sqlToRelNode(query, catalogReader); - assertEquals(rootRels.size(), plan.getRoots().size()); - for (int i = 0; i < rootRels.size(); i++) { - Plan.Root rootRel = SubstraitRelVisitor.convert(rootRels.get(i), extensions); + + assertEquals(plan1.getRoots().size(), plan2.getRoots().size()); + for (int i = 0; i < plan1.getRoots().size(); i++) { assertEquals( - rootRel.getInput().getRecordType(), plan.getRoots().get(i).getInput().getRecordType()); + plan1.getRoots().get(i).getInput().getRecordType(), + plan2.getRoots().get(i).getInput().getRecordType()); } - return plan; + + return plan2; } protected void assertPlanRoundtrip(Plan plan) { - io.substrait.proto.Plan protoPlan1 = new PlanProtoConverter().toProto(plan); - io.substrait.proto.Plan protoPlan2 = - new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan1)); + io.substrait.proto.Plan protoPlan1 = toProto(plan); + io.substrait.proto.Plan protoPlan2 = toProto(new ProtoPlanConverter().from(protoPlan1)); assertEquals(protoPlan1, protoPlan2); } @@ -129,13 +131,11 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( SqlToSubstrait s = new SqlToSubstrait(); - // 1. SQL -> Calcite RelRoot - List relRoots = s.sqlToRelNode(query, catalogReader); - assertEquals(1, relRoots.size()); - RelRoot relRoot1 = relRoots.get(0); + // 1. SQL -> Substrait Plan + Plan plan1 = s.convert(query, catalogReader); - // 2. Calcite RelRoot -> Substrait Rel - Plan.Root pojo1 = SubstraitRelVisitor.convert(relRoot1, extensions); + // 2. Substrait Plan -> Substrait Rel + Plan.Root pojo1 = plan1.getRoots().get(0); // 3. Substrait Rel -> Calcite RelNode RelRoot relRoot2 = substraitToCalcite.convert(pojo1); @@ -178,37 +178,36 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo SqlToSubstrait sqlConverter = new SqlToSubstrait(); ExtensionCollector extensionCollector = new ExtensionCollector(); - // SQL -> Calcite 1 - List relRoots = sqlConverter.sqlToRelNode(sqlQuery, catalogReader); - assertEquals(1, relRoots.size()); - RelRoot calcite1 = relRoots.get(0); + // SQL -> Substrait Plan 1 + Plan plan1 = sqlConverter.convert(sqlQuery, catalogReader); + assertEquals(1, plan1.getRoots().size()); - // Calcite 1 -> Substrait POJO 1 - Plan.Root pojo1 = SubstraitRelVisitor.convert(calcite1, extensions); + // Substrait Plan 1 -> Substrait Root 1 + Plan.Root root1 = plan1.getRoots().get(0); - // Substrait POJO 1 -> Substrait Proto - io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(pojo1); + // Substrait Root 1 -> Substrait Proto + io.substrait.proto.RelRoot proto = new RelProtoConverter(extensionCollector).toProto(root1); - // Substrait Proto -> Substrait Pojo 2 - Plan.Root pojo2 = new ProtoRelConverter(extensionCollector, extensions).from(proto); + // Substrait Proto -> Substrait Root 2 + Plan.Root root2 = new ProtoRelConverter(extensionCollector, extensions).from(proto); - // Verify that POJOs are the same - assertEquals(pojo1, pojo2); + // Verify that roots are the same + assertEquals(root1, root2); - // Substrait POJO 2 -> Calcite 2 + // Substrait Root 2 -> Calcite 2 final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory, catalogReader); - RelRoot calcite2 = substraitToCalcite.convert(pojo2); + RelRoot calcite2 = substraitToCalcite.convert(root2); // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to // do so assertNotNull(calcite2); - // Calcite 2 -> Substrait POJO 3 - Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite2, extensions); + // Calcite 2 -> Substrait Root 3 + Plan.Root root3 = SubstraitRelVisitor.convert(calcite2, extensions); // Verify that POJOs are the same - assertEquals(pojo1, pojo3); + assertEquals(root1, root3); } /** @@ -285,9 +284,9 @@ protected void assertRowMatch(RelDataType actual, List expected) { assertEquals(expected, struct.fields()); } - protected io.substrait.proto.Plan toSubstraitPlan(String sql, CalciteCatalogReader catalog) + protected Plan toSubstraitPlan(String sql, CalciteCatalogReader catalog) throws SqlParseException { - return new SqlToSubstrait().execute(sql, catalog); + return new SqlToSubstrait().convert(sql, catalog); } protected String toSql(io.substrait.proto.Plan protoPlan) { @@ -305,6 +304,10 @@ protected String toSql(Plan plan) { return SubstraitSqlDialect.toSql(project).getSql(); } + protected io.substrait.proto.Plan toProto(Plan plan) { + return new PlanProtoConverter().toProto(plan); + } + protected static CalciteCatalogReader schemaToCatalog(String schemaName, Schema schema) { CalciteSchema rootSchema = CalciteSchema.createRootSchema(false); rootSchema.add(schemaName, schema); diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index 6d033b809..54af92458 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -4,7 +4,6 @@ import io.substrait.isthmus.utils.SetUtils; import io.substrait.plan.Plan; -import io.substrait.plan.PlanProtoConverter; import io.substrait.plan.ProtoPlanConverter; import io.substrait.proto.AggregateFunction; import io.substrait.relation.Cross; @@ -23,7 +22,7 @@ public class ProtoPlanConverterTest extends PlanTestBase { private io.substrait.proto.Plan getProtoPlan(String query1) throws SqlParseException { SqlToSubstrait s = new SqlToSubstrait(); - return s.execute(query1, TPCH_CATALOG); + return toProto(s.convert(query1, TPCH_CATALOG)); } @Test @@ -54,8 +53,7 @@ public void distinctCount() throws IOException, SqlParseException { String distinctQuery = "select count(DISTINCT L_ORDERKEY) from lineitem"; io.substrait.proto.Plan protoPlan = getProtoPlan(distinctQuery); assertAggregateInvocationDistinct(protoPlan); - assertAggregateInvocationDistinct( - new PlanProtoConverter().toProto(new ProtoPlanConverter().from(protoPlan))); + assertAggregateInvocationDistinct(toProto(new ProtoPlanConverter().from(protoPlan))); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java index 6e95e3dcd..973f8829d 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelCopyOnWriteVisitorTest.java @@ -9,7 +9,6 @@ import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.sql.SubstraitSqlDialect; import io.substrait.plan.Plan; -import io.substrait.plan.ProtoPlanConverter; import io.substrait.relation.Aggregate; import io.substrait.relation.CopyOnWriteUtils; import io.substrait.relation.NamedScan; @@ -77,8 +76,7 @@ public class RelCopyOnWriteVisitorTest extends PlanTestBase { private Plan buildPlanFromQuery(String query) throws IOException, SqlParseException { SqlToSubstrait s = new SqlToSubstrait(); - io.substrait.proto.Plan protoPlan1 = s.execute(query, TPCH_CATALOG); - return new ProtoPlanConverter().from(protoPlan1); + return s.convert(query, TPCH_CATALOG); } @Test diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java index 23a946d92..9fcba2cc0 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubqueryPlanTest.java @@ -18,9 +18,10 @@ public class SubqueryPlanTest extends PlanTestBase { public void existsCorrelatedSubquery() throws SqlParseException { SqlToSubstrait s = new SqlToSubstrait(); Plan plan = - s.execute( - "select l_partkey from lineitem where exists (select o_orderdate from orders where o_orderkey = l_orderkey)", - TPCH_CATALOG); + toProto( + s.convert( + "select l_partkey from lineitem where exists (select o_orderdate from orders where o_orderkey = l_orderkey)", + TPCH_CATALOG)); Expression.Subquery subquery = plan.getRelations(0) @@ -59,9 +60,10 @@ public void existsCorrelatedSubquery() throws SqlParseException { public void uniqueCorrelatedSubquery() throws IOException, SqlParseException { SqlToSubstrait s = new SqlToSubstrait(); Plan plan = - s.execute( - "select l_partkey from lineitem where unique (select o_orderdate from orders where o_orderkey = l_orderkey)", - TPCH_CATALOG); + toProto( + s.convert( + "select l_partkey from lineitem where unique (select o_orderdate from orders where o_orderkey = l_orderkey)", + TPCH_CATALOG)); Expression.Subquery subquery = plan.getRelations(0) @@ -104,7 +106,7 @@ public void inPredicateCorrelatedSubQuery() throws IOException, SqlParseExceptio SqlToSubstrait s = new SqlToSubstrait(); String sql = "select l_orderkey from lineitem where l_partkey in (select p_partkey from part where p_partkey = l_partkey)"; - Plan plan = s.execute(sql, TPCH_CATALOG); + Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); Expression.Subquery subquery = plan.getRelations(0) @@ -142,7 +144,7 @@ public void notInPredicateCorrelatedSubquery() throws IOException, SqlParseExcep SqlToSubstrait s = new SqlToSubstrait(); String sql = "select l_orderkey from lineitem where l_partkey not in (select p_partkey from part where p_partkey = l_partkey)"; - Plan plan = s.execute(sql, TPCH_CATALOG); + Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); Expression.Subquery subquery = plan.getRelations(0) .getRoot() @@ -192,7 +194,7 @@ public void existsNestedCorrelatedSubquery() throws IOException, SqlParseExcepti + " FROM partsupp ps\n" + " WHERE ps.ps_partkey = p.p_partkey\n" + " AND PS.ps_suppkey = l.l_suppkey))"; - Plan plan = s.execute(sql, TPCH_CATALOG); + Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); Expression.Subquery outer_subquery = plan.getRelations(0) @@ -266,7 +268,7 @@ public void existsNestedCorrelatedSubquery() throws IOException, SqlParseExcepti public void nestedScalarCorrelatedSubquery() throws IOException, SqlParseException { SqlToSubstrait s = new SqlToSubstrait(); String sql = asString("subquery/nested_scalar_subquery_in_filter.sql"); - Plan plan = s.execute(sql, TPCH_CATALOG); + Plan plan = toProto(s.convert(sql, TPCH_CATALOG)); String planText = JsonFormat.printer().includingDefaultValueFields().print(plan); System.out.println(planText); @@ -344,7 +346,7 @@ public void correlatedScalarSubQInSelect() throws IOException { Assertions.assertThrows( UnsupportedOperationException.class, () -> { - s.execute(sql, TPCH_CATALOG); + s.convert(sql, TPCH_CATALOG); }); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java index 9be238bc3..828ffab71 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpcdsQueryTest.java @@ -3,7 +3,7 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; import static org.junit.jupiter.api.Assertions.assertThrows; -import io.substrait.proto.Plan; +import io.substrait.plan.Plan; import java.io.IOException; import java.util.Set; import java.util.stream.IntStream; @@ -14,7 +14,8 @@ /** TPC-DS test to convert SQL to Substrait and then convert those plans back to SQL. */ public class TpcdsQueryTest extends PlanTestBase { private static final Set toSubstraitExclusions = Set.of(9, 27, 36, 70, 86); - private static final Set fromSubstraitExclusions = Set.of(1, 30, 67, 81); + private static final Set fromSubstraitPojoExclusions = Set.of(1, 30, 81); + private static final Set fromSubstraitProtoExclusions = Set.of(1, 30, 67, 81); static IntStream testCases() { return IntStream.rangeClosed(1, 99).filter(n -> !toSubstraitExclusions.contains(n)); @@ -29,12 +30,21 @@ static IntStream testCases() { public void testQuery(int query) throws IOException { String inputSql = asString(String.format("tpcds/queries/%02d.sql", query)); - Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait"); + Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait POJO"); - if (!fromSubstraitExclusions.contains(query)) { - assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); + if (!fromSubstraitPojoExclusions.contains(query)) { + assertDoesNotThrow(() -> toSql(plan), "Substrait POJO to SQL"); } else { - assertThrows(Throwable.class, () -> toSql(plan), "Substrait to SQL"); + assertThrows(Throwable.class, () -> toSql(plan), "Substrait POJO to SQL"); + } + + io.substrait.proto.Plan proto = + assertDoesNotThrow(() -> toProto(plan), "Substrait POJO to Substrait PROTO"); + + if (!fromSubstraitProtoExclusions.contains(query)) { + assertDoesNotThrow(() -> toSql(proto), "Substrait PROTO to SQL"); + } else { + assertThrows(Throwable.class, () -> toSql(proto), "Substrait PROTO to SQL"); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java index 0a8b0e8af..1938689a1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/TpchQueryTest.java @@ -2,7 +2,7 @@ import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; -import io.substrait.proto.Plan; +import io.substrait.plan.Plan; import java.io.IOException; import java.util.stream.IntStream; import org.apache.calcite.sql.parser.SqlParseException; @@ -24,9 +24,14 @@ static IntStream testCases() { public void testQuery(int query) throws IOException { String inputSql = asString(String.format("tpch/queries/%02d.sql", query)); - Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait"); + Plan plan = assertDoesNotThrow(() -> toSubstraitPlan(inputSql), "SQL to Substrait POJO"); - assertDoesNotThrow(() -> toSql(plan), "Substrait to SQL"); + assertDoesNotThrow(() -> toSql(plan), "Substrait POJO to SQL"); + + io.substrait.proto.Plan proto = + assertDoesNotThrow(() -> toProto(plan), "Substrait POJO to Substrait PROTO"); + + assertDoesNotThrow(() -> toSql(proto), "Substrait PROTO to SQL"); } private Plan toSubstraitPlan(String sql) throws SqlParseException {