You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The Lightning Attention decode kernel currently lacks any performance data on H20 — one of the most widely deployed GPUs for inference serving. We need to establish a baseline and optimize, especially for the big batch regime.
H20 hardware specs:
Spec
H20
Architecture
Hopper SM90
SMs
78
FP32
74 TFLOPS
HBM
96 GB HBM3
Memory BW
4.0 TB/s
L2 Cache
60 MB
At big batch (B>=64), the current kernel is on par with FLA Triton (~1.0x) on Blackwell. Big batch is the primary inference serving scenario and the core optimization target. The small batch speedup seen on Blackwell is mainly due to lower TVM-FFI launch overhead rather than kernel-level compute advantages.
Current Kernel Architecture
Dispatch: B <= 32 → small_batch (8 blocks per (batch,head))
B > 32 → big_batch (1 block per (batch,head))
Both paths share the same micro-architecture:
- 128 threads / 4 warps per block
- TILE_V=8, TILE_K=128, 2-stage cp.async pipeline
- State [V,K] fp32: GMEM → SMEM → registers, warp shuffle reduction for output
Tasks
Phase 1: H20 Baseline & Profiling
Run bench_la_decode_vs_fla.py on H20 and publish BENCHMARK_H20.md (covering B=1,4,16,32,64,128,256)
Profile with ncu for B=64,128,256:
DRAM throughput vs. peak 4.0 TB/s — confirm bandwidth utilization
SM occupancy / achieved occupancy
Warp stall breakdown (memory dependency vs. barrier vs. compute)
L2 hit rate — state tile reuse
Determine the big batch bottleneck: is bandwidth saturated, or is kernel efficiency the limiting factor?
Phase 2: Big batch optimization (B>=64) — primary goal
At big batch, each (batch, head) pair reads + writes a full [V, K] fp32 state (128x128x4 = 64 KB). At B=64, H=64, total state traffic is ~8.6 GB; theoretical minimum at 4.0 TB/s is ~2.15ms. Need to analyze the gap between actual and theoretical time to identify optimization targets.
K-dimension splitting: FLA Triton uses NK=2 to split the K dimension, each block handling K/2. cuLA currently loads full K=128 per block (vec_size=4, 4x32 lanes). Splitting K can: (1) halve per-block state load, (2) improve L2 locality when two K-half blocks co-schedule on the same SM, (3) tradeoff: output requires cross-block reduction.
Increase warp count (4→8): Currently 4 warps, processing 4 rows per iteration (out of TILE_V=8). With 8 warps, each iteration processes 8 rows, halving V-loop iterations and increasing per-block compute density. Need to verify register pressure and occupancy tradeoff.
Increase TILE_V (8→16 or 32): With TILE_V=8, V=128 requires 16 loop iterations, each incurring cp.async + barrier overhead. Larger TILE_V reduces loop count; tradeoff is increased SMEM usage (TILE_V x TILE_K x 4B per stage).
Vectorize state writeback: Currently state writes are element-wise (gDst[...] = r_h[i], scalar stores). For big batch, state writeback is half of total GMEM traffic. Using vectorized stores (float4) or TMA store can improve write bandwidth utilization.
Evaluate TMA to replace cp.async: SM90 natively supports TMA. State tiles are regular contiguous memory blocks (TILE_V x TILE_K fp32); TMA can replace multiple cp.async instructions with a single descriptor-based load, reducing instruction overhead and potentially improving SMEM fill efficiency.
Phase 3: Small batch evaluation (B<=32)
Lower priority than big batch, but still need to verify basic performance on H20.
Confirm SM utilization at small batch on H20 (78 SMs). At B=1, H=64, grid = 512 blocks, which requires ~7 waves on 78 SMs — should be sufficient.
Evaluate whether NUM_BLOCKS_PER_STATE needs adjustment for 78 SMs (currently fixed at 8).
Success Criteria
Publish BENCHMARK_H20.md with complete decode performance data
Big batch (B>=64): achieve >= 1.2x speedup over FLA on H20 (currently ~1.0x on Blackwell)
Small batch (B<=16): maintain reasonable performance on H20 (no regression)
Description
The Lightning Attention decode kernel currently lacks any performance data on H20 — one of the most widely deployed GPUs for inference serving. We need to establish a baseline and optimize, especially for the big batch regime.
H20 hardware specs:
At big batch (B>=64), the current kernel is on par with FLA Triton (~1.0x) on Blackwell. Big batch is the primary inference serving scenario and the core optimization target. The small batch speedup seen on Blackwell is mainly due to lower TVM-FFI launch overhead rather than kernel-level compute advantages.
Current Kernel Architecture
Tasks
Phase 1: H20 Baseline & Profiling
bench_la_decode_vs_fla.pyon H20 and publishBENCHMARK_H20.md(covering B=1,4,16,32,64,128,256)ncufor B=64,128,256:Phase 2: Big batch optimization (B>=64) — primary goal
At big batch, each (batch, head) pair reads + writes a full
[V, K]fp32 state (128x128x4 = 64 KB). At B=64, H=64, total state traffic is ~8.6 GB; theoretical minimum at 4.0 TB/s is ~2.15ms. Need to analyze the gap between actual and theoretical time to identify optimization targets.gDst[...] = r_h[i], scalar stores). For big batch, state writeback is half of total GMEM traffic. Using vectorized stores (float4) or TMA store can improve write bandwidth utilization.Phase 3: Small batch evaluation (B<=32)
Lower priority than big batch, but still need to verify basic performance on H20.
NUM_BLOCKS_PER_STATEneeds adjustment for 78 SMs (currently fixed at 8).Success Criteria
BENCHMARK_H20.mdwith complete decode performance data