Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
212 changes: 151 additions & 61 deletions fast_llm/layers/ssm/gdn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,85 +38,175 @@ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor:


@torch.compile
def torch_chunk_gated_delta_rule(
query,
key,
value,
g,
beta,
chunk_size=64,
initial_state=None,
output_final_state=False,
use_qk_l2norm_in_kernel=False,
cu_seqlens=None,
):
initial_dtype = query.dtype
def _torch_chunk_gated_delta_rule_single(
query: torch.Tensor, # batch, sequence, heads, key_head_dim
key: torch.Tensor, # batch, sequence, heads, key_head_dim
value: torch.Tensor, # batch, sequence, heads, value_head_dim
g: torch.Tensor, # batch, sequence, heads (log decay rates)
beta: torch.Tensor, # batch, sequence, heads (write gate strengths)
chunk_size: int = 64,
initial_state: torch.Tensor | None = None, # batch, heads, key_head_dim, value_head_dim
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
input_dtype = query.dtype

if use_qk_l2norm_in_kernel:
query = _l2norm(query, dim=-1, eps=1e-6)
key = _l2norm(key, dim=-1, eps=1e-6)

# Transpose to head-first layout and upcast for numerical stability.
# batch, sequence, heads, dim -> batch, heads, sequence, dim
query, key, value, beta, g = (
x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g)
)

batch_size, num_heads, sequence_length, k_head_dim = key.shape
v_head_dim = value.shape[-1]
batch_size, num_heads, sequence_length, key_head_dim = key.shape
value_head_dim = value.shape[-1]

# Pad sequence length to a multiple of chunk_size.
pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size
query = torch.nn.functional.pad(query, (0, 0, 0, pad_size))
key = torch.nn.functional.pad(key, (0, 0, 0, pad_size))
value = torch.nn.functional.pad(value, (0, 0, 0, pad_size))
beta = torch.nn.functional.pad(beta, (0, pad_size))
g = torch.nn.functional.pad(g, (0, pad_size))
total_sequence_length = sequence_length + pad_size
scale = 1 / (query.shape[-1] ** 0.5)
query = query * scale

v_beta = value * beta.unsqueeze(-1)
k_beta = key * beta.unsqueeze(-1)
# reshape to chunks
query, key, value, k_beta, v_beta = (
x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta)
padded_sequence_length = sequence_length + pad_size
num_chunks = padded_sequence_length // chunk_size

query = query * (key_head_dim**-0.5)

# Beta-weighted keys and values for the delta rule write operations.
key_beta = key * beta.unsqueeze(-1) # batch, heads, sequence, key_head_dim
value_beta = value * beta.unsqueeze(-1) # batch, heads, sequence, value_head_dim

# Reshape into chunks: batch, heads, num_chunks, chunk_size, dim
query, key, value, key_beta, value_beta = (
x.reshape(batch_size, num_heads, num_chunks, chunk_size, x.shape[-1])
for x in (query, key, value, key_beta, value_beta)
)
g = g.reshape(batch_size, num_heads, num_chunks, chunk_size)

# Cumulative sum of log-decay rates within each chunk.
# log_decay_cumsum[..., t] = sum_{s=0}^{t} g[s]
log_decay_cumsum = g.cumsum(dim=-1) # batch, heads, num_chunks, chunk_size

# Intra-chunk decay matrix: entry [t, s] = exp(log_decay_cumsum[t] - log_decay_cumsum[s]) for t >= s.
intra_chunk_decay = (log_decay_cumsum.unsqueeze(-1) - log_decay_cumsum.unsqueeze(-2)).tril().exp()
# batch, heads, num_chunks, chunk_size, chunk_size

# --- Intra-chunk delta rule transformation ---
# Build the triangular transformation matrix that encodes how prior writes within a chunk
# are corrected by later writes (the delta rule update).
# Initial: T[t, s] = -(key_beta[t] · key[s]) * decay[t, s], strictly lower-triangular.
upper_triangular_mask = torch.triu(
torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0
)
intra_chunk_transform = -(key_beta @ key.transpose(-1, -2) * intra_chunk_decay).masked_fill(
upper_triangular_mask, 0
)
g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0)

# chunk decay
g = g.cumsum(dim=-1)
decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril()
attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0)
for i in range(1, chunk_size):
row = attn[..., i, :i].clone()
sub = attn[..., :i, :i].clone()
attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2)
attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device)
value = attn @ v_beta
k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1))
last_recurrent_state = (
torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value)
if initial_state is None
else initial_state.to(value)
# Iteratively apply the delta rule to build up the full transformation.
for chunk_pos in range(1, chunk_size):
row = intra_chunk_transform[..., chunk_pos, :chunk_pos].clone()
above = intra_chunk_transform[..., :chunk_pos, :chunk_pos].clone()
intra_chunk_transform[..., chunk_pos, :chunk_pos] = row + (row.unsqueeze(-1) * above).sum(-2)
intra_chunk_transform = intra_chunk_transform + torch.eye(
chunk_size, dtype=intra_chunk_transform.dtype, device=query.device
)
core_attn_out = torch.zeros_like(value)
mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)

# for each chunk
for i in range(0, total_sequence_length // chunk_size):
q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i]
attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0)
v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state
v_new = v_i - v_prime
attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state
core_attn_out[:, :, i] = attn_inter + attn @ v_new
last_recurrent_state = (
last_recurrent_state * g[:, :, i, -1, None, None].exp()
+ (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new

# Apply the transformation to get: corrected intra-chunk values and keys scaled by cumulative decay.
intra_chunk_value = intra_chunk_transform @ value_beta
# batch, heads, num_chunks, chunk_size, value_head_dim
key_cumulative_decay = intra_chunk_transform @ (key_beta * log_decay_cumsum.exp().unsqueeze(-1))
# batch, heads, num_chunks, chunk_size, key_head_dim

# --- Recurrent loop over chunks ---
if initial_state is None:
recurrent_state = torch.zeros(
batch_size, num_heads, key_head_dim, value_head_dim, device=query.device, dtype=query.dtype
)
else:
recurrent_state = initial_state.to(query)
output = torch.zeros_like(intra_chunk_value)
causal_mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1)

for chunk_index in range(num_chunks):
q = query[:, :, chunk_index] # batch, heads, chunk_size, key_head_dim
k = key[:, :, chunk_index] # batch, heads, chunk_size, key_head_dim
v = intra_chunk_value[:, :, chunk_index] # batch, heads, chunk_size, value_head_dim
log_decay = log_decay_cumsum[:, :, chunk_index] # batch, heads, chunk_size

# Intra-chunk causal attention weighted by decay.
intra_chunk_attn = (q @ k.transpose(-1, -2) * intra_chunk_decay[:, :, chunk_index]).masked_fill_(
causal_mask, 0
) # batch, heads, chunk_size, chunk_size

# Delta rule correction: subtract the recurrent state's contribution from the values.
state_contribution = key_cumulative_decay[:, :, chunk_index] @ recurrent_state
corrected_value = v - state_contribution # batch, heads, chunk_size, value_head_dim

# Combine cross-chunk output (from recurrent state) and intra-chunk output.
cross_chunk_output = (q * log_decay.exp().unsqueeze(-1)) @ recurrent_state
output[:, :, chunk_index] = cross_chunk_output + intra_chunk_attn @ corrected_value

# Update recurrent state: decay existing state and add new writes from this chunk.
last_log_decay = log_decay[:, :, -1] # batch, heads
recurrent_state = (
recurrent_state * last_log_decay.exp()[..., None, None]
+ (k * (last_log_decay.unsqueeze(-1) - log_decay).exp().unsqueeze(-1)).transpose(-1, -2) @ corrected_value
)

if not output_final_state:
last_recurrent_state = None
core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1])
core_attn_out = core_attn_out[:, :, :sequence_length]
core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype)
return core_attn_out, last_recurrent_state
recurrent_state = None

# Restore original layout: batch, sequence, heads, value_head_dim
output = output.reshape(batch_size, num_heads, padded_sequence_length, value_head_dim)
output = output[:, :, :sequence_length]
output = output.transpose(1, 2).contiguous().to(input_dtype)

return output, recurrent_state


def torch_chunk_gated_delta_rule(
query: torch.Tensor, # batch, sequence, heads, key_head_dim
key: torch.Tensor, # batch, sequence, heads, key_head_dim
value: torch.Tensor, # batch, sequence, heads, value_head_dim
g: torch.Tensor, # batch, sequence, heads (log decay rates)
beta: torch.Tensor, # batch, sequence, heads (write gate strengths)
chunk_size: int = 64,
initial_state: torch.Tensor | None = None, # batch, heads, key_head_dim, value_head_dim
output_final_state: bool = False,
use_qk_l2norm_in_kernel: bool = False,
cu_seqlens: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None]:
if cu_seqlens is None:
return _torch_chunk_gated_delta_rule_single(
query,
key,
value,
g,
beta,
chunk_size=chunk_size,
initial_state=initial_state,
output_final_state=output_final_state,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
# Process each document independently and concatenate results.
# Inputs have batch=1 with documents packed along the sequence dimension.
sequence_boundaries = cu_seqlens.tolist()
outputs = []
for seq_start, seq_end in zip(sequence_boundaries, sequence_boundaries[1:]):
out, _ = _torch_chunk_gated_delta_rule_single(
query[:, seq_start:seq_end],
key[:, seq_start:seq_end],
value[:, seq_start:seq_end],
g[:, seq_start:seq_end],
beta[:, seq_start:seq_end],
chunk_size=chunk_size,
use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
)
outputs.append(out)
return torch.cat(outputs, dim=1), None


class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]):
Expand Down
Loading
Loading