Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 26 additions & 25 deletions isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,31 +11,35 @@
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.ReturnTypes;

/**
* Provides Substrait-specific variants of Calcite aggregate functions to ensure type inference
* matches Substrait expectations.
*
* <p>Default Calcite implementations may infer return types that differ from Substrait, causing
* conversion issues. This class overrides those behaviors.
*/
public class AggregateFunctions {

// For some arithmetic aggregate functions, the default Calcite aggregate function implementations
// will infer return types that differ from those expected by Substrait.
// This type mismatch can cause conversion and planning failures.

/** Substrait-specific MIN aggregate function (nullable return type). */
public static SqlAggFunction MIN = new SubstraitSqlMinMaxAggFunction(SqlKind.MIN);

/** Substrait-specific MAX aggregate function (nullable return type). */
public static SqlAggFunction MAX = new SubstraitSqlMinMaxAggFunction(SqlKind.MAX);

/** Substrait-specific AVG aggregate function (nullable return type). */
public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG);

/** Substrait-specific SUM aggregate function (nullable return type). */
public static SqlAggFunction SUM = new SubstraitSumAggFunction();

/** Substrait-specific SUM0 aggregate function (non-null BIGINT return type). */
public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction();

/**
* Some Calcite rules, like {@link
* org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default
* Calcite aggregate functions into plans.
*
* <p>When converting these Calcite plans to Substrait, we need to convert the default Calcite
* aggregate calls to the Substrait specific variants.
*
* <p>This function attempts to convert the given {@code aggFunction} to its Substrait equivalent
* Converts default Calcite aggregate functions to Substrait-specific variants when needed.
*
* @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant
* @return an optional containing the Substrait equivalent of the given {@code aggFunction} if
* conversion was needed, empty otherwise.
* @param aggFunction the Calcite aggregate function
* @return optional containing Substrait equivalent if conversion applies
*/
public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggFunction) {
if (aggFunction instanceof SqlMinMaxAggFunction) {
Expand All @@ -53,7 +57,7 @@ public static Optional<SqlAggFunction> toSubstraitAggVariant(SqlAggFunction aggF
}
}

/** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */
/** Substrait variant of {@link SqlMinMaxAggFunction} that forces nullable return type. */
private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction {
public SubstraitSqlMinMaxAggFunction(SqlKind kind) {
super(kind);
Expand All @@ -65,12 +69,10 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
}
}

/** Extension of {@link SqlSumAggFunction} that ALWAYS infers a nullable return type */
/** Substrait variant of {@link SqlSumAggFunction} that forces nullable return type. */
private static class SubstraitSumAggFunction extends SqlSumAggFunction {
public SubstraitSumAggFunction() {
// This is intentionally null
// See the instantiation of SqlSumAggFunction in SqlStdOperatorTable
super(null);
super(null); // Matches Calcite's instantiation pattern
}

@Override
Expand All @@ -79,7 +81,7 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
}
}

/** Extension of {@link SqlAvgAggFunction} that ALWAYS infers a nullable return type */
/** Substrait variant of {@link SqlAvgAggFunction} that forces nullable return type. */
private static class SubstraitAvgAggFunction extends SqlAvgAggFunction {
public SubstraitAvgAggFunction(SqlKind kind) {
super(kind);
Expand All @@ -92,8 +94,8 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
}

/**
* Extension of {@link SqlSumEmptyIsZeroAggFunction} that ALWAYS infers a NOT NULL BIGINT return
* type
* Substrait variant of {@link SqlSumEmptyIsZeroAggFunction} that forces BIGINT return type and
* uses a user-friendly name.
*/
private static class SubstraitSumEmptyIsZeroAggFunction
extends org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction {
Expand All @@ -103,8 +105,7 @@ public SubstraitSumEmptyIsZeroAggFunction() {

@Override
public String getName() {
// the default name for this function is `$sum0`
// override this to `sum0` which is a nicer name to use in queries
// Override default `$sum0` with `sum0` for readability
return "sum0";
}

Expand Down
15 changes: 15 additions & 0 deletions isthmus/src/main/java/io/substrait/isthmus/CallConverter.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,22 @@
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;

/**
* Functional interface for converting Calcite {@link RexCall} expressions into Substrait {@link
* Expression}s.
*
* <p>Implementations should return an {@link Optional} containing the converted expression, or
* {@link Optional#empty()} if the call is not handled.
*/
@FunctionalInterface
public interface CallConverter {

/**
* Converts a Calcite {@link RexCall} into a Substrait {@link Expression}.
*
* @param call the Calcite function/operator call to convert
* @param topLevelConverter a function for converting nested {@link RexNode} operands
* @return an {@link Optional} containing the converted expression, or empty if not applicable
*/
Optional<Expression> convert(RexCall call, Function<RexNode, Expression> topLevelConverter);
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
import java.util.Set;
import java.util.stream.Collectors;

/**
* Utility methods for working with Substrait extensions.
*
* <p>Provides helpers to identify and extract dynamic (custom/user-defined) functions from an
* {@link io.substrait.extension.SimpleExtension.ExtensionCollection}.
*/
public class ExtensionUtils {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,71 @@
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil.SubQueryCollector;

/** Resolve correlated variable and get Depth map for RexFieldAccess */
// See OuterReferenceResolver.md for explanation how the Depth map is computed.
/**
* Resolve correlated variables and compute a depth map for {@link RexFieldAccess}.
*
* <p>Traverses a {@link RelNode} tree and:
*
* <ul>
* <li>Tracks nesting depth of {@link CorrelationId}s across filters, projects, subqueries, and
* correlates
* <li>Computes "steps out" for each {@link RexFieldAccess} referencing a {@link
* RexCorrelVariable}
* </ul>
*
* See OuterReferenceResolver.md for details on how the depth map is computed.
*/
public class OuterReferenceResolver extends RelNodeVisitor<RelNode, RuntimeException> {

private final Map<CorrelationId, Integer> nestedDepth;
private final Map<RexFieldAccess, Integer> fieldAccessDepthMap;

private final RexVisitor rexVisitor = new RexVisitor(this);

/** Creates a new resolver with empty depth tracking maps. */
public OuterReferenceResolver() {
nestedDepth = new HashMap<>();
fieldAccessDepthMap = new IdentityHashMap<>();
}

/**
* Returns the number of "steps out" (nesting depth) for a given {@link RexFieldAccess}.
*
* @param fieldAccess the field access referencing a {@link RexCorrelVariable}
* @return the number of outer scopes between the access and its correlation source, or {@code
* null} if not tracked
*/
public int getStepsOut(RexFieldAccess fieldAccess) {
return fieldAccessDepthMap.get(fieldAccess);
}

/**
* Applies the resolver to a {@link RelNode} tree, computing the depth map.
*
* @param r the root relational node
* @return the same node after traversal
* @throws RuntimeException if the visitor encounters an unrecoverable condition
*/
public RelNode apply(RelNode r) {
return reverseAccept(r);
}

/**
* Returns the computed map from {@link RexFieldAccess} to depth (steps out).
*
* @return map of field access to depth
*/
public Map<RexFieldAccess, Integer> getFieldAccessDepthMap() {
return fieldAccessDepthMap;
}

/**
* Visits a {@link Filter}, registering any correlation variables and visiting its condition.
*
* @param filter the filter node
* @return the result of {@link RelNodeVisitor#visit(Filter)}
* @throws RuntimeException if traversal fails
*/
@Override
public RelNode visit(Filter filter) throws RuntimeException {
for (CorrelationId id : filter.getVariablesSet()) {
Expand All @@ -50,6 +89,16 @@ public RelNode visit(Filter filter) throws RuntimeException {
return super.visit(filter);
}

/**
* Visits a {@link Correlate}, handling correlation depth for both sides.
*
* <p>Special case: the right side is a correlated subquery in the rel tree (not a REX), so we
* manually adjust depth before/after visiting it.
*
* @param correlate the correlate (correlated join) node
* @return the correlate node
* @throws RuntimeException if traversal fails
*/
@Override
public RelNode visit(Correlate correlate) throws RuntimeException {
for (CorrelationId id : correlate.getVariablesSet()) {
Expand All @@ -70,6 +119,13 @@ public RelNode visit(Correlate correlate) throws RuntimeException {
return correlate;
}

/**
* Visits a generic {@link RelNode}, applying traversal to all inputs.
*
* @param other the node to visit
* @return the node
* @throws RuntimeException if traversal fails
*/
@Override
public RelNode visitOther(RelNode other) throws RuntimeException {
for (RelNode child : other.getInputs()) {
Expand All @@ -78,6 +134,14 @@ public RelNode visitOther(RelNode other) throws RuntimeException {
return other;
}

/**
* Visits a {@link Project}, registering correlation variables and visiting any subqueries within
* its expressions.
*
* @param project the project node
* @return the result of {@link RelNodeVisitor#visit(Project)}
* @throws RuntimeException if traversal fails
*/
@Override
public RelNode visit(Project project) throws RuntimeException {
for (CorrelationId id : project.getVariablesSet()) {
Expand All @@ -91,13 +155,25 @@ public RelNode visit(Project project) throws RuntimeException {
return super.visit(project);
}

/** Rex visitor used to track correlation depth within expressions and subqueries. */
private static class RexVisitor extends RexShuttle {
final OuterReferenceResolver referenceResolver;

/**
* Creates a new Rex visitor bound to the given reference resolver.
*
* @param referenceResolver the parent resolver maintaining depth maps
*/
RexVisitor(OuterReferenceResolver referenceResolver) {
this.referenceResolver = referenceResolver;
}

/**
* Increments correlation depth when entering a subquery and decrements when exiting.
*
* @param subQuery the subquery expression
* @return the same subquery
*/
@Override
public RexNode visitSubQuery(RexSubQuery subQuery) {
referenceResolver.nestedDepth.replaceAll((k, v) -> v + 1);
Expand All @@ -108,6 +184,12 @@ public RexNode visitSubQuery(RexSubQuery subQuery) {
return subQuery;
}

/**
* Records depth for {@link RexFieldAccess} referencing a {@link RexCorrelVariable}.
*
* @param fieldAccess the field access expression
* @return the same field access
*/
@Override
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
Expand Down
Loading