Skip to content

# RFC: Split cuLA into Core Kernel Layer and FLA Transitional Wrapper #25

@icavan

Description

@icavan

Status: Draft
Created: 2026-04-04

Summary

Split the cula package into two layers:

  1. cula (core) — pure CUDA kernels with zero FLA dependency. FLA calls into this directly via its backend dispatch. This is the long-term deliverable.
  2. cula-fla (transitional wrapper) — a temporary package that exists only during the transition period before cuLA is fully integrated into FLA. It allows early adopters to try cuLA kernels directly with a familiar FLA-like API, without waiting for FLA upstream integration to complete. cula-fla is expected to be deprecated and removed once FLA natively dispatches to cuLA.
End state (target):  FLA user API → @dispatch → cula (core kernels)
Transition period:   User → cula-fla → cula (core kernels) + fla (gate, bwd, etc.)

Motivation

Current Problem

cuLA imports from FLA at multiple levels:

Category Examples Count
Utils tensor_cache, prepare_chunk_indices, prepare_lens, RCP_LN2, autocast_custom_fwd/bwd, input_guard ~10 imports across 6 files
Algorithm logic kda_gate_chunk_cumsum, recompute_w_u_fwd, chunk_kda_bwd, l2norm_fwd/bwd, chunk_local_cumsum ~8 imports across 5 files
CP / Distributed FLACPContext, chunk_gated_delta_rule_fwd_h_pre_process, compress_h0 3 imports across 2 files

Today the dependency is one-directional (cula → fla), so no circular import exists. But the end goal is for FLA to call cuLA kernels directly:

fla.ops.kda.chunk_kda() → @dispatch('common') → cula.ops.chunk_delta_h.chunk_gated_delta_rule_fwd_h()

If cula core still imports fla, this creates a circular dependency at import time. The solution: cula core must have zero FLA imports.

Architecture

┌──────────────────────────────────────────────────────────┐
│                     FLA                                  │
│  (user-facing API, Triton fallback, algorithm logic)     │
│           │                                              │
│     @dispatch('common')                                  │
│     @dispatch('kda')    ← backend dispatch               │
│           │                                              │
│     try: import cula    ← optional dependency            │
└───────────┼──────────────────────────────────────────────┘
            ↓
┌───────────────────────────────┐    ┌──────────────────────────────┐
│         cula (core)           │    │       cula-fla (trial)       │
│  CUDA kernels, zero FLA dep  │    │  Quick-start wrapper for     │
│                               │    │  users to try cuLA directly  │
│  • chunk_delta_h (fwd_h)     │    │                              │
│  • fwd_o                     │    │  • chunk_kda() autograd fn   │
│  • recompute_wu              │    │  • kda_prefill_hopper()      │
│  • lightning_attn (prefill)  │    │  • kda_prefill_blackwell()   │
│  • la_decode                 │    │                              │
│  • C++ kernels (sm90/sm10x)  │    │  depends on: cula + fla     │
│                               │    │  (uses fla for gate, bwd,   │
│  depends on: torch, cutlass,  │    │   l2norm, CP, intra-chunk)  │
│  tvm-ffi, triton              │    │                              │
└───────────────────────────────┘    └──────────────────────────────┘

Key insight: FLA already owns the orchestration (gating, intra-chunk WY repr, backward pass, CP coordination). It just needs to call cuLA's optimized CUDA kernels for the compute-heavy inner loops. FLA does this via @dispatch — no adapter layer needed.

cula-fla is a transitional package that bridges the gap during the period when:

  • FLA has not yet added cuLA backend dispatch upstream
  • cuLA has not yet implemented backward pass, gate, and intra-chunk as CUDA kernels
  • Early adopters want to test cuLA's performance improvements immediately

It is today's cula/kda/ extracted into a separate package — it reuses FLA's algorithm logic (gate, backward, etc.) and plugs in cuLA's CUDA kernels. Once FLA integration is complete, users switch from from cula_fla.kda import chunk_kda to from fla.ops.kda import chunk_kda and cula-fla gets archived.

Current Dependency Analysis

Files with zero or trivially-removable FLA imports (core-ready)

File FLA Imports Action
cula/ops/chunk_delta_h.py prepare_chunk_indices, prepare_lens, tensor_cache Internalize 3 small utils
cula/ops/fwd_o.py prepare_chunk_indices Same
cula/ops/lightning_attn.py None Ready
cula/ops/linear_attn.py None Ready
cula/ops/inv.py None Ready
cula/ops/recompute_wu.py None (FLA import only in __main__ benchmark) Ready
cula/ops/recompute_wu_occ.py None (FLA import only in __main__ benchmark) Ready
cula/lightning/la_decode.py None Ready
cula/utils.py None Ready
csrc/** (C++/CUDA) None Ready

Files that belong in cula-fla (heavy FLA dependency)

File FLA Dependencies Role
cula/kda/chunk.py l2norm_fwd/bwd, chunk_kda_bwd, FLACPContext, autocast_custom_fwd/bwd, input_guard torch.autograd.Function — forward uses cuLA, backward is entirely FLA
cula/kda/chunk_fwd.py kda_gate_chunk_cumsum, chunk_local_cumsum, RCP_LN2, FLACPContext, CP pre-process, compress_h0 Forward orchestration: gate → intra → fwd_h → fwd_o
cula/kda/chunk_intra.py recompute_w_u_fwd, prepare_chunk_indices, Triton ops (exp2, gather) Triton intra-chunk kernel (not a CUDA kernel)
cula/kda/hopper_fused_fwd.py l2norm_fwd, kda_gate_chunk_cumsum, chunk_local_cumsum, autocast utils Hopper fused forward orchestration
cula/kda/blackwell_fused_fwd.py Same as above + kda_gate_fwd Blackwell fused forward orchestration

FLA utils to internalize into cula core

Utility Lines Current Location Action
prepare_chunk_indices ~15 fla.ops.utils.index Copy to cula/utils.py
prepare_lens ~3 fla.ops.utils Copy to cula/utils.py
tensor_cache ~30 fla.utils Copy to cula/utils.py
RCP_LN2 / INV_LN2 1 fla.ops.utils.constant Already exists as INV_LN2 in cula/ops/chunk_delta_h.py

These are stable, trivial utilities unlikely to change. Internalizing avoids any FLA dependency in core.

Proposed Package Structure

cula (core) — pip install cula

cula/
├── __init__.py              # exports kernel functions
├── _version.py
├── utils.py                 # device detection, prepare_chunk_indices,
│                            # prepare_lens, tensor_cache, constants
├── ops/
│   ├── __init__.py
│   ├── chunk_delta_h.py     # fwd_h kernel (CuTe DSL, SM10X)
│   ├── fwd_o.py             # fwd_o kernel (CuTe DSL, SM10X)
│   ├── recompute_wu.py      # WU recomputation (CuTe DSL)
│   ├── inv.py               # matrix inverse
│   ├── lightning_attn.py    # Lightning Attention prefill
│   └── linear_attn.py       # Lightning Attention varlen
├── lightning/
│   ├── __init__.py
│   └── la_decode.py         # Lightning Attention decode
└── cudac/                   # C++/CUDA extension (sm90/sm100/sm103)

Dependencies: torch, nvidia-cutlass-dsl, apache-tvm-ffi, triton

Zero FLA dependency. Every function is a pure kernel: tensors in, tensors out.

cula-fla (transitional wrapper) — pip install cula-fla

Lifecycle: This package is a transitional artifact. It exists because cuLA's FLA integration is not yet complete — FLA has not yet added cuLA backend dispatch, and cuLA has not yet implemented all required kernels (backward pass, gate, intra-chunk) to be self-sufficient. Once FLA natively dispatches to cuLA core, cula-fla will be deprecated with a notice pointing users to from fla.ops.kda import chunk_kda instead.

cula_fla/
├── __init__.py
└── kda/
    ├── __init__.py            # exports: chunk_kda, kda_prefill_hopper, etc.
    ├── chunk.py               # torch.autograd.Function (fwd: cuLA, bwd: FLA)
    ├── chunk_fwd.py           # forward orchestration
    ├── chunk_intra.py         # Triton intra-chunk (FLA's recompute_w_u_fwd)
    ├── hopper_fused_fwd.py    # Hopper fused forward
    └── blackwell_fused_fwd.py # Blackwell fused forward

Dependencies: cula, flash-linear-attention

This is what cula/kda/ is today — moved into a separate package. Usage:

# Quick-start: try cuLA KDA directly (same interface as fla.ops.kda)
from cula_fla.kda import chunk_kda
o, final_state = chunk_kda(q, k, v, g, beta, ...)

FLA integration (in FLA's repo, not cuLA's)

FLA registers cuLA as a backend via its existing BackendRegistry. This lives in FLA's codebase:

# fla/ops/kda/backends/cula.py (in FLA repo)
from fla.ops.backends import BaseBackend

class CuLAKDABackend(BaseBackend):
    backend_type = "cula"
    package_name = "cula"       # is_available() → checks if cula is installed
    env_var = "FLA_CULA"        # FLA_CULA=0 to disable
    priority = 3                # higher than default Triton (5)

    def chunk_gated_delta_rule_fwd_h_verifier(self, k, **kw):
        from cula.utils import is_blackwell
        if not is_blackwell(k.device):
            return False, "cuLA fwd_h requires Blackwell GPU"
        return True, None

    def chunk_gated_delta_rule_fwd_h(self, k, w, u, **kw):
        from cula.ops.chunk_delta_h import chunk_gated_delta_rule_fwd_h
        return chunk_gated_delta_rule_fwd_h(k=k, w=w, u=u, **kw)

No import cycle: FLA imports cula.ops.* (core) which has no FLA dependency.

FLA can add more dispatch points as cuLA implements more kernels:

# Future: FLA dispatches fwd_o to cuLA
class CuLAFwdOBackend(BaseBackend):
    def chunk_gla_fwd_o(self, q, v, g, A, h, o, **kw):
        from cula.ops.fwd_o import chunk_gla_fwd_o
        return chunk_gla_fwd_o(q=q, v=v, g=g, A=A, h=h, o=o, **kw)

Migration Plan

Phase 1: Internalize utils — make cula core FLA-free

Low-risk refactor within the current single-package structure:

  1. Copy prepare_chunk_indices, prepare_lens, tensor_cache into cula/utils.py
  2. Update imports in cula/ops/chunk_delta_h.py, cula/ops/fwd_o.py to use cula.utils
  3. Verify: all files under cula/ops/ and cula/lightning/ have zero fla.* imports
  4. Keep cula/kda/ unchanged — still imports FLA, still works as today

No user-facing change. The cula/ops/ and cula/lightning/ modules are now independently importable without FLA installed.

Phase 2: Extract cula-fla package

  1. Move cula/kda/ contents into a new cula-fla package (separate repo or subdirectory)
  2. cula-fla depends on cula (core) + fla
  3. cula retains cula/ops/, cula/lightning/, cula/utils.py
  4. Publish both to PyPI: pip install cula and pip install cula-fla

Users who currently do from cula.kda import chunk_kda switch to from cula_fla.kda import chunk_kda. Breaking change, acceptable at pre-1.0.

Phase 3: FLA backend registration

Work with FLA-org to add cuLA backend registration in FLA's codebase:

  1. FLA adds fla/ops/kda/backends/cula.py (or fla/ops/common/backends/cula.py)
  2. @dispatch decorated functions in FLA automatically try cuLA when installed
  3. FLA users get cuLA acceleration with zero code changes:
    pip install cula  # just install the core package
    from fla.ops.kda import chunk_kda  # automatically uses cuLA kernels on Blackwell

Phase 4: Expand cuLA core, shrink cula-fla

As cuLA implements more CUDA kernels, they move into core and FLA adds dispatch points:

Kernel Today Future
fwd_h (chunk_delta_h) cuLA core cuLA core
fwd_o cuLA core cuLA core
recompute_wu cuLA core (WIP) cuLA core
chunk_intra cula-fla (uses FLA Triton) cuLA core (CUDA)
gate_cumsum cula-fla (uses FLA Triton) cuLA core (CUDA)
backward cula-fla (uses FLA Triton) cuLA core (CUDA)
l2norm cula-fla (uses FLA) cuLA core (CUDA)
Lightning kernels cuLA core cuLA core

End state: cula-fla is deprecated and archived. Users use from fla.ops.kda import chunk_kda which transparently dispatches to cuLA CUDA kernels. The transitional wrapper has served its purpose and is no longer needed.

Alternatives Considered

1. No split — just vendor FLA utils

Copy all needed FLA code (gate, backward, intra-chunk, l2norm) into cuLA.

Rejected:

  • Duplicates ~2000 lines of algorithm logic with ongoing maintenance burden
  • Backward pass alone is ~500 lines of complex Triton code that evolves with FLA
  • Doesn't scale as cuLA adds GDN, GLA, and other algorithms

2. Single package with optional FLA extras (cula[fla])

Keep everything in one package, gate FLA-dependent code behind try: import fla.

Rejected:

  • cula/kda/ would fail at import time without FLA — confusing to users
  • Makes it unclear which parts of cuLA work standalone vs. require FLA
  • Circular dependency still exists when FLA tries to import cula

3. Monorepo (merge cuLA into FLA)

Rejected for now due to:

  • cuLA uses CuTe DSL + CUTLASS C++ — fundamentally different build system from FLA's pure Triton
  • Different contributor pools

However, cuLA plans to migrate all its kernels fully to CuTe DSL, phasing out the current CUTLASS C++ code. Once cuLA completes this transition, the build system and toolchain gap between the two projects shrinks significantly — both would use Python-based kernel authoring (Triton for FLA, CuTe DSL for cuLA) with similar packaging and CI requirements. At that point, merging cuLA into FLA as a monorepo becomes a viable and potentially preferable option — a single repo with unified CI, shared infrastructure, and a single pip install fla that includes both Triton fallbacks and optimized CUDA kernels. This should be revisited with the FLA-org maintainers once cuLA's CuTe DSL migration is further along.

Open Questions

  1. Package naming: cula-fla or keep it as a subpackage within the same repo (cula/contrib/fla_wrapper/)?
    • Separate package is cleaner for dependency management
    • Same repo subdirectory is easier to maintain during rapid development

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions