Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/backend/server/static_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit",
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
# Qwen 3.5 MoE Series
"Qwen/Qwen3.5-0.8B": "Qwen/Qwen3.5-0.8B",
"Qwen/Qwen3.5-35B-A3B": "mlx-community/Qwen3.5-35B-A3B-4bit",
# Qwen 3.6 Series
"Qwen/Qwen3.6-35B-A3B": "mlx-community/Qwen3.6-35B-A3B-4bit",
"Qwen/Qwen3.6-27B": "mlx-community/Qwen3.6-27B-mxfp4",
# Qwen 3 Large MoE Models
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit",
Expand Down
29 changes: 18 additions & 11 deletions src/parallax/server/shard_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,12 @@
MODEL_CLASS_MAP = {
"kimi_k2": "mlx_lm.models.deepseek_v3",
"minimax_m2": "mlx_lm.models.minimax",
"qwen3_5_moe": "mlx_lm.models.qwen3_5",
}

ARCHITECTURE_CLASS_ALIASES = {
"GlmMoeDsaForCausalLM": "DeepseekV32ForCausalLM",
"Qwen3_5MoeForConditionalGeneration": "Qwen3_5ForConditionalGeneration",
}


Expand Down Expand Up @@ -310,6 +312,21 @@ def _cast_weight_array(weight_array: mx.array, dtype: mx.Dtype) -> mx.array:
return weight_array.astype(dtype)
return weight_array

@staticmethod
def _load_mlx_lm_module_and_args(model_type: str, config: Dict[str, Any]):
if model_type in MODEL_CLASS_MAP:
model_class = MODEL_CLASS_MAP[model_type]
else:
model_class = f"mlx_lm.models.{model_type}"

arch_module = importlib.import_module(model_class)
if hasattr(arch_module, "TextModelArgs"):
model_args_class = getattr(arch_module, "TextModelArgs")
else:
model_args_class = getattr(arch_module, "ModelArgs")

return arch_module, model_args_class.from_dict(config)

def load(
self, lazy: bool = False, strict: bool = False, use_selective_download: bool = True
) -> Tuple[nn.Module, Dict[str, Any], Any]:
Expand Down Expand Up @@ -366,18 +383,8 @@ def load(
if not model_type:
raise ValueError("model_type not found in config.json")

if model_type in MODEL_CLASS_MAP:
model_class = MODEL_CLASS_MAP[model_type]
else:
model_class = f"mlx_lm.models.{model_type}"

try:
arch_module = importlib.import_module(model_class)
if model_type == "qwen3_5" and hasattr(arch_module, "TextModelArgs"):
model_args_class = getattr(arch_module, "TextModelArgs")
else:
model_args_class = getattr(arch_module, "ModelArgs")
model_args = model_args_class.from_dict(config)
arch_module, model_args = self._load_mlx_lm_module_and_args(model_type, config)
self.arch_module = arch_module
self.model_args = model_args

Expand Down
83 changes: 80 additions & 3 deletions src/parallax/utils/model_download.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import json
import logging
import os
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional, Set

from huggingface_hub import hf_hub_download as _hf_hub_download
from huggingface_hub import snapshot_download as _snapshot_download
from modelscope import snapshot_download as _ms_snapshot_download
from modelscope.hub.file_download import model_file_download as _ms_model_file_download

from parallax.utils.weight_filter_utils import (
determine_needed_weight_files_for_download,
normalize_language_model_weight_key,
should_include_weight_key,
)

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -96,7 +98,7 @@ def selective_model_download(
if start_layer is not None and end_layer is not None:
logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})")

needed_weight_files = determine_needed_weight_files_for_download(
needed_weight_files = _determine_needed_weight_files_for_download(
model_path=model_path,
start_layer=start_layer,
end_layer=end_layer,
Expand Down Expand Up @@ -149,5 +151,80 @@ def selective_model_download(
]


def _determine_needed_weight_files_for_download(
model_path: Path,
start_layer: int,
end_layer: int,
config: Optional[Dict] = None,
) -> List[str]:
is_first_shard = start_layer == 0

is_last_shard = False
if config:
num_hidden_layers = config.get("num_hidden_layers", 0)
is_last_shard = end_layer >= num_hidden_layers
else:
config_file = model_path / "config.json"
if config_file.exists():
from parallax.utils.utils import normalize_model_config

with open(config_file, "r") as f:
cfg = normalize_model_config(json.load(f))
num_hidden_layers = cfg.get("num_hidden_layers", 0)
is_last_shard = end_layer >= num_hidden_layers

index_file = model_path / "model.safetensors.index.json"

if not index_file.exists():
logger.debug(f"Index file not found at {index_file}, checking for single weight file")
# For non-sharded models, look for single weight file
single_weight_files = [
"model.safetensors",
"pytorch_model.bin",
"model.bin",
]
for weight_file in single_weight_files:
if (model_path / weight_file).exists():
logger.debug(f"Found single weight file: {weight_file}")
return [weight_file]

logger.debug("No weight files found (neither index nor single file)")
return []

with open(index_file, "r") as f:
index_data = json.load(f)

weight_map = index_data.get("weight_map", {})
if not weight_map:
logger.debug("weight_map is empty in index file")
return []

tie_word_embeddings = False
if config:
tie_word_embeddings = config.get("tie_word_embeddings", False)

needed_files: Set[str] = set()

for key, filename in weight_map.items():
if filename in needed_files:
continue
key = normalize_language_model_weight_key(key)
if should_include_weight_key(
key=key,
start_layer=start_layer,
end_layer=end_layer,
is_first_shard=is_first_shard,
is_last_shard=is_last_shard,
tie_word_embeddings=tie_word_embeddings,
):
needed_files.add(filename)

result = sorted(list(needed_files))
logger.debug(
f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})"
)
return result


def _use_modelscope() -> bool:
return _USE_MODELSCOPE_ENV in os.environ
2 changes: 1 addition & 1 deletion src/parallax/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ def load_config_only(name: str, local_files_only: bool = False):
def normalize_model_config(config: dict) -> dict:
"""Expose nested text model fields at the top level for VLM-style configs."""
text_config = config.get("text_config")
if config.get("model_type") == "qwen3_5" and isinstance(text_config, dict):
if config.get("model_type") in {"qwen3_5", "qwen3_5_moe"} and isinstance(text_config, dict):
normalized = {**config, **text_config}
normalized["model_type"] = config["model_type"]
normalized["architectures"] = config.get("architectures", normalized.get("architectures"))
Expand Down
73 changes: 0 additions & 73 deletions src/parallax/utils/weight_filter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,76 +116,3 @@ def filter_weight_files_by_layer_range_for_load(
)

return filtered_files


def determine_needed_weight_files_for_download(
model_path: Path,
start_layer: int,
end_layer: int,
config: Optional[Dict] = None,
) -> List[str]:
is_first_shard = start_layer == 0

is_last_shard = False
if config:
num_hidden_layers = config.get("num_hidden_layers", 0)
is_last_shard = end_layer >= num_hidden_layers
else:
config_file = model_path / "config.json"
if config_file.exists():
with open(config_file, "r") as f:
cfg = json.load(f)
num_hidden_layers = cfg.get("num_hidden_layers", 0)
is_last_shard = end_layer >= num_hidden_layers

index_file = model_path / "model.safetensors.index.json"

if not index_file.exists():
logger.debug(f"Index file not found at {index_file}, checking for single weight file")
# For non-sharded models, look for single weight file
single_weight_files = [
"model.safetensors",
"pytorch_model.bin",
"model.bin",
]
for weight_file in single_weight_files:
if (model_path / weight_file).exists():
logger.debug(f"Found single weight file: {weight_file}")
return [weight_file]

logger.debug("No weight files found (neither index nor single file)")
return []

with open(index_file, "r") as f:
index_data = json.load(f)

weight_map = index_data.get("weight_map", {})
if not weight_map:
logger.debug("weight_map is empty in index file")
return []

tie_word_embeddings = False
if config:
tie_word_embeddings = config.get("tie_word_embeddings", False)

needed_files: Set[str] = set()

for key, filename in weight_map.items():
if filename in needed_files:
continue
key = normalize_language_model_weight_key(key)
if should_include_weight_key(
key=key,
start_layer=start_layer,
end_layer=end_layer,
is_first_shard=is_first_shard,
is_last_shard=is_last_shard,
tie_word_embeddings=tie_word_embeddings,
):
needed_files.add(filename)

result = sorted(list(needed_files))
logger.debug(
f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})"
)
return result
86 changes: 86 additions & 0 deletions tests/test_shard_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Tests for the shard_loader module.
"""

import json
import sys
from unittest.mock import Mock, patch

Expand All @@ -14,6 +15,8 @@
MLXModelLoader,
normalize_language_model_weight_key,
)
from parallax.utils.model_download import _determine_needed_weight_files_for_download
from parallax.utils.utils import normalize_model_config
from parallax.utils.weight_filter_utils import should_include_weight_key


Expand All @@ -27,10 +30,20 @@ def test_normalize_nested_language_model_weight_keys():
normalize_language_model_weight_key("model.language_model.layers.12.mlp.up_proj.weight")
== "model.layers.12.mlp.up_proj.weight"
)
assert (
normalize_language_model_weight_key(
"language_model.model.layers.12.mlp.switch_mlp.up_proj.weight"
)
== "model.layers.12.mlp.switch_mlp.up_proj.weight"
)
assert (
normalize_language_model_weight_key("model.language_model.norm.weight")
== "model.norm.weight"
)
assert (
normalize_language_model_weight_key("language_model.model.norm.weight")
== "model.norm.weight"
)
assert (
normalize_language_model_weight_key("model.language_model.lm_head.weight")
== "lm_head.weight"
Expand Down Expand Up @@ -111,6 +124,79 @@ def test_mlx_lm_sanitize_uses_local_layer_keys_for_shards():
) == ["layers.0.linear_attn.conv1d.weight"]


def test_qwen35_moe_uses_qwen35_text_args_and_sanitizer_module():
loader = MLXModelLoader("test_model_path")
config = normalize_model_config(
{
"model_type": "qwen3_5_moe",
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
"text_config": {
"model_type": "qwen3_5_moe_text",
"hidden_size": 2048,
"num_hidden_layers": 40,
"num_attention_heads": 16,
"num_key_value_heads": 2,
"vocab_size": 248320,
"num_experts": 256,
"num_experts_per_tok": 8,
"moe_intermediate_size": 512,
},
}
)

sanitizer_module, model_args = loader._load_mlx_lm_module_and_args("qwen3_5_moe", config)

assert MODEL_CLASS_MAP["qwen3_5_moe"] == "mlx_lm.models.qwen3_5"
assert sanitizer_module.__name__ == "mlx_lm.models.qwen3_5"
assert model_args.num_hidden_layers == 40
assert model_args.hidden_size == 2048
assert model_args.num_experts == 256
assert model_args.num_experts_per_tok == 8
assert model_args.moe_intermediate_size == 512


@pytest.mark.skipif(sys.platform != "darwin", reason="MLX tests require macOS")
def test_register_block_class_includes_qwen35_moe():
loader = MLXModelLoader("test_model_path")

assert "Qwen3_5MoeForConditionalGeneration" in loader.block_class_map


def test_selective_download_uses_nested_qwen35_moe_num_layers(tmp_path):
(tmp_path / "config.json").write_text(
json.dumps(
{
"model_type": "qwen3_5_moe",
"text_config": {
"num_hidden_layers": 40,
"tie_word_embeddings": False,
},
}
)
)
(tmp_path / "model.safetensors.index.json").write_text(
json.dumps(
{
"weight_map": {
"language_model.model.layers.39.linear_attn.in_proj_qkv.weight": (
"layers-39.safetensors"
),
"language_model.model.norm.weight": "final.safetensors",
"language_model.lm_head.weight": "final.safetensors",
}
}
)
)

needed_files = _determine_needed_weight_files_for_download(
tmp_path,
start_layer=39,
end_layer=40,
)

assert needed_files == ["final.safetensors", "layers-39.safetensors"]


@pytest.mark.skipif(sys.platform != "darwin", reason="MLX tests require macOS")
class TestMLXModelLoader:
"""Test MLXModelLoader functionality."""
Expand Down
Loading
Loading