From 23bf9754f96dc524bfea143ae20ac642404b7590 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 31 Mar 2026 16:59:59 -0400 Subject: [PATCH 1/2] Add PyTorch backup implementations for GDN and KDA SSM layers - Refactor torch_chunk_gated_delta_rule (GDN) for readability: rename variables, add type hints and section comments, split into a compiled inner function (_torch_chunk_gated_delta_rule_single) and a public wrapper that handles cu_seqlens by iterating over document boundaries. - Add torch_chunk_kda / _torch_chunk_kda_single and torch_kda_gate as pure-PyTorch fallbacks for KimiDeltaAttention (KDA), following the same pattern. KDA's per-dim vector decay and post-correction beta application differ structurally from GDN. - GatedDeltaNet and KimiDeltaAttention now fall back gracefully to the backup when FLA kernels are unavailable instead of raising ImportError. - Parametrize test_gdn and test_kda with fast/backup variants; backup uses threshold=1e-2 to accommodate float32-vs-kernel precision gaps. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/layers/ssm/gdn.py | 212 ++++++++++++++++++++++++++----------- fast_llm/layers/ssm/kda.py | 197 ++++++++++++++++++++++++++++++++-- tests/layers/test_ssm.py | 42 +++++++- 3 files changed, 378 insertions(+), 73 deletions(-) diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index cf5bc0bc4..ef6c6154f 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -38,85 +38,175 @@ def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: @torch.compile -def torch_chunk_gated_delta_rule( - query, - key, - value, - g, - beta, - chunk_size=64, - initial_state=None, - output_final_state=False, - use_qk_l2norm_in_kernel=False, - cu_seqlens=None, -): - initial_dtype = query.dtype +def _torch_chunk_gated_delta_rule_single( + query: torch.Tensor, # batch, sequence, heads, key_head_dim + key: torch.Tensor, # batch, sequence, heads, key_head_dim + value: torch.Tensor, # batch, sequence, heads, value_head_dim + g: torch.Tensor, # batch, sequence, heads (log decay rates) + beta: torch.Tensor, # batch, sequence, heads (write gate strengths) + chunk_size: int = 64, + initial_state: torch.Tensor | None = None, # batch, heads, key_head_dim, value_head_dim + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + input_dtype = query.dtype + if use_qk_l2norm_in_kernel: query = _l2norm(query, dim=-1, eps=1e-6) key = _l2norm(key, dim=-1, eps=1e-6) + + # Transpose to head-first layout and upcast for numerical stability. + # batch, sequence, heads, dim -> batch, heads, sequence, dim query, key, value, beta, g = ( x.transpose(1, 2).contiguous().to(torch.float32) for x in (query, key, value, beta, g) ) - batch_size, num_heads, sequence_length, k_head_dim = key.shape - v_head_dim = value.shape[-1] + batch_size, num_heads, sequence_length, key_head_dim = key.shape + value_head_dim = value.shape[-1] + + # Pad sequence length to a multiple of chunk_size. pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size query = torch.nn.functional.pad(query, (0, 0, 0, pad_size)) key = torch.nn.functional.pad(key, (0, 0, 0, pad_size)) value = torch.nn.functional.pad(value, (0, 0, 0, pad_size)) beta = torch.nn.functional.pad(beta, (0, pad_size)) g = torch.nn.functional.pad(g, (0, pad_size)) - total_sequence_length = sequence_length + pad_size - scale = 1 / (query.shape[-1] ** 0.5) - query = query * scale - - v_beta = value * beta.unsqueeze(-1) - k_beta = key * beta.unsqueeze(-1) - # reshape to chunks - query, key, value, k_beta, v_beta = ( - x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) + padded_sequence_length = sequence_length + pad_size + num_chunks = padded_sequence_length // chunk_size + + query = query * (key_head_dim**-0.5) + + # Beta-weighted keys and values for the delta rule write operations. + key_beta = key * beta.unsqueeze(-1) # batch, heads, sequence, key_head_dim + value_beta = value * beta.unsqueeze(-1) # batch, heads, sequence, value_head_dim + + # Reshape into chunks: batch, heads, num_chunks, chunk_size, dim + query, key, value, key_beta, value_beta = ( + x.reshape(batch_size, num_heads, num_chunks, chunk_size, x.shape[-1]) + for x in (query, key, value, key_beta, value_beta) + ) + g = g.reshape(batch_size, num_heads, num_chunks, chunk_size) + + # Cumulative sum of log-decay rates within each chunk. + # log_decay_cumsum[..., t] = sum_{s=0}^{t} g[s] + log_decay_cumsum = g.cumsum(dim=-1) # batch, heads, num_chunks, chunk_size + + # Intra-chunk decay matrix: entry [t, s] = exp(log_decay_cumsum[t] - log_decay_cumsum[s]) for t >= s. + intra_chunk_decay = (log_decay_cumsum.unsqueeze(-1) - log_decay_cumsum.unsqueeze(-2)).tril().exp() + # batch, heads, num_chunks, chunk_size, chunk_size + + # --- Intra-chunk delta rule transformation --- + # Build the triangular transformation matrix that encodes how prior writes within a chunk + # are corrected by later writes (the delta rule update). + # Initial: T[t, s] = -(key_beta[t] ยท key[s]) * decay[t, s], strictly lower-triangular. + upper_triangular_mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0 + ) + intra_chunk_transform = -(key_beta @ key.transpose(-1, -2) * intra_chunk_decay).masked_fill( + upper_triangular_mask, 0 ) - g = g.reshape(g.shape[0], g.shape[1], -1, chunk_size) - mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0) - - # chunk decay - g = g.cumsum(dim=-1) - decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() - attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) - for i in range(1, chunk_size): - row = attn[..., i, :i].clone() - sub = attn[..., :i, :i].clone() - attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - value = attn @ v_beta - k_cumdecay = attn @ (k_beta * g.exp().unsqueeze(-1)) - last_recurrent_state = ( - torch.zeros(batch_size, num_heads, k_head_dim, v_head_dim).to(value) - if initial_state is None - else initial_state.to(value) + # Iteratively apply the delta rule to build up the full transformation. + for chunk_pos in range(1, chunk_size): + row = intra_chunk_transform[..., chunk_pos, :chunk_pos].clone() + above = intra_chunk_transform[..., :chunk_pos, :chunk_pos].clone() + intra_chunk_transform[..., chunk_pos, :chunk_pos] = row + (row.unsqueeze(-1) * above).sum(-2) + intra_chunk_transform = intra_chunk_transform + torch.eye( + chunk_size, dtype=intra_chunk_transform.dtype, device=query.device ) - core_attn_out = torch.zeros_like(value) - mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) - - # for each chunk - for i in range(0, total_sequence_length // chunk_size): - q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] - attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - v_new = v_i - v_prime - attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - core_attn_out[:, :, i] = attn_inter + attn @ v_new - last_recurrent_state = ( - last_recurrent_state * g[:, :, i, -1, None, None].exp() - + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new + + # Apply the transformation to get: corrected intra-chunk values and keys scaled by cumulative decay. + intra_chunk_value = intra_chunk_transform @ value_beta + # batch, heads, num_chunks, chunk_size, value_head_dim + key_cumulative_decay = intra_chunk_transform @ (key_beta * log_decay_cumsum.exp().unsqueeze(-1)) + # batch, heads, num_chunks, chunk_size, key_head_dim + + # --- Recurrent loop over chunks --- + if initial_state is None: + recurrent_state = torch.zeros( + batch_size, num_heads, key_head_dim, value_head_dim, device=query.device, dtype=query.dtype + ) + else: + recurrent_state = initial_state.to(query) + output = torch.zeros_like(intra_chunk_value) + causal_mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1) + + for chunk_index in range(num_chunks): + q = query[:, :, chunk_index] # batch, heads, chunk_size, key_head_dim + k = key[:, :, chunk_index] # batch, heads, chunk_size, key_head_dim + v = intra_chunk_value[:, :, chunk_index] # batch, heads, chunk_size, value_head_dim + log_decay = log_decay_cumsum[:, :, chunk_index] # batch, heads, chunk_size + + # Intra-chunk causal attention weighted by decay. + intra_chunk_attn = (q @ k.transpose(-1, -2) * intra_chunk_decay[:, :, chunk_index]).masked_fill_( + causal_mask, 0 + ) # batch, heads, chunk_size, chunk_size + + # Delta rule correction: subtract the recurrent state's contribution from the values. + state_contribution = key_cumulative_decay[:, :, chunk_index] @ recurrent_state + corrected_value = v - state_contribution # batch, heads, chunk_size, value_head_dim + + # Combine cross-chunk output (from recurrent state) and intra-chunk output. + cross_chunk_output = (q * log_decay.exp().unsqueeze(-1)) @ recurrent_state + output[:, :, chunk_index] = cross_chunk_output + intra_chunk_attn @ corrected_value + + # Update recurrent state: decay existing state and add new writes from this chunk. + last_log_decay = log_decay[:, :, -1] # batch, heads + recurrent_state = ( + recurrent_state * last_log_decay.exp()[..., None, None] + + (k * (last_log_decay.unsqueeze(-1) - log_decay).exp().unsqueeze(-1)).transpose(-1, -2) @ corrected_value ) if not output_final_state: - last_recurrent_state = None - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1]) - core_attn_out = core_attn_out[:, :, :sequence_length] - core_attn_out = core_attn_out.transpose(1, 2).contiguous().to(initial_dtype) - return core_attn_out, last_recurrent_state + recurrent_state = None + + # Restore original layout: batch, sequence, heads, value_head_dim + output = output.reshape(batch_size, num_heads, padded_sequence_length, value_head_dim) + output = output[:, :, :sequence_length] + output = output.transpose(1, 2).contiguous().to(input_dtype) + + return output, recurrent_state + + +def torch_chunk_gated_delta_rule( + query: torch.Tensor, # batch, sequence, heads, key_head_dim + key: torch.Tensor, # batch, sequence, heads, key_head_dim + value: torch.Tensor, # batch, sequence, heads, value_head_dim + g: torch.Tensor, # batch, sequence, heads (log decay rates) + beta: torch.Tensor, # batch, sequence, heads (write gate strengths) + chunk_size: int = 64, + initial_state: torch.Tensor | None = None, # batch, heads, key_head_dim, value_head_dim + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None]: + if cu_seqlens is None: + return _torch_chunk_gated_delta_rule_single( + query, + key, + value, + g, + beta, + chunk_size=chunk_size, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + # Process each document independently and concatenate results. + # Inputs have batch=1 with documents packed along the sequence dimension. + sequence_boundaries = cu_seqlens.tolist() + outputs = [] + for seq_start, seq_end in zip(sequence_boundaries, sequence_boundaries[1:]): + out, _ = _torch_chunk_gated_delta_rule_single( + query[:, seq_start:seq_end], + key[:, seq_start:seq_end], + value[:, seq_start:seq_end], + g[:, seq_start:seq_end], + beta[:, seq_start:seq_end], + chunk_size=chunk_size, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + outputs.append(out) + return torch.cat(outputs, dim=1), None class GatedDeltaNet[ConfigType: GatedDeltaNetConfig](BlockWithBias[ConfigType]): diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 1c313102f..c59bfe036 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -2,6 +2,7 @@ import typing import torch +import torch.nn.functional from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ @@ -25,6 +26,182 @@ _kda_available = False +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + +def torch_kda_gate( + g: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """Pure PyTorch backup for fused_kda_gate.""" + num_heads, head_dim = g.shape[-2:] + g = g.float() + if dt_bias is not None: + g = g + dt_bias.view(num_heads, head_dim) + return (-A_log.view(num_heads, 1).float().exp() * torch.nn.functional.softplus(g)).to(output_dtype) + + +@torch.compile +def _torch_chunk_kda_single( + q: torch.Tensor, # batch, sequence, heads, head_dim + k: torch.Tensor, # batch, sequence, heads, head_dim + v: torch.Tensor, # batch, sequence, heads, head_dim + g: torch.Tensor, # batch, sequence, heads, head_dim (log decay rates per dim) + beta: torch.Tensor, # batch, sequence, heads (write gate strengths) + chunk_size: int = 64, + initial_state: torch.Tensor | None = None, # batch, heads, head_dim, head_dim + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor | None]: + input_dtype = q.dtype + + if use_qk_l2norm_in_kernel: + q = _l2norm(q, dim=-1, eps=1e-6) + k = _l2norm(k, dim=-1, eps=1e-6) + + # Transpose to head-first layout and upcast for numerical stability. + # batch, sequence, heads, dim -> batch, heads, sequence, dim + q, k, v, g = (x.transpose(1, 2).contiguous().to(torch.float32) for x in (q, k, v, g)) + beta = beta.transpose(1, 2).contiguous().to(torch.float32) + + batch_size, num_heads, sequence_length, head_dim = q.shape + + # Pad sequence length to a multiple of chunk_size. + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + q = torch.nn.functional.pad(q, (0, 0, 0, pad_size)) + k = torch.nn.functional.pad(k, (0, 0, 0, pad_size)) + v = torch.nn.functional.pad(v, (0, 0, 0, pad_size)) + g = torch.nn.functional.pad(g, (0, 0, 0, pad_size)) + beta = torch.nn.functional.pad(beta, (0, pad_size)) + padded_sequence_length = sequence_length + pad_size + num_chunks = padded_sequence_length // chunk_size + + q = q * (head_dim**-0.5) + + # Reshape to chunks: (batch, heads, num_chunks, chunk_size, head_dim) + q, k, v, g = (x.reshape(batch_size, num_heads, num_chunks, chunk_size, head_dim) for x in (q, k, v, g)) + # beta: (batch, heads, num_chunks, chunk_size) + beta = beta.reshape(batch_size, num_heads, num_chunks, chunk_size) + + # Cumulative sum of log-decays within each chunk (over the position dimension). + g = g.cumsum(dim=-2) # batch, heads, num_chunks, chunk_size, head_dim + + # Build the per-chunk intra-sequence delta-rule transform matrix A. + # decay_matrix[..., c, i, d] = exp(g[c, d] - g[i, d]) โ€” decay from position i to position c. + # g.unsqueeze(-2): (batch, heads, num_chunks, chunk_size, 1, head_dim) โ€” "c" positions + # g.unsqueeze(-3): (batch, heads, num_chunks, 1, chunk_size, head_dim) โ€” "i" positions + decay_matrix = (g.unsqueeze(-2) - g.unsqueeze(-3)).exp() + # intra_chunk_A[..., c, i] = sum_d(k[c, d] * k[i, d] * decay_matrix[c, i, d]) + intra_chunk_A = (k.unsqueeze(-2) * k.unsqueeze(-3) * decay_matrix).sum(-1) + # Multiply each row c by beta[c] (write gate applied before delta-rule correction). + intra_chunk_A = intra_chunk_A * beta.unsqueeze(-1) + # Mask upper triangular (including diagonal) and flip sign for delta-rule update. + upper_triangular_mask = torch.triu( + torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=0 + ) + intra_chunk_A = -intra_chunk_A.masked_fill(upper_triangular_mask, 0) + # Iterative delta-rule refinement. + for chunk_pos in range(1, chunk_size): + row = intra_chunk_A[..., chunk_pos, :chunk_pos].clone() + above = intra_chunk_A[..., :chunk_pos, :chunk_pos].clone() + intra_chunk_A[..., chunk_pos, :chunk_pos] = row + (row.unsqueeze(-1) * above).sum(-2) + # Add identity and multiply each column i by beta[i] (write gate applied after correction). + intra_chunk_A = ( + intra_chunk_A + torch.eye(chunk_size, dtype=intra_chunk_A.dtype, device=q.device) + ) * beta.unsqueeze(-2) + + # Precompute per-chunk write keys and corrected values for the recurrent state update. + # intra_chunk_w[..., c, d] = sum_i(A[c, i] * exp(g[i, d]) * k[i, d]) + intra_chunk_w = intra_chunk_A @ (g.exp() * k) # batch, heads, num_chunks, chunk_size, head_dim + # intra_chunk_u[..., c, d] = sum_i(A[c, i] * v[i, d]) + intra_chunk_u = intra_chunk_A @ v # batch, heads, num_chunks, chunk_size, head_dim + + # Precompute intra-chunk causal attention scores. + # intra_chunk_attn[..., c, j] = sum_d(q[c, d] * k[j, d] * decay_matrix[c, j, d]) + intra_chunk_attn = (q.unsqueeze(-2) * k.unsqueeze(-3) * decay_matrix).sum(-1) + causal_mask = torch.triu(torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=q.device), diagonal=1) + intra_chunk_attn = intra_chunk_attn.masked_fill(causal_mask, 0) + + if initial_state is None: + recurrent_state = torch.zeros(batch_size, num_heads, head_dim, head_dim, device=q.device, dtype=q.dtype) + else: + recurrent_state = initial_state.to(q) + + output = torch.zeros_like(intra_chunk_u) + for chunk_index in range(num_chunks): + q_chunk = q[:, :, chunk_index] # batch, heads, chunk_size, head_dim + k_chunk = k[:, :, chunk_index] + u_chunk = intra_chunk_u[:, :, chunk_index] + w_chunk = intra_chunk_w[:, :, chunk_index] + g_chunk = g[:, :, chunk_index] + attn_chunk = intra_chunk_attn[:, :, chunk_index] + + # Remove the state's contribution from the intra-chunk corrected values. + value_corrected = u_chunk - w_chunk @ recurrent_state # batch, heads, chunk_size, head_dim + # Cross-chunk contribution: queries attend to the recurrent state via the cumulative decay. + cross_chunk_output = (q_chunk * g_chunk.exp()) @ recurrent_state + output[:, :, chunk_index] = cross_chunk_output + attn_chunk @ value_corrected + + # Decay the state by the cumulative log-decay at the last position in the chunk. + last_g = g_chunk[:, :, -1] # batch, heads, head_dim + recurrent_state = recurrent_state * last_g.exp().unsqueeze(-1) # broadcast over value dim + # Write new key-value associations, weighted by the decay from each position to the chunk end. + inter_chunk_decay = (last_g.unsqueeze(-2) - g_chunk).exp() # batch, heads, chunk_size, head_dim + recurrent_state = recurrent_state + (inter_chunk_decay * k_chunk).transpose(-1, -2) @ value_corrected + + if not output_final_state: + recurrent_state = None + + # Remove padding and restore sequence-first layout. + output = output.reshape(batch_size, num_heads, padded_sequence_length, head_dim) + output = output[:, :, :sequence_length] + output = output.transpose(1, 2).contiguous().to(input_dtype) + return output, recurrent_state + + +def torch_chunk_kda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + chunk_size: int = 64, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + use_qk_l2norm_in_kernel: bool = False, + cu_seqlens: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None]: + if cu_seqlens is None: + return _torch_chunk_kda_single( + q, + k, + v, + g, + beta, + chunk_size=chunk_size, + initial_state=initial_state, + output_final_state=output_final_state, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + sequence_boundaries = cu_seqlens.tolist() + outputs = [] + for seq_start, seq_end in zip(sequence_boundaries, sequence_boundaries[1:]): + out, _ = _torch_chunk_kda_single( + q[:, seq_start:seq_end], + k[:, seq_start:seq_end], + v[:, seq_start:seq_end], + g[:, seq_start:seq_end], + beta[:, seq_start:seq_end], + chunk_size=chunk_size, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + outputs.append(out) + return torch.cat(outputs, dim=1), None + + class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): """ Implementation of the Kimi Delta Attention mixer. @@ -46,11 +223,6 @@ def __init__( super().__init__( config, distributed_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, return_bias=return_bias ) - if not _kda_available: - raise ImportError( - "KimiDeltaAttention requires the `fla-core` package. " - "Please install it with `pip install -U fla-core`." - ) self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) self._heads_dim = TensorDim( @@ -188,6 +360,17 @@ def __init__( peft=self._peft, ) + if _kda_available and distributed_config.use_cuda: + self._chunk_kda = chunk_kda + self._kda_gate = fused_kda_gate + else: + logger.warning( + "Fast implementation for KimiDeltaAttention is not available. " + "Please ensure that 'fla-core' is properly installed." + ) + self._chunk_kda = torch_chunk_kda + self._kda_gate = torch_kda_gate + def _forward( self, input_: torch.Tensor, @@ -226,9 +409,9 @@ def _forward( g_kernel = ( self.f_b_proj(self.f_a_proj(input_)).unsqueeze(0).unflatten(-1, (self._local_heads, self._config.head_dim)) ) - g_kernel = fused_kda_gate(g_kernel, self.A_log.float(), dt_bias=self.dt_bias) + g_kernel = self._kda_gate(g_kernel, self.A_log.float(), dt_bias=self.dt_bias) - out, _ = chunk_kda( + out, _ = self._chunk_kda( q=q, k=k, v=v, diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index 7613fa67f..9c31ec80f 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -10,6 +10,7 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.layers.decoder.config import MixerConfig from fast_llm.layers.ssm.config import GatedDeltaNetConfig, KimiDeltaAttentionConfig, MambaConfig +from fast_llm.layers.ssm.gdn import _fast_gdn_available from fast_llm.layers.ssm.kda import _kda_available from fast_llm.utils import Assert from tests.utils.utils import get_stage @@ -24,6 +25,7 @@ except ImportError: Apriel2GatedDeltaNet = None Apriel2Mamba = None + KimiDeltaAttention = None is_fast_path_available = False HIDDEN_SIZE = 16 @@ -98,7 +100,20 @@ def _compare_mixers( @pytest.mark.slow # Arguments ('seq_idx',) not implemented for torch implementation of 1d convolution. @pytest.mark.skipif(not is_fast_path_available, reason="GDN deps missing") -def test_gdn(testing_device): +@pytest.mark.parametrize( + "use_backup", + [ + pytest.param(False, marks=pytest.mark.skipif(not _fast_gdn_available, reason="FLA not available")), + True, + ], + ids=["fast", "backup"], +) +def test_gdn(testing_device, use_backup, monkeypatch): + if use_backup: + import fast_llm.layers.ssm.gdn as gdn_module + + monkeypatch.setattr(gdn_module, "_fast_gdn_available", False) + dtype = torch.bfloat16 NUM_V_HEADS = 4 @@ -120,12 +135,27 @@ def test_gdn(testing_device): .eval() ) fast_llm_config = GatedDeltaNetConfig.from_dict(config_common, {"normalization": {"epsilon": 1e-5}}) - _compare_mixers(fast_llm_config, hf_layer, {}) + # The backup uses float32 arithmetic while the reference uses the FLA kernel, so + # bfloat16-level numerical differences are expected; use a looser threshold. + _compare_mixers(fast_llm_config, hf_layer, {}, threshold=1e-2 if use_backup else 1e-5) @pytest.mark.slow -@pytest.mark.skipif(not _kda_available, reason="KDA fused kernels not available") -def test_kda(): +@pytest.mark.skipif(KimiDeltaAttention is None, reason="KDA external model not available") +@pytest.mark.parametrize( + "use_backup", + [ + pytest.param(False, marks=pytest.mark.skipif(not _kda_available, reason="KDA fused kernels not available")), + True, + ], + ids=["fast", "backup"], +) +def test_kda(testing_device, use_backup, monkeypatch): + if use_backup: + import fast_llm.layers.ssm.kda as kda_module + + monkeypatch.setattr(kda_module, "_kda_available", False) + NUM_HEADS = 4 HEAD_DIM = 4 KERNEL_SIZE = 4 @@ -141,7 +171,9 @@ def test_kda(): fast_llm_config = KimiDeltaAttentionConfig.from_dict(kda_config, {}) - _compare_mixers(fast_llm_config, hf_layer, {}) + # The backup uses float32 arithmetic while the reference uses FLA kernels, so + # bfloat16-level numerical differences are expected; use a looser threshold. + _compare_mixers(fast_llm_config, hf_layer, {}, threshold=1e-2 if use_backup else 1e-5) @pytest.mark.slow From 75b6fd371f7a8edcb833596b59b14b93b57c931d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 31 Mar 2026 17:01:40 -0400 Subject: [PATCH 2/2] Upgrade to transformers 5.x - Bump transformers requirement to >=5.4.0,<6.0.0. - Adapt HuggingfaceModelConfig to transformers 5.x API: use __post_init__ instead of __init__, dtype instead of torch_dtype, and move fast_llm_config type narrowing to TYPE_CHECKING guards. - Import no_init_weights from transformers.initialization (moved in 5.x). - Support transformers 5.x rope_parameters format in LlamaAttentionConverter import_config/export_config, keeping backward compatibility with 4.x format. - Update external models and tests accordingly. Co-Authored-By: Claude Sonnet 4.6 --- fast_llm/engine/inference/config.py | 22 +-- fast_llm/engine/inference/huggingface.py | 5 +- fast_llm/models/gpt/conversion/llama.py | 142 +++++++++++------- fast_llm/models/gpt/huggingface.py | 4 +- fast_llm/models/multimodal/huggingface.py | 4 +- .../apriel2/modeling_apriel2.py | 3 +- .../mtp_llama/modeling_mtp_llama.py | 20 ++- setup.cfg | 2 +- tests/models/test_checkpoint.py | 6 +- 9 files changed, 129 insertions(+), 79 deletions(-) diff --git a/fast_llm/engine/inference/config.py b/fast_llm/engine/inference/config.py index d19e2478d..7ab8667be 100644 --- a/fast_llm/engine/inference/config.py +++ b/fast_llm/engine/inference/config.py @@ -16,16 +16,16 @@ class HuggingfaceModelConfig(transformers.PretrainedConfig): model_type = "fast_llm" model_config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig + fast_llm_config: FastLLMModelConfig | None = None + use_cache: bool = True - def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs): + def __post_init__(self, **kwargs): # Needed for `to_diff_dict` (`__repr__`) - if fast_llm_config is None: - fast_llm_config = self.model_config_class() - self.fast_llm_config = fast_llm_config - self.use_cache = kwargs.pop("use_cache", True) - super().__init__(**kwargs) - if self.torch_dtype is not None: - assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch + if self.fast_llm_config is None: + self.fast_llm_config = self.model_config_class() + super().__post_init__(**kwargs) + if self.dtype is not None: + assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None: # Hack the method to save at the right place. @@ -88,9 +88,9 @@ def _get_config_dict( ) metadata = cls.model_config_class.load_metadata(pretrained) updates = {} - torch_dtype = kwargs.pop("torch_dtype", None) - if torch_dtype is not None: - updates[("distributed", "compute_dtype")] = torch_dtype + dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None)) + if dtype is not None: + updates[("distributed", "compute_dtype")] = dtype fast_llm_config = cls.model_config_class.from_metadata( pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates ) diff --git a/fast_llm/engine/inference/huggingface.py b/fast_llm/engine/inference/huggingface.py index 5a07bd51b..27b33933b 100644 --- a/fast_llm/engine/inference/huggingface.py +++ b/fast_llm/engine/inference/huggingface.py @@ -7,6 +7,7 @@ import transformers.generation.utils import transformers.modeling_outputs import transformers.utils.generic +from transformers.initialization import no_init_weights as transformers_no_init_weights from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat @@ -38,7 +39,7 @@ def __init__( **kwargs, ): if config is None: - config = self.config_class(fast_llm_model.config) + config = self.config_class(fast_llm_config=fast_llm_model.config) assert self.runner_class.model_class.config_class is config.model_config_class assert config.fast_llm_config is fast_llm_model.config @@ -70,7 +71,7 @@ def __init__( # Transformers needs to be able to inspect the base model. self.fast_llm_base_model = fast_llm_model.base_model - with transformers.modeling_utils.no_init_weights(): + with transformers_no_init_weights(): self.post_init() if fast_llm_model.config.multi_stage.zero_stage == 3: diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 38dc38586..6737d172e 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -188,36 +188,68 @@ def import_weight( class LlamaAttentionConverter: @classmethod def import_config(cls, config: dict) -> dict: - try: - rope_type = config["rope_scaling"]["rope_type"] - except (KeyError, TypeError): - rope_type = "default" - rotary_config = { - "type": rope_type, - "theta": config["rope_theta"], - } - if rope_type == "default": - pass - elif rope_type == "llama3": - rotary_config.update( - { - "scale_factor": config["rope_scaling"]["factor"], - "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], - "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], - } - ) - elif rope_type == "yarn": - rotary_config.update( - { - "attention_factor": config["rope_scaling"]["attention_factor"], - "beta_fast": config["rope_scaling"]["beta_fast"], - "beta_slow": config["rope_scaling"]["beta_slow"], - "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], - } - ) + # transformers 5.x consolidates rope_theta + rope_scaling into rope_parameters + if "rope_parameters" in config: + rope_params = config["rope_parameters"] + rope_type = rope_params.get("rope_type", "default") + rotary_config = { + "type": rope_type, + "theta": rope_params["rope_theta"], + } + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": rope_params["factor"], + "low_frequency_factor": rope_params["low_freq_factor"], + "high_frequency_factor": rope_params["high_freq_factor"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": rope_params["attention_factor"], + "beta_fast": rope_params["beta_fast"], + "beta_slow": rope_params["beta_slow"], + "original_context_length": rope_params["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") else: - raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") + # transformers 4.x format: rope_theta at top level, rope_scaling separate + try: + rope_type = config["rope_scaling"]["rope_type"] + except (KeyError, TypeError): + rope_type = "default" + rotary_config = { + "type": rope_type, + "theta": config["rope_theta"], + } + if rope_type == "default": + pass + elif rope_type == "llama3": + rotary_config.update( + { + "scale_factor": config["rope_scaling"]["factor"], + "low_frequency_factor": config["rope_scaling"]["low_freq_factor"], + "high_frequency_factor": config["rope_scaling"]["high_freq_factor"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], + } + ) + elif rope_type == "yarn": + rotary_config.update( + { + "attention_factor": config["rope_scaling"]["attention_factor"], + "beta_fast": config["rope_scaling"]["beta_fast"], + "beta_slow": config["rope_scaling"]["beta_slow"], + "original_context_length": config["rope_scaling"]["original_max_position_embeddings"], + } + ) + else: + raise NotImplementedError(f"Unsupported rotary type: {rope_type}") out = { "rotary": rotary_config, "heads": config["num_attention_heads"], @@ -235,36 +267,40 @@ def import_config(cls, config: dict) -> dict: def export_config(cls, config: AttentionConfig) -> dict: cls._check_config(config) Assert.eq(config.softmax_scale_power, 0.5) - out = { - "num_attention_heads": config.heads, - "num_key_value_heads": config.head_groups, - "head_dim": config.head_size, - "attention_bias": config.add_linear_biases, - "attention_dropout": config.dropout, - "rope_theta": config.rotary.theta, - } + rope_parameters = {"rope_theta": config.rotary.theta} if type(config.rotary) is DefaultRotaryConfig: - pass + rope_parameters["rope_type"] = "default" elif type(config.rotary) is Llama3RotaryConfig: - out["rope_scaling"] = { - "rope_type": "llama3", - "factor": config.rotary.scale_factor, - "low_freq_factor": config.rotary.low_frequency_factor, - "high_freq_factor": config.rotary.high_frequency_factor, - "original_max_position_embeddings": config.rotary.original_context_length, - } + rope_parameters.update( + { + "rope_type": "llama3", + "factor": config.rotary.scale_factor, + "low_freq_factor": config.rotary.low_frequency_factor, + "high_freq_factor": config.rotary.high_frequency_factor, + "original_max_position_embeddings": config.rotary.original_context_length, + } + ) elif type(config.rotary) is YarnRotaryConfig: - out["rope_scaling"] = { - "rope_type": "yarn", - "attention_factor": config.rotary.attention_factor, - "beta_fast": config.rotary.beta_fast, - "beta_slow": config.rotary.beta_slow, - "original_max_position_embeddings": config.rotary.original_context_length, - } + rope_parameters.update( + { + "rope_type": "yarn", + "attention_factor": config.rotary.attention_factor, + "beta_fast": config.rotary.beta_fast, + "beta_slow": config.rotary.beta_slow, + "original_max_position_embeddings": config.rotary.original_context_length, + } + ) else: raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}") - return out + return { + "num_attention_heads": config.heads, + "num_key_value_heads": config.head_groups, + "head_dim": config.head_size, + "attention_bias": config.add_linear_biases, + "attention_dropout": config.dropout, + "rope_parameters": rope_parameters, + } @classmethod def _check_config(cls, config: AttentionConfig) -> None: diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index 55c30c7ee..1fcb3fc25 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -20,7 +20,9 @@ class HuggingfaceGPTModelConfig(HuggingfaceModelConfig): model_type = "fast_llm_gpt" model_config_class = GPTModelConfig - fast_llm_config: GPTModelConfig + + if typing.TYPE_CHECKING: + fast_llm_config: GPTModelConfig class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel): diff --git a/fast_llm/models/multimodal/huggingface.py b/fast_llm/models/multimodal/huggingface.py index 8bf14d715..93770b446 100644 --- a/fast_llm/models/multimodal/huggingface.py +++ b/fast_llm/models/multimodal/huggingface.py @@ -22,7 +22,9 @@ class HuggingfaceMultiModalModelConfig(HuggingfaceGPTModelConfig): model_type = "fast_llm_multi_modal" model_config_class = MultiModalModelConfig - fast_llm_config: MultiModalModelConfig + + if typing.TYPE_CHECKING: + fast_llm_config: MultiModalModelConfig class HuggingfaceMultiModalModelForCausalLM(HuggingfaceGPTModelForCausalLM): diff --git a/fast_llm_external_models/apriel2/modeling_apriel2.py b/fast_llm_external_models/apriel2/modeling_apriel2.py index 9e82dfc4f..2e5933e79 100644 --- a/fast_llm_external_models/apriel2/modeling_apriel2.py +++ b/fast_llm_external_models/apriel2/modeling_apriel2.py @@ -289,7 +289,7 @@ def get_mask_sizes(self, cache_position, layer_idx): For SSM/linear layers: kv_offset = 0, kv_length = query_length (no KV cache to attend to) """ - query_length = cache_position.shape[0] + query_length = cache_position if isinstance(cache_position, int) else cache_position.shape[0] layer = self.layers[layer_idx] # Handle stochastic layers by getting the active mixer's cache @@ -794,6 +794,7 @@ def setup( hidden_size=hidden_size, num_attention_heads=num_heads, partial_rotary_factor=1.0, + rope_parameters={"rope_theta": rope_theta, "rope_type": "default"}, ) return nn.ModuleDict({"rotary_emb": MistralRotaryEmbedding(config=rotary_config)}) diff --git a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py index d0e1988f1..4c8909eaf 100644 --- a/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py +++ b/fast_llm_external_models/mtp_llama/modeling_mtp_llama.py @@ -56,21 +56,29 @@ def extra_repr(self): class LlamaRotaryEmbedding(nn.Module): def __init__(self, config: MTPLlamaConfig, device=None): super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and config.rope_scaling is not None: - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" + self.rope_type = config.rope_parameters.get("rope_type", "default") self.max_seq_len_cached = config.max_position_embeddings self.original_max_seq_len = config.max_position_embeddings self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] + if self.rope_type == "default": + self.rope_init_fn = self.compute_default_rope_parameters + else: + self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) self.register_buffer("inv_freq", inv_freq, persistent=False) self.original_inv_freq = self.inv_freq + @staticmethod + def compute_default_rope_parameters(config, device=None, seq_len=None): + base = config.rope_parameters["rope_theta"] + dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads + inv_freq = 1.0 / ( + base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim) + ) + return inv_freq, 1.0 + def _dynamic_frequency_update(self, position_ids, device): """ dynamic RoPE layers should recompute `inv_freq` in the following situations: diff --git a/setup.cfg b/setup.cfg index 955702907..aa1fb6ed8 100644 --- a/setup.cfg +++ b/setup.cfg @@ -42,7 +42,7 @@ OPTIONAL = # Huggingface tools HUGGINGFACE = - transformers>=4.57.3,<5.0.0 + transformers>=5.4.0,<6.0.0 hf-transfer>=0.1.9 datasets>=4.4.1 huggingface-hub>=0.36.0 diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 5f0f5a80f..3501ff1c2 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -391,13 +391,13 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic hidden_states = output.hidden_states + (output.logits,) # Llava models doesn't return vision hidden states, so we run the vision model directly instead. if model_testing_config.model_type == "multimodal": - if hasattr(model, "vision_tower"): - vision_output = model.vision_tower( + if hasattr(model.model, "vision_tower"): + vision_output = model.model.vision_tower( pixel_values=kwargs["pixel_values"], image_sizes=kwargs["image_sizes"], output_hidden_states=True, ) - adapter_output = model.multi_modal_projector(vision_output.hidden_states[-1]) + adapter_output = model.model.multi_modal_projector(vision_output.hidden_states[-1]) hidden_states = vision_output.hidden_states + (adapter_output,) + hidden_states hidden_states_ref_ = hidden_states_ref.copy() # Adjust the vision hidden states