Skip to content

[RFC] Geometric Sparse Attention Layer (AETHER) #133

@teerthsharma

Description

@teerthsharma

Proposal

I propose adding GeometricSparseAttention, a new modular attention layer that enables data-dependent, mathematically safe sparsity for long-context inference.

Unlike static sparse patterns (e.g., Sliding Window, BigBird) which are rigid, this layer uses AETHER (Adaptive Event-driven Threshold Hybrid Entangled Rendering) logic to dynamically prune computation blocks at runtime based on the geometric topology of the keys.

Motivation

Current attention mechanisms in Penzai (pz.nn.Attention) scale quadratically $O(N^2)$. While Penzai excels at model surgery and interpretability, analyzing long-context behavior (100k+ tokens) is currently computationally prohibitive.

By introducing a geometric decision gate, we can enable Penzai users to:

  1. Scale: Run inference on massive contexts using sub-linear compute.
  2. Inspect: Visualize the "Manifold Mask" in Treescope to understand semantically which blocks the model deems important.

Technical Approach

The core logic relies on the Cauchy-Schwarz Upper Bound to guarantee safety. For a query $q$ and a key block $B$ with centroid $\mu$ and radius $r$:

$$\max_{k \in B} (q \cdot k) \le q \cdot \mu + |q| \cdot r$$

If this upper bound is below the threshold $\tau$, the block is skipped.

Proposed API

The layer would follow the pz.nn.Layer interface and fully support NamedArray for axis safety.

@pz.pytree_dataclass
class GeometricSparseAttention(pz.nn.Layer):
    """A geometric sparse attention layer compatible with pz.select()."""
    
    block_size: int = 64
    threshold: float = 0.15
    # ... implementation details ...

    def __call__(self, query, key, value, mask=None):
        # 1. Compute block centroids/radii (JAX-friendly reshape)
        # 2. Compute upper-bound scores
        # 3. Create boolean mask
        # 4. Apply masked attention
        pass

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions