diff --git a/src/backend/server/static_config.py b/src/backend/server/static_config.py index beb73891..60dede44 100644 --- a/src/backend/server/static_config.py +++ b/src/backend/server/static_config.py @@ -84,7 +84,11 @@ "Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit", "Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit", "Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit", + # Qwen 3.5 MoE Series + "Qwen/Qwen3.5-0.8B": "Qwen/Qwen3.5-0.8B", + "Qwen/Qwen3.5-35B-A3B": "mlx-community/Qwen3.5-35B-A3B-4bit", # Qwen 3.6 Series + "Qwen/Qwen3.6-35B-A3B": "mlx-community/Qwen3.6-35B-A3B-4bit", "Qwen/Qwen3.6-27B": "mlx-community/Qwen3.6-27B-mxfp4", # Qwen 3 Large MoE Models "Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit", diff --git a/src/parallax/server/shard_loader.py b/src/parallax/server/shard_loader.py index 3a25d1a6..341d8112 100644 --- a/src/parallax/server/shard_loader.py +++ b/src/parallax/server/shard_loader.py @@ -33,10 +33,12 @@ MODEL_CLASS_MAP = { "kimi_k2": "mlx_lm.models.deepseek_v3", "minimax_m2": "mlx_lm.models.minimax", + "qwen3_5_moe": "mlx_lm.models.qwen3_5", } ARCHITECTURE_CLASS_ALIASES = { "GlmMoeDsaForCausalLM": "DeepseekV32ForCausalLM", + "Qwen3_5MoeForConditionalGeneration": "Qwen3_5ForConditionalGeneration", } @@ -310,6 +312,21 @@ def _cast_weight_array(weight_array: mx.array, dtype: mx.Dtype) -> mx.array: return weight_array.astype(dtype) return weight_array + @staticmethod + def _load_mlx_lm_module_and_args(model_type: str, config: Dict[str, Any]): + if model_type in MODEL_CLASS_MAP: + model_class = MODEL_CLASS_MAP[model_type] + else: + model_class = f"mlx_lm.models.{model_type}" + + arch_module = importlib.import_module(model_class) + if hasattr(arch_module, "TextModelArgs"): + model_args_class = getattr(arch_module, "TextModelArgs") + else: + model_args_class = getattr(arch_module, "ModelArgs") + + return arch_module, model_args_class.from_dict(config) + def load( self, lazy: bool = False, strict: bool = False, use_selective_download: bool = True ) -> Tuple[nn.Module, Dict[str, Any], Any]: @@ -366,18 +383,8 @@ def load( if not model_type: raise ValueError("model_type not found in config.json") - if model_type in MODEL_CLASS_MAP: - model_class = MODEL_CLASS_MAP[model_type] - else: - model_class = f"mlx_lm.models.{model_type}" - try: - arch_module = importlib.import_module(model_class) - if model_type == "qwen3_5" and hasattr(arch_module, "TextModelArgs"): - model_args_class = getattr(arch_module, "TextModelArgs") - else: - model_args_class = getattr(arch_module, "ModelArgs") - model_args = model_args_class.from_dict(config) + arch_module, model_args = self._load_mlx_lm_module_and_args(model_type, config) self.arch_module = arch_module self.model_args = model_args diff --git a/src/parallax/utils/model_download.py b/src/parallax/utils/model_download.py index f91eeab7..1f709a4f 100644 --- a/src/parallax/utils/model_download.py +++ b/src/parallax/utils/model_download.py @@ -1,7 +1,8 @@ +import json import logging import os from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional, Set from huggingface_hub import hf_hub_download as _hf_hub_download from huggingface_hub import snapshot_download as _snapshot_download @@ -9,7 +10,8 @@ from modelscope.hub.file_download import model_file_download as _ms_model_file_download from parallax.utils.weight_filter_utils import ( - determine_needed_weight_files_for_download, + normalize_language_model_weight_key, + should_include_weight_key, ) logger = logging.getLogger(__name__) @@ -96,7 +98,7 @@ def selective_model_download( if start_layer is not None and end_layer is not None: logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})") - needed_weight_files = determine_needed_weight_files_for_download( + needed_weight_files = _determine_needed_weight_files_for_download( model_path=model_path, start_layer=start_layer, end_layer=end_layer, @@ -149,5 +151,80 @@ def selective_model_download( ] +def _determine_needed_weight_files_for_download( + model_path: Path, + start_layer: int, + end_layer: int, + config: Optional[Dict] = None, +) -> List[str]: + is_first_shard = start_layer == 0 + + is_last_shard = False + if config: + num_hidden_layers = config.get("num_hidden_layers", 0) + is_last_shard = end_layer >= num_hidden_layers + else: + config_file = model_path / "config.json" + if config_file.exists(): + from parallax.utils.utils import normalize_model_config + + with open(config_file, "r") as f: + cfg = normalize_model_config(json.load(f)) + num_hidden_layers = cfg.get("num_hidden_layers", 0) + is_last_shard = end_layer >= num_hidden_layers + + index_file = model_path / "model.safetensors.index.json" + + if not index_file.exists(): + logger.debug(f"Index file not found at {index_file}, checking for single weight file") + # For non-sharded models, look for single weight file + single_weight_files = [ + "model.safetensors", + "pytorch_model.bin", + "model.bin", + ] + for weight_file in single_weight_files: + if (model_path / weight_file).exists(): + logger.debug(f"Found single weight file: {weight_file}") + return [weight_file] + + logger.debug("No weight files found (neither index nor single file)") + return [] + + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data.get("weight_map", {}) + if not weight_map: + logger.debug("weight_map is empty in index file") + return [] + + tie_word_embeddings = False + if config: + tie_word_embeddings = config.get("tie_word_embeddings", False) + + needed_files: Set[str] = set() + + for key, filename in weight_map.items(): + if filename in needed_files: + continue + key = normalize_language_model_weight_key(key) + if should_include_weight_key( + key=key, + start_layer=start_layer, + end_layer=end_layer, + is_first_shard=is_first_shard, + is_last_shard=is_last_shard, + tie_word_embeddings=tie_word_embeddings, + ): + needed_files.add(filename) + + result = sorted(list(needed_files)) + logger.debug( + f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})" + ) + return result + + def _use_modelscope() -> bool: return _USE_MODELSCOPE_ENV in os.environ diff --git a/src/parallax/utils/utils.py b/src/parallax/utils/utils.py index 240f8967..014e3bb0 100644 --- a/src/parallax/utils/utils.py +++ b/src/parallax/utils/utils.py @@ -303,7 +303,7 @@ def load_config_only(name: str, local_files_only: bool = False): def normalize_model_config(config: dict) -> dict: """Expose nested text model fields at the top level for VLM-style configs.""" text_config = config.get("text_config") - if config.get("model_type") == "qwen3_5" and isinstance(text_config, dict): + if config.get("model_type") in {"qwen3_5", "qwen3_5_moe"} and isinstance(text_config, dict): normalized = {**config, **text_config} normalized["model_type"] = config["model_type"] normalized["architectures"] = config.get("architectures", normalized.get("architectures")) diff --git a/src/parallax/utils/weight_filter_utils.py b/src/parallax/utils/weight_filter_utils.py index 71d73d0f..5065d61b 100644 --- a/src/parallax/utils/weight_filter_utils.py +++ b/src/parallax/utils/weight_filter_utils.py @@ -116,76 +116,3 @@ def filter_weight_files_by_layer_range_for_load( ) return filtered_files - - -def determine_needed_weight_files_for_download( - model_path: Path, - start_layer: int, - end_layer: int, - config: Optional[Dict] = None, -) -> List[str]: - is_first_shard = start_layer == 0 - - is_last_shard = False - if config: - num_hidden_layers = config.get("num_hidden_layers", 0) - is_last_shard = end_layer >= num_hidden_layers - else: - config_file = model_path / "config.json" - if config_file.exists(): - with open(config_file, "r") as f: - cfg = json.load(f) - num_hidden_layers = cfg.get("num_hidden_layers", 0) - is_last_shard = end_layer >= num_hidden_layers - - index_file = model_path / "model.safetensors.index.json" - - if not index_file.exists(): - logger.debug(f"Index file not found at {index_file}, checking for single weight file") - # For non-sharded models, look for single weight file - single_weight_files = [ - "model.safetensors", - "pytorch_model.bin", - "model.bin", - ] - for weight_file in single_weight_files: - if (model_path / weight_file).exists(): - logger.debug(f"Found single weight file: {weight_file}") - return [weight_file] - - logger.debug("No weight files found (neither index nor single file)") - return [] - - with open(index_file, "r") as f: - index_data = json.load(f) - - weight_map = index_data.get("weight_map", {}) - if not weight_map: - logger.debug("weight_map is empty in index file") - return [] - - tie_word_embeddings = False - if config: - tie_word_embeddings = config.get("tie_word_embeddings", False) - - needed_files: Set[str] = set() - - for key, filename in weight_map.items(): - if filename in needed_files: - continue - key = normalize_language_model_weight_key(key) - if should_include_weight_key( - key=key, - start_layer=start_layer, - end_layer=end_layer, - is_first_shard=is_first_shard, - is_last_shard=is_last_shard, - tie_word_embeddings=tie_word_embeddings, - ): - needed_files.add(filename) - - result = sorted(list(needed_files)) - logger.debug( - f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})" - ) - return result diff --git a/tests/test_shard_loader.py b/tests/test_shard_loader.py index a10be272..5d93fc3d 100644 --- a/tests/test_shard_loader.py +++ b/tests/test_shard_loader.py @@ -2,6 +2,7 @@ Tests for the shard_loader module. """ +import json import sys from unittest.mock import Mock, patch @@ -14,6 +15,8 @@ MLXModelLoader, normalize_language_model_weight_key, ) +from parallax.utils.model_download import _determine_needed_weight_files_for_download +from parallax.utils.utils import normalize_model_config from parallax.utils.weight_filter_utils import should_include_weight_key @@ -27,10 +30,20 @@ def test_normalize_nested_language_model_weight_keys(): normalize_language_model_weight_key("model.language_model.layers.12.mlp.up_proj.weight") == "model.layers.12.mlp.up_proj.weight" ) + assert ( + normalize_language_model_weight_key( + "language_model.model.layers.12.mlp.switch_mlp.up_proj.weight" + ) + == "model.layers.12.mlp.switch_mlp.up_proj.weight" + ) assert ( normalize_language_model_weight_key("model.language_model.norm.weight") == "model.norm.weight" ) + assert ( + normalize_language_model_weight_key("language_model.model.norm.weight") + == "model.norm.weight" + ) assert ( normalize_language_model_weight_key("model.language_model.lm_head.weight") == "lm_head.weight" @@ -111,6 +124,79 @@ def test_mlx_lm_sanitize_uses_local_layer_keys_for_shards(): ) == ["layers.0.linear_attn.conv1d.weight"] +def test_qwen35_moe_uses_qwen35_text_args_and_sanitizer_module(): + loader = MLXModelLoader("test_model_path") + config = normalize_model_config( + { + "model_type": "qwen3_5_moe", + "architectures": ["Qwen3_5MoeForConditionalGeneration"], + "text_config": { + "model_type": "qwen3_5_moe_text", + "hidden_size": 2048, + "num_hidden_layers": 40, + "num_attention_heads": 16, + "num_key_value_heads": 2, + "vocab_size": 248320, + "num_experts": 256, + "num_experts_per_tok": 8, + "moe_intermediate_size": 512, + }, + } + ) + + sanitizer_module, model_args = loader._load_mlx_lm_module_and_args("qwen3_5_moe", config) + + assert MODEL_CLASS_MAP["qwen3_5_moe"] == "mlx_lm.models.qwen3_5" + assert sanitizer_module.__name__ == "mlx_lm.models.qwen3_5" + assert model_args.num_hidden_layers == 40 + assert model_args.hidden_size == 2048 + assert model_args.num_experts == 256 + assert model_args.num_experts_per_tok == 8 + assert model_args.moe_intermediate_size == 512 + + +@pytest.mark.skipif(sys.platform != "darwin", reason="MLX tests require macOS") +def test_register_block_class_includes_qwen35_moe(): + loader = MLXModelLoader("test_model_path") + + assert "Qwen3_5MoeForConditionalGeneration" in loader.block_class_map + + +def test_selective_download_uses_nested_qwen35_moe_num_layers(tmp_path): + (tmp_path / "config.json").write_text( + json.dumps( + { + "model_type": "qwen3_5_moe", + "text_config": { + "num_hidden_layers": 40, + "tie_word_embeddings": False, + }, + } + ) + ) + (tmp_path / "model.safetensors.index.json").write_text( + json.dumps( + { + "weight_map": { + "language_model.model.layers.39.linear_attn.in_proj_qkv.weight": ( + "layers-39.safetensors" + ), + "language_model.model.norm.weight": "final.safetensors", + "language_model.lm_head.weight": "final.safetensors", + } + } + ) + ) + + needed_files = _determine_needed_weight_files_for_download( + tmp_path, + start_layer=39, + end_layer=40, + ) + + assert needed_files == ["final.safetensors", "layers-39.safetensors"] + + @pytest.mark.skipif(sys.platform != "darwin", reason="MLX tests require macOS") class TestMLXModelLoader: """Test MLXModelLoader functionality.""" diff --git a/tests/test_static_config.py b/tests/test_static_config.py index 24e60b8d..6eeb1039 100644 --- a/tests/test_static_config.py +++ b/tests/test_static_config.py @@ -47,3 +47,46 @@ def fake_load_config_only(model_name, local_files_only=False): assert model_info.num_kv_heads == 4 assert model_info.param_bytes_per_element == 0.5 assert model_info.mlx_param_bytes_per_element == 0.5 + + +def test_qwen3_5_moe_4bit_model_info_uses_text_config(monkeypatch): + def fake_load_config_only(model_name, local_files_only=False): + assert model_name in { + "Qwen/Qwen3.5-35B-A3B", + "mlx-community/Qwen3.5-35B-A3B-4bit", + } + return normalize_model_config( + { + "model_type": "qwen3_5_moe", + "architectures": ["Qwen3_5MoeForConditionalGeneration"], + "quantization": {"bits": 4, "group_size": 64, "mode": "affine"}, + "text_config": { + "num_hidden_layers": 40, + "full_attention_interval": 4, + "head_dim": 256, + "hidden_size": 2048, + "moe_intermediate_size": 512, + "num_attention_heads": 16, + "num_experts": 256, + "num_experts_per_tok": 8, + "num_key_value_heads": 2, + "vocab_size": 248320, + }, + } + ) + + monkeypatch.setattr(static_config, "load_config_only", fake_load_config_only) + + model_info = get_model_info("Qwen/Qwen3.5-35B-A3B") + + assert model_info.num_layers == 40 + assert model_info.mlx_model_name == "mlx-community/Qwen3.5-35B-A3B-4bit" + assert model_info.head_size == 256 + assert model_info.hidden_dim == 2048 + assert model_info.num_attention_heads == 16 + assert model_info.num_kv_heads == 2 + assert model_info.num_local_experts == 256 + assert model_info.num_experts_per_tok == 8 + assert model_info.moe_intermediate_dim == 512 + assert model_info.param_bytes_per_element == 0.5 + assert model_info.mlx_param_bytes_per_element == 0.5