forked from inclusionAI/dInfer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodeling_llada2_moe.py
More file actions
executable file
·2003 lines (1668 loc) · 88.2 KB
/
modeling_llada2_moe.py
File metadata and controls
executable file
·2003 lines (1668 loc) · 88.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2025 Antgroup and The HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch BailingMoE model."""
import math
import warnings
from typing import List, Optional, Tuple, Union, Literal
import torch
import torch.nn.functional as F
from torch import nn
import tqdm
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.modeling_outputs import MoeModelOutputWithPast
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from transformers.utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
logging,
replace_return_docstrings,
)
from transformers.utils.import_utils import is_torch_fx_available
from .configuration_llada2_moe import LLaDA2MoeConfig
from torch.nn.modules.normalization import RMSNorm
import torch.distributed as dist
from ..decoding.utils import KVCache
from transformers.generation.utils import GenerationMixin
from dataclasses import dataclass
from transformers.utils import ModelOutput
from pathlib import Path
import json
from safetensors.torch import load_file
from functools import partial
from vllm.model_executor.layers.fused_moe import FusedMoE
import re
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
ReplicatedLinear,
QKVParallelLinear,
RowParallelLinear)
def torch_all_reduce(tensor):
torch.distributed.all_reduce(tensor)
return tensor
import vllm.distributed as vllm_distributed
vllm_distributed.tensor_model_parallel_all_reduce = torch_all_reduce
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "LLaDA2MoeConfig"
def roll_tensor(tensor, shifts=-1, dims=-1, fill_value=0):
"""Roll the tensor input along the given dimension(s).
Inserted elements are set to be 0.0.
"""
rolled_tensor = torch.roll(tensor, shifts=shifts, dims=dims)
rolled_tensor.select(dims, shifts).fill_(fill_value)
return rolled_tensor, rolled_tensor.sum()
def replace_linear_class(
linear: nn.Linear, style: Literal["colwise", "rowwise", "qkv"],
quant_config, model_config
) -> Union[ColumnParallelLinear, RowParallelLinear]:
"""
Replace nn.Linear with one of vLLM's tensor parallel linear classes.
Args:
linear (nn.Linear): `nn.Linear` to be replaced.
style (str): Tensor parallel style of the new linear, e.g. "colwise".
quant_config (QuantConfig): Quantization config for the new linear.
Returns:
Union[ColumnParallelLinear, RowParallelLinear]: The new linear.
"""
if not isinstance(style, str):
raise ValueError(
f"Unsupported parallel style type {type(style)}, expected str")
vllm_linear_cls = {
"colwise": ColumnParallelLinear,
"rowwise": RowParallelLinear,
"qkv": QKVParallelLinear
}.get(style, ReplicatedLinear)
if style != "qkv":
return vllm_linear_cls(
input_size=linear.in_features,
output_size=linear.out_features,
bias=linear.bias is not None,
quant_config=quant_config,
return_bias=False,
)
else:
return QKVParallelLinear(
hidden_size = model_config.hidden_size,
head_size=model_config.head_dim,
total_num_heads=model_config.num_attention_heads,
total_num_kv_heads=model_config.num_key_value_heads,
bias=linear.bias is not None,
quant_config=quant_config,
return_bias=False,
)
def _all_gather_cat(
tensor: torch.Tensor,
dim: int = 1,
group: Optional[dist.ProcessGroup] = None,
normal_len: int = 0,
last_len: int = 0,
) -> torch.Tensor:
"""
Gather tensors along `dim` from all ranks and concatenate them.
Only the last chunk may be shorter than `normal_len`; all others are exactly `normal_len`.
Args:
tensor: local tensor on current rank
dim: dimension along which to concatenate
normal_len: length of the first (world_size-1) ranks along `dim`
last_len: length of the last rank along `dim`
Returns:
Concatenated tensor of shape [total_len, ...] along `dim`
"""
world_size = dist.get_world_size(group)
rank = dist.get_rank(group)
if world_size == 1:
return tensor
# 1. Move the concatenation dimension to 0 for easier all_gather
tensor = tensor.movedim(dim, 0) # [L_local, ...]
L_local = tensor.size(0)
# 2. Compute global length across all ranks
total_len = normal_len * (world_size - 1) + last_len
# 3. Pre-allocate receive buffers (same shape for all ranks, sized for the largest chunk)
max_len = max(normal_len, last_len)
gather_list = [
torch.empty([max_len] + list(tensor.shape[1:]),
dtype=tensor.dtype,
device=tensor.device)
for _ in range(world_size)
]
# 4. Copy local data into the corresponding buffer (only first L_local rows are valid)
gather_list[rank][:L_local] = tensor
# 5. All-gather (communicate only valid parts)
dist.all_gather(gather_list, gather_list[rank], group=group)
# 6. Trim padding and concatenate
gathered = torch.cat(gather_list, dim=0)[:total_len]
# 7. Move dimension back to original position
return gathered.movedim(0, dim)
class H2Embed:
def __init__(self, embedding: nn.Embedding, tau: float = 1.0):
"""
W_e : token embedding weights [V, d]
tau : temperature; lower values yield sharper distributions
"""
self.embedding = embedding
self.W_e = embedding.weight
self.tau = tau
self.sp_size = 1 # no sequence parallel by default
def __call__(
self,
x: torch.Tensor,
mask_index: Optional[torch.Tensor] = None,
logits: Optional[torch.Tensor] = None,
iter_cont_weight: float = 0.0
) -> torch.Tensor:
"""
Args:
x: [B, L] token ids
mask_index: [B, L] bool tensor, True where continuous embedding should be used
logits: [B, L, V] logits used to produce continuous embeddings
iter_cont_weight: blending weight between continuous and discrete embeddings
Returns:
Embedded representations [B, L, d]
"""
rank = get_tensor_model_parallel_rank()
world_size = get_tensor_model_parallel_world_size()
seq_len = x.shape[1]
# If sequence parallel is enabled, each rank handles a slice of the sequence
if self.sp_size > 1:
normal_seq_len = (seq_len + self.sp_size - 1) // self.sp_size
last_seq_len = seq_len - normal_seq_len * (self.sp_size - 1)
part_start = normal_seq_len * rank
part_end = min(normal_seq_len * (rank + 1), seq_len)
x_part = x[:, part_start:part_end]
if mask_index is not None:
mask_part = mask_index[:, part_start:part_end]
logits_part = logits[:, part_start:part_end] if logits is not None else None
else:
mask_part = None
logits_part = None
else:
x_part = x
mask_part = mask_index
logits_part = logits
# Base discrete embedding
result_part = self.embedding(x_part)
# Replace selected positions with continuous embeddings
if mask_part is not None and logits_part is not None:
prob = torch.softmax(logits_part / self.tau, dim=-1) # [B, L_part, V]
input_embeds_h = prob @ self.W_e # [B, L_part, d]
# Blend continuous and discrete embeddings
result_part = torch.where(
mask_part.unsqueeze(-1),
iter_cont_weight * input_embeds_h + 1 * result_part,
result_part
)
# 4. Gather and concatenate sequence slices across ranks
if self.sp_size > 1:
out = _all_gather_cat(
result_part,
dim=1,
group=None,
normal_len=normal_seq_len,
last_len=last_seq_len
)
else:
out = result_part
return out
@dataclass
class MoEV2CausalLMOutputWithPast(ModelOutput):
"""
Base class for causal language model (or autoregressive) outputs as well as Mixture of Expert's router hidden
states terms, to train a MoE model.
Args:
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
`past_key_values` input) to speed up sequential decoding.
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +
one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
sequence_length)`.
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
heads.
z_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
z_loss for the sparse modules.
aux_loss (`torch.FloatTensor`, *optional*, returned when `labels` is provided):
aux_loss for the sparse modules.
router_logits (`tuple(torch.FloatTensor)`, *optional*, returned when `output_router_logits=True` is passed or when `config.add_router_probs=True`):
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, sequence_length, num_experts)`.
Router logits of the encoder model, useful to compute the auxiliary loss and the z_loss for the sparse
modules.
"""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
attentions: Optional[tuple[torch.FloatTensor, ...]] = None
z_loss: Optional[torch.FloatTensor] = None
aux_loss: Optional[torch.FloatTensor] = None
router_logits: Optional[tuple[torch.FloatTensor]] = None
mtp_loss: Optional[torch.FloatTensor] = None
mtp_logits: Optional[tuple[torch.FloatTensor, ...]] = None
class MoeV2ModelOutputWithPast(MoeModelOutputWithPast):
def __init__(self, mtp_hidden_states=None, **kwargs):
super().__init__(**kwargs)
self.mtp_hidden_states = mtp_hidden_states
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
indices,
cu_seqlens,
max_seqlen_in_batch,
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
warnings.warn(
"Calling `transformers.models.LLaDA2Moe.modeling_LLaDA2Moe._prepare_4d_attention_mask` is deprecated and will be removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
)
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
warnings.warn(
"Calling `transformers.models.LLaDA2Moe.modeling_LLaDA2Moe._make_causal_mask` is deprecated and will be removed in v4.37. Use `transformers.models.LLaDA2Moe.modeling_LLaDA2Moe.AttentionMaskConverter._make_causal_mask"
)
return AttentionMaskConverter._make_causal_mask(
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
)
class LLaDA2MoeRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LLaDA2MoeRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
ALL_LAYERNORM_LAYERS.append(LLaDA2MoeRMSNorm)
class LLaDA2MoeRotaryEmbedding(nn.Module):
def __init__(self, config: LLaDA2MoeConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
# Keep half or full tensor for later concatenation
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
# Apply rotary embeddings on the first half or full tensor
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
# Concatenate back to full shape
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
class LLaDA2MoeMLP(nn.Module):
def __init__(self, config: LLaDA2MoeConfig, intermediate_size: int):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class LLaDA2MoeGate(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.top_k = config.num_experts_per_tok
self.num_experts = config.num_experts
self.n_group = config.n_group
self.topk_group = config.topk_group
# topk selection algorithm
self.gating_dim = config.hidden_size
self.weight = nn.Parameter(torch.empty((self.num_experts, self.gating_dim)))
self.routed_scaling_factor = config.routed_scaling_factor
self.register_buffer("expert_bias", torch.zeros((self.num_experts)))
self.reset_parameters()
def reset_parameters(self) -> None:
import torch.nn.init as init
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def group_limited_topk(
self,
scores: torch.Tensor,
):
num_tokens, _ = scores.size()
# Organize the experts into groups
group_scores = scores.view(num_tokens, self.n_group, -1).topk(2, dim=-1)[0].sum(dim=-1)
group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=False)[1]
group_mask = torch.zeros_like(group_scores)
group_mask.scatter_(1, group_idx, 1)
# Mask the experts based on selection groups
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_tokens, self.n_group, self.num_experts // self.n_group)
.reshape(num_tokens, -1)
)
masked_scores = scores.masked_fill(~score_mask.bool(), float('-inf'))
probs, top_indices = torch.topk(masked_scores, k=self.top_k, dim=-1)
return probs, top_indices
def forward(self, hidden_states):
# compute gating score
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
scores = torch.sigmoid(logits.float()).type_as(logits)
scores_for_routing = scores + self.expert_bias
_, topk_idx = self.group_limited_topk(scores_for_routing)
scores = torch.gather(scores, dim=1, index=topk_idx).type_as(logits)
topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
topk_weight = topk_weight * self.routed_scaling_factor
return topk_idx, topk_weight, logits
def get_logits(self, hidden_states):
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
logits = F.linear(hidden_states.type(torch.float32), self.weight.type(torch.float32))
return logits
def routing(self, hidden_states, gating_output, topk, renormalize):
scores = torch.sigmoid(gating_output.float()).type_as(gating_output)
scores_for_routing = scores + self.expert_bias
_, topk_idx = self.group_limited_topk(scores_for_routing)
scores = torch.gather(scores, dim=1, index=topk_idx).type_as(gating_output)
topk_weight = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if self.top_k > 1 else scores
topk_weight = topk_weight * self.routed_scaling_factor
return topk_weight, topk_idx
def static_routing_function(gate, hidden_states, gating_output, topk, renormalize):
return gate.routing(hidden_states, gating_output, topk, renormalize)
class LLaDA2MoeSparseMoeBlock(nn.Module):
"""A tensor-parallel MoE implementation for Olmoe that shards each expert
across all ranks.
Each expert's weights are sharded across all ranks and a fused MoE
kernel is used for the forward pass, and finally we reduce the outputs
across ranks.
"""
def __init__(self,
config,
prefix: str = ""):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.norm_topk_prob = config.norm_topk_prob
# Gate always runs at half / full precision for now.
self.gate = LLaDA2MoeGate(config)
# print('config.num_shared_experts', config.num_shared_experts)
if config.num_shared_experts is not None:
# print('config.num_shared_experts is not None!')
self.shared_experts = LLaDA2MoeMLP(
config=config, intermediate_size=config.moe_intermediate_size * config.num_shared_experts
)
# custom_routing = partial(custom_routing_function, gate=self.gate)
self.experts = FusedMoE(num_experts=self.num_experts,
top_k=self.top_k,
hidden_size=self.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=True,
quant_config=None,
tp_size=None,
custom_routing_function=partial(static_routing_function, self.gate),
prefix=f"{prefix}.experts")
# This is a hack. expert_map in FusedMoE isn't moved to GPU by default.
# We have to register it explicitly so that it can be moved to GPU with FusedMoE
expert_map = self.experts.expert_map
del self.experts.expert_map
self.experts.register_buffer('expert_map', expert_map)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# print(" mlp", "input", hidden_states.flatten()[:10].cpu())
res = self.shared_experts(hidden_states)
# print(" mlp", "initial identity", identity.flatten()[:10].cpu())
bsz, seq_len, h = hidden_states.shape
hidden_states_flat = hidden_states.view(-1, h)
router_logits = self.gate.get_logits(hidden_states_flat)
# print(" mlp", "router_logits", router_logits.flatten()[:10].cpu())
y = self.experts.forward_impl(hidden_states=hidden_states_flat,
router_logits=router_logits)
y = y.view(bsz, seq_len, h)
# y = hidden_states
# print(" mlp", "after experts", y.flatten()[:10].cpu())
if self.config.num_shared_experts is not None:
# print('config.num_shared_experts is not None!')
# print(" mlp", "shared_experts identity", identity.flatten()[:10].cpu())
y = y + res
# print(" mlp", "after shared_experts", y.flatten()[:10].cpu())
return y
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->LLaDA2Moe
class LLaDA2MoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LLaDA2MoeConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = config.head_dim or self.hidden_size // self.num_heads
partial_rotary_factor = config.partial_rotary_factor if hasattr(config, "partial_rotary_factor") else 1.0
self.rope_dim = int(self.head_dim * partial_rotary_factor)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = False
self.tp_size = 1
self.query_key_value = nn.Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_key_value_heads) * self.head_dim,
bias=config.use_qkv_bias,
)
# if self.config.use_qk_norm:
self.query_layernorm = LLaDA2MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.key_layernorm = LLaDA2MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.dense = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.use_bias)
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split(
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# if self.config.use_qk_norm:
query_states = self.query_layernorm(query_states)
key_states = self.key_layernorm(key_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
kv_seq_len = key_states.shape[-2]
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, -1)
attn_output = self.dense(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
# Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2 with Llama->LLaDA2Moe
class LLaDA2MoeFlashAttention2(LLaDA2MoeAttention):
"""
LLaDA2Moe flash attention module. This module inherits from `LLaDA2MoeAttention` as the weights of the module stays
untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
flash attention and deal with padding tokens in case the input contains any of them.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
# flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
# Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# LLaDA2MoeFlashAttention2 attention does not support output_attentions
output_attentions = False
bsz, q_len, _ = hidden_states.size()
# Flash attention requires the input to have the shape
# batch_size x seq_length x head_dim x hidden_dim
# therefore we just need to keep the original shape
qkv = self.query_key_value(hidden_states)
qkv = qkv.view(bsz, q_len, self.num_heads + 2 * self.num_key_value_heads, self.head_dim)
query_states, key_states, value_states = qkv.split(
[self.num_heads, self.num_key_value_heads, self.num_key_value_heads], dim=-2
)
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# if self.config.use_qk_norm:
query_states = self.query_layernorm(query_states)
key_states = self.key_layernorm(key_states)
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
# TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
# to be able to avoid many of these transpose/reshape/view.
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
dropout_rate = self.attention_dropout if self.training else 0.0
# In PEFT, usually we cast the layer norms in float32 for training stability reasons
# therefore the input hidden states gets silently cast in float32. Hence, we need
# cast them back in the correct dtype just to be sure everything works as expected.
# This might slow down training & inference so it is recommended to not cast the LayerNorms
# in fp32. (LLaDA2MoeRMSNorm handles it correctly)
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# Handle the case where the model is quantized
if hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
elif torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
else:
target_dtype = self.query_key_value.weight.dtype
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
attn_output = self._flash_attention_forward(
query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.dense(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
):
"""
Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
first unpad the input, then computes the attention scores and pad the final attention scores.
Args:
query_states (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key_states (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value_states (`torch.Tensor`):
Input value states to be passed to Flash Attention API
attention_mask (`torch.Tensor`):
The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
position of padding tokens and 1 for the position of non-padding tokens.
dropout (`int`, *optional*):
Attention dropout
softmax_scale (`float`, *optional*):
The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
query_length (`int`):
The length of the query sequence in terms of tokens. This represents the number of tokens in the
`query_states` tensor along the sequence dimension. It is used to determine the effective sequence
length for attention computations.
"""
if not self._flash_attn_uses_top_left_mask:
causal = self.is_causal
else:
# TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LLaDA2MoeFlashAttention2 __init__.
causal = self.is_causal and query_length != 1
# Contains at least one padding token in the sequence
if attention_mask is not None:
batch_size = query_states.shape[0]
query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
query_states, key_states, value_states, attention_mask, query_length
)
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
attn_output_unpad = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_in_batch_q,
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=causal,
)
attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
)
return attn_output
def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
key_layer = index_first_axis(
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
value_layer = index_first_axis(
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
)
if query_length == kv_seq_len:
query_layer = index_first_axis(
query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
)
cu_seqlens_q = cu_seqlens_k
max_seqlen_in_batch_q = max_seqlen_in_batch_k
indices_q = indices_k
elif query_length == 1:
max_seqlen_in_batch_q = 1
cu_seqlens_q = torch.arange(
batch_size + 1, dtype=torch.int32, device=query_layer.device
) # There is a memcpy here, that is very bad.
indices_q = cu_seqlens_q[:-1]
query_layer = query_layer.squeeze(1)
else:
# The -q_len: slice assumes left padding.
attention_mask = attention_mask[:, -query_length:]
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
return (
query_layer,
key_layer,
value_layer,
indices_q,
(cu_seqlens_q, cu_seqlens_k),
(max_seqlen_in_batch_q, max_seqlen_in_batch_k),
)
# Copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->LLaDA2Moe
class LLaDA2MoeSdpaAttention(LLaDA2MoeAttention):
"""
LLaDA2Moe attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
`LLaDA2MoeAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
SDPA API.
"""
# Adapted from LLaDA2MoeAttention.forward
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
cache_position: Optional[torch.LongTensor] = None,
replace_position= None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if output_attentions: