diff --git a/examples/isthmus-api/src/main/java/io/substrait/examples/ToSql.java b/examples/isthmus-api/src/main/java/io/substrait/examples/ToSql.java index 1cb8b70ad..12c1e315b 100644 --- a/examples/isthmus-api/src/main/java/io/substrait/examples/ToSql.java +++ b/examples/isthmus-api/src/main/java/io/substrait/examples/ToSql.java @@ -3,6 +3,7 @@ import io.substrait.examples.IsthmusAppExamples.Action; import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.ConverterProvider; import io.substrait.isthmus.SubstraitToCalcite; import io.substrait.isthmus.SubstraitTypeSystem; import io.substrait.plan.Plan; @@ -11,7 +12,6 @@ import java.io.IOException; import java.nio.file.Files; import java.nio.file.Paths; -import org.apache.calcite.jdbc.JavaTypeFactoryImpl; import org.apache.calcite.rel.rel2sql.RelToSqlConverter; import org.apache.calcite.sql.SqlDialect; @@ -52,9 +52,9 @@ public void run(String[] args) { final SimpleExtension.ExtensionCollection extensions = DefaultExtensionCatalog.DEFAULT_COLLECTION; - final SubstraitToCalcite converter = - new SubstraitToCalcite( - extensions, new JavaTypeFactoryImpl(SubstraitTypeSystem.TYPE_SYSTEM)); + final ConverterProvider converterProvider = + new ConverterProvider(SubstraitTypeSystem.TYPE_FACTORY, extensions); + final SubstraitToCalcite converter = new SubstraitToCalcite(converterProvider); // Determine which SQL Dialect we want the converted queries to be in final SqlDialect sqlDialect = SqlDialect.DatabaseProduct.MYSQL.getDialect(); diff --git a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java index 35d5c6ed2..106de4810 100644 --- a/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java +++ b/isthmus-cli/src/main/java/io/substrait/isthmus/cli/IsthmusEntryPoint.java @@ -1,12 +1,9 @@ package io.substrait.isthmus.cli; -import com.google.common.annotations.VisibleForTesting; import com.google.protobuf.Message; import com.google.protobuf.TextFormat; import com.google.protobuf.util.JsonFormat; import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.isthmus.FeatureBoard; -import io.substrait.isthmus.ImmutableFeatureBoard; import io.substrait.isthmus.SqlExpressionToSubstrait; import io.substrait.isthmus.SqlToSubstrait; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; @@ -16,7 +13,6 @@ import java.io.IOException; import java.util.List; import java.util.concurrent.Callable; -import org.apache.calcite.avatica.util.Casing; import org.apache.calcite.prepare.Prepare; import picocli.CommandLine; import picocli.CommandLine.Command; @@ -56,11 +52,6 @@ enum OutputFormat { BINARY, // protobuf BINARY format } - @Option( - names = {"--unquotedcasing"}, - description = "Calcite's casing policy for unquoted identifiers: ${COMPLETION-CANDIDATES}") - private Casing unquotedCasing = Casing.TO_UPPER; - public static void main(String... args) { CommandLine commandLine = new CommandLine(new IsthmusEntryPoint()); commandLine.setCaseInsensitiveEnumValuesAllowed(true); @@ -83,15 +74,14 @@ public static void main(String... args) { @Override public Integer call() throws Exception { - FeatureBoard featureBoard = buildFeatureBoard(); // Isthmus image is parsing SQL Expression if that argument is defined if (sqlExpressions != null) { SqlExpressionToSubstrait converter = - new SqlExpressionToSubstrait(featureBoard, DefaultExtensionCatalog.DEFAULT_COLLECTION); + new SqlExpressionToSubstrait(DefaultExtensionCatalog.DEFAULT_COLLECTION); ExtendedExpression extendedExpression = converter.convert(sqlExpressions, createStatements); printMessage(extendedExpression); } else { // by default Isthmus image are parsing SQL Query - SqlToSubstrait converter = new SqlToSubstrait(featureBoard); + SqlToSubstrait converter = new SqlToSubstrait(); Prepare.CatalogReader catalog = SubstraitCreateStatementParser.processCreateStatementsToCatalog( createStatements.toArray(String[]::new)); @@ -110,9 +100,4 @@ private void printMessage(Message message) throws IOException { message.writeTo(System.out); } } - - @VisibleForTesting - FeatureBoard buildFeatureBoard() { - return ImmutableFeatureBoard.builder().unquotedCasing(unquotedCasing).build(); - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/ConverterProvider.java b/isthmus/src/main/java/io/substrait/isthmus/ConverterProvider.java new file mode 100644 index 000000000..3027addad --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/ConverterProvider.java @@ -0,0 +1,228 @@ +package io.substrait.isthmus; + +import io.substrait.extension.DefaultExtensionCatalog; +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitOperatorTable; +import io.substrait.isthmus.expression.AggregateFunctionConverter; +import io.substrait.isthmus.expression.CallConverters; +import io.substrait.isthmus.expression.ExpressionRexConverter; +import io.substrait.isthmus.expression.FieldSelectionConverter; +import io.substrait.isthmus.expression.RexExpressionConverter; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import io.substrait.isthmus.expression.SqlArrayValueConstructorCallConverter; +import io.substrait.isthmus.expression.SqlMapValueConstructorCallConverter; +import io.substrait.isthmus.expression.WindowFunctionConverter; +import io.substrait.relation.Rel; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Function; +import org.apache.calcite.avatica.util.Casing; +import org.apache.calcite.config.CalciteConnectionConfig; +import org.apache.calcite.config.CalciteConnectionProperty; +import org.apache.calcite.jdbc.CalciteSchema; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.rex.RexBuilder; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.parser.SqlParser; +import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; +import org.apache.calcite.sql.validate.SqlConformanceEnum; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.calcite.tools.Frameworks; +import org.apache.calcite.tools.RelBuilder; + +/** + * ConverterProvider provides a single-point of configuration for a number of conversions: {@code + * SQl <-> Calcite <-> Substrait} + * + *

It is consumed by all conversion classes as their primary source of configuration. + * + *

The no argument constructor {@link #ConverterProvider()} provides reasonable system defaults. + * + *

Other constructors allow for further customization of conversion behaviours. + * + *

More in-depth customization can be achieved by extending this class, as is done in {@link + * DynamicConverterProvider}. + */ +public class ConverterProvider { + + protected RelDataTypeFactory typeFactory; + + protected ScalarFunctionConverter scalarFunctionConverter; + protected AggregateFunctionConverter aggregateFunctionConverter; + protected WindowFunctionConverter windowFunctionConverter; + + protected TypeConverter typeConverter; + + public ConverterProvider() { + this(SubstraitTypeSystem.TYPE_FACTORY, DefaultExtensionCatalog.DEFAULT_COLLECTION); + } + + public ConverterProvider(SimpleExtension.ExtensionCollection extensions) { + this(SubstraitTypeSystem.TYPE_FACTORY, extensions); + } + + public ConverterProvider( + RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { + this( + typeFactory, + new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory), + new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), + new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), + TypeConverter.DEFAULT); + } + + public ConverterProvider( + RelDataTypeFactory typeFactory, + ScalarFunctionConverter sfc, + AggregateFunctionConverter afc, + WindowFunctionConverter wfc, + TypeConverter tc) { + this.typeFactory = typeFactory; + this.scalarFunctionConverter = sfc; + this.aggregateFunctionConverter = afc; + this.windowFunctionConverter = wfc; + this.typeConverter = tc; + } + + // SQL to Calcite Processing + + /** + * A {@link SqlParser.Config} is a Calcite class which controls SQL parsing behaviour like + * identifier casing. + */ + public SqlParser.Config getSqlParserConfig() { + return SqlParser.Config.DEFAULT + .withUnquotedCasing(Casing.TO_UPPER) + .withParserFactory(SqlDdlParserImpl.FACTORY) + .withConformance(SqlConformanceEnum.LENIENT); + } + + /** + * A {@link CalciteConnectionConfig} is a Calcite class which controls SQL processing behaviour + * like table name case-sensitivity. + */ + public CalciteConnectionConfig getCalciteConnectionConfig() { + return CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); + } + + /** + * A {@link SqlToRelConverter.Config} is a Calcite class which controls SQL processing behaviour + * like field-trimming. + */ + public SqlToRelConverter.Config getSqlToRelConverterConfig() { + return SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false); + } + + /** + * A {@link SqlOperatorTable} is a Calcite class which stores the {@link + * org.apache.calcite.sql.SqlOperator}s available and controls valid identifiers during SQL + * processing. + */ + public SqlOperatorTable getSqlOperatorTable() { + return SubstraitOperatorTable.INSTANCE; + } + + // Calcite to Substrait Processing + + /** + * A {@link SubstraitRelVisitor} converts Calcite {@link org.apache.calcite.rel.RelNode}s to + * Substrait {@link Rel}s + */ + public SubstraitRelVisitor getSubstraitRelVisitor() { + return new SubstraitRelVisitor(this); + } + + /** + * A {@link RexExpressionConverter} converts Calcite {@link org.apache.calcite.rex.RexNode}s to + * Substrait equivalents. + */ + public RexExpressionConverter getRexExpressionConverter(SubstraitRelVisitor srv) { + return new RexExpressionConverter( + srv, getCallConverters(), getWindowFunctionConverter(), getTypeConverter()); + } + + /** + * {@link CallConverter}s are used to convert Calcite {@link org.apache.calcite.rex.RexCall}s to + * Substrait equivalents. + */ + public List getCallConverters() { + ArrayList callConverters = new ArrayList<>(); + callConverters.add(new FieldSelectionConverter(typeConverter)); + callConverters.add(CallConverters.CASE); + callConverters.add(CallConverters.CAST.apply(typeConverter)); + callConverters.add(CallConverters.REINTERPRET.apply(typeConverter)); + callConverters.add(new SqlArrayValueConstructorCallConverter(typeConverter)); + callConverters.add(new SqlMapValueConstructorCallConverter()); + callConverters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); + callConverters.add(scalarFunctionConverter); + return callConverters; + } + + // Substrait To Calcite Processing + + /** + * When converting from Substrait to Calcite, Calcite needs to have a schema available. The + * default strategy uses a {@link SchemaCollector} to generate a {@link CalciteSchema} on the fly + * based on the leaf nodes of a Substrait plan. + * + *

Override to customize the schema generation behaviour + */ + public Function getSchemaResolver() { + SchemaCollector schemaCollector = new SchemaCollector(this); + return schemaCollector::toSchema; + } + + /** + * A {@link SubstraitRelNodeConverter} is used when converting from Substrait {@link Rel}s to + * Calcite {@link org.apache.calcite.rel.RelNode}s. + */ + public SubstraitRelNodeConverter getSubstraitRelNodeConverter(RelBuilder relBuilder) { + return new SubstraitRelNodeConverter(relBuilder, this); + } + + /** + * A {@link ExpressionRexConverter} converts Substrait {@link io.substrait.expression.Expression} + * to Calcite equivalents + */ + public ExpressionRexConverter getExpressionRexConverter( + SubstraitRelNodeConverter relNodeConverter) { + ExpressionRexConverter erc = + new ExpressionRexConverter( + getTypeFactory(), + getScalarFunctionConverter(), + getWindowFunctionConverter(), + getTypeConverter()); + erc.setRelNodeConverter(relNodeConverter); + return erc; + } + + /** + * A {@link RelBuilder} is a Calcite class used as a factory for creating {@link + * org.apache.calcite.rel.RelNode}s. + */ + public RelBuilder getRelBuilder(CalciteSchema schema) { + return RelBuilder.create(Frameworks.newConfigBuilder().defaultSchema(schema.plus()).build()); + } + + // Utility Getters + + public RelDataTypeFactory getTypeFactory() { + return typeFactory; + } + + public ScalarFunctionConverter getScalarFunctionConverter() { + return scalarFunctionConverter; + } + + public AggregateFunctionConverter getAggregateFunctionConverter() { + return aggregateFunctionConverter; + } + + public WindowFunctionConverter getWindowFunctionConverter() { + return windowFunctionConverter; + } + + public TypeConverter getTypeConverter() { + return typeConverter; + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/DynamicConverterProvider.java b/isthmus/src/main/java/io/substrait/isthmus/DynamicConverterProvider.java new file mode 100644 index 000000000..297220deb --- /dev/null +++ b/isthmus/src/main/java/io/substrait/isthmus/DynamicConverterProvider.java @@ -0,0 +1,100 @@ +package io.substrait.isthmus; + +import io.substrait.extension.SimpleExtension; +import io.substrait.isthmus.calcite.SubstraitOperatorTable; +import io.substrait.isthmus.expression.FunctionMappings; +import io.substrait.isthmus.expression.ScalarFunctionConverter; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import org.apache.calcite.rel.type.RelDataTypeFactory; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.util.SqlOperatorTables; + +public class DynamicConverterProvider extends ConverterProvider { + + private final SimpleExtension.ExtensionCollection extensions; + + public DynamicConverterProvider( + RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { + super(typeFactory, extensions); + this.extensions = extensions; + this.scalarFunctionConverter = createScalarFunctionConverter(extensions, typeFactory); + } + + @Override + public List getCallConverters() { + List callConverters = super.getCallConverters(); + + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + List additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); + callConverters.add( + new ScalarFunctionConverter( + extensions.scalarFunctions(), + additionalSignatures, + typeFactory, + TypeConverter.DEFAULT)); + return callConverters; + } + + @Override + public SqlOperatorTable getSqlOperatorTable() { + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + ExtensionUtils.getDynamicExtensions(extensions); + if (!dynamicExtensionCollection.scalarFunctions().isEmpty() + || !dynamicExtensionCollection.aggregateFunctions().isEmpty()) { + List generatedDynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + return SqlOperatorTables.chain( + SubstraitOperatorTable.INSTANCE, SqlOperatorTables.of(generatedDynamicOperators)); + } + + return SubstraitOperatorTable.INSTANCE; + } + + @Override + public ScalarFunctionConverter getScalarFunctionConverter() { + return scalarFunctionConverter; + } + + private static ScalarFunctionConverter createScalarFunctionConverter( + SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { + + List additionalSignatures; + + java.util.Set knownFunctionNames = + FunctionMappings.SCALAR_SIGS.stream() + .map(FunctionMappings.Sig::name) + .collect(Collectors.toSet()); + + List dynamicFunctions = + extensions.scalarFunctions().stream() + .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase())) + .collect(Collectors.toList()); + + if (dynamicFunctions.isEmpty()) { + additionalSignatures = Collections.emptyList(); + } else { + SimpleExtension.ExtensionCollection dynamicExtensionCollection = + SimpleExtension.ExtensionCollection.builder().scalarFunctions(dynamicFunctions).build(); + + List dynamicOperators = + SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); + + additionalSignatures = + dynamicOperators.stream() + .map(op -> FunctionMappings.s(op, op.getName())) + .collect(Collectors.toList()); + } + + return new ScalarFunctionConverter( + extensions.scalarFunctions(), additionalSignatures, typeFactory, TypeConverter.DEFAULT); + } +} diff --git a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java b/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java deleted file mode 100644 index a54f24146..000000000 --- a/isthmus/src/main/java/io/substrait/isthmus/FeatureBoard.java +++ /dev/null @@ -1,35 +0,0 @@ -package io.substrait.isthmus; - -import org.apache.calcite.avatica.util.Casing; -import org.immutables.value.Value; - -/** - * A feature board is a collection of flags that are enabled or configurations that control the - * handling of a request to convert query [batch] to Substrait plans. - */ -@Value.Immutable -public abstract class FeatureBoard { - - /** - * @return Calcite's identifier casing policy for unquoted identifiers. - */ - @Value.Default - public Casing unquotedCasing() { - return Casing.TO_UPPER; - } - - /** - * Controls whether to support dynamic user-defined functions (UDFs) during SQL to Substrait plan - * conversion. - * - *

When enabled, custom functions defined in extension YAML files are available for use in SQL - * queries. These functions will be dynamically converted to SQL operators during plan conversion. - * This feature must be explicitly enabled by users and is disabled by default. - * - * @return true if dynamic UDFs should be supported; false otherwise (default) - */ - @Value.Default - public boolean allowDynamicUdfs() { - return false; - } -} diff --git a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java index 99eaac1ab..45581b2fb 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SchemaCollector.java @@ -30,6 +30,11 @@ public SchemaCollector(RelDataTypeFactory typeFactory, TypeConverter typeConvert this.typeConverter = typeConverter; } + public SchemaCollector(ConverterProvider converterProvider) { + this.typeFactory = converterProvider.getTypeFactory(); + this.typeConverter = converterProvider.getTypeConverter(); + } + /** * Returns a {@link CalciteSchema} containing all tables and schemas defined in {@link NamedScan}s * and {@link NamedWrite}s within the provided relation operation tree. diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java index f667deab0..170dc4f94 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlConverterBase.java @@ -1,7 +1,5 @@ package io.substrait.isthmus; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import org.apache.calcite.config.CalciteConnectionConfig; import org.apache.calcite.config.CalciteConnectionProperty; import org.apache.calcite.plan.Contexts; @@ -14,12 +12,10 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.parser.ddl.SqlDdlParserImpl; -import org.apache.calcite.sql.validate.SqlConformanceEnum; import org.apache.calcite.sql2rel.SqlToRelConverter; public class SqlConverterBase { - protected final SimpleExtension.ExtensionCollection extensionCollection; + protected final ConverterProvider converterProvider; public static final CalciteConnectionConfig CONNECTION_CONFIG = CalciteConnectionConfig.DEFAULT.set( @@ -32,15 +28,11 @@ public class SqlConverterBase { final SqlParser.Config parserConfig; - protected static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); - final FeatureBoard featureBoard; - - protected SqlConverterBase( - FeatureBoard features, SimpleExtension.ExtensionCollection extensionCollection) { - this.factory = SubstraitTypeSystem.TYPE_FACTORY; - this.config = - CalciteConnectionConfig.DEFAULT.set(CalciteConnectionProperty.CASE_SENSITIVE, "false"); - this.converterConfig = SqlToRelConverter.config().withTrimUnusedFields(true).withExpand(false); + protected SqlConverterBase(ConverterProvider converterProvider) { + this.converterProvider = converterProvider; + this.factory = converterProvider.getTypeFactory(); + this.config = converterProvider.getCalciteConnectionConfig(); + this.converterConfig = converterProvider.getSqlToRelConverterConfig(); VolcanoPlanner planner = new VolcanoPlanner(RelOptCostImpl.FACTORY, Contexts.of("hello")); this.relOptCluster = RelOptCluster.create(planner, new RexBuilder(factory)); relOptCluster.setMetadataQuerySupplier( @@ -49,17 +41,6 @@ protected SqlConverterBase( new ProxyingMetadataHandlerProvider(DefaultRelMetadataProvider.INSTANCE); return new RelMetadataQuery(handler); }); - featureBoard = features == null ? FEATURES_DEFAULT : features; - parserConfig = - SqlParser.Config.DEFAULT - .withUnquotedCasing(featureBoard.unquotedCasing()) - .withParserFactory(SqlDdlParserImpl.FACTORY) - .withConformance(SqlConformanceEnum.LENIENT); - - this.extensionCollection = extensionCollection; - } - - protected SqlConverterBase(FeatureBoard features) { - this(features, DefaultExtensionCatalog.DEFAULT_COLLECTION); + parserConfig = converterProvider.getSqlParserConfig(); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java index 3d45f8bde..3bdc2e56c 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlExpressionToSubstrait.java @@ -7,7 +7,6 @@ import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.calcite.SubstraitTable; import io.substrait.isthmus.expression.RexExpressionConverter; -import io.substrait.isthmus.expression.ScalarFunctionConverter; import io.substrait.isthmus.sql.SubstraitCreateStatementParser; import io.substrait.isthmus.sql.SubstraitSqlValidator; import io.substrait.type.NamedStruct; @@ -35,15 +34,17 @@ public class SqlExpressionToSubstrait extends SqlConverterBase { protected final RexExpressionConverter rexConverter; public SqlExpressionToSubstrait() { - this(FEATURES_DEFAULT, DefaultExtensionCatalog.DEFAULT_COLLECTION); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } - public SqlExpressionToSubstrait( - FeatureBoard features, SimpleExtension.ExtensionCollection extensions) { - super(features, extensions); - ScalarFunctionConverter scalarFunctionConverter = - new ScalarFunctionConverter(extensions.scalarFunctions(), factory); - this.rexConverter = new RexExpressionConverter(scalarFunctionConverter); + /** Use {@link #SqlExpressionToSubstrait(ConverterProvider)} instead */ + public SqlExpressionToSubstrait(SimpleExtension.ExtensionCollection extensions) { + this(new ConverterProvider(extensions)); + } + + public SqlExpressionToSubstrait(ConverterProvider converterProvider) { + super(converterProvider); + this.rexConverter = new RexExpressionConverter(converterProvider.getScalarFunctionConverter()); } private static final class Result { diff --git a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java index e60494244..0a61986f9 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SqlToSubstrait.java @@ -2,70 +2,39 @@ import io.substrait.extension.DefaultExtensionCatalog; import io.substrait.extension.SimpleExtension; -import io.substrait.isthmus.calcite.SubstraitOperatorTable; import io.substrait.isthmus.sql.SubstraitSqlToCalcite; import io.substrait.plan.ImmutablePlan.Builder; import io.substrait.plan.Plan; import io.substrait.plan.Plan.Version; -import io.substrait.plan.PlanProtoConverter; -import java.util.List; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.sql.SqlDialect; -import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.SqlOperatorTable; import org.apache.calcite.sql.parser.SqlParseException; import org.apache.calcite.sql.parser.SqlParser; -import org.apache.calcite.sql.util.SqlOperatorTables; -/** Take a SQL statement and a set of table definitions and return a substrait plan. */ +/** + * Take a SQL statement and a set of table definitions and return a substrait plan. + * + *

Conversion behaviours can be customized using a {@link ConverterProvider} + */ public class SqlToSubstrait extends SqlConverterBase { private final SqlOperatorTable operatorTable; + protected final ConverterProvider converterProvider; public SqlToSubstrait() { - this(DefaultExtensionCatalog.DEFAULT_COLLECTION, null); + this(DefaultExtensionCatalog.DEFAULT_COLLECTION); } - public SqlToSubstrait(FeatureBoard features) { - this(DefaultExtensionCatalog.DEFAULT_COLLECTION, features); - } - - public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - super(features, extensions); - - if (featureBoard.allowDynamicUdfs()) { - SimpleExtension.ExtensionCollection dynamicExtensionCollection = - ExtensionUtils.getDynamicExtensions(extensions); - if (!dynamicExtensionCollection.scalarFunctions().isEmpty() - || !dynamicExtensionCollection.aggregateFunctions().isEmpty()) { - List generatedDynamicOperators = - SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, this.factory); - this.operatorTable = - SqlOperatorTables.chain( - SubstraitOperatorTable.INSTANCE, SqlOperatorTables.of(generatedDynamicOperators)); - return; - } - } - this.operatorTable = SubstraitOperatorTable.INSTANCE; + /** Use {@link SqlToSubstrait#SqlToSubstrait(ConverterProvider)} instead */ + @Deprecated + public SqlToSubstrait(SimpleExtension.ExtensionCollection extensions) { + this(new ConverterProvider(extensions)); } - /** - * Converts one or more SQL statements into a Substrait {@link io.substrait.proto.Plan}. - * - * @param sqlStatements a string containing one more SQL statements - * @param catalogReader the {@link Prepare.CatalogReader} for finding tables/views referenced in - * the SQL statements - * @return a Substrait proto {@link io.substrait.proto.Plan} - * @throws SqlParseException if there is an error while parsing the SQL statements string - * @deprecated use {@link #convert(String, org.apache.calcite.prepare.Prepare.CatalogReader)} - * instead to get a {@link Plan} and convert that to a {@link io.substrait.proto.Plan} using - * {@link PlanProtoConverter#toProto(Plan)} - */ - @Deprecated - public io.substrait.proto.Plan execute(String sqlStatements, Prepare.CatalogReader catalogReader) - throws SqlParseException { - PlanProtoConverter planToProto = new PlanProtoConverter(); - return planToProto.toProto( - convert(sqlStatements, catalogReader, SqlDialect.DatabaseProduct.CALCITE.getDialect())); + public SqlToSubstrait(ConverterProvider converterProvider) { + super(converterProvider); + this.operatorTable = converterProvider.getSqlOperatorTable(); + this.converterProvider = converterProvider; } /** @@ -84,7 +53,7 @@ public Plan convert(final String sqlStatements, final Prepare.CatalogReader cata // TODO: consider case in which one sql passes conversion while others don't SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader, operatorTable).stream() - .map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard)) + .map(root -> SubstraitRelVisitor.convert(root, converterProvider)) .forEach(root -> builder.addRoots(root)); return builder.build(); @@ -112,7 +81,7 @@ public Plan convert( // TODO: consider case in which one sql passes conversion while others don't SubstraitSqlToCalcite.convertQueries(sqlStatements, catalogReader, sqlParserConfig).stream() - .map(root -> SubstraitRelVisitor.convert(root, extensionCollection, featureBoard)) + .map(root -> SubstraitRelVisitor.convert(root, converterProvider)) .forEach(root -> builder.addRoots(root)); return builder.build(); diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java index 47daf97e2..ac625986d 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelNodeConverter.java @@ -12,9 +12,7 @@ import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; import io.substrait.isthmus.expression.ExpressionRexConverter; -import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.ScalarFunctionConverter; -import io.substrait.isthmus.expression.WindowFunctionConverter; import io.substrait.relation.AbstractDdlRel; import io.substrait.relation.AbstractRelVisitor; import io.substrait.relation.AbstractUpdate; @@ -57,7 +55,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; import java.util.stream.Stream; -import org.apache.calcite.plan.RelOptCluster; import org.apache.calcite.plan.RelOptTable; import org.apache.calcite.plan.RelTraitDef; import org.apache.calcite.prepare.Prepare; @@ -84,7 +81,6 @@ import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; -import org.apache.calcite.sql.parser.SqlParser; import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; @@ -105,138 +101,37 @@ public class SubstraitRelNodeConverter protected final RexBuilder rexBuilder; private final TypeConverter typeConverter; + /** Use {@link #SubstraitRelNodeConverter(RelBuilder, ConverterProvider)} instead */ + @Deprecated public SubstraitRelNodeConverter( SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory, RelBuilder relBuilder) { - this(extensions, typeFactory, relBuilder, ImmutableFeatureBoard.builder().build()); + this(relBuilder, new ConverterProvider(typeFactory, extensions)); } - public SubstraitRelNodeConverter( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - FeatureBoard featureBoard) { - this( - typeFactory, - relBuilder, - createScalarFunctionConverter(extensions, typeFactory, featureBoard.allowDynamicUdfs()), - new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory), - new WindowFunctionConverter(extensions.windowFunctions(), typeFactory), - TypeConverter.DEFAULT); - } - - public SubstraitRelNodeConverter( - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter) { - this( - typeFactory, - relBuilder, - scalarFunctionConverter, - aggregateFunctionConverter, - windowFunctionConverter, - typeConverter, - new ExpressionRexConverter( - typeFactory, scalarFunctionConverter, windowFunctionConverter, typeConverter)); - } - - public SubstraitRelNodeConverter( - RelDataTypeFactory typeFactory, - RelBuilder relBuilder, - ScalarFunctionConverter scalarFunctionConverter, - AggregateFunctionConverter aggregateFunctionConverter, - WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter, - ExpressionRexConverter expressionRexConverter) { - this.typeFactory = typeFactory; - this.typeConverter = typeConverter; + public SubstraitRelNodeConverter(RelBuilder relBuilder, ConverterProvider converterProvider) { + this.typeFactory = converterProvider.getTypeFactory(); + this.typeConverter = converterProvider.getTypeConverter(); this.relBuilder = relBuilder; this.rexBuilder = new RexBuilder(typeFactory); - this.scalarFunctionConverter = scalarFunctionConverter; - this.aggregateFunctionConverter = aggregateFunctionConverter; - this.expressionRexConverter = expressionRexConverter; - this.expressionRexConverter.setRelNodeConverter(this); - } - - private static ScalarFunctionConverter createScalarFunctionConverter( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - boolean allowDynamicUdfs) { - - List additionalSignatures; - - if (allowDynamicUdfs) { - java.util.Set knownFunctionNames = - FunctionMappings.SCALAR_SIGS.stream() - .map(FunctionMappings.Sig::name) - .collect(Collectors.toSet()); - - List dynamicFunctions = - extensions.scalarFunctions().stream() - .filter(f -> !knownFunctionNames.contains(f.name().toLowerCase())) - .collect(Collectors.toList()); - - if (dynamicFunctions.isEmpty()) { - additionalSignatures = Collections.emptyList(); - } else { - SimpleExtension.ExtensionCollection dynamicExtensionCollection = - SimpleExtension.ExtensionCollection.builder().scalarFunctions(dynamicFunctions).build(); - - List dynamicOperators = - SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); - - additionalSignatures = - dynamicOperators.stream() - .map(op -> FunctionMappings.s(op, op.getName())) - .collect(Collectors.toList()); - } - } else { - additionalSignatures = Collections.emptyList(); - } - - return new ScalarFunctionConverter( - extensions.scalarFunctions(), additionalSignatures, typeFactory, TypeConverter.DEFAULT); + this.scalarFunctionConverter = converterProvider.getScalarFunctionConverter(); + this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter(); + this.expressionRexConverter = converterProvider.getExpressionRexConverter(this); } public static RelNode convert( - Rel relRoot, - RelOptCluster relOptCluster, - Prepare.CatalogReader catalogReader, - SqlParser.Config parserConfig, - SimpleExtension.ExtensionCollection extensions) { - return convert( - relRoot, - relOptCluster, - catalogReader, - parserConfig, - extensions, - ImmutableFeatureBoard.builder().build()); - } - - public static RelNode convert( - Rel relRoot, - RelOptCluster relOptCluster, - Prepare.CatalogReader catalogReader, - SqlParser.Config parserConfig, - SimpleExtension.ExtensionCollection extensions, - FeatureBoard featureBoard) { + Rel relRoot, Prepare.CatalogReader catalogReader, ConverterProvider converterProvider) { RelBuilder relBuilder = RelBuilder.create( Frameworks.newConfigBuilder() - .parserConfig(parserConfig) + .parserConfig(converterProvider.getSqlParserConfig()) .defaultSchema(catalogReader.getRootSchema().plus()) .traitDefs((List) null) .programs() .build()); - return relRoot.accept( - new SubstraitRelNodeConverter( - extensions, relOptCluster.getTypeFactory(), relBuilder, featureBoard), - Context.newContext()); + converterProvider.getSubstraitRelNodeConverter(relBuilder), Context.newContext()); } @Override diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java index 835d8493d..b43468bae 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitRelVisitor.java @@ -8,8 +8,6 @@ import io.substrait.isthmus.calcite.rel.CreateTable; import io.substrait.isthmus.calcite.rel.CreateView; import io.substrait.isthmus.expression.AggregateFunctionConverter; -import io.substrait.isthmus.expression.CallConverters; -import io.substrait.isthmus.expression.FunctionMappings; import io.substrait.isthmus.expression.LiteralConverter; import io.substrait.isthmus.expression.RexExpressionConverter; import io.substrait.isthmus.expression.ScalarFunctionConverter; @@ -59,88 +57,58 @@ import org.apache.calcite.rel.core.TableModify; import org.apache.calcite.rel.type.RelDataType; import org.apache.calcite.rel.type.RelDataTypeFactory; -import org.apache.calcite.rex.RexBuilder; import org.apache.calcite.rex.RexFieldAccess; import org.apache.calcite.rex.RexInputRef; import org.apache.calcite.rex.RexNode; -import org.apache.calcite.sql.SqlOperator; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.util.ImmutableBitSet; import org.immutables.value.Value; +/** + * SubstraitRelVisitor is used to convert Calcite {@link RelNode}s to Substrait {@link Rel}s. + * + *

Conversion behaviours can be customized by using a {@link ConverterProvider} and/or extending + * this class + */ @SuppressWarnings("UnstableApiUsage") @Value.Enclosing public class SubstraitRelVisitor extends RelNodeVisitor { - private static final FeatureBoard FEATURES_DEFAULT = ImmutableFeatureBoard.builder().build(); private static final Expression.BoolLiteral TRUE = ExpressionCreator.bool(false, true); protected final RexExpressionConverter rexExpressionConverter; protected final AggregateFunctionConverter aggregateFunctionConverter; protected final TypeConverter typeConverter; - protected final FeatureBoard featureBoard; private Map fieldAccessDepthMap; + /** Use {@link SubstraitRelVisitor#SubstraitRelVisitor(ConverterProvider)} */ + @Deprecated public SubstraitRelVisitor( RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { - this(typeFactory, extensions, FEATURES_DEFAULT); - } - - public SubstraitRelVisitor( - RelDataTypeFactory typeFactory, - SimpleExtension.ExtensionCollection extensions, - FeatureBoard features) { - - this.typeConverter = TypeConverter.DEFAULT; - ArrayList converters = new ArrayList<>(); - converters.addAll(CallConverters.defaults(typeConverter)); - - if (features.allowDynamicUdfs()) { - SimpleExtension.ExtensionCollection dynamicExtensionCollection = - ExtensionUtils.getDynamicExtensions(extensions); - List dynamicOperators = - SimpleExtensionToSqlOperator.from(dynamicExtensionCollection, typeFactory); - - List additionalSignatures = - dynamicOperators.stream() - .map(op -> FunctionMappings.s(op, op.getName())) - .collect(Collectors.toList()); - converters.add( - new ScalarFunctionConverter( - extensions.scalarFunctions(), - additionalSignatures, - typeFactory, - TypeConverter.DEFAULT)); - } else { - converters.add(new ScalarFunctionConverter(extensions.scalarFunctions(), typeFactory)); - } - - converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); - this.aggregateFunctionConverter = - new AggregateFunctionConverter(extensions.aggregateFunctions(), typeFactory); - WindowFunctionConverter windowFunctionConverter = - new WindowFunctionConverter(extensions.windowFunctions(), typeFactory); - this.rexExpressionConverter = - new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); - this.featureBoard = features; + this(new ConverterProvider(typeFactory, extensions)); } + /** Use {@link SubstraitRelVisitor#SubstraitRelVisitor(ConverterProvider)} */ + @Deprecated public SubstraitRelVisitor( RelDataTypeFactory typeFactory, ScalarFunctionConverter scalarFunctionConverter, AggregateFunctionConverter aggregateFunctionConverter, WindowFunctionConverter windowFunctionConverter, - TypeConverter typeConverter, - FeatureBoard features) { - ArrayList converters = new ArrayList(); - converters.addAll(CallConverters.defaults(typeConverter)); - converters.add(scalarFunctionConverter); - converters.add(CallConverters.CREATE_SEARCH_CONV.apply(new RexBuilder(typeFactory))); - this.aggregateFunctionConverter = aggregateFunctionConverter; - this.rexExpressionConverter = - new RexExpressionConverter(this, converters, windowFunctionConverter, typeConverter); - this.typeConverter = typeConverter; - this.featureBoard = features; + TypeConverter typeConverter) { + this( + new ConverterProvider( + typeFactory, + scalarFunctionConverter, + aggregateFunctionConverter, + windowFunctionConverter, + typeConverter)); + } + + public SubstraitRelVisitor(ConverterProvider converterProvider) { + this.typeConverter = converterProvider.getTypeConverter(); + this.aggregateFunctionConverter = converterProvider.getAggregateFunctionConverter(); + this.rexExpressionConverter = converterProvider.getRexExpressionConverter(this); } protected Expression toExpression(RexNode node) { @@ -628,38 +596,32 @@ public List apply(List inputs) { } /** - * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} using default features. - * - *

This is a convenience method that delegates to {@link #convert(RelRoot, - * SimpleExtension.ExtensionCollection, FeatureBoard)} using {@link #FEATURES_DEFAULT}. + * Deprecated, use {@link #convert(RelRoot, ConverterProvider)} directly * * @param relRoot The Calcite RelRoot to convert. * @param extensions The extension collection to use for the conversion. * @return The resulting Substrait Plan.Root. */ + @Deprecated public static Plan.Root convert(RelRoot relRoot, SimpleExtension.ExtensionCollection extensions) { - return convert(relRoot, extensions, FEATURES_DEFAULT); + return convert(relRoot, new ConverterProvider(extensions)); } /** - * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} using a custom visitor. - * - *

This is the main conversion entry point for a complete plan. It applies the provided {@link - * SubstraitRelVisitor} to the final projected {@link RelNode} from the {@code relRoot}, and wraps - * the resulting {@link Rel} in a {@link Plan.Root}. + * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} * - *

This method also correctly extracts the final output field names, paying special attention - * to nested types (structs, maps) via the visitor's type converter, rather than using the names - * from {@code relRoot.validatedRowType} directly. + *

Converts the output of {@link RelRoot#project()} to a Substrait {@link Rel} and wraps it in + * a {@link Plan.Root}. Handles the extraction of final output field names, paying special + * attention to nested types (structs, maps) via the visitor's type converter, rather than using + * the names from {@link RelRoot#validatedRowType} directly. * - * @param relRoot The Calcite RelRoot to convert. This is expected to be a complete, optimized - * plan. - * @param visitor {@link SubstraitRelVisitor} or its subclass. This allows for custom visitor - * behavior. - * @return The resulting Substrait Plan.Root, containing the converted relational tree and the - * output names. + * @param relRoot The Calcite RelRoot to convert. This is expected to be a complete plan. + * @param converterProvider The {@link ConverterProvider} controlling conversion behaviours. + * @return The resulting Substrait {@link Plan.Root}, containing the converted relational tree and + * the output names. */ - public static Plan.Root convert(RelRoot relRoot, SubstraitRelVisitor visitor) { + public static Plan.Root convert(RelRoot relRoot, ConverterProvider converterProvider) { + SubstraitRelVisitor visitor = converterProvider.getSubstraitRelVisitor(); visitor.popFieldAccessDepthMap(relRoot.rel); Rel rel = visitor.apply(relRoot.project()); @@ -670,80 +632,31 @@ public static Plan.Root convert(RelRoot relRoot, SubstraitRelVisitor visitor) { } /** - * Converts a Calcite {@link RelRoot} to a Substrait {@link Plan.Root} using the specified - * features. - * - *

This is a convenience method that delegates to {@link #convert(RelRoot, - * SubstraitRelVisitor)} using an instance of the {@link SubstraitRelVisitor} as the visitor. - * - * @param relRoot The Calcite RelRoot to convert. - * @param extensions The extension collection to use for the conversion. - * @param features The feature board specifying enabled Substrait features. - * @return The resulting Substrait Plan.Root. - */ - public static Plan.Root convert( - RelRoot relRoot, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - return convert( - relRoot, - new SubstraitRelVisitor(relRoot.rel.getCluster().getTypeFactory(), extensions, features)); - } - - /** - * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} using default features. + * Deprecated, use {@link #convert(RelNode, ConverterProvider)} directly * *

This method is suitable for converting a relational sub-tree, but it does not produce a * {@link Plan.Root}. For a complete plan conversion, use {@link #convert(RelRoot, * SimpleExtension.ExtensionCollection)}. * - *

This is a convenience method that delegates to {@link #convert(RelNode, - * SimpleExtension.ExtensionCollection, FeatureBoard)} using {@link #FEATURES_DEFAULT}. - * * @param relNode The Calcite RelNode (and its subtree) to convert. * @param extensions The extension collection to use for the conversion. * @return The resulting Substrait Rel. */ + @Deprecated public static Rel convert(RelNode relNode, SimpleExtension.ExtensionCollection extensions) { - return convert(relNode, extensions, FEATURES_DEFAULT); + return convert(relNode, new ConverterProvider(extensions)); } /** - * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} using a custom visitor. - * - *

This is the main conversion entry point for a partial plan or a single node (and its - * children). It applies the provided {@link SubstraitRelVisitor} to the given {@code relNode}. - * - *

This method does not wrap the result in a {@link Plan.Root} or extract output names. For - * that, use {@link #convert(RelRoot, SubstraitRelVisitor)}. + * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} * - * @param relNode The Calcite RelNode (and its subtree) to convert. - * @param visitor {@link SubstraitRelVisitor} or its subclass. This allows for custom visitor - * behavior. + * @param relNode The Calcite RelNode to convert. + * @param converterProvider The {@link ConverterProvider} controlling conversion behaviours. * @return The resulting Substrait Rel. */ - public static Rel convert(RelNode relNode, SubstraitRelVisitor visitor) { + public static Rel convert(RelNode relNode, ConverterProvider converterProvider) { + SubstraitRelVisitor visitor = converterProvider.getSubstraitRelVisitor(); visitor.popFieldAccessDepthMap(relNode); return visitor.apply(relNode); } - - /** - * Converts a Calcite {@link RelNode} to a Substrait {@link Rel} using the specified features. - * - *

This method is suitable for converting a relational sub-tree, but it does not produce a - * {@link Plan.Root}. For a complete plan conversion, use {@link #convert(RelRoot, - * SimpleExtension.ExtensionCollection, FeatureBoard)}. - * - *

This is a convenience method that delegates to {@link #convert(RelNode, - * SubstraitRelVisitor)} using an instance of the {@link SubstraitRelVisitor} as the visitor. - * - * @param relNode The Calcite RelNode (and its subtree) to convert. - * @param extensions The extension collection to use for the conversion. - * @param features The feature board specifying enabled Substrait features. - * @return The resulting Substrait Rel. - */ - public static Rel convert( - RelNode relNode, SimpleExtension.ExtensionCollection extensions, FeatureBoard features) { - return convert( - relNode, - new SubstraitRelVisitor(relNode.getCluster().getTypeFactory(), extensions, features)); - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java index 772a3e192..9418ad6bd 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToCalcite.java @@ -1,18 +1,11 @@ package io.substrait.isthmus; -import io.substrait.extension.SimpleExtension; import io.substrait.isthmus.SubstraitRelNodeConverter.Context; import io.substrait.plan.Plan; -import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; -import io.substrait.relation.RelCopyOnWriteVisitor; -import io.substrait.type.NamedStruct; import io.substrait.util.EmptyVisitationContext; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; -import java.util.Optional; import org.apache.calcite.jdbc.CalciteSchema; import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; @@ -22,116 +15,47 @@ import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeField; import org.apache.calcite.sql.SqlKind; -import org.apache.calcite.tools.Frameworks; import org.apache.calcite.tools.RelBuilder; import org.apache.calcite.util.Pair; /** * Converts between Substrait {@link Rel}s and Calcite {@link RelNode}s. * - *

Can be extended to customize the {@link RelBuilder} and {@link SubstraitRelNodeConverter} used - * in the conversion. + *

Conversion behaviours can be customized using a {@link ConverterProvider} */ public class SubstraitToCalcite { - protected final SimpleExtension.ExtensionCollection extensions; protected final RelDataTypeFactory typeFactory; - protected final TypeConverter typeConverter; protected final Prepare.CatalogReader catalogReader; - protected final FeatureBoard featureBoard; + protected ConverterProvider converterProvider; - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, RelDataTypeFactory typeFactory) { - this(extensions, typeFactory, TypeConverter.DEFAULT, null); - } - - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - Prepare.CatalogReader catalogReader) { - this(extensions, typeFactory, TypeConverter.DEFAULT, catalogReader); - } - - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { - this(extensions, typeFactory, typeConverter, null); - } - - public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter, - Prepare.CatalogReader catalogReader) { - this( - extensions, - typeFactory, - typeConverter, - catalogReader, - ImmutableFeatureBoard.builder().build()); + public SubstraitToCalcite(ConverterProvider converterProvider) { + this(converterProvider, null); } public SubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter, - Prepare.CatalogReader catalogReader, - FeatureBoard featureBoard) { - this.extensions = extensions; - this.typeFactory = typeFactory; - this.typeConverter = typeConverter; + ConverterProvider converterProvider, Prepare.CatalogReader catalogReader) { + this.converterProvider = converterProvider; + this.typeFactory = converterProvider.getTypeFactory(); this.catalogReader = catalogReader; - this.featureBoard = featureBoard; - } - - /** - * Extracts a {@link CalciteSchema} from a {@link Rel} - * - *

Override this method to customize schema extraction. - */ - protected CalciteSchema toSchema(Rel rel) { - SchemaCollector schemaCollector = new SchemaCollector(typeFactory, typeConverter); - return schemaCollector.toSchema(rel); - } - - /** - * Creates a {@link RelBuilder} from the extracted {@link CalciteSchema} - * - *

Override this method to customize the {@link RelBuilder}. - */ - protected RelBuilder createRelBuilder(CalciteSchema schema) { - return RelBuilder.create(Frameworks.newConfigBuilder().defaultSchema(schema.plus()).build()); - } - - /** - * Creates a {@link SubstraitRelNodeConverter} from the {@link RelBuilder} - * - *

Override this method to customize the {@link SubstraitRelNodeConverter}. - */ - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter(extensions, typeFactory, relBuilder, featureBoard); } /** * Converts a Substrait {@link Rel} to a Calcite {@link RelNode} * - *

Generates a {@link CalciteSchema} based on the contents of the {@link Rel}, which will be - * used to construct a {@link RelBuilder} with the required schema information to build {@link - * RelNode}s, and a then a {@link SubstraitRelNodeConverter} to perform the actual conversion. - * * @param rel {@link Rel} to convert * @return {@link RelNode} */ public RelNode convert(Rel rel) { RelBuilder relBuilder; if (catalogReader != null) { - relBuilder = createRelBuilder(catalogReader.getRootSchema()); + relBuilder = converterProvider.getRelBuilder(catalogReader.getRootSchema()); } else { - CalciteSchema rootSchema = toSchema(rel); - relBuilder = createRelBuilder(rootSchema); + CalciteSchema rootSchema = converterProvider.getSchemaResolver().apply(rel); + relBuilder = converterProvider.getRelBuilder(rootSchema); } - SubstraitRelNodeConverter converter = createSubstraitRelNodeConverter(relBuilder); + SubstraitRelNodeConverter converter = + converterProvider.getSubstraitRelNodeConverter(relBuilder); return rel.accept(converter, Context.newContext()); } @@ -229,29 +153,4 @@ private Pair renameFields( return Pair.of(currentIndex, type); } } - - private static class NamedStructGatherer extends RelCopyOnWriteVisitor { - Map, NamedStruct> tableMap; - - private NamedStructGatherer() { - super(); - this.tableMap = new HashMap<>(); - } - - public static Map, NamedStruct> gatherTables(Rel rel) { - NamedStructGatherer visitor = new NamedStructGatherer(); - rel.accept(visitor, EmptyVisitationContext.INSTANCE); - return visitor.tableMap; - } - - @Override - public Optional visit(NamedScan namedScan, EmptyVisitationContext context) { - Optional result = super.visit(namedScan, context); - - List tableName = namedScan.getNames(); - tableMap.put(tableName, namedScan.getInitialSchema()); - - return result; - } - } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java index e327ab007..f6e06cb88 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java +++ b/isthmus/src/main/java/io/substrait/isthmus/SubstraitToSql.java @@ -5,18 +5,27 @@ import org.apache.calcite.prepare.Prepare; import org.apache.calcite.rel.RelNode; +/** + * SubstraitToSql assists with converting Substrait to SQL + * + *

Conversion behaviours can be customized using a {@link ConverterProvider} + */ public class SubstraitToSql extends SqlConverterBase { public SubstraitToSql() { - super(FEATURES_DEFAULT); + this(new ConverterProvider()); } - public SubstraitToSql(SimpleExtension.ExtensionCollection extensions) { - super(FEATURES_DEFAULT, extensions); + /** Deprecated, use {@link #SubstraitToSql(ConverterProvider)} instead */ + public SubstraitToSql(SimpleExtension.ExtensionCollection extensionCollection) { + this(new ConverterProvider(extensionCollection)); + } + + public SubstraitToSql(ConverterProvider converterProvider) { + super(converterProvider); } public RelNode substraitRelToCalciteRel(Rel relRoot, Prepare.CatalogReader catalog) { - return SubstraitRelNodeConverter.convert( - relRoot, relOptCluster, catalog, parserConfig, extensionCollection); + return SubstraitRelNodeConverter.convert(relRoot, catalog, converterProvider); } } diff --git a/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java index 8cf4958d8..65d24a0ed 100644 --- a/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java +++ b/isthmus/src/main/java/io/substrait/isthmus/expression/SqlMapValueConstructorCallConverter.java @@ -15,7 +15,7 @@ public class SqlMapValueConstructorCallConverter implements CallConverter { - SqlMapValueConstructorCallConverter() {} + public SqlMapValueConstructorCallConverter() {} @Override public Optional convert( diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java index 7a0d4ddc5..3dac4b65e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexAggregateTest.java @@ -6,8 +6,6 @@ import io.substrait.expression.AggregateFunctionInvocation; import io.substrait.expression.Expression; import io.substrait.expression.ImmutableAggregateFunctionInvocation; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import io.substrait.relation.Aggregate; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -18,8 +16,6 @@ import org.junit.jupiter.api.Test; class ComplexAggregateTest extends PlanTestBase { - protected static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - DefaultExtensionCatalog.DEFAULT_COLLECTION; final TypeCreator R = TypeCreator.of(false); SubstraitBuilder b = new SubstraitBuilder(extensions); @@ -59,7 +55,7 @@ protected void validateAggregateTransformation(Aggregate pojo, Rel expectedTrans assertEquals(expectedTransform, converterPojo); // Substrait POJO -> Calcite - new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(pojo); + substraitToCalcite.convert(pojo); } @Test @@ -192,7 +188,7 @@ void outOfOrderGroupingKeysHaveCorrectCalciteType() { input -> b.grouping(input, 2, 0), input -> List.of(), b.namedScan(List.of("foo"), List.of("a", "b", "c"), List.of(R.I64, R.I64, R.STRING))); - RelNode relNode = new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory).convert(rel); + RelNode relNode = substraitToCalcite.convert(rel); assertRowMatch(relNode.getRowType(), R.STRING, R.I64); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java index 300786f21..297b04a1e 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ComplexSortTest.java @@ -4,8 +4,6 @@ import io.substrait.dsl.SubstraitBuilder; import io.substrait.expression.Expression; -import io.substrait.extension.DefaultExtensionCatalog; -import io.substrait.extension.SimpleExtension; import io.substrait.relation.Rel; import io.substrait.type.TypeCreator; import java.io.PrintWriter; @@ -21,15 +19,9 @@ class ComplexSortTest extends PlanTestBase { - private static final SimpleExtension.ExtensionCollection EXTENSION_COLLECTION = - DefaultExtensionCatalog.DEFAULT_COLLECTION; - final TypeCreator R = TypeCreator.of(false); SubstraitBuilder b = new SubstraitBuilder(extensions); - final SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); - /** * A {@link RelWriterImpl} that annotates each {@link RelNode} with its {@link RelCollation} trait * information. A {@link RelNode} is only annotated if its {@link RelCollation} is not empty. diff --git a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java index 34e06d0ac..a0404a8bd 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/CustomFunctionTest.java @@ -26,7 +26,6 @@ import java.util.List; import org.apache.calcite.rel.RelNode; import org.apache.calcite.rel.type.RelDataType; -import org.apache.calcite.rel.type.RelDataTypeFactory; import org.apache.calcite.rel.type.RelDataTypeSystem; import org.apache.calcite.sql.SqlAggFunction; import org.apache.calcite.sql.SqlFunction; @@ -35,7 +34,6 @@ import org.apache.calcite.sql.type.ReturnTypes; import org.apache.calcite.sql.type.SqlTypeFactoryImpl; import org.apache.calcite.sql.type.SqlTypeName; -import org.apache.calcite.tools.RelBuilder; import org.jspecify.annotations.Nullable; import org.junit.jupiter.api.Test; @@ -55,10 +53,10 @@ class CustomFunctionTest extends PlanTestBase { } // Load custom extension into an ExtensionCollection - static final SimpleExtension.ExtensionCollection extensionCollection = - SimpleExtension.load("custom.yaml", FUNCTIONS_CUSTOM); + static final SimpleExtension.ExtensionCollection CUSTOM_EXTENSIONS = + SimpleExtension.load(URN, FUNCTIONS_CUSTOM); - final SubstraitBuilder b = new SubstraitBuilder(extensionCollection); + final SubstraitBuilder b = new SubstraitBuilder(CUSTOM_EXTENSIONS); // Create user-defined types static final String aTypeName = "a_type"; @@ -239,52 +237,33 @@ public RelDataType toCalcite(Type.UserDefined type) { // Create Function Converters that can handle the custom functions ScalarFunctionConverter scalarFunctionConverter = new ScalarFunctionConverter( - extensionCollection.scalarFunctions(), + CUSTOM_EXTENSIONS.scalarFunctions(), additionalScalarSignatures, typeFactory, typeConverter); AggregateFunctionConverter aggregateFunctionConverter = new AggregateFunctionConverter( - extensionCollection.aggregateFunctions(), + CUSTOM_EXTENSIONS.aggregateFunctions(), additionalAggregateSignatures, typeFactory, typeConverter); WindowFunctionConverter windowFunctionConverter = - new WindowFunctionConverter(extensionCollection.windowFunctions(), typeFactory); + new WindowFunctionConverter(CUSTOM_EXTENSIONS.windowFunctions(), typeFactory); - final SubstraitToCalcite substraitToCalcite = - new CustomSubstraitToCalcite(extensionCollection, typeFactory, typeConverter); - - // Create a SubstraitRelVisitor that uses the custom Function Converters - final SubstraitRelVisitor calciteToSubstrait = - new SubstraitRelVisitor( + ConverterProvider converterProvider = + new ConverterProvider( typeFactory, scalarFunctionConverter, aggregateFunctionConverter, windowFunctionConverter, - typeConverter, - ImmutableFeatureBoard.builder().build()); - - // Create a SubstraitToCalcite converter that has access to the custom Function Converters - class CustomSubstraitToCalcite extends SubstraitToCalcite { + typeConverter); - public CustomSubstraitToCalcite( - SimpleExtension.ExtensionCollection extensions, - RelDataTypeFactory typeFactory, - TypeConverter typeConverter) { - super(extensions, typeFactory, typeConverter); - } + // Create a SubstraitRelVisitor that uses the custom Function Converters + final SubstraitRelVisitor calciteToSubstrait = new SubstraitRelVisitor(converterProvider); + final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(converterProvider); - @Override - protected SubstraitRelNodeConverter createSubstraitRelNodeConverter(RelBuilder relBuilder) { - return new SubstraitRelNodeConverter( - typeFactory, - relBuilder, - scalarFunctionConverter, - aggregateFunctionConverter, - windowFunctionConverter, - typeConverter); - } + CustomFunctionTest() { + super(CUSTOM_EXTENSIONS); } @Test @@ -602,7 +581,7 @@ void customTypesLiteralInFunctionsRoundtrip() { ExtensionCollector extensionCollector = new ExtensionCollector(); io.substrait.proto.Rel protoRel = new RelProtoConverter(extensionCollector).toProto(rel1); - Rel rel3 = new ProtoRelConverter(extensionCollector, extensionCollection).from(protoRel); + Rel rel3 = new ProtoRelConverter(extensionCollector, CUSTOM_EXTENSIONS).from(protoRel); assertEquals(rel1, rel3); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java index a99c5ff35..930e02e7c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NameRoundtripTest.java @@ -23,9 +23,6 @@ void preserveNamesFromSql() throws Exception { CalciteCatalogReader catalogReader = SubstraitCreateStatementParser.processCreateStatementsToCatalog(createStatement); - SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(EXTENSION_COLLECTION, typeFactory); - String query = "SELECT \"a\", \"B\" FROM foo GROUP BY a, b"; List expectedNames = List.of("a", "B"); diff --git a/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java b/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java index e0c4b8023..d1f677ba1 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/NestedExpressionsTest.java @@ -22,7 +22,6 @@ class NestedExpressionsTest extends PlanTestBase { protected static final SimpleExtension.ExtensionCollection defaultExtensionCollection = DefaultExtensionCatalog.DEFAULT_COLLECTION; protected SubstraitBuilder b = new SubstraitBuilder(defaultExtensionCollection); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); Expression literalExpression = Expression.BoolLiteral.builder().value(true).build(); Expression.ScalarFunctionInvocation nonLiteralExpression = b.add(b.i32(7), b.i32(42)); diff --git a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java index a37916bc4..628abe166 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java +++ b/isthmus/src/test/java/io/substrait/isthmus/PlanTestBase.java @@ -48,6 +48,9 @@ public class PlanTestBase { protected static final TypeCreator R = TypeCreator.of(false); protected static final TypeCreator N = TypeCreator.of(true); + protected SubstraitToCalcite substraitToCalcite; + protected ConverterProvider converterProvider; + protected static final CalciteCatalogReader TPCH_CATALOG; static { @@ -69,8 +72,15 @@ protected PlanTestBase() { } protected PlanTestBase(SimpleExtension.ExtensionCollection extensions) { + this(extensions, new ConverterProvider(extensions)); + } + + protected PlanTestBase( + SimpleExtension.ExtensionCollection extensions, ConverterProvider converterProvider) { this.extensions = extensions; this.substraitBuilder = new SubstraitBuilder(extensions); + this.converterProvider = converterProvider; + this.substraitToCalcite = new SubstraitToCalcite(converterProvider); } public static String asString(String resource) throws IOException { @@ -134,7 +144,8 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( // Return list of sql -> Substrait rel -> Calcite rel. SqlToSubstrait s2s = new SqlToSubstrait(); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); + SubstraitToCalcite substraitToCalcite = + new SubstraitToCalcite(converterProvider, catalogReader); // 1. SQL -> Substrait Plan Plan plan1 = s2s.convert(query, catalogReader); @@ -146,7 +157,7 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( RelRoot relRoot2 = substraitToCalcite.convert(pojo1); // 4. Calcite RelNode -> Substrait Rel - Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions); + Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, converterProvider); assertEquals(pojo1, pojo2); return relRoot2; @@ -170,22 +181,15 @@ protected RelRoot assertSqlSubstraitRelRoundTrip( * * @param query the SQL query to test * @param catalogReader the Calcite catalog with table definitions - * @param featureBoard optional FeatureBoard to control conversion behavior (e.g., dynamic UDFs). - * If null, a default FeatureBoard is used. */ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( - String query, Prepare.CatalogReader catalogReader, FeatureBoard featureBoard) - throws Exception { - // Use provided FeatureBoard, or create default if null - FeatureBoard features = - featureBoard != null ? featureBoard : ImmutableFeatureBoard.builder().build(); - + String query, Prepare.CatalogReader catalogReader) throws Exception { SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(extensions, typeFactory, TypeConverter.DEFAULT, null, features); - SqlToSubstrait s = new SqlToSubstrait(extensions, features); + new SubstraitToCalcite(converterProvider, catalogReader); + SqlToSubstrait sqlToSubstrait = new SqlToSubstrait(converterProvider); // 1. SQL -> Substrait Plan - Plan plan1 = s.convert(query, catalogReader); + Plan plan1 = sqlToSubstrait.convert(query, catalogReader); // 2. Substrait Plan -> Substrait Root (POJO 1) Plan.Root pojo1 = plan1.getRoots().get(0); @@ -194,7 +198,7 @@ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( RelRoot relRoot2 = substraitToCalcite.convert(pojo1); // 4. Calcite RelNode -> Substrait Root (POJO 2) - Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, extensions, features); + Plan.Root pojo2 = SubstraitRelVisitor.convert(relRoot2, converterProvider); // Note: pojo1 and pojo2 may differ due to different optimization strategies applied by: // - SqlNode->RelRoot conversion during SQL->Substrait conversion @@ -205,23 +209,13 @@ protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( RelRoot relRoot3 = substraitToCalcite.convert(pojo2); // 6. Calcite RelNode -> Substrait Root (POJO 3) - Plan.Root pojo3 = SubstraitRelVisitor.convert(relRoot3, extensions, features); + Plan.Root pojo3 = SubstraitRelVisitor.convert(relRoot3, converterProvider); // Verify that subsequent round trips are stable (pojo2 and pojo3 should be identical) assertEquals(pojo2, pojo3); return relRoot2; } - /** - * Convenience overload of {@link #assertSqlSubstraitRelRoundTripLoosePojoComparison(String, - * Prepare.CatalogReader, FeatureBoard)} with default FeatureBoard behavior (no dynamic UDFs). - */ - protected RelRoot assertSqlSubstraitRelRoundTripLoosePojoComparison( - String query, Prepare.CatalogReader catalogReader) throws Exception { - return assertSqlSubstraitRelRoundTripLoosePojoComparison( - query, catalogReader, ImmutableFeatureBoard.builder().build()); - } - @Beta protected void assertFullRoundTrip(String query) throws SqlParseException { assertFullRoundTrip(query, TPCH_CATALOG); @@ -269,7 +263,7 @@ protected void assertFullRoundTrip(String sqlQuery, Prepare.CatalogReader catalo // Substrait Root 2 -> Calcite 2 final SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(extensions, typeFactory, catalogReader); + new SubstraitToCalcite(converterProvider, catalogReader); RelRoot calcite2 = substraitToCalcite.convert(root2); // It would be ideal to compare calcite1 and calcite2, however there isn't a good mechanism to @@ -327,7 +321,7 @@ protected void assertFullRoundTripWithIdentityProjectionWorkaround( assertEquals(root0, root1); final SubstraitToCalcite substraitToCalcite = - new SubstraitToCalcite(extensions, typeFactory, catalogReader); + new SubstraitToCalcite(converterProvider, catalogReader); // Substrait POJO 1 -> Calcite 1 RelRoot calcite1 = substraitToCalcite.convert(root1); @@ -375,7 +369,7 @@ protected void assertFullRoundTrip(Rel pojo1) { assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelNode calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + RelNode calcite = new SubstraitToCalcite(converterProvider).convert(pojo2); // Calcite -> Substrait POJO 3 io.substrait.relation.Rel pojo3 = SubstraitRelVisitor.convert(calcite, extensions); @@ -406,7 +400,7 @@ protected void assertFullRoundTrip(Plan.Root pojo1) { assertEquals(pojo1, pojo2); // Substrait POJO 2 -> Calcite - RelRoot calcite = new SubstraitToCalcite(extensions, typeFactory).convert(pojo2); + RelRoot calcite = new SubstraitToCalcite(converterProvider).convert(pojo2); // Calcite -> Substrait POJO 3 io.substrait.plan.Plan.Root pojo3 = SubstraitRelVisitor.convert(calcite, extensions); @@ -441,7 +435,7 @@ protected String toSql(Plan plan) { assertEquals(1, roots.size(), "number of roots"); Root root = roots.get(0); - RelRoot relRoot = new SubstraitToCalcite(extensions, typeFactory).convert(root); + RelRoot relRoot = new SubstraitToCalcite(converterProvider).convert(root); RelNode project = relRoot.project(true); return SubstraitSqlDialect.toSql(project).getSql(); } diff --git a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java index da8423c03..f71a81c1a 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/ProtoPlanConverterTest.java @@ -73,7 +73,6 @@ public Optional visit(Cross cross, EmptyVisitationContext context) return super.visit(cross, context); } }; - ImmutableFeatureBoard featureBoard = ImmutableFeatureBoard.builder().build(); String query1 = "select\n" @@ -82,7 +81,7 @@ public Optional visit(Cross cross, EmptyVisitationContext context) + "from\n" + " \"customer\" c cross join\n" + " \"orders\" o"; - Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait(featureBoard)); + Plan plan1 = assertProtoPlanRoundrip(query1, new SqlToSubstrait()); plan1 .getRoots() .forEach( @@ -96,7 +95,7 @@ public Optional visit(Cross cross, EmptyVisitationContext context) + "from\n" + " \"customer\" c,\n" + " \"orders\" o"; - Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait(featureBoard)); + Plan plan2 = assertProtoPlanRoundrip(query2, new SqlToSubstrait()); plan2 .getRoots() .forEach( diff --git a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java index fc1c4d812..d71d550d6 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/RelExtensionRoundtripTest.java @@ -83,7 +83,9 @@ void roundtrip(Rel pojo1) { Context.newContext()); // Calcite -> Substrait POJO 3 - Rel pojo3 = (new CustomSubstraitRelVisitor(typeFactory, extensions)).apply(calcite); + Rel pojo3 = + (new CustomSubstraitRelVisitor(new ConverterProvider(typeFactory, extensions))) + .apply(calcite); assertEquals(pojo1, pojo3); } @@ -248,9 +250,8 @@ public RelNode visit(ExtensionMulti extensionMulti, Context context) throws Runt /** Extends the standard {@link SubstraitRelVisitor} to handle the {@link ColumnAppenderRel} */ static class CustomSubstraitRelVisitor extends SubstraitRelVisitor { - public CustomSubstraitRelVisitor( - RelDataTypeFactory typeFactory, SimpleExtension.ExtensionCollection extensions) { - super(typeFactory, extensions); + public CustomSubstraitRelVisitor(ConverterProvider converterProvider) { + super(converterProvider); } @Override diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java index d2ffd60ef..9f0c00867 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitExpressionConverterTest.java @@ -74,7 +74,6 @@ void scalarSubQuery() { Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); @@ -95,7 +94,6 @@ void existsSetPredicate() { Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); @@ -116,7 +114,6 @@ void uniqueSetPredicate() { Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); RelNode calciteRel = substraitToCalcite.convert(query); assertInstanceOf(LogicalProject.class, calciteRel); @@ -137,7 +134,6 @@ void unspecifiedSetPredicate() { Project query = b.project(input -> List.of(expr), b.emptyScan()); - SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); Exception exception = assertThrows( UnsupportedOperationException.class, diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java index e9cc9e02a..bfe7c96d5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitRelNodeConverterTest.java @@ -31,8 +31,6 @@ class SubstraitRelNodeConverterTest extends PlanTestBase { final Rel commonTable = b.namedScan(List.of("example"), List.of("a", "b", "c", "d"), commonTableType); - final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); - @Nested class Aggregate { @Test @@ -44,7 +42,7 @@ void direct() { input -> List.of(b.count(input, 0)), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.I64); } @@ -58,7 +56,7 @@ void emit() { b.remap(1, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, R.I64); } } @@ -69,7 +67,7 @@ class Cross { void direct() { Plan.Root root = b.root(b.cross(commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableTypeTwice); } @@ -77,7 +75,7 @@ void direct() { void emit() { Plan.Root root = b.root(b.cross(commonTable, commonTable, b.remap(0, 1, 4, 6))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, R.FP32, R.I32, N.STRING); } } @@ -88,7 +86,7 @@ class Fetch { void direct() { Plan.Root root = b.root(b.fetch(20, 40, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @@ -96,7 +94,7 @@ void direct() { void emit() { Plan.Root root = b.root(b.fetch(20, 40, b.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -107,7 +105,7 @@ class Filter { void direct() { Plan.Root root = b.root(b.filter(input -> b.bool(true), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @@ -115,7 +113,7 @@ void direct() { void emit() { Plan.Root root = b.root(b.filter(input -> b.bool(true), b.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -126,7 +124,7 @@ class Join { void direct() { Plan.Root root = b.root(b.innerJoin(input -> b.bool(true), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableTypeTwice); } @@ -135,7 +133,7 @@ void emit() { Plan.Root root = b.root(b.innerJoin(input -> b.bool(true), b.remap(0, 6), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } @@ -151,7 +149,7 @@ void leftJoin() { b.remap(6, 7, 8), b.join(ji -> b.bool(true), JoinType.LEFT, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.STRING, R.FP64, N.STRING); } @@ -167,7 +165,7 @@ void rightJoin() { b.remap(6, 7, 8), b.join(ji -> b.bool(true), JoinType.RIGHT, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, R.STRING); } @@ -183,7 +181,7 @@ void outerJoin() { b.remap(6, 7, 8), b.join(ji -> b.bool(true), JoinType.OUTER, joinTable, joinTable))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), N.STRING, N.FP64, N.STRING); } } @@ -195,7 +193,7 @@ void direct() { Plan.Root root = b.root(b.namedScan(List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, R.FP32); } @@ -206,7 +204,7 @@ void emit() { b.namedScan( List.of("example"), List.of("a", "b"), List.of(R.I32, R.FP32), b.remap(1))); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.FP32); } } @@ -217,7 +215,7 @@ class Project { void direct() { Plan.Root root = b.root(b.project(input -> b.fieldReferences(input, 1, 0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch( relNode.getRowType(), R.I32, R.FP32, N.STRING, N.BOOLEAN, R.FP32, R.I32, N.STRING); } @@ -229,7 +227,7 @@ void emit() { b.project( input -> b.fieldReferences(input, 1, 0, 2), b.remap(0, 2, 4, 6), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING, R.FP32, N.STRING); } } @@ -240,7 +238,7 @@ class Set { void direct() { Plan.Root root = b.root(b.set(SetOp.UNION_ALL, commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @@ -248,7 +246,7 @@ void direct() { void emit() { Plan.Root root = b.root(b.set(SetOp.UNION_ALL, b.remap(0, 2), commonTable, commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -259,7 +257,7 @@ class Sort { void direct() { Plan.Root root = b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), commonTableType); } @@ -268,7 +266,7 @@ void emit() { Plan.Root root = b.root(b.sort(input -> b.sortFields(input, 0, 1, 2), b.remap(0, 2), commonTable)); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32, N.STRING); } } @@ -284,7 +282,7 @@ void direct() { .build(); Plan.Root root = b.root(emptyScan); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), List.of(R.I32, N.STRING)); } @@ -297,7 +295,7 @@ void emit() { .build(); Plan.Root root = b.root(emptyScanWithRemap); - RelNode relNode = converter.convert(root.getInput()); + RelNode relNode = substraitToCalcite.convert(root.getInput()); assertRowMatch(relNode.getRowType(), R.I32); } } diff --git a/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java index d5d8ada75..4488bbae7 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/SubstraitToCalciteTest.java @@ -13,7 +13,6 @@ import org.junit.jupiter.api.Test; class SubstraitToCalciteTest extends PlanTestBase { - final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); @Test void testConvertRootSingleColumn() { @@ -24,7 +23,7 @@ void testConvertRootSingleColumn() { .addNames("store") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); assertEquals(root.getNames(), relRoot.fields.rightList()); } @@ -38,7 +37,7 @@ void testConvertRootMultipleColumns() { .addNames("s_store_id", "store") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); assertEquals(root.getNames(), relRoot.fields.rightList()); } @@ -58,7 +57,7 @@ void testConvertRootStructField() { assertEquals(List.of("store", "store_id", "store_name"), root.getNames()); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -84,7 +83,7 @@ void testConvertRootArrayWithStructField() { .addNames("store", "store_id", "store_name") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -114,7 +113,7 @@ void testConvertRootMapWithStructValues() { .addNames("store", "store_id", "store_name") .build(); - final RelRoot relRoot = converter.convert(root); + final RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); @@ -144,7 +143,7 @@ void testConvertRootMapWithStructKeys() { .addNames("store", "store_id", "store_name") .build(); - RelRoot relRoot = converter.convert(root); + RelRoot relRoot = substraitToCalcite.convert(root); // Apache Calcite's RelRoot.fields only contains the top level field names assertEquals(List.of("store"), relRoot.fields.rightList()); diff --git a/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java index 69b8be3b9..3bab99bd5 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/UdfSqlSubstraitTest.java @@ -13,6 +13,7 @@ class UdfSqlSubstraitTest extends PlanTestBase { UdfSqlSubstraitTest() { super(loadExtensions(List.of(CUSTOM_FUNCTION_PATH))); + this.converterProvider = new DynamicConverterProvider(typeFactory, extensions); } @Test @@ -22,16 +23,14 @@ void customUdfTest() throws Exception { SubstraitCreateStatementParser.processCreateStatementsToCatalog( "CREATE TABLE t(x VARCHAR NOT NULL)"); - FeatureBoard featureBoard = ImmutableFeatureBoard.builder().allowDynamicUdfs(true).build(); - assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT regexp_extract_custom(x, 'ab') from t", catalogReader, featureBoard); + "SELECT regexp_extract_custom(x, 'ab') from t", catalogReader); assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT format_text('UPPER', x) FROM t", catalogReader, featureBoard); + "SELECT format_text('UPPER', x) FROM t", catalogReader); assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT system_property_get(x) FROM t", catalogReader, featureBoard); + "SELECT system_property_get(x) FROM t", catalogReader); assertSqlSubstraitRelRoundTripLoosePojoComparison( - "SELECT safe_divide_custom(10,0) FROM t", catalogReader, featureBoard); + "SELECT safe_divide_custom(10,0) FROM t", catalogReader); } private static SimpleExtension.ExtensionCollection loadExtensions( diff --git a/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java b/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java index 7d0b26fd8..a3c5db18c 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/VirtualTableScanTest.java @@ -20,7 +20,6 @@ class VirtualTableScanTest extends PlanTestBase { final SubstraitBuilder b = new SubstraitBuilder(extensions); - final SubstraitToCalcite substraitToCalcite = new SubstraitToCalcite(extensions, typeFactory); @Test void literalOnlyVirtualTable() { diff --git a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java index ea82e0a8d..a25912538 100644 --- a/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java +++ b/isthmus/src/test/java/io/substrait/isthmus/expression/SubqueryConversionTest.java @@ -4,7 +4,6 @@ import io.substrait.expression.FieldReference; import io.substrait.isthmus.PlanTestBase; -import io.substrait.isthmus.SubstraitToCalcite; import io.substrait.isthmus.sql.SubstraitSqlDialect; import io.substrait.relation.NamedScan; import io.substrait.relation.Rel; @@ -15,7 +14,6 @@ import org.junit.jupiter.api.Test; class SubqueryConversionTest extends PlanTestBase { - protected final SubstraitToCalcite converter = new SubstraitToCalcite(extensions, typeFactory); private final Rel customerTableScan = substraitBuilder.namedScan( @@ -68,7 +66,7 @@ void testOuterFieldReferenceOneStep() { Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // LogicalFilter has field reference with $cor0 correlation variable // outer LogicalProject has variablesSet containing $cor0 correlation variable @@ -147,7 +145,7 @@ void testOuterFieldReferenceTwoSteps() { Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // most inner LogicalFilter has field reference with $cor0 correlation variable // most outer LogicalProject has variablesSet containing $cor0 correlation variable @@ -225,7 +223,7 @@ void testInPredicateOuterFieldReference() { Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // most inner LogicalFilter has field reference with $cor0 correlation variable // most outer LogicalProject has variablesSet containing $cor0 correlation variable @@ -318,7 +316,7 @@ void testSetPredicateOuterFieldReference() { Remap.of(List.of(2, 3)), orderTableScan); - final RelNode calciteRel = converter.convert(root); + final RelNode calciteRel = substraitToCalcite.convert(root); // most inner LogicalFilter has field references with $cor0 and $cor1 correlation variables // most outer LogicalProject has variablesSet containing $cor0 correlation variable