Skip to content

Phase 6: Fused Metal paged-attention kernel #123

@inureyes

Description

@inureyes

Context

Gather-then-SDPA (Phase 2) re-materializes a contiguous K/V each decode step, so its cost grows with context length. Phase 0 (ADR 0001) confirmed that cost is material at long or batched context, so this issue replaces it with a fused Metal kernel that reads scattered blocks directly.

Phase 0 outcome (ADR 0001, #117): the gather-then-SDPA cost is material at long or batched context. Layout-A overhead over contiguous SDPA is ~56% at 16k and ~67% at 32k for a single sequence, and already ~48% at 1k under batch 4 (2-3x past 4k), because batched SDPA stays cheap while the per-step gather scales with the sequence count. This kernel is therefore the planned optimization for the >=16k or batched-decode regime; until it lands, the #119 gather path is the active path and remains the correctness reference and fallback. See docs/adr/0001-paged-attention-gather-vs-fused-kernel.md.

Tasks

  • Implement a paged decode attention Metal kernel under src/lib/mlx-cpp/turbo/, modeled on sparse_v_sdpa.metal, consuming Q + block table + the pool tensors.
  • Gate it behind use_native_paged_kernel plus a feature/env flag; keep gather-then-SDPA as the fallback path.
  • Follow the kernel re-validation discipline in CLAUDE.md: compile on Apple Silicon and confirm RMS < 5e-3 vs the graph reference over 200 steps.

Acceptance criteria

  • The fused path matches the reference within RMS < 5e-3 and removes the per-step gather copy.
  • Decode throughput at 16k context is within the target gap of the dense backend established in the Phase 0 ADR.

Dependencies

Blocked by Phase 2; gated by the Phase 0 kernel-strategy decision.

Part of #116

Metadata

Metadata

Assignees

No one assigned

    Labels

    area:coremlxcel-core: MLX FFI, primitives, KV cache, layersplatform:macosmacOS (Apple Silicon) specificstatus:backlogIn the backlog, not yet readytype:performancePerformance improvements

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions