Description
FLA has implemented Intracard CP (fla/ops/common/intracard_cp.py) — a technique that splits long sequences into sub-sequences and processes them in parallel to break the sequential bottleneck of the fwd_h state recurrence. cuLA already has a highly optimized CUDA fwd_h kernel (cula/ops/chunk_delta_h.py) using SM10X tcgen05 MMA + register-carry.
This issue is to analyze the full Intracard CP pipeline, identify which components have real CUDA optimization/fusion opportunities with measurable speedup, and implement them.
Current FLA Pipeline Breakdown
When Intracard CP activates (long sequence + varlen + inference mode), the execution flow is:
1. [Python] compute_subseq_len — CPU: compute optimal split length
2. [Python] prepare_subseq_cu_seqlens — CPU: insert split points into cu_seqlens
3. [Python] _precompute_intracard_indices — CPU: compute scatter/gather indices
4. [Python] torch.tensor(...) — CPU→GPU: create cu_seqlens_subseq, cu_seqlens_split tensors
5. [Triton] pre_process_fwd_kernel_merged — GPU: for each sub-seq, compute (h_ext[K,V], M[K,K])
6. [Triton] merge_fwd_bwd_kernel — GPU: chain multiply M @ h + h_ext across sub-seqs
7. [Python] initial_state_expanded[...] — GPU: scatter merged states into expanded tensor
8. [Triton] prepare_chunk_indices — GPU: compute chunk indices for sub-seq cu_seqlens
9. [Triton] chunk_gated_delta_rule_fwd_h — GPU: main fwd_h with corrected initial states
Optimization Opportunity Analysis
Opportunity 1 (HIGH): CUDA pre_scan kernel — replace step 5
The pre_process_fwd_kernel_merged Triton kernel is structurally almost identical to fwd_h: it iterates over chunks doing the same recurrence (h = decay * h + k^T @ v_new), plus additionally computes transition matrix M via chain multiply (M = (diag(decay) - k^T @ w) @ M).
cuLA's CUDA fwd_h is already faster than FLA's Triton version on SM10X thanks to register-carry + tcgen05 MMA + TMA pipeline. The same optimization techniques apply directly:
- Register-carry for
h_ext accumulation (eliminate GMEM roundtrip)
- tcgen05 MMA for
k^T @ v_new and k^T @ w matmuls
- TMA async pipeline for k, w, u loads
The M computation ([K,K] matrix, K=128 → 128x128 fp32) is an additional output that can be computed in the same kernel using the CUDA warp group that's already doing the gating/accumulation work.
Key question to verify: is pre_scan actually on the critical path? When there are few splits (e.g., 4), step 5 launches one block per sub-seq per head — this may be too small to matter. Profile first.
Opportunity 2 (MEDIUM): Fuse steps 5+6 — pre_scan + merge in single kernel
Currently pre_scan writes (h_ext, M) to GMEM, then merge reads them back. For a typical 4-way split:
- pre_scan output: 4 × H × K × (K+V) fp32 = 4 × H × 128 × 256 × 4 bytes
- merge: reads the same, does a small chain multiply
If both are in one persistent kernel, the pre_scan CTA can signal completion via mbarrier, and the merge CTA can consume (h_ext, M) without GMEM roundtrip (via L2 or shared memory between CTAs in the same cluster). However, the merge is sequentially dependent on pre_scan completing for all sub-seqs, so the fusion benefit depends on whether the GMEM traffic is actually the bottleneck vs. the serial dependency.
Opportunity 3 (MEDIUM): Fuse steps 6+7+8 — merge + scatter + chunk_indices
After merge, there are three small operations:
- Scatter merged states into
initial_state_expanded (Python indexing, ~4 small copies)
prepare_chunk_indices — a Triton kernel to compute chunk indices from new cu_seqlens
- The actual
fwd_h kernel reads initial_state_expanded
These are tiny kernels with high launch overhead relative to work. Options:
- Fuse the scatter into the merge kernel (merge kernel directly writes to
initial_state_expanded at the correct offsets)
- Pre-compute chunk indices on CPU (they're deterministic from cu_seqlens) to eliminate kernel launch
Opportunity 4 (LOW): Merge kernel CUDA rewrite — step 6
The merge kernel does a chain multiply: h = M_j @ h + h_ext_j for ~4 iterations with [K,K] @ [K,V] matmuls (K=128, V=128). This is a tiny amount of compute (~4 GEMM of 128x128). Triton handles this fine; CUDA rewrite would save kernel launch overhead (~5μs) but the absolute speedup is small.
Opportunity 5 (LOW): Eliminate CPU→GPU transfers — step 4
FLA already caches GPU tensors keyed by cu_seqlens identity. The CPU→GPU transfer only happens on cache miss. In steady-state serving (same batch structure repeated), this is already amortized to near-zero.
Recommended Approach
- Profile first: Run the FLA intracard path on representative long-sequence workloads (e.g., 128K+ tokens, H=8, K=V=128) on SM10X. Use nsys to measure time spent in each step. This determines which opportunities are worth pursuing.
- Implement CUDA pre_scan (Opp 1) if profiling confirms it's on the critical path — high confidence of speedup since the kernel structure matches cuLA's existing optimized
fwd_h.
- Fuse merge + scatter (Opp 3) — small effort, eliminates launch overhead.
- Evaluate pre_scan + merge fusion (Opp 2) only if GMEM traffic between them is significant.
Tasks
References
Description
FLA has implemented Intracard CP (
fla/ops/common/intracard_cp.py) — a technique that splits long sequences into sub-sequences and processes them in parallel to break the sequential bottleneck of thefwd_hstate recurrence. cuLA already has a highly optimized CUDAfwd_hkernel (cula/ops/chunk_delta_h.py) using SM10X tcgen05 MMA + register-carry.This issue is to analyze the full Intracard CP pipeline, identify which components have real CUDA optimization/fusion opportunities with measurable speedup, and implement them.
Current FLA Pipeline Breakdown
When Intracard CP activates (long sequence + varlen + inference mode), the execution flow is:
Optimization Opportunity Analysis
Opportunity 1 (HIGH): CUDA pre_scan kernel — replace step 5
The
pre_process_fwd_kernel_mergedTriton kernel is structurally almost identical tofwd_h: it iterates over chunks doing the same recurrence (h = decay * h + k^T @ v_new), plus additionally computes transition matrixMvia chain multiply (M = (diag(decay) - k^T @ w) @ M).cuLA's CUDA
fwd_his already faster than FLA's Triton version on SM10X thanks to register-carry + tcgen05 MMA + TMA pipeline. The same optimization techniques apply directly:h_extaccumulation (eliminate GMEM roundtrip)k^T @ v_newandk^T @ wmatmulsThe
Mcomputation ([K,K]matrix, K=128 → 128x128 fp32) is an additional output that can be computed in the same kernel using the CUDA warp group that's already doing the gating/accumulation work.Key question to verify: is pre_scan actually on the critical path? When there are few splits (e.g., 4), step 5 launches one block per sub-seq per head — this may be too small to matter. Profile first.
Opportunity 2 (MEDIUM): Fuse steps 5+6 — pre_scan + merge in single kernel
Currently pre_scan writes
(h_ext, M)to GMEM, then merge reads them back. For a typical 4-way split:If both are in one persistent kernel, the pre_scan CTA can signal completion via mbarrier, and the merge CTA can consume
(h_ext, M)without GMEM roundtrip (via L2 or shared memory between CTAs in the same cluster). However, the merge is sequentially dependent on pre_scan completing for all sub-seqs, so the fusion benefit depends on whether the GMEM traffic is actually the bottleneck vs. the serial dependency.Opportunity 3 (MEDIUM): Fuse steps 6+7+8 — merge + scatter + chunk_indices
After merge, there are three small operations:
initial_state_expanded(Python indexing, ~4 small copies)prepare_chunk_indices— a Triton kernel to compute chunk indices from new cu_seqlensfwd_hkernel readsinitial_state_expandedThese are tiny kernels with high launch overhead relative to work. Options:
initial_state_expandedat the correct offsets)Opportunity 4 (LOW): Merge kernel CUDA rewrite — step 6
The merge kernel does a chain multiply:
h = M_j @ h + h_ext_jfor ~4 iterations with[K,K] @ [K,V]matmuls (K=128, V=128). This is a tiny amount of compute (~4 GEMM of 128x128). Triton handles this fine; CUDA rewrite would save kernel launch overhead (~5μs) but the absolute speedup is small.Opportunity 5 (LOW): Eliminate CPU→GPU transfers — step 4
FLA already caches GPU tensors keyed by
cu_seqlensidentity. The CPU→GPU transfer only happens on cache miss. In steady-state serving (same batch structure repeated), this is already amortized to near-zero.Recommended Approach
fwd_h.Tasks
nsysto identify actual bottleneckspre_process_fwd_kernel_mergedvs. a CUDA version reusing cuLA'sChunkDeltaRuleFwdHinfrastructurechunk_kda_fwd(transparent activation in varlen inference mode)References
fla/ops/common/intracard_cp.pyfla/ops/cp/chunk_delta_h.pyfla/ops/cp/KCP.mdcula/ops/chunk_delta_h.py— register-carry + tcgen05 MMA baseline to extend