You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
cuLA's decode kernel uses a pretransposed state layout [V, K] (K-last, i.e., [B*H, V, K]), which is bank-conflict-friendly: with K contiguous in SMEM, threads in a warp reading different V-rows at the same K-offset land on different banks. This layout should be kept — it is efficient for decode.
However, the prefill/chunk kernels (chunk_gated_delta_rule_fwd_h, kda_fwd_prefill) currently output final state in [K, V] layout (K-first, FLA convention: [N, H, K, V]). This means every prefill → decode transition requires a caller-side transpose:
This transpose is a full GPU memcpy (64KB per head at K=V=128 fp32), and it happens on every decode step in serving. For B=256, H=64, total transpose traffic is ~1 GB — pure waste.
The goal is to make the prefill kernels optionally output [V, K] state directly, so decode can consume it zero-copy.
Decode tile config: TILE_V=8, TILE_K=128, 4 warps, 2-stage cp.async pipeline — designed for K-contiguous (K-last) access
Tasks
Phase 1: Analysis
Trace all paths that produce final_state in prefill (chunk_gated_delta_rule_fwd_h, kda_fwd_prefill, hopper_fused_fwd) — identify where the [K, V] layout is materialized (kernel-level write pattern vs. post-kernel reshape)
Determine whether the prefill kernel's state accumulation loop naturally writes [K, V] or if it can be trivially reordered to write [V, K] — e.g., swapping the inner/outer loop or transposing the register tile before writeback
Quantify the transpose cost: measure .permute(0,1,3,2).contiguous() latency for typical sizes (B=1..256, H=64, K=V=128, fp32)
Phase 2: Prefill Kernel Modification
Add a transpose_state (or state_layout) parameter to chunk_gated_delta_rule_fwd_h — when enabled, write final_state as [N, H, V, K] instead of [N, H, K, V]
Implement the transposed writeback in the CuTe DSL kernel: either swap the store loop order, or transpose the register tile in-register before the final store (register transpose is free compared to GMEM transpose)
Add the same option to kda_fwd_prefill / Hopper fused path if applicable
Ensure initial_state input still accepts [K, V] (standard FLA convention) — the transpose only applies to the output
Phase 3: End-to-End Integration
Wire the transpose_state flag through the chunk forward orchestration (chunk_fwd.py) so it's exposed to callers
Update decode call sites to consume pretransposed state directly — remove the .permute(0,1,3,2).contiguous() workaround
Update tests to verify prefill [V, K] output feeds directly into decode without transpose
Benchmark the end-to-end prefill → decode pipeline with and without the transpose elimination
Alternative Approaches
A. V-Last [K, V] Decode Kernel
Instead of changing prefill output, write a new decode kernel that natively consumes [K, V] (V-last / V-contiguous) state. This eliminates the transpose by adapting the decode side.
Recommended Approach
Primary: Prefill transposed writeback (Phase 2 above) — minimal code change, zero decode regression, register transpose is essentially free.
Secondary: If prefill modification proves difficult (e.g., fused Hopper kernel has rigid output layout), fall back to Alternative A (V-Last decode kernel with SMEM swizzle). This is more work but self-contained on the decode side.
Success Criteria
Prefill can output [V, K] state directly via a flag — no post-kernel transpose needed
Decode consumes prefill output zero-copy — no .permute().contiguous() in the serving path
No prefill performance regression from transposed writeback (target: within 1%)
End-to-end latency improvement measurable at large batch sizes (B>=64)
Motivation
cuLA's decode kernel uses a pretransposed state layout
[V, K](K-last, i.e.,[B*H, V, K]), which is bank-conflict-friendly: with K contiguous in SMEM, threads in a warp reading different V-rows at the same K-offset land on different banks. This layout should be kept — it is efficient for decode.However, the prefill/chunk kernels (
chunk_gated_delta_rule_fwd_h,kda_fwd_prefill) currently output final state in[K, V]layout (K-first, FLA convention:[N, H, K, V]). This means every prefill → decode transition requires a caller-side transpose:This transpose is a full GPU memcpy (64KB per head at K=V=128 fp32), and it happens on every decode step in serving. For B=256, H=64, total transpose traffic is ~1 GB — pure waste.
The goal is to make the prefill kernels optionally output
[V, K]state directly, so decode can consume it zero-copy.Background
chunk_delta_h.py→final_state: [N, H, K, V](chunk_delta_h.py:2029)kda_fwd_prefill→final_state(hopper_fused_fwd.py:106)[B*H, V, K](pretransposed) (la_decode.py:86)Tasks
Phase 1: Analysis
final_statein prefill (chunk_gated_delta_rule_fwd_h,kda_fwd_prefill,hopper_fused_fwd) — identify where the[K, V]layout is materialized (kernel-level write pattern vs. post-kernel reshape)[K, V]or if it can be trivially reordered to write[V, K]— e.g., swapping the inner/outer loop or transposing the register tile before writeback.permute(0,1,3,2).contiguous()latency for typical sizes (B=1..256, H=64, K=V=128, fp32)Phase 2: Prefill Kernel Modification
transpose_state(orstate_layout) parameter tochunk_gated_delta_rule_fwd_h— when enabled, writefinal_stateas[N, H, V, K]instead of[N, H, K, V]kda_fwd_prefill/ Hopper fused path if applicableinitial_stateinput still accepts[K, V](standard FLA convention) — the transpose only applies to the outputPhase 3: End-to-End Integration
transpose_stateflag through the chunk forward orchestration (chunk_fwd.py) so it's exposed to callers.permute(0,1,3,2).contiguous()workaround[V, K]output feeds directly into decode without transposeAlternative Approaches
A. V-Last
[K, V]Decode KernelInstead of changing prefill output, write a new decode kernel that natively consumes
[K, V](V-last / V-contiguous) state. This eliminates the transpose by adapting the decode side.Recommended Approach
Primary: Prefill transposed writeback (Phase 2 above) — minimal code change, zero decode regression, register transpose is essentially free.
Secondary: If prefill modification proves difficult (e.g., fused Hopper kernel has rigid output layout), fall back to Alternative A (V-Last decode kernel with SMEM swizzle). This is more work but self-contained on the decode side.
Success Criteria
[V, K]state directly via a flag — no post-kernel transpose needed.permute().contiguous()in the serving path