Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions fast_llm/engine/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,16 @@
class HuggingfaceModelConfig(transformers.PretrainedConfig):
model_type = "fast_llm"
model_config_class: typing.ClassVar[type[FastLLMModelConfig]] = FastLLMModelConfig
fast_llm_config: FastLLMModelConfig | None = None
use_cache: bool = True

def __init__(self, fast_llm_config: FastLLMModelConfig | None = None, **kwargs):
def __post_init__(self, **kwargs):
# Needed for `to_diff_dict` (`__repr__`)
if fast_llm_config is None:
fast_llm_config = self.model_config_class()
self.fast_llm_config = fast_llm_config
self.use_cache = kwargs.pop("use_cache", True)
super().__init__(**kwargs)
if self.torch_dtype is not None:
assert self.torch_dtype == self.fast_llm_config.distributed.compute_dtype.torch
if self.fast_llm_config is None:
self.fast_llm_config = self.model_config_class()
super().__post_init__(**kwargs)
if self.dtype is not None:
assert self.dtype == self.fast_llm_config.distributed.compute_dtype.torch

def save_pretrained(self, save_directory: str | os.PathLike, push_to_hub: bool = False, **kwargs) -> None:
# Hack the method to save at the right place.
Expand Down Expand Up @@ -88,9 +88,9 @@ def _get_config_dict(
)
metadata = cls.model_config_class.load_metadata(pretrained)
updates = {}
torch_dtype = kwargs.pop("torch_dtype", None)
if torch_dtype is not None:
updates[("distributed", "compute_dtype")] = torch_dtype
dtype = kwargs.pop("dtype", kwargs.pop("torch_dtype", None))
if dtype is not None:
updates[("distributed", "compute_dtype")] = dtype
fast_llm_config = cls.model_config_class.from_metadata(
pretrained, metadata, default=kwargs.pop("fast_llm_config", None), updates=updates
)
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import transformers.generation.utils
import transformers.modeling_outputs
import transformers.utils.generic
from transformers.initialization import no_init_weights as transformers_no_init_weights

from fast_llm.core.distributed import broadcast, broadcast_object, safe_barrier
from fast_llm.engine.checkpoint.config import CheckpointLoadConfig, FastLLMCheckpointFormat
Expand Down Expand Up @@ -38,7 +39,7 @@ def __init__(
**kwargs,
):
if config is None:
config = self.config_class(fast_llm_model.config)
config = self.config_class(fast_llm_config=fast_llm_model.config)

assert self.runner_class.model_class.config_class is config.model_config_class
assert config.fast_llm_config is fast_llm_model.config
Expand Down Expand Up @@ -70,7 +71,7 @@ def __init__(
# Transformers needs to be able to inspect the base model.
self.fast_llm_base_model = fast_llm_model.base_model

with transformers.modeling_utils.no_init_weights():
with transformers_no_init_weights():
self.post_init()

if fast_llm_model.config.multi_stage.zero_stage == 3:
Expand Down
142 changes: 89 additions & 53 deletions fast_llm/models/gpt/conversion/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,36 +188,68 @@ def import_weight(
class LlamaAttentionConverter:
@classmethod
def import_config(cls, config: dict) -> dict:
try:
rope_type = config["rope_scaling"]["rope_type"]
except (KeyError, TypeError):
rope_type = "default"
rotary_config = {
"type": rope_type,
"theta": config["rope_theta"],
}
if rope_type == "default":
pass
elif rope_type == "llama3":
rotary_config.update(
{
"scale_factor": config["rope_scaling"]["factor"],
"low_frequency_factor": config["rope_scaling"]["low_freq_factor"],
"high_frequency_factor": config["rope_scaling"]["high_freq_factor"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
}
)
elif rope_type == "yarn":
rotary_config.update(
{
"attention_factor": config["rope_scaling"]["attention_factor"],
"beta_fast": config["rope_scaling"]["beta_fast"],
"beta_slow": config["rope_scaling"]["beta_slow"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
}
)
# transformers 5.x consolidates rope_theta + rope_scaling into rope_parameters
if "rope_parameters" in config:
rope_params = config["rope_parameters"]
rope_type = rope_params.get("rope_type", "default")
rotary_config = {
"type": rope_type,
"theta": rope_params["rope_theta"],
}
if rope_type == "default":
pass
elif rope_type == "llama3":
rotary_config.update(
{
"scale_factor": rope_params["factor"],
"low_frequency_factor": rope_params["low_freq_factor"],
"high_frequency_factor": rope_params["high_freq_factor"],
"original_context_length": rope_params["original_max_position_embeddings"],
}
)
elif rope_type == "yarn":
rotary_config.update(
{
"attention_factor": rope_params["attention_factor"],
"beta_fast": rope_params["beta_fast"],
"beta_slow": rope_params["beta_slow"],
"original_context_length": rope_params["original_max_position_embeddings"],
}
)
else:
raise NotImplementedError(f"Unsupported rotary type: {rope_type}")
else:
raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}")
# transformers 4.x format: rope_theta at top level, rope_scaling separate
try:
rope_type = config["rope_scaling"]["rope_type"]
except (KeyError, TypeError):
rope_type = "default"
rotary_config = {
"type": rope_type,
"theta": config["rope_theta"],
}
if rope_type == "default":
pass
elif rope_type == "llama3":
rotary_config.update(
{
"scale_factor": config["rope_scaling"]["factor"],
"low_frequency_factor": config["rope_scaling"]["low_freq_factor"],
"high_frequency_factor": config["rope_scaling"]["high_freq_factor"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
}
)
elif rope_type == "yarn":
rotary_config.update(
{
"attention_factor": config["rope_scaling"]["attention_factor"],
"beta_fast": config["rope_scaling"]["beta_fast"],
"beta_slow": config["rope_scaling"]["beta_slow"],
"original_context_length": config["rope_scaling"]["original_max_position_embeddings"],
}
)
else:
raise NotImplementedError(f"Unsupported rotary type: {rope_type}")
out = {
"rotary": rotary_config,
"heads": config["num_attention_heads"],
Expand All @@ -235,36 +267,40 @@ def import_config(cls, config: dict) -> dict:
def export_config(cls, config: AttentionConfig) -> dict:
cls._check_config(config)
Assert.eq(config.softmax_scale_power, 0.5)
out = {
"num_attention_heads": config.heads,
"num_key_value_heads": config.head_groups,
"head_dim": config.head_size,
"attention_bias": config.add_linear_biases,
"attention_dropout": config.dropout,
"rope_theta": config.rotary.theta,
}
rope_parameters = {"rope_theta": config.rotary.theta}
if type(config.rotary) is DefaultRotaryConfig:
pass
rope_parameters["rope_type"] = "default"
elif type(config.rotary) is Llama3RotaryConfig:
out["rope_scaling"] = {
"rope_type": "llama3",
"factor": config.rotary.scale_factor,
"low_freq_factor": config.rotary.low_frequency_factor,
"high_freq_factor": config.rotary.high_frequency_factor,
"original_max_position_embeddings": config.rotary.original_context_length,
}
rope_parameters.update(
{
"rope_type": "llama3",
"factor": config.rotary.scale_factor,
"low_freq_factor": config.rotary.low_frequency_factor,
"high_freq_factor": config.rotary.high_frequency_factor,
"original_max_position_embeddings": config.rotary.original_context_length,
}
)
elif type(config.rotary) is YarnRotaryConfig:
out["rope_scaling"] = {
"rope_type": "yarn",
"attention_factor": config.rotary.attention_factor,
"beta_fast": config.rotary.beta_fast,
"beta_slow": config.rotary.beta_slow,
"original_max_position_embeddings": config.rotary.original_context_length,
}
rope_parameters.update(
{
"rope_type": "yarn",
"attention_factor": config.rotary.attention_factor,
"beta_fast": config.rotary.beta_fast,
"beta_slow": config.rotary.beta_slow,
"original_max_position_embeddings": config.rotary.original_context_length,
}
)
else:
raise NotImplementedError(f"Unsupported rotary type: {type(config.rotary).__name__}")

return out
return {
"num_attention_heads": config.heads,
"num_key_value_heads": config.head_groups,
"head_dim": config.head_size,
"attention_bias": config.add_linear_biases,
"attention_dropout": config.dropout,
"rope_parameters": rope_parameters,
}

@classmethod
def _check_config(cls, config: AttentionConfig) -> None:
Expand Down
4 changes: 3 additions & 1 deletion fast_llm/models/gpt/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@
class HuggingfaceGPTModelConfig(HuggingfaceModelConfig):
model_type = "fast_llm_gpt"
model_config_class = GPTModelConfig
fast_llm_config: GPTModelConfig

if typing.TYPE_CHECKING:
fast_llm_config: GPTModelConfig


class HuggingfaceGPTModelForCausalLM(HuggingfacePreTrainedModel):
Expand Down
4 changes: 3 additions & 1 deletion fast_llm/models/multimodal/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
class HuggingfaceMultiModalModelConfig(HuggingfaceGPTModelConfig):
model_type = "fast_llm_multi_modal"
model_config_class = MultiModalModelConfig
fast_llm_config: MultiModalModelConfig

if typing.TYPE_CHECKING:
fast_llm_config: MultiModalModelConfig


class HuggingfaceMultiModalModelForCausalLM(HuggingfaceGPTModelForCausalLM):
Expand Down
3 changes: 2 additions & 1 deletion fast_llm_external_models/apriel2/modeling_apriel2.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def get_mask_sizes(self, cache_position, layer_idx):
For SSM/linear layers:
kv_offset = 0, kv_length = query_length (no KV cache to attend to)
"""
query_length = cache_position.shape[0]
query_length = cache_position if isinstance(cache_position, int) else cache_position.shape[0]
layer = self.layers[layer_idx]

# Handle stochastic layers by getting the active mixer's cache
Expand Down Expand Up @@ -794,6 +794,7 @@ def setup(
hidden_size=hidden_size,
num_attention_heads=num_heads,
partial_rotary_factor=1.0,
rope_parameters={"rope_theta": rope_theta, "rope_type": "default"},
)
return nn.ModuleDict({"rotary_emb": MistralRotaryEmbedding(config=rotary_config)})

Expand Down
20 changes: 14 additions & 6 deletions fast_llm_external_models/mtp_llama/modeling_mtp_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,29 @@ def extra_repr(self):
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, config: MTPLlamaConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.rope_type = config.rope_parameters.get("rope_type", "default")
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings

self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
if self.rope_type == "default":
self.rope_init_fn = self.compute_default_rope_parameters
else:
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq

@staticmethod
def compute_default_rope_parameters(config, device=None, seq_len=None):
base = config.rope_parameters["rope_theta"]
dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
)
return inv_freq, 1.0

def _dynamic_frequency_update(self, position_ids, device):
"""
dynamic RoPE layers should recompute `inv_freq` in the following situations:
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ OPTIONAL =

# Huggingface tools
HUGGINGFACE =
transformers>=4.57.3,<5.0.0
transformers>=5.4.0,<6.0.0
hf-transfer>=0.1.9
datasets>=4.4.1
huggingface-hub>=0.36.0
Expand Down
6 changes: 3 additions & 3 deletions tests/models/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,13 +391,13 @@ def test_huggingface_model(model_testing_config, get_convert_path, testing_devic
hidden_states = output.hidden_states + (output.logits,)
# Llava models doesn't return vision hidden states, so we run the vision model directly instead.
if model_testing_config.model_type == "multimodal":
if hasattr(model, "vision_tower"):
vision_output = model.vision_tower(
if hasattr(model.model, "vision_tower"):
vision_output = model.model.vision_tower(
pixel_values=kwargs["pixel_values"],
image_sizes=kwargs["image_sizes"],
output_hidden_states=True,
)
adapter_output = model.multi_modal_projector(vision_output.hidden_states[-1])
adapter_output = model.model.multi_modal_projector(vision_output.hidden_states[-1])
hidden_states = vision_output.hidden_states + (adapter_output,) + hidden_states
hidden_states_ref_ = hidden_states_ref.copy()
# Adjust the vision hidden states
Expand Down
Loading