Skip to content

Identify and implement CUDA optimization opportunities for Intracard CP (single-card sequence splitting) #20

@icavan

Description

@icavan

Description

FLA has implemented Intracard CP (fla/ops/common/intracard_cp.py) — a technique that splits long sequences into sub-sequences and processes them in parallel to break the sequential bottleneck of the fwd_h state recurrence. cuLA already has a highly optimized CUDA fwd_h kernel (cula/ops/chunk_delta_h.py) using SM10X tcgen05 MMA + register-carry.

This issue is to analyze the full Intracard CP pipeline, identify which components have real CUDA optimization/fusion opportunities with measurable speedup, and implement them.

Current FLA Pipeline Breakdown

When Intracard CP activates (long sequence + varlen + inference mode), the execution flow is:

1. [Python] compute_subseq_len           — CPU: compute optimal split length
2. [Python] prepare_subseq_cu_seqlens    — CPU: insert split points into cu_seqlens
3. [Python] _precompute_intracard_indices — CPU: compute scatter/gather indices
4. [Python] torch.tensor(...)            — CPU→GPU: create cu_seqlens_subseq, cu_seqlens_split tensors
5. [Triton] pre_process_fwd_kernel_merged — GPU: for each sub-seq, compute (h_ext[K,V], M[K,K])
6. [Triton] merge_fwd_bwd_kernel         — GPU: chain multiply M @ h + h_ext across sub-seqs
7. [Python] initial_state_expanded[...]   — GPU: scatter merged states into expanded tensor
8. [Triton] prepare_chunk_indices         — GPU: compute chunk indices for sub-seq cu_seqlens
9. [Triton] chunk_gated_delta_rule_fwd_h  — GPU: main fwd_h with corrected initial states

Optimization Opportunity Analysis

Opportunity 1 (HIGH): CUDA pre_scan kernel — replace step 5

The pre_process_fwd_kernel_merged Triton kernel is structurally almost identical to fwd_h: it iterates over chunks doing the same recurrence (h = decay * h + k^T @ v_new), plus additionally computes transition matrix M via chain multiply (M = (diag(decay) - k^T @ w) @ M).

cuLA's CUDA fwd_h is already faster than FLA's Triton version on SM10X thanks to register-carry + tcgen05 MMA + TMA pipeline. The same optimization techniques apply directly:

  • Register-carry for h_ext accumulation (eliminate GMEM roundtrip)
  • tcgen05 MMA for k^T @ v_new and k^T @ w matmuls
  • TMA async pipeline for k, w, u loads

The M computation ([K,K] matrix, K=128 → 128x128 fp32) is an additional output that can be computed in the same kernel using the CUDA warp group that's already doing the gating/accumulation work.

Key question to verify: is pre_scan actually on the critical path? When there are few splits (e.g., 4), step 5 launches one block per sub-seq per head — this may be too small to matter. Profile first.

Opportunity 2 (MEDIUM): Fuse steps 5+6 — pre_scan + merge in single kernel

Currently pre_scan writes (h_ext, M) to GMEM, then merge reads them back. For a typical 4-way split:

  • pre_scan output: 4 × H × K × (K+V) fp32 = 4 × H × 128 × 256 × 4 bytes
  • merge: reads the same, does a small chain multiply

If both are in one persistent kernel, the pre_scan CTA can signal completion via mbarrier, and the merge CTA can consume (h_ext, M) without GMEM roundtrip (via L2 or shared memory between CTAs in the same cluster). However, the merge is sequentially dependent on pre_scan completing for all sub-seqs, so the fusion benefit depends on whether the GMEM traffic is actually the bottleneck vs. the serial dependency.

Opportunity 3 (MEDIUM): Fuse steps 6+7+8 — merge + scatter + chunk_indices

After merge, there are three small operations:

  1. Scatter merged states into initial_state_expanded (Python indexing, ~4 small copies)
  2. prepare_chunk_indices — a Triton kernel to compute chunk indices from new cu_seqlens
  3. The actual fwd_h kernel reads initial_state_expanded

These are tiny kernels with high launch overhead relative to work. Options:

  • Fuse the scatter into the merge kernel (merge kernel directly writes to initial_state_expanded at the correct offsets)
  • Pre-compute chunk indices on CPU (they're deterministic from cu_seqlens) to eliminate kernel launch

Opportunity 4 (LOW): Merge kernel CUDA rewrite — step 6

The merge kernel does a chain multiply: h = M_j @ h + h_ext_j for ~4 iterations with [K,K] @ [K,V] matmuls (K=128, V=128). This is a tiny amount of compute (~4 GEMM of 128x128). Triton handles this fine; CUDA rewrite would save kernel launch overhead (~5μs) but the absolute speedup is small.

Opportunity 5 (LOW): Eliminate CPU→GPU transfers — step 4

FLA already caches GPU tensors keyed by cu_seqlens identity. The CPU→GPU transfer only happens on cache miss. In steady-state serving (same batch structure repeated), this is already amortized to near-zero.

Recommended Approach

  1. Profile first: Run the FLA intracard path on representative long-sequence workloads (e.g., 128K+ tokens, H=8, K=V=128) on SM10X. Use nsys to measure time spent in each step. This determines which opportunities are worth pursuing.
  2. Implement CUDA pre_scan (Opp 1) if profiling confirms it's on the critical path — high confidence of speedup since the kernel structure matches cuLA's existing optimized fwd_h.
  3. Fuse merge + scatter (Opp 3) — small effort, eliminates launch overhead.
  4. Evaluate pre_scan + merge fusion (Opp 2) only if GMEM traffic between them is significant.

Tasks

  • Profile the full FLA intracard pipeline on SM10X with nsys to identify actual bottlenecks
  • Evaluate pre_scan kernel speedup opportunity: compare FLA Triton pre_process_fwd_kernel_merged vs. a CUDA version reusing cuLA's ChunkDeltaRuleFwdH infrastructure
  • Evaluate fusion opportunities between pre_scan, merge, scatter, and chunk_indices
  • Implement the highest-ROI optimizations based on profiling results
  • Integrate into chunk_kda_fwd (transparent activation in varlen inference mode)
  • Add correctness tests and benchmarks

References

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions