diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java index 036ffcd225177e..014e156ff6e10c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/mv/AbstractMaterializedViewRule.java @@ -44,7 +44,6 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.scalar.DateTrunc; @@ -866,14 +865,6 @@ private boolean containsNullRejectSlot(Set> requireNoNullableViewSlot, CascadesContext cascadesContext) { Set queryPulledUpPredicates = queryPredicates.stream() .flatMap(expr -> ExpressionUtils.extractConjunction(expr).stream()) - .map(expr -> { - // NOTICE inferNotNull generate Not with isGeneratedIsNotNull = false, - // so, we need set this flag to false before comparison. - if (expr instanceof Not) { - return ((Not) expr).withGeneratedIsNotNull(false); - } - return expr; - }) .collect(Collectors.toSet()); Set queryNullRejectPredicates = ExpressionUtils.inferNotNull(queryPulledUpPredicates, cascadesContext); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java index 5467de2a9f25d7..34a4bcb7be89a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java @@ -164,7 +164,7 @@ public ValueDesc visitNot(Not not, ExpressionRewriteContext context) { if (childValue instanceof DiscreteValue) { return new NotDiscreteValue(context, childValue.getReference(), ((DiscreteValue) childValue).values); } else if (childValue instanceof IsNullValue) { - return new IsNotNullValue(context, childValue.getReference(), not); + return new IsNotNullValue(context, childValue.getReference()); } else { return new UnknownValue(context, not); } @@ -190,8 +190,7 @@ private ValueDesc processCompound(ExpressionRewriteContext context, List isNotNull = expression -> expression instanceof Not - && expression.child(0) instanceof IsNull - && !((Not) expression).isGeneratedIsNotNull(); + && expression.child(0) instanceof IsNull; for (Expression predicate : predicates) { hasNullExpression = hasNullExpression || predicate.isNullLiteral(); hasIsNullExpression = hasIsNullExpression || predicate instanceof IsNull; @@ -278,14 +277,7 @@ private ValueDesc intersect(ExpressionRewriteContext context, Expression referen // = (TA is not null or null) and (TA is not null) // = TA is not null // = IsNotNull(TA) - if (rangeValue.isRangeAll() && collector.isNotNullValueOpt.isPresent()) { - // Notice that if collector has only isGenerateNotNullValueOpt, we should not keep the rangeAll here - // for expression: (Not(IsNull(TA)) OR NULL) AND GeneratedNot(IsNull(TA)) - // will be converted to RangeAll(TA) AND IsNotNullValue(TA, generated=true) - // if we skip this RangeAll, the final result will be IsNotNullValue(TA, generated=true) - // then convert back to expression: GeneratedNot(IsNull(TA)), - // but later EliminateNotNull rule will remove this generated Not expression, - // then the final result will be TRUE, which is wrong. + if (rangeValue.isRangeAll() && collector.hasIsNotNullValue) { continue; } if (mergeRangeValueDesc == null) { @@ -366,7 +358,7 @@ private ValueDesc intersect(ExpressionRewriteContext context, Expression referen resultValues.add(new EmptyValue(context, reference)); } if (collector.hasIsNullValue) { - if (collector.hasIsNotNullValue()) { + if (collector.hasIsNotNullValue) { return new UnknownValue(context, BooleanLiteral.FALSE); } // nullable's EmptyValue have contains IsNull, no need to add @@ -374,12 +366,11 @@ private ValueDesc intersect(ExpressionRewriteContext context, Expression referen resultValues.add(new IsNullValue(context, reference)); } } - if (collector.hasIsNotNullValue()) { + if (collector.hasIsNotNullValue) { if (collector.hasEmptyValue) { return new UnknownValue(context, BooleanLiteral.FALSE); } - collector.isNotNullValueOpt.ifPresent(resultValues::add); - collector.isGenerateNotNullValueOpt.ifPresent(resultValues::add); + resultValues.add(new IsNotNullValue(context, reference)); } Optional shortCutResult = mergeCompoundValues(context, reference, resultValues, collector, true); if (shortCutResult.isPresent()) { @@ -397,7 +388,7 @@ private ValueDesc intersect(ExpressionRewriteContext context, Expression referen } private ValueDesc union(ExpressionRewriteContext context, Expression reference, ValueDescCollector collector) { - if (collector.hasIsNotNullValue()) { + if (collector.hasIsNotNullValue) { if (!collector.rangeValues.isEmpty() || !collector.discreteValues.isEmpty() || !collector.notDiscreteValues.isEmpty()) { @@ -471,12 +462,12 @@ private ValueDesc union(ExpressionRewriteContext context, Expression reference, } if (collector.hasIsNullValue) { - if (collector.hasIsNotNullValue() || hasRangeAll) { + if (collector.hasIsNotNullValue || hasRangeAll) { return new UnknownValue(context, BooleanLiteral.TRUE); } resultValues.add(new IsNullValue(context, reference)); } - if (collector.hasIsNotNullValue()) { + if (collector.hasIsNotNullValue) { if (collector.hasEmptyValue) { // EmptyValue(TA) or TA is not null // = TA is null and null or TA is not null @@ -484,8 +475,7 @@ private ValueDesc union(ExpressionRewriteContext context, Expression reference, // = RangeAll(TA) resultValues.add(new RangeValue(context, reference, Range.all())); } else { - collector.isNotNullValueOpt.ifPresent(resultValues::add); - collector.isGenerateNotNullValueOpt.ifPresent(resultValues::add); + resultValues.add(new IsNotNullValue(context, reference)); } } @@ -615,11 +605,8 @@ public interface ValueDescVisitor { } private static class ValueDescCollector implements ValueDescVisitor { - // generated not is null != not is null - Optional isNotNullValueOpt = Optional.empty(); - Optional isGenerateNotNullValueOpt = Optional.empty(); - boolean hasIsNullValue = false; + boolean hasIsNotNullValue = false; boolean hasEmptyValue = false; List rangeValues = Lists.newArrayList(); List discreteValues = Lists.newArrayList(); @@ -635,10 +622,6 @@ int size() { return rangeValues.size() + discreteValues.size() + compoundValues.size() + unknownValues.size(); } - boolean hasIsNotNullValue() { - return isNotNullValueOpt.isPresent() || isGenerateNotNullValueOpt.isPresent(); - } - @Override public Void visitEmptyValue(EmptyValue emptyValue, Void context) { hasEmptyValue = true; @@ -671,11 +654,7 @@ public Void visitIsNullValue(IsNullValue isNullValue, Void context) { @Override public Void visitIsNotNullValue(IsNotNullValue isNotNullValue, Void context) { - if (isNotNullValue.not.isGeneratedIsNotNull()) { - isGenerateNotNullValueOpt = Optional.of(isNotNullValue); - } else { - isNotNullValueOpt = Optional.of(isNotNullValue); - } + hasIsNotNullValue = true; return null; } @@ -1214,15 +1193,8 @@ protected UnionType getUnionType(ValueDesc other, int depth) { * a is not null */ public static class IsNotNullValue extends ValueDesc { - final Not not; - - public IsNotNullValue(ExpressionRewriteContext context, Expression reference, Not not) { + public IsNotNullValue(ExpressionRewriteContext context, Expression reference) { super(context, reference); - this.not = not; - } - - public Not getNotExpression() { - return this.not; } @Override @@ -1238,7 +1210,7 @@ protected boolean nullable() { @Override protected boolean containsAll(ValueDesc other, int depth) { if (other instanceof IsNotNullValue) { - return not.isGeneratedIsNotNull() == ((IsNotNullValue) other).not.isGeneratedIsNotNull(); + return true; } else if (other instanceof CompoundValue) { return ((CompoundValue) other).isContainedAllBy(this, depth); } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index cd86283afe8ebe..db20fd71369f70 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -177,7 +177,7 @@ public Expression visitIsNullValue(IsNullValue value, Void context) { @Override public Expression visitIsNotNullValue(IsNotNullValue value, Void context) { - return value.getNotExpression(); + return new Not(new IsNull(value.getReference())); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java index 2f160aba3fced6..55abc121c831d1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagation.java @@ -481,11 +481,6 @@ private boolean needReplaceWithConstant(Expression expression, Map buildRules() { .when(EliminateNotNull::containsNot) .thenApply(ctx -> { LogicalFilter filter = ctx.root; - List predicates = removeGeneratedNotNull(filter.getConjuncts(), + List predicates = removeNotNull(filter.getConjuncts(), ctx.cascadesContext); if (predicates.size() == filter.getConjuncts().size()) { return null; @@ -65,7 +65,7 @@ public List buildRules() { .when(EliminateNotNull::containsNot) .thenApply(ctx -> { LogicalJoin join = ctx.root; - List newOtherJoinConjuncts = removeGeneratedNotNull( + List newOtherJoinConjuncts = removeNotNull( join.getOtherJoinConjuncts(), ctx.cascadesContext); if (newOtherJoinConjuncts.size() == join.getOtherJoinConjuncts().size()) { return null; @@ -77,7 +77,7 @@ public List buildRules() { ); } - private List removeGeneratedNotNull(Collection exprs, CascadesContext ctx) { + private List removeNotNull(Collection exprs, CascadesContext ctx) { // Example: `id > 0 and id is not null and name is not null(generated)` // predicatesNotContainIsNotNull: `id > 0` // predicatesNotContainIsNotNull infer nonNullable slots: `id` @@ -87,14 +87,11 @@ private List removeGeneratedNotNull(Collection exprs, Ca List slotsFromIsNotNull = Lists.newArrayList(); for (Expression expr : exprs) { - // remove generated `is not null` - if (!(expr instanceof Not) || !((Not) expr).isGeneratedIsNotNull()) { - Optional notNullSlot = TypeUtils.isNotNull(expr); - if (notNullSlot.isPresent()) { - slotsFromIsNotNull.add(notNullSlot.get()); - } else { - predicatesNotContainIsNotNull.add(expr); - } + Optional notNullSlot = TypeUtils.isNotNull(expr); + if (notNullSlot.isPresent()) { + slotsFromIsNotNull.add(notNullSlot.get()); + } else { + predicatesNotContainIsNotNull.add(expr); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java index b8704f5610b9ea..db8e4ebbce515c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java @@ -75,7 +75,7 @@ public Rule build() { boolean conjunctsChanged = false; if (!notNullSlots.isEmpty()) { for (Slot slot : notNullSlots) { - Not isNotNull = new Not(new IsNull(slot), true); + Not isNotNull = new Not(new IsNull(slot)); conjunctsChanged |= conjuncts.add(isNotNull); } } @@ -134,11 +134,11 @@ private JoinType tryEliminateOuterJoin(JoinType joinType, boolean canFilterLeftN private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, Collection container) { boolean containerChanged = false; if (swapedEqualTo.left().nullable()) { - Not not = new Not(new IsNull(swapedEqualTo.left()), true); + Not not = new Not(new IsNull(swapedEqualTo.left())); containerChanged |= container.add(not); } if (swapedEqualTo.right().nullable()) { - Not not = new Not(new IsNull(swapedEqualTo.right()), true); + Not not = new Not(new IsNull(swapedEqualTo.right())); containerChanged |= container.add(not); } return containerChanged; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java index e30190592a6b3b..ce331629907c4b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNull.java @@ -20,7 +20,6 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; @@ -30,8 +29,8 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Filter; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.PlanUtils; import com.google.common.collect.ImmutableSet; @@ -46,7 +45,7 @@ public class InferAggNotNull extends OneRewriteRuleFactory { @Override public Rule build() { return logicalAggregate() - .when(agg -> agg.getGroupByExpressions().size() == 0) + .when(agg -> agg.getGroupByExpressions().isEmpty()) .when(agg -> agg.getAggregateFunctions().size() == 1) .when(agg -> { Set funcs = agg.getAggregateFunctions(); @@ -64,20 +63,16 @@ public Rule build() { if ((agg.child() instanceof Filter)) { predicates = ((Filter) agg.child()).getConjuncts(); } - ImmutableSet.Builder needGenerateNotNullsBuilder = ImmutableSet.builder(); - for (Expression isNotNull : isNotNulls) { - if (!predicates.contains(isNotNull)) { - isNotNull = ((Not) isNotNull).withGeneratedIsNotNull(true); - if (!predicates.contains(isNotNull)) { - needGenerateNotNullsBuilder.add(isNotNull); - } - } - } - Set needGenerateNotNulls = needGenerateNotNullsBuilder.build(); - if (needGenerateNotNulls.isEmpty()) { + if (predicates.containsAll(isNotNulls)) { return null; } - return agg.withChildren(PlanUtils.filter(needGenerateNotNulls, agg.child()).get()); + ImmutableSet.Builder newPredicateBuilder + = ImmutableSet.builderWithExpectedSize(predicates.size() + isNotNulls.size()); + newPredicateBuilder.addAll(predicates); + newPredicateBuilder.addAll(isNotNulls); + Plan newFilterChild = agg.child() instanceof Filter ? agg.child().child(0) : agg.child(); + LogicalFilter newFilter = new LogicalFilter<>(newPredicateBuilder.build(), newFilterChild); + return agg.withChildren(newFilter); }).toRule(RuleType.INFER_AGG_NOT_NULL); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNull.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNull.java index 002a40dd16c1ce..94102c47f39735 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNull.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/InferFilterNotNull.java @@ -20,7 +20,6 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Not; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.util.ExpressionUtils; @@ -44,34 +43,17 @@ public class InferFilterNotNull extends OneRewriteRuleFactory { @Override public Rule build() { return logicalFilter() - .when(filter -> { - for (Expression conjunct : filter.getConjuncts()) { - if (conjunct.containsType(Not.class) - && conjunct.anyMatch(n -> n instanceof Not && ((Not) n).isGeneratedIsNotNull())) { - return false; - } - } - return true; - }) .thenApply(ctx -> { LogicalFilter filter = ctx.root; Set predicates = filter.getConjuncts(); Set isNotNulls = ExpressionUtils.inferNotNull(predicates, ctx.cascadesContext); - Builder needGenerateNotNullsBuilder = ImmutableSet.builder(); - for (Expression isNotNull : isNotNulls) { - if (!predicates.contains(isNotNull)) { - needGenerateNotNullsBuilder.add(((Not) isNotNull).withGeneratedIsNotNull(true)); - } - } - Set needGenerateNotNulls = needGenerateNotNullsBuilder.build(); - if (needGenerateNotNulls.isEmpty()) { + if (predicates.containsAll(isNotNulls)) { return null; } - Builder conjuncts = ImmutableSet.builderWithExpectedSize( - predicates.size() + needGenerateNotNulls.size()); + predicates.size() + isNotNulls.size()); conjuncts.addAll(predicates); - conjuncts.addAll(needGenerateNotNulls); + conjuncts.addAll(isNotNulls); return PlanUtils.filter(conjuncts.build(), filter.child()).get(); }).toRule(RuleType.INFER_FILTER_NOT_NULL); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java index e12276ff57fb4d..880279ee41e935 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Not.java @@ -29,7 +29,6 @@ import com.google.common.collect.ImmutableList; import java.util.List; -import java.util.Objects; /** * Not expression: not a. @@ -38,27 +37,16 @@ public class Not extends Expression implements UnaryExpression, ExpectsInputType public static final List EXPECTS_INPUT_TYPES = ImmutableList.of(BooleanType.INSTANCE); - private final boolean isGeneratedIsNotNull; - public Not(Expression child) { - this(child, false); - } - - public Not(List child, boolean isGeneratedIsNotNull, boolean inferred) { - super(child, inferred); - this.isGeneratedIsNotNull = isGeneratedIsNotNull; + this(ImmutableList.of(child)); } - public Not(Expression child, boolean isGeneratedIsNotNull) { - this(ImmutableList.of(child), isGeneratedIsNotNull); - } - - private Not(List child, boolean isGeneratedIsNotNull) { - this(child, isGeneratedIsNotNull, false); + private Not(List child) { + this(child, false); } - public boolean isGeneratedIsNotNull() { - return isGeneratedIsNotNull; + private Not(List child, boolean inferred) { + super(child, inferred); } @Override @@ -76,24 +64,6 @@ public R accept(ExpressionVisitor visitor, C context) { return visitor.visitNot(this, context); } - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - Not other = (Not) o; - return Objects.equals(child(), other.child()) - && isGeneratedIsNotNull == other.isGeneratedIsNotNull; - } - - @Override - public int computeHashCode() { - return Objects.hash(child().hashCode(), isGeneratedIsNotNull); - } - @Override public String toString() { return "( not " + child().toString() + ")"; @@ -114,11 +84,7 @@ public String computeToSql() { @Override public Not withChildren(List children) { Preconditions.checkArgument(children.size() == 1); - return new Not(children, isGeneratedIsNotNull); - } - - public Not withGeneratedIsNotNull(boolean isGeneratedIsNotNull) { - return new Not(children, isGeneratedIsNotNull); + return new Not(children); } @Override @@ -128,6 +94,6 @@ public List expectedInputTypes() { @Override public Expression withInferred(boolean inferred) { - return new Not(this.children, false, inferred); + return new Not(this.children, inferred); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index dadfc9fa0f406a..2af92339f68fb7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -768,7 +768,7 @@ public static Set inferNotNullSlots(Set predicates, CascadesCo public static Set inferNotNull(Set predicates, CascadesContext cascadesContext) { ImmutableSet.Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size()); for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) { - newPredicates.add(new Not(new IsNull(slot), false)); + newPredicates.add(new Not(new IsNull(slot))); } return newPredicates.build(); } @@ -781,18 +781,12 @@ public static Set inferNotNull(Set predicates, Set ImmutableSet.Builder newPredicates = ImmutableSet.builderWithExpectedSize(predicates.size()); for (Slot slot : inferNotNullSlots(predicates, cascadesContext)) { if (slots.contains(slot)) { - newPredicates.add(new Not(new IsNull(slot), true)); + newPredicates.add(new Not(new IsNull(slot))); } } return newPredicates.build(); } - public static boolean isGeneratedNotNull(Expression expression) { - return expression instanceof Not - && ((Not) expression).isGeneratedIsNotNull() - && ((Not) expression).child() instanceof IsNull; - } - /** flatExpressions */ public static List flatExpressions(List> expressionLists) { int num = 0; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java index 285e8313e6a48c..fa3c9a53e2a0ec 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/ConstantPropagationTest.java @@ -166,15 +166,11 @@ void testExpressionNotReplace() { assertRewrite("t.a = 1 and t.a = t.b", "a = 1 and b = 1"); assertRewrite("t1.a = 1 and t1.a = t2.b", "a = 1 and a = b and b = 1"); - // for `a is not null`, if this Not isGeneratedIsNotNull, then will not rewrite it SlotReference a = new SlotReference("a", IntegerType.INSTANCE, true); - Expression expr1 = ExpressionUtils.and(new EqualTo(a, new IntegerLiteral(1)), new Not(new IsNull(a), false)); + Expression expr1 = ExpressionUtils.and(new EqualTo(a, new IntegerLiteral(1)), new Not(new IsNull(a))); Expression rewrittenExpr1 = rewriteExpression(expr1, true); Expression expectExpr1 = new EqualTo(a, new IntegerLiteral(1)); Assertions.assertEquals(expectExpr1, rewrittenExpr1); - Expression expr2 = ExpressionUtils.and(new EqualTo(a, new IntegerLiteral(1)), new Not(new IsNull(a), true)); - Expression rewrittenExpr2 = rewriteExpression(expr2, true); - Assertions.assertEquals(expr2, rewrittenExpr2); // for `a match_any xx`, don't replace it, because the match require left child is column, not literal SlotReference b = new SlotReference("b", StringType.INSTANCE, true); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java index 7d20c2f22a6f2b..5b3b8d27da4cd5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/InferAggNotNullTest.java @@ -18,7 +18,9 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.Not; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; @@ -36,17 +38,18 @@ class InferAggNotNullTest implements MemoPatternMatchSupported { @Test void testInfer() { + Slot slot = scan1.getOutput().get(1); LogicalPlan plan = new LogicalPlanBuilder(scan1) .aggGroupUsingIndex(ImmutableList.of(), - ImmutableList.of(new Alias(new Count(true, scan1.getOutput().get(1)), "dnt"))) + ImmutableList.of(new Alias(new Count(true, slot), "dnt"))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyTopDown(new InferAggNotNull()) .matches( logicalAggregate( - logicalFilter().when(filter -> filter.getConjuncts().stream() - .allMatch(e -> ((Not) e).isGeneratedIsNotNull())) + logicalFilter().when(filter -> filter.getConjuncts() + .contains(new Not(new IsNull(slot)))) ) ); }