diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 89cf2b21585e7d..1ee4cae40f27cf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -2497,36 +2497,37 @@ public PlanFragment visitPhysicalRepeat(PhysicalRepeat repeat, P PlanFragment inputPlanFragment = repeat.child(0).accept(this, context); List> distributeExprLists = getDistributeExprs(repeat.child(0)); - ImmutableSet flattenGroupingSetExprs = ImmutableSet.copyOf( - ExpressionUtils.flatExpressions(repeat.getGroupingSets())); + List flattenGroupingExpressions = repeat.getGroupByExpressions(); + Set preRepeatExpressions = Sets.newLinkedHashSet(); + // keep group by expression coming first + for (Expression groupByExpr : flattenGroupingExpressions) { + // NormalizeRepeat had converted group by expression to slot + preRepeatExpressions.add((Slot) groupByExpr); + } - List aggregateFunctionUsedSlots = repeat.getOutputExpressions() - .stream() - .filter(output -> !flattenGroupingSetExprs.contains(output)) - .filter(output -> !output.containsType(GroupingScalarFunction.class)) - .distinct() - .map(NamedExpression::toSlot) + // add aggregate function used expressions + for (NamedExpression outputExpr : repeat.getOutputExpressions()) { + if (!outputExpr.containsType(GroupingScalarFunction.class)) { + preRepeatExpressions.add(outputExpr.toSlot()); + } + } + + List preRepeatExprs = preRepeatExpressions.stream() + .map(expr -> ExpressionTranslator.translate(expr, context)) .collect(ImmutableList.toImmutableList()); - // keep flattenGroupingSetExprs comes first - List preRepeatExprs = Stream.concat(flattenGroupingSetExprs.stream(), aggregateFunctionUsedSlots.stream()) - .map(expr -> ExpressionTranslator.translate(expr, context)).collect(ImmutableList.toImmutableList()); - - // outputSlots's order need same with preRepeatExprs - List outputSlots = Stream.concat(Stream - .concat(repeat.getOutputExpressions().stream() - .filter(output -> flattenGroupingSetExprs.contains(output)), - repeat.getOutputExpressions().stream() - .filter(output -> !flattenGroupingSetExprs.contains(output)) - .filter(output -> !output.containsType(GroupingScalarFunction.class)) - .distinct() - ), - Stream.concat(Stream.of(repeat.getGroupingId().toSlot()), - repeat.getOutputExpressions().stream() - .filter(output -> output.containsType(GroupingScalarFunction.class))) - ) - .map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList()); + // outputSlots's order must match preRepeatExprs, then grouping id, then grouping function slots + ImmutableList.Builder outputSlotsBuilder + = ImmutableList.builderWithExpectedSize(repeat.getOutputExpressions().size() + 1); + outputSlotsBuilder.addAll(preRepeatExpressions); + outputSlotsBuilder.add(repeat.getGroupingId().toSlot()); + for (NamedExpression outputExpr : repeat.getOutputExpressions()) { + if (outputExpr.containsType(GroupingScalarFunction.class)) { + outputSlotsBuilder.add(outputExpr.toSlot()); + } + } + List outputSlots = outputSlotsBuilder.build(); // NOTE: we should first translate preRepeatExprs, then generate output tuple, // or else the preRepeatExprs can not find the bottom slotRef and throw // exception: invalid slot id diff --git a/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java b/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java index fdfde662dd2e3c..4289b5b32178c6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java +++ b/fe/fe-core/src/main/java/org/apache/doris/planner/RepeatNode.java @@ -107,4 +107,8 @@ public String getNodeExplainString(String detailPrefix, TExplainLevel detailLeve public boolean isSerialOperator() { return children.get(0).isSerialOperator(); } + + public GroupingInfo getGroupingInfo() { + return groupingInfo; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java index 938187fef2b14b..fdf1c726cfc0e1 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslatorTest.java @@ -17,6 +17,11 @@ package org.apache.doris.nereids.glue.translator; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.GroupingInfo; +import org.apache.doris.analysis.SlotRef; +import org.apache.doris.analysis.TupleDescriptor; +import org.apache.doris.catalog.Column; import org.apache.doris.catalog.KeysType; import org.apache.doris.catalog.OlapTable; import org.apache.doris.nereids.properties.DataTrait; @@ -34,16 +39,19 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.planner.AggregationNode; import org.apache.doris.planner.OlapScanNode; import org.apache.doris.planner.PlanFragment; import org.apache.doris.planner.PlanNode; import org.apache.doris.planner.Planner; +import org.apache.doris.planner.RepeatNode; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import mockit.Injectable; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -53,9 +61,18 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; public class PhysicalPlanTranslatorTest extends TestWithFeService { + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test_db"); + createTable("create table test_db.t(a int, b int) distributed by hash(a) buckets 3 " + + "properties('replication_num' = '1');"); + connectContext.getSessionVariable().setDisableNereidsRules("prune_empty_partition"); + } + @Test public void testOlapPrune(@Injectable LogicalProperties placeHolder) throws Exception { OlapTable t1 = PlanConstructor.newOlapTable(0, "t1", 0, KeysType.AGG_KEYS); @@ -93,10 +110,6 @@ public void testOlapPrune(@Injectable LogicalProperties placeHolder) throws Exce @Test public void testAggNeedsFinalize() throws Exception { - createDatabase("test_db"); - createTable("create table test_db.t(a int, b int) distributed by hash(a) buckets 3 " - + "properties('replication_num' = '1');"); - connectContext.getSessionVariable().setDisableNereidsRules("prune_empty_partition"); String querySql = "select b from test_db.t group by b"; Planner planner = getSQLPlanner(querySql); Assertions.assertNotNull(planner); @@ -125,4 +138,31 @@ public void testAggNeedsFinalize() throws Exception { Assertions.assertTrue(upperNeedsFinalize, "upper AggregationNode needsFinalize should be true"); } + + @Test + public void testRepeatInputOutputOrder() throws Exception { + String sql = "select grouping(a), grouping(b), grouping_id(a, b), sum(a + 2 * b), sum(a + 3 * b) + grouping_id(b, a, b), b, a, b, a" + + " from test_db.t" + + " group by grouping sets((a, b), (), (b), (a, b), (a + b), (a * b))"; + PlanChecker.from(connectContext).checkPlannerResult(sql, + planner -> { + Set repeatNodes = Sets.newHashSet(); + planner.getFragments().stream() + .map(PlanFragment::getPlanRoot) + .forEach(plan -> plan.collect(RepeatNode.class, repeatNodes)); + Assertions.assertEquals(1, repeatNodes.size()); + RepeatNode repeatNode = repeatNodes.iterator().next(); + GroupingInfo groupingInfo = repeatNode.getGroupingInfo(); + List preRepeatExprs = groupingInfo.getPreRepeatExprs(); + TupleDescriptor outputs = groupingInfo.getOutputTupleDesc(); + for (int i = 0; i < preRepeatExprs.size(); i++) { + Expr inputExpr = preRepeatExprs.get(i); + Assertions.assertInstanceOf(SlotRef.class, inputExpr); + Column inputColumn = ((SlotRef) inputExpr).getColumn(); + Column outputColumn = outputs.getSlots().get(i).getColumn(); + Assertions.assertEquals(inputColumn, outputColumn); + } + } + ); + } } diff --git a/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out b/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out new file mode 100644 index 00000000000000..c39bd2cf92ca7f --- /dev/null +++ b/regression-test/data/nereids_p0/repeat/test_repeat_output_slot.out @@ -0,0 +1,51 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sql_1_shape -- +PhysicalResultSink +--PhysicalProject +----hashAgg[GLOBAL] +------hashAgg[LOCAL] +--------PhysicalRepeat +----------PhysicalProject +------------PhysicalOlapScan[tbl_test_repeat_output_slot] + +-- !sql_1_result -- +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 +100000 + +-- !sql_2_shape -- +PhysicalResultSink +--PhysicalProject +----filter((GROUPING_PREFIX_col_varchar_50__undef_signed__index_inverted_col_datetime_6__undef_signed_col_varchar_50__undef_signed > 0)) +------hashAgg[GLOBAL] +--------hashAgg[LOCAL] +----------PhysicalRepeat +------------PhysicalProject +--------------PhysicalOlapScan[tbl_test_repeat_output_slot] + +-- !sql_2_result -- +\N ALL 1 6 \N \N \N +\N ALL 1 6 \N \N \N +2020-01-04T00:00 ALL 1 6 \N \N a +2020-01-04T00:00 ALL 1 6 \N \N a +2020-01-04T00:00 ALL 1 6 \N \N b +2020-01-04T00:00 ALL 1 6 \N \N b +2020-01-04T00:00 ALL 1 7 \N \N \N + diff --git a/regression-test/suites/nereids_p0/repeat/test_repeat_output_slot.groovy b/regression-test/suites/nereids_p0/repeat/test_repeat_output_slot.groovy new file mode 100644 index 00000000000000..c0b9f322b6acd1 --- /dev/null +++ b/regression-test/suites/nereids_p0/repeat/test_repeat_output_slot.groovy @@ -0,0 +1,82 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_repeat_output_slot") { + sql """ + SET enable_fallback_to_original_planner=false; + SET enable_nereids_planner=true; + SET ignore_shape_nodes='PhysicalDistribute'; + SET disable_nereids_rules='PRUNE_EMPTY_PARTITION'; + SET runtime_filter_mode=OFF; + SET disable_join_reorder=true; + + DROP TABLE IF EXISTS tbl_test_repeat_output_slot FORCE; + + """ + + sql """ + CREATE TABLE tbl_test_repeat_output_slot ( + col_datetime_6__undef_signed datetime(6), + col_varchar_50__undef_signed varchar(50), + col_varchar_50__undef_signed__index_inverted varchar(50) + ) engine=olap + distributed by hash(col_datetime_6__undef_signed) buckets 10 + properties('replication_num' = '1'); + """ + + sql """ + INSERT INTO tbl_test_repeat_output_slot VALUES + (null, null, null), (null, "a", "x"), (null, "a", "y"), + ('2020-01-02', "b", "x"), ('2020-01-02', 'a', 'x'), ('2020-01-02', 'b', 'y'), + ('2020-01-03', 'a', 'x'), ('2020-01-03', 'a', 'y'), ('2020-01-03', 'b', 'x'), ('2020-01-03', 'b', 'y'), + ('2020-01-04', 'a', 'x'), ('2020-01-04', 'a', 'y'), ('2020-01-04', 'b', 'x'), ('2020-01-04', 'b', 'y'); + """ + + explainAndOrderResult 'sql_1', ''' + SELECT 100000 + FROM tbl_test_repeat_output_slot + GROUP BY GROUPING SETS ( + (col_datetime_6__undef_signed, col_varchar_50__undef_signed) + , () + , (col_varchar_50__undef_signed) + , (col_datetime_6__undef_signed, col_varchar_50__undef_signed) + ); + ''' + + explainAndOrderResult 'sql_2', ''' + SELECT MAX(col_datetime_6__undef_signed) AS total_col_datetime, + CASE WHEN GROUPING(col_varchar_50__undef_signed__index_inverted) = 1 THEN 'ALL' + ELSE CAST(col_varchar_50__undef_signed__index_inverted AS VARCHAR) + END AS pretty_val, + IF(GROUPING_ID(col_varchar_50__undef_signed__index_inverted, + col_datetime_6__undef_signed, + col_varchar_50__undef_signed) > 0, 1, 0) AS is_agg_row, + GROUPING_ID(col_varchar_50__undef_signed__index_inverted, + col_datetime_6__undef_signed, col_varchar_50__undef_signed) AS having_filter_col, + col_varchar_50__undef_signed__index_inverted, + col_datetime_6__undef_signed, + col_varchar_50__undef_signed + FROM tbl_test_repeat_output_slot + GROUP BY GROUPING SETS ( + (col_varchar_50__undef_signed__index_inverted, col_datetime_6__undef_signed, col_varchar_50__undef_signed), + (), + (col_varchar_50__undef_signed), + (col_varchar_50__undef_signed__index_inverted, col_datetime_6__undef_signed, col_varchar_50__undef_signed), + (col_varchar_50__undef_signed)) + HAVING having_filter_col > 0; + ''' +}