Skip to content

Expand dimensional mapreduce / reduce#83

Merged
maleadt merged 39 commits into
JuliaGPU:mainfrom
shreyas-omkar:sh/mapreduce_nd_dims
Jun 19, 2026
Merged

Expand dimensional mapreduce / reduce#83
maleadt merged 39 commits into
JuliaGPU:mainfrom
shreyas-omkar:sh/mapreduce_nd_dims

Conversation

@shreyas-omkar

@shreyas-omkar shreyas-omkar commented Jun 1, 2026

Copy link
Copy Markdown
Member

This PR extends AcceleratedKernels.mapreduce and reduce from single-dimension reductions to a more Base-compatible dimensional reduction implementation. It adds tuple dims, dims=:, dims=(), oversized and duplicate dims handling, type-changing mapreduce, and multi-input mapreduce(f, op, A, B, ...; dims=...).

Notable API include:

  • dims is now accepted as Union{Nothing, Int, Tuple{Vararg{Int}}, Colon} across reduce, mapreduce, and the arithmetic wrappers. Dimensional reductions now follow Base behavior for tuple dims, duplicate dims, dims beyond ndims, empty dims, colon dims, and zero-sized reduced or kept dimensions.
  • Multi-input mapreduce is also supported, including an explicit backend argument after the input arrays.
  • The default neutral is now derived from typeof(init), preserving the documented result-type behavior for type-changing reductions.

Implementation

The original CartesianIndices-only GPU approach was correct but too expensive for common reductions because it moved index decoding and integer division into hot loops. The final implementation keeps Cartesian indexing as a generic fallback, but uses stride-based fast paths for dense and strided GPU sources backed by a single dense column-major buffer.

For fast-path sources, mapreduce_nd resolves the layout to (buffer, base_offset, strides). This lets views, offset views, adjoints, permuted dims, and reshapes use optimized flat-buffer kernels while still indexing the right underlying storage. Broadcasted, lazy, or otherwise non-dense sources use the generic fallback, which mirrors the CPU Cartesian path.

The kernel setup canonicalizes reduced and kept dimensions into contiguous stride segments. Common cases such as dims=2 or dims=(1, 2) collapse to a single reduce segment and avoid per-element division entirely. Non-contiguous reductions such as dims=(1, 3) still use multi-segment decoding, with power-of-two segment sizes optimized via shift/mask operations.

GPU reductions dispatch between one-thread-per-output, tiled strided, one-block-per-output, and multigroup strategies. The heuristics were tuned to avoid the earlier regressions from overusing multigroup reductions while preserving occupancy and coalescing-sensitive cases.

Performance

Several performance fixes were made: fast paths avoid hot-loop Cartesian indexing, common single-segment reductions avoid per-element division, multigroup reductions are gated to low-output large-reduction cases, by_block can grid-stride over outputs, and strided GPU sources now route through fast kernels instead of falling back unnecessarily.

The current implementation is broadly competitive with, and often faster than, the previous AK implementation and GPUArrays-style reductions across CUDA, Metal, and oneAPI. Remaining gaps are mostly backend-specialization opportunities such as subgroup collectives, vectorized loads/stores, small fixed-extent output-vectorized kernels, hardware-aware launch sizing, and selected vendor primitive dispatch.

@shreyas-omkar

Copy link
Copy Markdown
Member Author

@maleadt and @christiangnrd, Please take a look.

@maleadt

maleadt commented Jun 2, 2026

Copy link
Copy Markdown
Member

This seems to crash on Base.mapreduce(f, op, A; dims=(1,4)) with a 3D input, which is legal in Base.

Also, please keep the LLM attribution.

Comment thread src/reduce/mapreduce_nd.jl
Comment thread src/reduce/mapreduce_nd.jl Outdated
@shreyas-omkar shreyas-omkar marked this pull request as draft June 6, 2026 06:51
@shreyas-omkar shreyas-omkar force-pushed the sh/mapreduce_nd_dims branch from 9997d93 to 3f9e269 Compare June 10, 2026 05:36
@shreyas-omkar

Copy link
Copy Markdown
Member Author
Workload AK GPUArrays ref PyTorch AK vs PyTorch
100×100×100, dims=2 52.77μs 56.38μs 25μs 2.11×
3×100000×3, dims=2 98.82μs 56.89μs 33μs 2.99×
100×50×100, dims=(1,3) 79.72μs 57.84μs 28μs 2.85×
500×5×500, dims=(1,3) 99.5μs 63.94μs 36μs 2.76×
200×200×200, all dims 132.38μs 116.82μs 31μs 4.27×

Currently 2-4× slower than PyTorch, and beating or matching GPUArrays' own kernel on most cases. The implementation has dimension canonicalization, Val{sizes} for compile time index decode, multi-group reduction, and input staging.
The remaining gap to PyTorch comes down to three things I haven't been able to do yet:

Warp shuffle reduction: PyTorch replaces the last 5 levels of the shared memory tree with __shfl_down_sync, eliminating 5 @synchronize calls and shared memory writes. This alone is probably worth ~1.5× on the block reduction.
Vectorized loads: PyTorch loads float4 (4 elements per memory transaction). The inner loop currently loads one element per thread per cycle. Would need @load intrinsics in KA to do this portably.
Occupancy-based block size tuning: PyTorch queries the device for optimal thread count per block. We hardcode 256. For Blackwell this may not be optimal.

@shreyas-omkar shreyas-omkar force-pushed the sh/mapreduce_nd_dims branch from 3f9e269 to 93f3a8d Compare June 10, 2026 05:46
Comment thread src/reduce/mapreduce_nd.jl Outdated
@maleadt

maleadt commented Jun 10, 2026

Copy link
Copy Markdown
Member

Nit: I don't see you detecting duplicate dims (e.g. reduce(+, x; dims=(2,2))) which Base errors on.

Performance wise, I think you should also compare against the current AK.jl implementation; this work shouldn't significantly regress that. Testing locally with Metal.jl:

Workload AK main PR #83 Δ
100×100×100, dims=1 291.0 μs 368.2 μs 1.27× slower
100×100×100, dims=2 203.0 μs 231.9 μs 1.14× slower
100×100×100, dims=3 184.1 μs 286.0 μs 1.55× slower
3×100000×3, dims=2 263.1 μs 562.3 μs 2.14× slower
10×1000000, dims=2 1542.1 μs 4627.6 μs 3.00× slower
1000000×10, dims=1 1348.8 μs 4633.6 μs 3.44× slower

I'm guessing this is because of still doing _reduce_offset (integer division) inside the inner loop, and allocating a fresh partial array on every call.

AFAIU you're also specializing on the exact runtime array sizes (Val(outer_sizes_tup) / Val(reduce_sizes_tup)). That's not a viable path. PyTorch's OffsetCalculator exists exactly to avoid this.

@shreyas-omkar shreyas-omkar force-pushed the sh/mapreduce_nd_dims branch from 93f3a8d to 38ef5c5 Compare June 10, 2026 11:05
@maleadt

maleadt commented Jun 11, 2026

Copy link
Copy Markdown
Member

Pushed some optimizations that improve performance significantly on my M1 (purposefully chosen as an old, weak GPU).

Core fixes:

  • Avoid per-element idiv
  • Don't use the multi-group reduction that often, now only when are too few outputs to fill the GPU and the reduction is large (see TARGET_BLOCKS)
  • Remove the unroll that causes undue register pressure

Performance, in microseconds:

shape dims main AK GPUArrays PyTorch-MPS
(3, 1000000) 1 874 737 743 590
(3, 1000000) 2 1549 638 1258 478
(1000, 1000) 1 890 754 561 288
(1000, 1000) 2 617 687 571 294
(256, 4096) 1 540 565 557 263
(256, 4096) 2 542 541 622 299
(4096, 256) 1 342 379 345 286
(4096, 256) 2 352 387 566 409
(128, 128, 128) 1 792 812 803 335
(128, 128, 128) 2 453 456 486 323
(128, 128, 128) 3 440 468 372 323
(128, 128, 128) (1,2) n/a 526 620 318
(128, 128, 128) (2,3) n/a 855 1140 341
(128, 128, 128) (1,3) n/a 1654 611 327
(64, 64, 64, 64) 4 1586 1614 1588 2063
(1000000, 8) 1 1522 990 1168 1041
(1000000, 8) 2 1182 1142 1070 1667

@shreyas-omkar shreyas-omkar force-pushed the sh/mapreduce_nd_dims branch from 3194f0d to b65af66 Compare June 12, 2026 13:21
@shreyas-omkar

Copy link
Copy Markdown
Member Author

With TARGET_BLOCKS = 256 the numbers on SM_86 based GPU are:

GPU: NVIDIA GeForce RTX 3060

Shape Dims AK (μs) Ref (μs) CUDA.@sync sum(x; dims=dims) Ratio
(3, 1000000) 1 68.0 59.1 1.15×
(3, 1000000) 2 66.4 106.8 0.62×
(1000, 1000) 1 38.6 68.6 0.56×
(1000, 1000) 2 67.7 76.0 0.89×
(256, 4096) 1 66.6 83.1 0.80×
(256, 4096) 2 51.7 55.8 0.93×
(4096, 256) 1 35.8 43.4 0.83×
(4096, 256) 2 34.2 87.8 0.39×
(128, 128, 128) 1 76.6 209.0 0.37×
(128, 128, 128) 2 45.9 207.9 0.22×
(128, 128, 128) 3 45.6 214.8 0.21×
(128, 128, 128) (1, 2) 48.0 87.9 0.55×
(128, 128, 128) (2, 3) 83.1 101.6 0.82×
(128, 128, 128) (1, 3) 87.0 87.8 0.99×
(64, 64, 64, 64) 4 230.9 269.0 0.86×
(1000000, 8) 1 129.0 157.8 0.82×
(1000000, 8) 2 130.9 126.1 1.04×

@shreyas-omkar

Copy link
Copy Markdown
Member Author

The per-element offset decode for non-contiguous reduced dim sets (e.g.
dims=(1,3)) walked every segment with an integer div + remainder. Integer
division is ~20+ cycles on GPU vs ~1 for shift/and, and runs once per
element in the hot reduce loop.

For power-of-two segment sizes (the common case), replace div/rem with
trailing_zeros-shift and (sz-1)-mask; non-pow2 segments keep division.
The final segment needs no decode (the remaining quotient is the index),
saving one op. Single-segment reductions early-return before this loop and
are unaffected.

Results (dims=(1,3) on 128^3, the multi-segment case):
RTX 5080 : 1.76x -> 1.39x vs cuBLAS
RX 9060 XT: 66us -> 34us

shreyas-omkar and others added 14 commits June 17, 2026 20:42
The single reduce/outer segment that canonicalization produces in the common
case needs no division: the index is in-range, so the offset is just j*stride.
Restores the multiply-add inner loop the divmod decode had regressed.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The 4-way @unroll hurt coalesced stride-1 reductions and gave only noisy,
non-reproducible gains on strided ones (in-process A/B on Metal). Revert to the
plain strided while-loop, matching the pre-existing 1D reduction.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The old gate fired multi-group whenever dst_size >= block_size, exploding normal
reductions into thousands of blocks + a second pass (e.g. 1000x1000 dims=1).
Only split when there are too few outputs to fill the GPU and the reduction is
large; cap reduce_groups at block_size so the second pass is single-level,
dropping the recursive fallback.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Extract the repeated alloc-or-validate-temp block into _alloc_or_temp (also
adds the backend check the zero-dim case was missing), flatten the GPU path
behind an early return, drop the unused CPU `dims` argument, and rewrite the
stale header/section comments to match the implementation.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@maleadt maleadt force-pushed the sh/mapreduce_nd_dims branch from 9bb4f68 to 4bd1fc1 Compare June 18, 2026 09:18
@maleadt

maleadt commented Jun 18, 2026

Copy link
Copy Markdown
Member

Did a bunch of work yesterday, and managed to improve performance across the board some more, as verified on M1/M3/5080/Iris Xe:

Host/backend Case main ms PR start ms ratio
M3 Metal wide (3, 1_000_000), dims=1, Float32 1.7675 1.3185 0.746
M3 Metal wide (3, 1_000_000), dims=2, Float32 2.0915 0.6807 0.325
M3 Metal square (1024, 1024), dims=1, Float32 0.8879 0.5768 0.650
M3 Metal square (1024, 1024), dims=2, Float32 1.2720 1.5223 1.197
M1 Metal wide (3, 1_000_000), dims=1, Float32 1.8229 1.7677 0.970
M1 Metal wide (3, 1_000_000), dims=2, Float32 2.2677 0.8643 0.381
M1 Metal square (1024, 1024), dims=1, Float32 0.9338 0.6399 0.685
M1 Metal square (1024, 1024), dims=2, Float32 0.8068 0.8229 1.020
RTX 5080 CUDA wide (3, 1_000_000), dims=1, Float32 0.9715 0.9073 0.934
RTX 5080 CUDA wide (3, 1_000_000), dims=2, Float32 0.5900 0.0290 0.049
RTX 5080 CUDA square (1024, 1024), dims=1, Float32 0.1093 0.0221 0.202
RTX 5080 CUDA square (1024, 1024), dims=2, Float32 0.0350 0.0369 1.055
Iris Xe oneAPI wide (3, 1_000_000), dims=1, Float32 2.6961 2.8972 1.075
Iris Xe oneAPI wide (3, 1_000_000), dims=2, Float32 3.3583 1.1255 0.335
Iris Xe oneAPI square (1024, 1024), dims=1, Float32 1.1104 0.3612 0.325
Iris Xe oneAPI square (1024, 1024), dims=2, Float32 0.8543 0.8330 0.975

I also added support for multi-input mapreduce (AK.mapreduce(f, op, A, B, ...; init, dims=...)).

A bit of analysis by Codex:

Known Performance Deficiencies

  • wide_reduce_dim1 on CUDA is the top issue. The current by_thread path gives one thread one
    output and performs a three-element scalar loop. That is simple and portable, but it does not look
    like PyTorch's output-vectorized path. A better kernel would have each thread or lane group produce
    multiple adjacent outputs, use vectorized loads/stores when alignment permits, and avoid treating a
    three-element reduction like a general reduction tree problem.
  • The portable kernels still use scalar loads/stores. Vectorized memory operations should help the
    many-output cases and some tiled-strided cases, provided the implementation can prove alignment and
    avoid hurting non-contiguous inputs.
  • Generic multi-segment reductions such as dims=(1,3) still carry dynamic division code in the
    generated IR because segment sizes are runtime kernel arguments. Specializing segment sizes would
    remove more codegen overhead, but it would also increase compile-cache pressure and method
    multiplicity.
  • Multigroup's second pass is a separate kernel launch. For tiny outputs this is visible overhead.
    Atomics could combine partials in one pass for selected associative and commutative operations,
    but that changes ordering, narrows the supported operation set, and should not be the default
    portable semantics path.
  • The generic TARGET_BLOCKS=256 is intentionally conservative. PyTorch uses hardware-aware target
    grid sizing. A generic bump to 512 was mixed across CUDA and Metal and was rejected, but backend
    packages could choose hardware-specific launch targets.
  • M3 Metal still trails native on several current rows. The existing portable kernels are close, but
    Metal-specific SIMD-group reductions or vectorized memory paths may recover that remaining 10% to
    20%.

Simplification and Specialization

A KernelIntrinsics-style layer would be the best way to simplify the kernel internals without losing
performance. The hand-written shared-memory reduction trees, barriers, subgroup-sized assumptions,
and eventually vectorized load/store decisions could move behind portable collectives such as:

  • workgroup reduction for arbitrary associative operators
  • subgroup or warp reduction for CUDA warps, Metal SIMD-groups, and oneAPI/OpenCL subgroups
  • backend-lowered vectorized loads/stores where alignment and contiguous layout are known
  • possibly scoped atomics for explicitly commutative/associative fast paths

That would let AK keep the four high-level work decompositions while avoiding backend branches inside
AK's generic kernels. It would also give CUDA/Metal/oneAPI backends a place to lower the same source
to warp shuffles, SIMD-group reductions, or subgroup SPIR-V operations.

Backend-specific specialization is worth pursuing, but not in this PR's portable implementation:

  • CUDA could use CUB/CCCL-like block-reduce policies, warp shuffles, vectorized output paths,
    hardware-aware launch sizing, and cuBLAS/GEMV dispatch for pure matrix sums where semantics allow.
  • Metal could use SIMD-group collectives and Metal-specific threadgroup-memory tuning.
  • oneAPI and OpenCL could use subgroup collectives, but the current AK path already beats the tested
    native paths by a large margin, so this is lower priority.

Recommended next work:

  1. Add a CUDA-focused output-vectorized path for small fixed reduction extents, starting with
    (3, N), dims=1.
  2. Add vectorized load/store support for proven-contiguous layouts.
  3. Prototype subgroup/workgroup collectives in a KernelIntrinsics-style layer and replace the local
    reduction trees once the abstraction is stable.
  4. Consider optional backend dispatch for pure + reductions to vendor primitives where the result
    type, order expectations, and shape semantics line up.
  5. Keep atomics out of the generic path unless the API explicitly opts into commutative/associative
    fast semantics.

maleadt and others added 4 commits June 18, 2026 11:58
The stride-based fast-path kernels index the source by a flat linear offset
(`src[offset + 1]`), which is only valid for dense column-major arrays. Until
now any other GPU source was rejected with an ArgumentError, and a Broadcasted
over a non-trivial wrapper (e.g. PermutedDimsArray) failed to compile at all
because `@Const(src)` keeps the wrapper's bounds-throw from being elided.

Replace the hard dense-stride rejection with a route to a new generic kernel
`_mapreduce_nd_generic!`: one thread per output, reducing over the reduced
extents via Cartesian indexing (`src[J]`, `J = max(Iother, Ireduce)`), mirroring
the CPU `_mapreduce_nd_cpu_sections!` path. It makes no layout assumption and
deliberately does not wrap the source in `@Const`, so it compiles for strided
views, adjoints, permuted dims, and broadcasts over them. Dense arrays keep the
fast paths via the `_mapreduce_fastpath_dense` predicate; all Broadcasted
sources currently take the fallback, which also removes the synthetic
`_mapreduce_strides(::Broadcasted)` values that previously misled the
coalescing/tiled dispatch heuristics.

Tests that asserted rejection of strided GPU sources now assert correct results
(strided view and PermutedDimsArray). Verified on Metal (non-dense now correct,
dense fast paths unchanged) and the CPU suite.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
Extend the stride-based fast paths to any source backed by a single dense
column-major buffer — strided views, adjoints, permuted dims, and reshapes —
instead of sending them to the generic Cartesian fallback. The kernels already
do stride arithmetic; they were only limited by indexing the wrapper logically
(`src[offset+1]`), which assumed offset == linear index (dense only).

`_mapreduce_strided_layout` resolves a source to `(buffer, base_offset, strides)`:
the dense buffer to index, the wrapper's element offset within it, and the real
strides. The kernels now index `buffer[base_offset + Σ coordᵈ·strideᵈ + 1]`, with
`base_offset` folded into the per-output base so dense reductions pay at most one
extra add per output (never per reduced element). Passing the dense parent buffer
rather than the wrapper also keeps `@Const` valid, so this compiles for wrappers
(e.g. PermutedDimsArray) that the wrapper-as-`@Const` path cannot. The coalescing
and tiled-strided dispatch heuristics now see the true strides instead of the
fabricated dense strides a Broadcasted source used to supply.

Only `Broadcasted` and sources not backed by a dense buffer (lazy/computed arrays,
nested wrappers, complex adjoints without `strides`) take the generic fallback.

Verified vs Base on CPU (Pkg.test), Metal (M1/M3), and CUDA, covering strided
views with nonzero base offset, adjoints, permuted dims over the by_thread /
by_block / multigroup paths, and multi-input. No measurable dense-path regression
(CUDA/Metal A/B within run-to-run noise); strided sources that previously hit the
generic fallback are now ~8x faster on CUDA (e.g. strided-view dims=2 0.17→0.02 ms).

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…ard)

The earlier "destructure Val dims, drop div/rem intrinsics" cleanup changed the
tiled-strided and multigroup row/lane and output/group decode to
`unsigned(x) ÷ unsigned(const)`. That regressed square `dims=2` (tiled) by ~15-20%
and wide `dims=2` (multigroup) by ~4% on CUDA only.

Root cause is signedness, not the divide. The divide-by-constant folds to shift/and
on every backend (no div/rem instruction, no divide-by-zero check). But the unsigned
*result* flows into the signed index arithmetic (`iout = iblock*rows + row`, …),
forcing a checked `Int(::Unsigned)` conversion and a cold `throw(InexactError)`
guard. GPUCompiler's PTX backend (ptx.jl `lower_unreachable!`) keeps that guard in
the kernel — the compare plus a `call julia_throw_inexacterror` — only turning the
final `unreachable` into a trap; the guard then sits in the latency-bound hot path.
The Metal backend (metal.jl `lower_unreachable_control_flow!`) force-inlines the
throwing function and rewrites unreachable→ret, scrubbing it, so Metal/oneAPI never
regressed.

Fix: wrap the constant-divisor div/rem in `Int(...)` so the result stays signed. The
divide still folds to shift/and, and the signed result avoids the Int-conversion
guard entirely. PTX for the tiled kernel goes 157→134 lines, 3→0 inexact-throw sites.

Confirmed on an RTX 5080 vs tb/mapreduce_wip (the MAREDUCE reference): square dims=2
0.0207→0.0178 ms, 512² 0.0184→0.0167, wide dims=2 0.0281→0.0280; all dense shapes
now within run-to-run noise of WIP. CPU/Metal/CUDA correctness re-verified.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@maleadt maleadt changed the title Support tuple dims in mapreduce_nd using CartesianIndices Expand dimensional mapreduce / reduce Jun 19, 2026
@maleadt maleadt marked this pull request as ready for review June 19, 2026 06:42
@maleadt maleadt merged commit c485695 into JuliaGPU:main Jun 19, 2026
38 of 39 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants