diff --git a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
index 6cba80781..0d5d5bf0e 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/AggregateFunctions.java
@@ -11,31 +11,35 @@
import org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction;
import org.apache.calcite.sql.type.ReturnTypes;
+/**
+ * Provides Substrait-specific variants of Calcite aggregate functions to ensure type inference
+ * matches Substrait expectations.
+ *
+ *
Default Calcite implementations may infer return types that differ from Substrait, causing
+ * conversion issues. This class overrides those behaviors.
+ */
public class AggregateFunctions {
- // For some arithmetic aggregate functions, the default Calcite aggregate function implementations
- // will infer return types that differ from those expected by Substrait.
- // This type mismatch can cause conversion and planning failures.
-
+ /** Substrait-specific MIN aggregate function (nullable return type). */
public static SqlAggFunction MIN = new SubstraitSqlMinMaxAggFunction(SqlKind.MIN);
+
+ /** Substrait-specific MAX aggregate function (nullable return type). */
public static SqlAggFunction MAX = new SubstraitSqlMinMaxAggFunction(SqlKind.MAX);
+
+ /** Substrait-specific AVG aggregate function (nullable return type). */
public static SqlAggFunction AVG = new SubstraitAvgAggFunction(SqlKind.AVG);
+
+ /** Substrait-specific SUM aggregate function (nullable return type). */
public static SqlAggFunction SUM = new SubstraitSumAggFunction();
+
+ /** Substrait-specific SUM0 aggregate function (non-null BIGINT return type). */
public static SqlAggFunction SUM0 = new SubstraitSumEmptyIsZeroAggFunction();
/**
- * Some Calcite rules, like {@link
- * org.apache.calcite.rel.rules.AggregateExpandDistinctAggregatesRule}, introduce the default
- * Calcite aggregate functions into plans.
- *
- *
When converting these Calcite plans to Substrait, we need to convert the default Calcite
- * aggregate calls to the Substrait specific variants.
- *
- *
This function attempts to convert the given {@code aggFunction} to its Substrait equivalent
+ * Converts default Calcite aggregate functions to Substrait-specific variants when needed.
*
- * @param aggFunction the {@link SqlAggFunction} to convert to a Substrait specific variant
- * @return an optional containing the Substrait equivalent of the given {@code aggFunction} if
- * conversion was needed, empty otherwise.
+ * @param aggFunction the Calcite aggregate function
+ * @return optional containing Substrait equivalent if conversion applies
*/
public static Optional toSubstraitAggVariant(SqlAggFunction aggFunction) {
if (aggFunction instanceof SqlMinMaxAggFunction) {
@@ -53,7 +57,7 @@ public static Optional toSubstraitAggVariant(SqlAggFunction aggF
}
}
- /** Extension of {@link SqlMinMaxAggFunction} that ALWAYS infers a nullable return type */
+ /** Substrait variant of {@link SqlMinMaxAggFunction} that forces nullable return type. */
private static class SubstraitSqlMinMaxAggFunction extends SqlMinMaxAggFunction {
public SubstraitSqlMinMaxAggFunction(SqlKind kind) {
super(kind);
@@ -65,12 +69,10 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
}
}
- /** Extension of {@link SqlSumAggFunction} that ALWAYS infers a nullable return type */
+ /** Substrait variant of {@link SqlSumAggFunction} that forces nullable return type. */
private static class SubstraitSumAggFunction extends SqlSumAggFunction {
public SubstraitSumAggFunction() {
- // This is intentionally null
- // See the instantiation of SqlSumAggFunction in SqlStdOperatorTable
- super(null);
+ super(null); // Matches Calcite's instantiation pattern
}
@Override
@@ -79,7 +81,7 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
}
}
- /** Extension of {@link SqlAvgAggFunction} that ALWAYS infers a nullable return type */
+ /** Substrait variant of {@link SqlAvgAggFunction} that forces nullable return type. */
private static class SubstraitAvgAggFunction extends SqlAvgAggFunction {
public SubstraitAvgAggFunction(SqlKind kind) {
super(kind);
@@ -92,8 +94,8 @@ public RelDataType inferReturnType(SqlOperatorBinding opBinding) {
}
/**
- * Extension of {@link SqlSumEmptyIsZeroAggFunction} that ALWAYS infers a NOT NULL BIGINT return
- * type
+ * Substrait variant of {@link SqlSumEmptyIsZeroAggFunction} that forces BIGINT return type and
+ * uses a user-friendly name.
*/
private static class SubstraitSumEmptyIsZeroAggFunction
extends org.apache.calcite.sql.fun.SqlSumEmptyIsZeroAggFunction {
@@ -103,8 +105,7 @@ public SubstraitSumEmptyIsZeroAggFunction() {
@Override
public String getName() {
- // the default name for this function is `$sum0`
- // override this to `sum0` which is a nicer name to use in queries
+ // Override default `$sum0` with `sum0` for readability
return "sum0";
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java b/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java
index 8d68ef612..bc1f465c5 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/CallConverter.java
@@ -6,7 +6,22 @@
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexNode;
+/**
+ * Functional interface for converting Calcite {@link RexCall} expressions into Substrait {@link
+ * Expression}s.
+ *
+ * Implementations should return an {@link Optional} containing the converted expression, or
+ * {@link Optional#empty()} if the call is not handled.
+ */
@FunctionalInterface
public interface CallConverter {
+
+ /**
+ * Converts a Calcite {@link RexCall} into a Substrait {@link Expression}.
+ *
+ * @param call the Calcite function/operator call to convert
+ * @param topLevelConverter a function for converting nested {@link RexNode} operands
+ * @return an {@link Optional} containing the converted expression, or empty if not applicable
+ */
Optional convert(RexCall call, Function topLevelConverter);
}
diff --git a/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java
index 377020bb3..ba273f0a6 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/ExtensionUtils.java
@@ -7,6 +7,12 @@
import java.util.Set;
import java.util.stream.Collectors;
+/**
+ * Utility methods for working with Substrait extensions.
+ *
+ * Provides helpers to identify and extract dynamic (custom/user-defined) functions from an
+ * {@link io.substrait.extension.SimpleExtension.ExtensionCollection}.
+ */
public class ExtensionUtils {
/**
diff --git a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java
index eeb645175..64e67e878 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/OuterReferenceResolver.java
@@ -15,8 +15,20 @@
import org.apache.calcite.rex.RexSubQuery;
import org.apache.calcite.rex.RexUtil.SubQueryCollector;
-/** Resolve correlated variable and get Depth map for RexFieldAccess */
-// See OuterReferenceResolver.md for explanation how the Depth map is computed.
+/**
+ * Resolve correlated variables and compute a depth map for {@link RexFieldAccess}.
+ *
+ *
Traverses a {@link RelNode} tree and:
+ *
+ *
+ * - Tracks nesting depth of {@link CorrelationId}s across filters, projects, subqueries, and
+ * correlates
+ *
- Computes "steps out" for each {@link RexFieldAccess} referencing a {@link
+ * RexCorrelVariable}
+ *
+ *
+ * See OuterReferenceResolver.md for details on how the depth map is computed.
+ */
public class OuterReferenceResolver extends RelNodeVisitor {
private final Map nestedDepth;
@@ -24,23 +36,50 @@ public class OuterReferenceResolver extends RelNodeVisitor();
fieldAccessDepthMap = new IdentityHashMap<>();
}
+ /**
+ * Returns the number of "steps out" (nesting depth) for a given {@link RexFieldAccess}.
+ *
+ * @param fieldAccess the field access referencing a {@link RexCorrelVariable}
+ * @return the number of outer scopes between the access and its correlation source, or {@code
+ * null} if not tracked
+ */
public int getStepsOut(RexFieldAccess fieldAccess) {
return fieldAccessDepthMap.get(fieldAccess);
}
+ /**
+ * Applies the resolver to a {@link RelNode} tree, computing the depth map.
+ *
+ * @param r the root relational node
+ * @return the same node after traversal
+ * @throws RuntimeException if the visitor encounters an unrecoverable condition
+ */
public RelNode apply(RelNode r) {
return reverseAccept(r);
}
+ /**
+ * Returns the computed map from {@link RexFieldAccess} to depth (steps out).
+ *
+ * @return map of field access to depth
+ */
public Map getFieldAccessDepthMap() {
return fieldAccessDepthMap;
}
+ /**
+ * Visits a {@link Filter}, registering any correlation variables and visiting its condition.
+ *
+ * @param filter the filter node
+ * @return the result of {@link RelNodeVisitor#visit(Filter)}
+ * @throws RuntimeException if traversal fails
+ */
@Override
public RelNode visit(Filter filter) throws RuntimeException {
for (CorrelationId id : filter.getVariablesSet()) {
@@ -50,6 +89,16 @@ public RelNode visit(Filter filter) throws RuntimeException {
return super.visit(filter);
}
+ /**
+ * Visits a {@link Correlate}, handling correlation depth for both sides.
+ *
+ * Special case: the right side is a correlated subquery in the rel tree (not a REX), so we
+ * manually adjust depth before/after visiting it.
+ *
+ * @param correlate the correlate (correlated join) node
+ * @return the correlate node
+ * @throws RuntimeException if traversal fails
+ */
@Override
public RelNode visit(Correlate correlate) throws RuntimeException {
for (CorrelationId id : correlate.getVariablesSet()) {
@@ -70,6 +119,13 @@ public RelNode visit(Correlate correlate) throws RuntimeException {
return correlate;
}
+ /**
+ * Visits a generic {@link RelNode}, applying traversal to all inputs.
+ *
+ * @param other the node to visit
+ * @return the node
+ * @throws RuntimeException if traversal fails
+ */
@Override
public RelNode visitOther(RelNode other) throws RuntimeException {
for (RelNode child : other.getInputs()) {
@@ -78,6 +134,14 @@ public RelNode visitOther(RelNode other) throws RuntimeException {
return other;
}
+ /**
+ * Visits a {@link Project}, registering correlation variables and visiting any subqueries within
+ * its expressions.
+ *
+ * @param project the project node
+ * @return the result of {@link RelNodeVisitor#visit(Project)}
+ * @throws RuntimeException if traversal fails
+ */
@Override
public RelNode visit(Project project) throws RuntimeException {
for (CorrelationId id : project.getVariablesSet()) {
@@ -91,13 +155,25 @@ public RelNode visit(Project project) throws RuntimeException {
return super.visit(project);
}
+ /** Rex visitor used to track correlation depth within expressions and subqueries. */
private static class RexVisitor extends RexShuttle {
final OuterReferenceResolver referenceResolver;
+ /**
+ * Creates a new Rex visitor bound to the given reference resolver.
+ *
+ * @param referenceResolver the parent resolver maintaining depth maps
+ */
RexVisitor(OuterReferenceResolver referenceResolver) {
this.referenceResolver = referenceResolver;
}
+ /**
+ * Increments correlation depth when entering a subquery and decrements when exiting.
+ *
+ * @param subQuery the subquery expression
+ * @return the same subquery
+ */
@Override
public RexNode visitSubQuery(RexSubQuery subQuery) {
referenceResolver.nestedDepth.replaceAll((k, v) -> v + 1);
@@ -108,6 +184,12 @@ public RexNode visitSubQuery(RexSubQuery subQuery) {
return subQuery;
}
+ /**
+ * Records depth for {@link RexFieldAccess} referencing a {@link RexCorrelVariable}.
+ *
+ * @param fieldAccess the field access expression
+ * @return the same field access
+ */
@Override
public RexNode visitFieldAccess(RexFieldAccess fieldAccess) {
if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
diff --git a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java
index f2419ab01..80d17d1c0 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/PreCalciteAggregateValidator.java
@@ -12,23 +12,23 @@
import java.util.stream.Collectors;
/**
- * Not all Substrait {@link Aggregate} rels are convertable to {@link
- * org.apache.calcite.rel.core.Aggregate} rels
+ * Validates and rewrites Substrait {@link Aggregate} relations for compatibility with Calcite
+ * {@link org.apache.calcite.rel.core.Aggregate}.
*
- *
The code in this class can:
+ *
Responsibilities:
*
*
- * - Check for these cases
- *
- Rewrite the Substrait {@link Aggregate} such that it can be converted to Calcite
+ *
- Check if an {@link Aggregate} can be converted directly to Calcite
+ *
- Rewrite invalid aggregates into a form acceptable by Calcite
*
*/
public class PreCalciteAggregateValidator {
/**
- * Checks that the given {@link Aggregate} is valid for use in Calcite
+ * Checks whether the given {@link Aggregate} is valid for Calcite conversion.
*
- * @param aggregate
- * @return
+ * @param aggregate the Substrait aggregate relation
+ * @return {@code true} if valid for Calcite, {@code false} otherwise
*/
public static boolean isValidCalciteAggregate(Aggregate aggregate) {
return aggregate.getMeasures().stream()
@@ -38,12 +38,11 @@ public static boolean isValidCalciteAggregate(Aggregate aggregate) {
}
/**
- * Checks that all expressions present in the given {@link Aggregate.Measure} are {@link
- * FieldReference}s, as Calcite expects all expressions in {@link
- * org.apache.calcite.rel.core.Aggregate}s to be field references.
+ * Checks if an {@link Aggregate.Measure} uses only {@link FieldReference}s for arguments, sort
+ * fields, and pre-measure filter.
*
- * @return true if the {@code measure} can be converted to a Calcite equivalent without changes,
- * false otherwise.
+ * @param measure the aggregate measure to validate
+ * @return {@code true} if valid, {@code false} otherwise
*/
private static boolean isValidCalciteMeasure(Aggregate.Measure measure) {
return
@@ -58,32 +57,19 @@ private static boolean isValidCalciteMeasure(Aggregate.Measure measure) {
}
/**
- * Checks that all expressions present in the given {@link Aggregate.Grouping} are {@link
- * FieldReference}s, as Calcite expects all expressions in {@link
- * org.apache.calcite.rel.core.Aggregate}s to be field references.
+ * Checks if an {@link Aggregate.Grouping} uses only {@link FieldReference}s and ensures grouping
+ * fields are in ascending order.
*
- * Additionally, checks that all grouping fields are specified in ascending order.
- *
- * @return true if the {@code grouping} can be converted to a Calcite equivalent without changes,
- * false otherwise.
+ * @param grouping the aggregate grouping to validate
+ * @return {@code true} if valid, {@code false} otherwise
*/
private static boolean isValidCalciteGrouping(Aggregate.Grouping grouping) {
if (!grouping.getExpressions().stream().allMatch(e -> isSimpleFieldReference(e))) {
- // all grouping expressions must be field references
return false;
}
- // Calcite stores grouping fields in an ImmutableBitSet and does not track the order of the
- // grouping fields. The output record shape that Calcite generates ALWAYS has the groupings in
- // ascending field order. This causes issues with Substrait in cases where the grouping fields
- // in Substrait are not defined in ascending order.
-
- // For example, if a grouping is defined as (0, 2, 1) in Substrait, Calcite will output it as
- // (0, 1, 2), which means that the Calcite output will no longer line up with the expectations
- // of the Substrait plan.
List groupingFields =
grouping.getExpressions().stream()
- // isSimpleFieldReference above guarantees that the expr is a FieldReference
.map(expr -> getFieldRefOffset((FieldReference) expr))
.collect(Collectors.toList());
@@ -112,6 +98,10 @@ private static boolean isOrdered(List list) {
return true;
}
+ /**
+ * Transforms invalid aggregates into Calcite-compatible form by projecting non-field expressions
+ * and reordering groupings.
+ */
public static class PreCalciteAggregateTransformer {
// New expressions to include in the project before the aggregate
@@ -122,18 +112,19 @@ public static class PreCalciteAggregateTransformer {
private PreCalciteAggregateTransformer(Aggregate aggregate) {
this.newExpressions = new ArrayList<>();
- // The Substrait project output includes all input fields, followed by expressions
this.expressionOffset = aggregate.getInput().getRecordType().fields().size();
}
/**
- * Transforms an {@link Aggregate} that cannot be handled by Calcite into an equivalent that can
- * be handled by:
+ * Rewrites an {@link Aggregate} so that it can be converted to Calcite by:
*
*
- * - Moving all non-field references into a project before the aggregation
- *
- Adding all groupings to this project so that they are referenced in "order"
+ *
- Projecting non-field references before aggregation
+ *
- Ensuring groupings are in ascending order
*
+ *
+ * @param aggregate the original Substrait aggregate
+ * @return a transformed Calcite-compatible aggregate
*/
public static Aggregate transformToValidCalciteAggregate(Aggregate aggregate) {
PreCalciteAggregateTransformer at = new PreCalciteAggregateTransformer(aggregate);
@@ -189,8 +180,6 @@ private Aggregate.Measure updateMeasure(Aggregate.Measure measure) {
}
private Aggregate.Grouping updateGrouping(Aggregate.Grouping grouping) {
- // project out all groupings unconditionally, even field references
- // this ensures that out of order groupings are re-projected into in order groupings
List newGroupingExpressions =
grouping.getExpressions().stream().map(this::projectOut).collect(Collectors.toList());
return Aggregate.Grouping.builder().expressions(newGroupingExpressions).build();
@@ -212,14 +201,15 @@ private Expression projectOutNonFieldReference(Expression expr) {
}
/**
- * Adds a new expression to the project at {@link
- * PreCalciteAggregateTransformer#expressionOffset} and returns a field reference to the new
- * expression
+ * Adds a new expression to the pre-aggregate project and returns a field reference pointing to
+ * it.
+ *
+ * @param expr the expression to project out
+ * @return a {@link FieldReference} to the projected expression
*/
private Expression projectOut(Expression expr) {
newExpressions.add(expr);
return FieldReference.builder()
- // create a field reference to the new expression, then update the expression offset
.addSegments(FieldReference.StructField.of(expressionOffset++))
.type(expr.getType())
.build();
diff --git a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java
index 81c4e9a49..9a45ee40c 100644
--- a/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java
+++ b/isthmus/src/main/java/io/substrait/isthmus/RelNodeVisitor.java
@@ -18,80 +18,215 @@
import org.apache.calcite.rel.core.Union;
import org.apache.calcite.rel.core.Values;
-/** A more generic version of RelShuttle that allows an alternative return value. */
+/**
+ * A generic visitor for {@link RelNode} trees that supports custom return types and checked
+ * exceptions.
+ *
+ * provides type-safe methods for common Calcite relational operators and a fallback for
+ * unhandled types. It is useful when implementing transformations or analysis logic over relational
+ * expressions without extending Calcite's built-in visitor classes.
+ *
+ * @param