Skip to content

Commit 17ef6a7

Browse files
committed
logicalplan: distribute throgh nested aggregations
Signed-off-by: Michael Hoffmann <[email protected]>
1 parent a956a63 commit 17ef6a7

File tree

3 files changed

+172
-8
lines changed

3 files changed

+172
-8
lines changed

engine/distributed_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,28 @@ func TestDistributedAggregations(t *testing.T) {
282282
{name: "group_right with partition label in include", query: `bar{pod="nginx-1"} * on (pod) group_right (zone) bar`},
283283
{name: "group_left without partition label", query: `bar * on (zone) group_left (pod) bar{zone="east-1"}`},
284284
{name: "group_right without partition label", query: `bar{zone="east-1"} * on (zone) group_right (pod) bar`},
285+
// Nested aggregation tests - using deterministic aggregations (max/min/sum/count)
286+
// instead of topk/bottomk to avoid flaky tests due to tie-breaking.
287+
{name: "max over sum by partition", query: `max(sum by (zone, pod) (bar))`},
288+
{name: "max over max by partition", query: `max(max by (zone) (bar))`},
289+
{name: "min over sum by partition", query: `min(sum by (zone, pod) (bar))`},
290+
{name: "min over min by partition", query: `min(min by (zone) (bar))`},
291+
{name: "count over sum by partition", query: `count(sum by (zone, pod) (bar))`},
292+
{name: "sum over max by partition", query: `sum(max by (zone) (bar))`},
293+
{name: "min over count by partition", query: `min(count by (zone) (bar))`},
294+
{name: "sum over max by zone", query: `sum(max by (zone) (bar))`},
295+
{name: "max over sum over rate by partition", query: `max(sum by (zone) (rate(bar[1m])))`},
296+
{name: "min over avg by partition", query: `min(avg by (zone) (bar))`},
297+
{name: "max over max over sum by partition", query: `max(max(sum by (zone) (bar)))`},
298+
{name: "sum over min over max by partition", query: `sum(min(max by (zone, pod) (bar)))`},
299+
{name: "max over sum without partition", query: `max(sum by (pod) (bar))`},
300+
{name: "min over max without partition", query: `min(max(bar))`},
301+
{name: "max over binary op by partition", query: `max(sum by (zone) (bar) / count by (zone) (bar))`},
302+
{name: "count over max by partition", query: `count(max by (zone) (bar))`},
303+
{name: "group over sum by partition", query: `group(sum by (zone) (bar))`},
304+
{name: "max over binary with on() by partition", query: `max(bar * on (zone, pod) bar)`},
305+
{name: "max over sum with without() by partition", query: `max(sum without (pod) (bar))`},
306+
{name: "max over sum with without() without partition", query: `max(sum without (zone) (bar))`},
285307
}
286308

287309
lookbackDeltas := []time.Duration{0, 30 * time.Second, 5 * time.Minute}

logicalplan/distribute.go

Lines changed: 90 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -221,9 +221,30 @@ func (m DistributedExecutionOptimizer) Optimize(plan Node, opts *query.Options)
221221
return true
222222
}
223223

224-
// If the current node is an aggregation, distribute the operation and
225-
// stop the traversal.
224+
// Handle absent functions specially
225+
if isAbsent(*current) {
226+
*current = m.distributeAbsent(*current, engines, calculateStartOffset(current, opts.LookbackDelta), m.subqueryOpts(parents, current, opts))
227+
return true
228+
}
229+
230+
// If the current node is an aggregation, check if we should distribute here
231+
// or continue traversing up.
226232
if aggr, ok := (*current).(*Aggregation); ok {
233+
// If this aggregation preserves partition labels, check if the parent
234+
// is also a distributive aggregation that we could push through.
235+
// This enables patterns like: topk(10, sum by (P, instance) (X))
236+
// where P is a partition label - we can push the entire expression
237+
// to remote engines.
238+
if preservesPartitionLabels(*current, engineLabels) {
239+
if parent != nil {
240+
if parentAggr, ok := (*parent).(*Aggregation); ok {
241+
if _, ok := distributiveAggregations[parentAggr.Op]; ok {
242+
// Parent is a distributive aggregation, continue up
243+
return false
244+
}
245+
}
246+
}
247+
}
227248
localAggregation := aggr.Op
228249
if aggr.Op == parser.COUNT {
229250
localAggregation = parser.SUM
@@ -240,10 +261,6 @@ func (m DistributedExecutionOptimizer) Optimize(plan Node, opts *query.Options)
240261
}
241262
return true
242263
}
243-
if isAbsent(*current) {
244-
*current = m.distributeAbsent(*current, engines, calculateStartOffset(current, opts.LookbackDelta), m.subqueryOpts(parents, current, opts))
245-
return true
246-
}
247264

248265
// If the parent operation is distributive, continue the traversal.
249266
if isDistributive(parent, m.SkipBinaryPushdown, engineLabels, warns) {
@@ -523,6 +540,73 @@ func numSteps(start, end time.Time, step time.Duration) int64 {
523540
return (end.UnixMilli()-start.UnixMilli())/step.Milliseconds() + 1
524541
}
525542

543+
// preservesPartitionLabels checks if an expression preserves all partition labels.
544+
// An expression preserves partition labels if the output series will still have
545+
// those labels, meaning results from different engines won't overlap and can be
546+
// coalesced without deduplication.
547+
//
548+
// This enables pushing more operations to remote engines. For example:
549+
//
550+
// topk(10, sum by (P, instance) (X))
551+
//
552+
// If P is a partition label, the sum preserves P, so topk can also be pushed
553+
// down since each engine's top 10 won't overlap with other engines' top 10.
554+
func preservesPartitionLabels(expr Node, partitionLabels map[string]struct{}) bool {
555+
if len(partitionLabels) == 0 {
556+
return true
557+
}
558+
559+
switch e := expr.(type) {
560+
case *VectorSelector, *MatrixSelector, *NumberLiteral, *StringLiteral:
561+
return true
562+
case *Aggregation:
563+
for lbl := range partitionLabels {
564+
if slices.Contains(e.Grouping, lbl) == e.Without {
565+
return false
566+
}
567+
}
568+
return true
569+
case *Binary:
570+
if e.VectorMatching != nil {
571+
for lbl := range partitionLabels {
572+
inMatching := slices.Contains(e.VectorMatching.MatchingLabels, lbl)
573+
inInclude := slices.Contains(e.VectorMatching.Include, lbl)
574+
if !inInclude && inMatching != e.VectorMatching.On {
575+
return false
576+
}
577+
}
578+
}
579+
return preservesPartitionLabels(e.LHS, partitionLabels) &&
580+
preservesPartitionLabels(e.RHS, partitionLabels)
581+
case *FunctionCall:
582+
if e.Func.Name == "label_replace" {
583+
if _, ok := partitionLabels[UnsafeUnwrapString(e.Args[1])]; ok {
584+
return false
585+
}
586+
}
587+
for _, arg := range e.Args {
588+
if arg.ReturnType() == parser.ValueTypeVector || arg.ReturnType() == parser.ValueTypeMatrix {
589+
if !preservesPartitionLabels(arg, partitionLabels) {
590+
return false
591+
}
592+
}
593+
}
594+
return true
595+
case *Unary:
596+
return preservesPartitionLabels(e.Expr, partitionLabels)
597+
case *Parens:
598+
return preservesPartitionLabels(e.Expr, partitionLabels)
599+
case *StepInvariantExpr:
600+
return preservesPartitionLabels(e.Expr, partitionLabels)
601+
case *CheckDuplicateLabels:
602+
return preservesPartitionLabels(e.Expr, partitionLabels)
603+
case *Subquery:
604+
return preservesPartitionLabels(e.Expr, partitionLabels)
605+
default:
606+
return false
607+
}
608+
}
609+
526610
func isDistributive(expr *Node, skipBinaryPushdown bool, engineLabels map[string]struct{}, warns *annotations.Annotations) bool {
527611
if expr == nil {
528612
return false

logicalplan/distribute_test.go

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,8 +442,6 @@ count by (cluster) (
442442
skipBinopPushdown: true,
443443
},
444444
{
445-
// group_left/group_right with partition label cannot be distributed because
446-
// match cardinality changes when each partition only sees one value for that label.
447445
name: "topk over binary with group_left including partition label does not distribute",
448446
expr: `topk(5, metric_a * on (pod) group_left(region) metric_b)`,
449447
expected: `topk(5, dedup(remote(metric_a), remote(metric_a)) * on (pod) group_left (region) dedup(remote(metric_b), remote(metric_b)))`,
@@ -528,6 +526,66 @@ count by (cluster) (
528526
expr: `metric_a unless on () vector(0)`,
529527
expected: `dedup(remote(metric_a), remote(metric_a)) unless on () vector(0)`,
530528
},
529+
{
530+
name: "topk over sum by partition",
531+
expr: `topk(10, sum by (region, instance) (http_requests_total))`,
532+
expected: `topk(10, dedup(remote(topk by (region) (10, sum by (region, instance) (http_requests_total))), remote(topk by (region) (10, sum by (region, instance) (http_requests_total)))))`,
533+
},
534+
{
535+
name: "bottomk over max by partition",
536+
expr: `bottomk(5, max by (region, pod) (cpu_usage))`,
537+
expected: `bottomk(5, dedup(remote(bottomk by (region) (5, max by (region, pod) (cpu_usage))), remote(bottomk by (region) (5, max by (region, pod) (cpu_usage)))))`,
538+
},
539+
{
540+
name: "topk over sum without partition",
541+
expr: `topk(10, sum by (instance) (http_requests_total))`,
542+
expected: `topk(10, sum by (instance) (dedup(remote(sum by (instance, region) (http_requests_total)), remote(sum by (instance, region) (http_requests_total)))))`,
543+
},
544+
{
545+
name: "count over sum by partition",
546+
expr: `count(sum by (region, pod) (http_requests_total))`,
547+
expected: `sum(dedup(remote(count by (region) (sum by (region, pod) (http_requests_total))), remote(count by (region) (sum by (region, pod) (http_requests_total)))))`,
548+
},
549+
{
550+
name: "topk over binary with group_left preserving partition",
551+
expr: `topk(5, metric_a * on (pod) group_left(region) metric_b)`,
552+
expected: `topk(5, dedup(remote(topk by (region) (5, metric_a * on (pod) group_left (region) metric_b)), remote(topk by (region) (5, metric_a * on (pod) group_left (region) metric_b))))`,
553+
},
554+
{
555+
name: "topk over binary with group_right preserving partition",
556+
expr: `topk(5, metric_a * on (pod) group_right(region) metric_b)`,
557+
expected: `topk(5, dedup(remote(topk by (region) (5, metric_a * on (pod) group_right (region) metric_b)), remote(topk by (region) (5, metric_a * on (pod) group_right (region) metric_b))))`,
558+
},
559+
{
560+
name: "topk over binary with on() including partition",
561+
expr: `topk(5, metric_a * on (region, pod) metric_b)`,
562+
expected: `topk(5, dedup(remote(topk by (region) (5, metric_a * on (region, pod) metric_b)), remote(topk by (region) (5, metric_a * on (region, pod) metric_b))))`,
563+
},
564+
{
565+
name: "topk over binary with on() excluding partition",
566+
expr: `topk(5, metric_a * on (pod) metric_b)`,
567+
expected: `topk(5, dedup(remote(metric_a), remote(metric_a)) * on (pod) dedup(remote(metric_b), remote(metric_b)))`,
568+
},
569+
{
570+
name: "topk over binary with ignoring() excluding partition",
571+
expr: `topk(5, metric_a * ignoring (pod) metric_b)`,
572+
expected: `topk(5, dedup(remote(topk by (region) (5, metric_a * ignoring (pod) metric_b)), remote(topk by (region) (5, metric_a * ignoring (pod) metric_b))))`,
573+
},
574+
{
575+
name: "topk over binary with ignoring() including partition",
576+
expr: `topk(5, metric_a * ignoring (region) metric_b)`,
577+
expected: `topk(5, dedup(remote(metric_a), remote(metric_a)) * ignoring (region) dedup(remote(metric_b), remote(metric_b)))`,
578+
},
579+
{
580+
name: "topk over sum with without() excluding partition",
581+
expr: `topk(5, sum without (pod) (metric_a))`,
582+
expected: `topk(5, dedup(remote(topk by (region) (5, sum without (pod) (metric_a))), remote(topk by (region) (5, sum without (pod) (metric_a)))))`,
583+
},
584+
{
585+
name: "topk over sum with without() including partition",
586+
expr: `topk(5, sum without (region) (metric_a))`,
587+
expected: `topk(5, sum without (region) (dedup(remote(sum without () (metric_a)), remote(sum without () (metric_a)))))`,
588+
},
531589
}
532590

533591
engines := []api.RemoteEngine{

0 commit comments

Comments
 (0)