Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 39 additions & 8 deletions docs/source/user_guide/subgroup.md

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc looks ok to me

Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,18 @@ The old names remain as deprecated aliases that emit a `DeprecationWarning` on f

### Voting and predicate ops

All three take a `log2_size` template parameter and reduce over each `2**log2_size` group of consecutive lanes, broadcasting the `i32` (`0` or `1`) result to every lane in the group. Same shape as `reduce_all_add` / `inclusive_*` / `exclusive_*`.
`subgroup.ballot` is a single-instruction hardware primitive that always operates over the full subgroup; it does not take a `log2_size`. The remaining three (`all_true` / `any_true` / `all_equal`) take a `log2_size` template parameter and reduce over each `2**log2_size` group of consecutive lanes, broadcasting the `i32` (`0` or `1`) result to every lane in the group. Same shape as `reduce_all_add` / `inclusive_*` / `exclusive_*`.

| Op | CUDA | AMDGPU | SPIR-V (Vulkan / Metal) |
|---------------------------------------------|---------------|--------|-------------------------|
| `subgroup.all_true(predicate, log2_size)` | yes (fast at `log2_size==5`) | yes | yes |
| `subgroup.any_true(predicate, log2_size)` | yes (fast at `log2_size==5`) | yes | yes |
| `subgroup.all_equal(value, log2_size)` | yes (fast at `log2_size==5`, transitively via `all_true`) | yes | yes |
| Op | CUDA | AMDGPU | SPIR-V (Vulkan / Metal) | dtypes / return |
|---------------------------------------------|---------------|--------|-------------------------|----------------------------|
| `subgroup.ballot(predicate)` | yes | yes | yes | `i32` predicate → `u32` bitmask |
| `subgroup.all_true(predicate, log2_size)` | yes (fast at `log2_size==5`) | yes | yes | `i32` predicate → `i32` (0/1) |
| `subgroup.any_true(predicate, log2_size)` | yes (fast at `log2_size==5`) | yes | yes | `i32` predicate → `i32` (0/1) |
| `subgroup.all_equal(value, log2_size)` | yes (fast at `log2_size==5`, transitively via `all_true`) | yes | yes | any value supporting `==` → `i32` (0/1) |

CUDA shortcut: when `log2_size == 5` (full warp), `all_true` / `any_true` lower to a single `__all_sync(0xFFFFFFFF, p)` / `__any_sync(0xFFFFFFFF, p)` (one `vote.all` / `vote.any` instruction). The shortcut is selected at trace time via `qd.static()` on `impl.current_cfg().arch` and the compile-time `log2_size`, so partial-warp uses (and every other backend) cleanly fall back to a portable `shuffle_xor` butterfly with no branch in the emitted IR.
`ballot` lowers to one instruction on every backend (`__ballot_sync` on CUDA, `v_ballot_b32` on AMDGPU, `OpGroupNonUniformBallot` on SPIR-V); see the dedicated semantics section below. The result covers the first 32 lanes; on AMDGPU CDNA wave64 only the low 32 bits are returned, consistent with the `u32` return type.

CUDA shortcut for the log2_size voters: when `log2_size == 5` (full warp), `all_true` / `any_true` lower to a single `__all_sync(0xFFFFFFFF, p)` / `__any_sync(0xFFFFFFFF, p)` (one `vote.all` / `vote.any` instruction). The shortcut is selected at trace time via `qd.static()` on `impl.current_cfg().arch` and the compile-time `log2_size`, so partial-warp uses (and every other backend) cleanly fall back to a portable `shuffle_xor` butterfly with no branch in the emitted IR.

`all_equal` always uses the broadcast-and-`all_true` form: every lane reads the value at the start of its group via `shuffle`, compares it with its own value, and `all_true`-reduces the per-lane equality bit. Cost: `1 + log2_size` shuffles in the portable case, or `1 shuffle + 1 vote.all` on CUDA at full-warp. We deliberately do *not* use `__match_all_sync` even on CUDA: it requires sm_70+, and it does bit-equality on floats, contradicting this op's documented `OpGroupNonUniformAllEqual` semantics (`NaN != NaN`, `+0.0 == -0.0`). Callers wanting bit-equality on floats should bit-cast to the same-width integer dtype before calling.

Expand Down Expand Up @@ -171,6 +174,17 @@ Returns `1` on lane 0 of every subgroup and `0` on every other lane. Useful for
- Caller contract on every backend: call from uniform control flow with all lanes active. Calling either op from divergent control flow has implementation-defined behaviour (CUDA's `nvvm.bar.warp.sync` will deadlock if the mask does not match the active set; AMDGPU's `wave.barrier` is a no-op on most chips so divergent calls silently pass through).
- The legacy names `subgroup.barrier()` and `subgroup.memory_barrier()` are still available as deprecated aliases. They forward to `sync()` / `mem_fence()` and emit a `DeprecationWarning` on first use; prefer the new names in new code.

### `ballot(predicate)`

Each lane evaluates `predicate` (an `i32`; non-zero is true, zero is false) and the result is a `u32` bitmask where bit `i` is set if lane `i`'s predicate was non-zero.

- Returns a `u32`. Bit 0 corresponds to lane 0, bit 1 to lane 1, etc.
- On CUDA, maps to `__ballot_sync(0xFFFFFFFF, predicate)`. On SPIR-V, maps to `OpGroupNonUniformBallot` (component 0 of the uvec4 result). On AMDGPU, maps to the `ballot.i32` intrinsic.
- The result covers the first 32 lanes. On AMDGPU CDNA with 64-wide wavefronts only the low 32 bits are returned; the upper 32 lanes are not represented. This is consistent with the 32-bit return type.
- Must be called from uniform control flow (all active lanes must execute the ballot).

Ballot is a building block for warp-cooperative algorithms: population counts (`popcount(ballot(cond))` counts how many lanes satisfy `cond`), prefix masks, and lane compaction.

### `reduce_add(value, log2_size)`

Sums `value` across `2**log2_size` consecutive lanes via a `shuffle_down` tree. The result is valid **in lane 0** of each group; other lanes hold partial sums and should be considered undefined.
Expand Down Expand Up @@ -245,6 +259,21 @@ def broadcast(a: qd.types.ndarray(dtype=qd.f32, ndim=1)):

After the kernel, every lane in a subgroup holds the original value of its lane 0. `subgroup.broadcast(a[i], qd.u32(0))` is interchangeable here.

### Ballot: count how many lanes satisfy a condition

```python
@qd.kernel
def count_positive(a: qd.types.ndarray(dtype=qd.f32, ndim=1),
counts: qd.types.ndarray(dtype=qd.u32, ndim=1)):
qd.loop_config(block_dim=32)
for i in range(a.shape[0]):
mask = subgroup.ballot(qd.i32(a[i] > 0.0))
if subgroup.invocation_id() == 0:
counts[i // 32] = mask
```

After the kernel, `counts[g]` contains a bitmask of which lanes in group `g` had positive values. Use `popcount(mask)` on the host to get the count.

### Identity shuffle (each lane reads its own id)

Useful as a sanity check:
Expand Down Expand Up @@ -362,6 +391,7 @@ After the call, lane `k` (within each group of 32) holds `a[group_start] + a[gro
- Shuffles are register-to-register on CUDA (`__shfl_sync`, `__shfl_down_sync`, `__shfl_up_sync`) and on SPIR-V where the GPU has hardware support — typically a handful of cycles, no memory traffic.
- AMDGPU `shuffle`, `shuffle_down`, and `shuffle_up` all go through `ds_permute` / `ds_bpermute` today (LDS-routed, roughly tens of cycles).
- `shuffle_xor` and `broadcast_first` are `@qd.func` wrappers over `shuffle` / `broadcast` and inline at trace time, so on every backend they cost exactly the same as the underlying op.
- `ballot` is a single hardware instruction on all backends — one cycle on CUDA (`__ballot_sync`), one instruction on AMDGPU (`v_ballot_b32`), and `OpGroupNonUniformBallot` on SPIR-V.
- `reduce_add` and `reduce_all_add` both issue exactly `log2_size` shuffles and `log2_size` adds per call. No barriers, no shared memory, no launch overhead (they inline).
- Pick `reduce_all_add` over `reduce_add + broadcast` when you need the result in every lane — same cost, one fewer shuffle.
- 64-bit dtypes (`i64`, `u64`, `f64`) are emulated as two 32-bit shuffles on AMDGPU. Prefer 32-bit values when you have a choice.
Expand All @@ -370,4 +400,5 @@ After the call, lane `k` (within each group of 32) holds `a[group_start] + a[gro
## Related

- [tile16](tile16.md) — `Tile16x16` builds on `subgroup.shuffle` to implement register-resident 16×16 matrix tiles.
- `qd.simt.warp.*` — CUDA-only counterparts (`warp.all_nonzero`, `warp.any_nonzero`, `warp.unique`, `warp.ballot`, `warp.match_*`, `warp.active_mask`, ...). The voting ops (`all_nonzero` / `any_nonzero` / `unique`) overlap with the new portable `subgroup.{all_true, any_true}`; the rest stay CUDA-bound. Useful when you need explicit active-mask control or an op that has no portable equivalent yet.
- `subgroup.ballot` — single-instruction u32 bitmask of lanes where the predicate is non-zero (see above).
- `qd.simt.warp.*` — CUDA-only counterparts (`warp.all_nonzero`, `warp.any_nonzero`, `warp.unique`, `warp.match_*`, `warp.active_mask`, ...). The voting ops (`all_nonzero` / `any_nonzero` / `unique`) overlap with the new portable `subgroup.{all_true, any_true}` and `warp.ballot` overlaps with `subgroup.ballot`; the rest stay CUDA-bound. Useful when you need explicit active-mask control or an op that has no portable equivalent yet.
11 changes: 11 additions & 0 deletions python/quadrants/lang/simt/subgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ def elect():
return i32(invocation_id() == 0)


def ballot(predicate):
"""Return a ``u32`` bitmask whose bit ``i`` is set iff lane ``i``'s ``predicate`` is non-zero.

Single hardware instruction on every backend (``__ballot_sync`` on CUDA, ``v_ballot_b32`` on AMDGPU,
``OpGroupNonUniformBallot`` on SPIR-V). The result covers the first 32 lanes; on AMDGPU CDNA wave64 only the low
32 bits are returned, consistent with the ``u32`` return type. Caller contract: uniform CF + all lanes active.
"""
return impl.call_internal("subgroupBallot", predicate, with_runtime_context=False)


# --- Voting / predicate ops ------------------------------------------------------------
#
# All three are group-scoped over ``2**log2_size`` consecutive lanes, mirror the API of ``reduce_all_add`` /
Expand Down Expand Up @@ -443,6 +453,7 @@ def shuffle_down(value, offset):
"barrier",
"memory_barrier",
"elect",
"ballot",
"all_true",
"any_true",
"all_equal",
Expand Down
2 changes: 2 additions & 0 deletions quadrants/codegen/amdgpu/codegen_amdgpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,8 @@ class TaskCodeGenAMDGPU : public TaskCodeGenLLVM {
llvm_val[stmt] = emit_amdgpu_shuffle_up(
/* value=*/llvm_val[stmt->args[0]],
/* dt=*/stmt->args[0]->ret_type, offset);
} else if (stmt->func_name == "subgroupBallot") {
llvm_val[stmt] = call("amdgpu_ballot_i32", llvm_val[stmt->args[0]]);
} else if (stmt->func_name == "subgroupInvocationId") {
llvm_val[stmt] = call("amdgpu_lane_id");
} else if (stmt->func_name == "subgroupSize") {
Expand Down
2 changes: 2 additions & 0 deletions quadrants/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
/* value=*/llvm_val[stmt->args[0]],
/* dt=*/stmt->args[0]->ret_type,
/* offset=*/llvm_val[stmt->args[1]]);
} else if (stmt->func_name == "subgroupBallot") {
llvm_val[stmt] = call("cuda_ballot_i32", llvm_val[stmt->args[0]]);
} else if (stmt->func_name == "subgroupInvocationId") {
llvm_val[stmt] = call("cuda_lane_id");
} else if (stmt->func_name == "subgroupSize") {
Expand Down
7 changes: 7 additions & 0 deletions quadrants/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1452,6 +1452,13 @@ void TaskCodegen::visit(InternalFuncStmt *stmt) {
auto index = ir_->query_value(stmt->args[1]->raw_name());
val = ir_->make_value(spv::OpGroupNonUniformBroadcast, value.stype,
ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), value, index);
} else if (stmt->func_name == "subgroupBallot") {
auto predicate = ir_->query_value(stmt->args[0]->raw_name());
auto pred_bool =
ir_->make_value(spv::OpINotEqual, ir_->bool_type(), predicate, ir_->int_immediate_number(ir_->i32_type(), 0));
auto ballot_vec = ir_->make_value(spv::OpGroupNonUniformBallot, ir_->v4_u32_type(),
ir_->int_immediate_number(ir_->i32_type(), spv::ScopeSubgroup), pred_bool);
val = ir_->make_value(spv::OpCompositeExtract, ir_->u32_type(), ballot_vec, 0);
} else if (shuffle_ops.find(stmt->func_name) != shuffle_ops.end()) {
auto arg0 = ir_->query_value(stmt->args[0]->raw_name());
auto arg1 = ir_->query_value(stmt->args[1]->raw_name());
Expand Down
3 changes: 3 additions & 0 deletions quadrants/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,9 @@ void IRBuilder::init_pre_defs() {
t_v3_uint_.id = id_counter_++;
ib_.begin(spv::OpTypeVector).add(t_v3_uint_).add_seq(t_uint32_, 3).commit(&global_);

t_v4_uint_.id = id_counter_++;
ib_.begin(spv::OpTypeVector).add(t_v4_uint_).add_seq(t_uint32_, 4).commit(&global_);

t_v4_fp32_.id = id_counter_++;
ib_.begin(spv::OpTypeVector).add(t_v4_fp32_).add_seq(t_fp32_, 4).commit(&global_);

Expand Down
4 changes: 4 additions & 0 deletions quadrants/codegen/spirv/spirv_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,9 @@ class IRBuilder {
SType f32_type() const {
return t_fp32_;
}
SType v4_u32_type() const {
return t_v4_uint_;
}

SType i16_type() const {
return t_int16_;
Expand Down Expand Up @@ -574,6 +577,7 @@ class IRBuilder {
SType t_v2_int_;
SType t_v3_int_;
SType t_v3_uint_;
SType t_v4_uint_;
SType t_v4_fp32_;
SType t_v3_fp32_;
SType t_v2_fp32_;
Expand Down
1 change: 1 addition & 0 deletions quadrants/inc/internal_ops.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ PER_INTERNAL_OP(subgroupBroadcast)
PER_INTERNAL_OP(subgroupShuffle)
PER_INTERNAL_OP(subgroupShuffleDown)
PER_INTERNAL_OP(subgroupShuffleUp)
PER_INTERNAL_OP(subgroupBallot)
PER_INTERNAL_OP(subgroupSize)
PER_INTERNAL_OP(subgroupInvocationId)
// subgroupAdd / subgroupMul / subgroupMin / subgroupMax / subgroupAnd / subgroupOr / subgroupXor removed: use portable
Expand Down
1 change: 1 addition & 0 deletions quadrants/ir/type_system.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,7 @@ void Operations::init_internals() {
POLY_OP(subgroupShuffle, false, Signature({}, {ValueT, !u32}, ValueT));
POLY_OP(subgroupShuffleDown, false, Signature({}, {ValueT, !u32}, ValueT));
POLY_OP(subgroupShuffleUp, false, Signature({}, {ValueT, !u32}, ValueT));
PLAIN_OP(subgroupBallot, u32, false, i32);
PLAIN_OP(subgroupSize, i32, false);
PLAIN_OP(subgroupInvocationId, i32, false);
// subgroupAdd / subgroupMul / subgroupMin / subgroupMax / subgroupAnd / subgroupOr / subgroupXor
Expand Down
1 change: 1 addition & 0 deletions quadrants/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,7 @@ std::unique_ptr<llvm::Module> QuadrantsLLVMContext::module_from_file(const std::
patch_intrinsic("amdgpu_ds_bpermute", llvm::Intrinsic::amdgcn_ds_bpermute);
patch_intrinsic("amdgpu_mbcnt_lo", llvm::Intrinsic::amdgcn_mbcnt_lo);
patch_intrinsic("amdgpu_mbcnt_hi", llvm::Intrinsic::amdgcn_mbcnt_hi);
patch_intrinsic("amdgpu_ballot_w32", llvm::Intrinsic::amdgcn_ballot, true, {llvm::Type::getInt32Ty(*ctx)});

link_module_with_amdgpu_libdevice(module);
patch_amdgpu_kernel_dim("block_dim", llvm::ConstantInt::get(llvm::Type::getInt32Ty(*ctx), 0));
Expand Down
9 changes: 9 additions & 0 deletions quadrants/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -990,6 +990,15 @@ i32 amdgpu_mbcnt_hi(i32 mask, i32 base) {
return 0;
}

i32 amdgpu_ballot_w32(bool bit) {
__builtin_trap();
return 0;
}

i32 amdgpu_ballot_i32(i32 predicate) {
return amdgpu_ballot_w32((bool)predicate);
}

i32 amdgpu_lane_id() {
return amdgpu_mbcnt_hi(-1, amdgpu_mbcnt_lo(-1, 0));
}
Expand Down
83 changes: 83 additions & 0 deletions tests/python/test_simt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1282,6 +1282,89 @@ def run_and_check(label, src_values, expected_per_group):
run_and_check("all_nan", [nan] * N, lambda g: 0)


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_all_true():
"""Ballot with all lanes voting true should return a full bitmask."""
N = 32
result = qd.field(dtype=qd.u32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
result[i] = subgroup.ballot(1)

foo()

for i in range(N):
assert result[i] != 0, f"lane {i}: ballot returned 0, expected non-zero"


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_all_false():
"""Ballot with all lanes voting false should return zero."""
N = 32
result = qd.field(dtype=qd.u32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
result[i] = subgroup.ballot(0)

foo()

for i in range(N):
assert result[i] == 0, f"lane {i}: ballot returned {result[i]}, expected 0"


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_even_lanes():
"""Even-numbered lanes vote true; odd lanes vote false."""
N = 32
result = qd.field(dtype=qd.u32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
lane = subgroup.invocation_id()
result[i] = subgroup.ballot(1 - lane % 2)

foo()

mask = result[0]
assert mask & 0x1, "lane 0 should have voted true"
assert not (mask & 0x2), "lane 1 should have voted false"
assert mask & 0x4, "lane 2 should have voted true"
assert not (mask & 0x8), "lane 3 should have voted false"


@test_utils.test(arch=qd.gpu)
def test_subgroup_ballot_popcount():
"""Verify popcount of ballot(1) equals the subgroup size."""
N = 32
ballot_val = qd.field(dtype=qd.u32, shape=N)
sg_size = qd.field(dtype=qd.i32, shape=N)

@qd.kernel
def foo():
qd.loop_config(block_dim=N)
for i in range(N):
ballot_val[i] = subgroup.ballot(1)
sg_size[i] = subgroup.group_size()

foo()

bv = ballot_val[0]
sz = sg_size[0]
actual_popcount = bin(bv).count("1")
expected = min(sz, N)
assert (
actual_popcount == expected
), f"popcount({bv:#x}) = {actual_popcount}, expected {expected} (subgroup size {sz})"


@test_utils.test(arch=qd.gpu)
def test_subgroup_invocation_id_range():
"""Verify invocation IDs are non-negative."""
Expand Down
Loading