diff --git a/docs/guides/gradient-checkpointing.md b/docs/guides/gradient-checkpointing.md index 58156b8453..1b418a6df9 100644 --- a/docs/guides/gradient-checkpointing.md +++ b/docs/guides/gradient-checkpointing.md @@ -25,13 +25,18 @@ distributed: ### Configure Programmatically ```python -from nemo_automodel.components.distributed.config import FSDP2Config -from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager - -config = FSDP2Config(activation_checkpointing=True) -# device_mesh is created elsewhere (e.g. by the recipe via setup_distributed) -manager = FSDP2Manager(config, device_mesh=device_mesh, moe_mesh=moe_mesh) -model = manager.parallelize(model) +from nemo_automodel import NeMoAutoModelForCausalLM +from nemo_automodel.components.distributed.config import DistributedSetup + +distributed_setup = DistributedSetup.build( + strategy="fsdp2", + activation_checkpointing=True, +) + +model = NeMoAutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-3.2-1B", + distributed_setup=distributed_setup, +) ``` ## Combine with Linear-Cut Cross-Entropy (LC-CE) @@ -82,4 +87,4 @@ automodel examples/llm_finetune/llama3_2/llama_3_2_1b_my_finetune.yaml If we run with the above settings (activation ckpt = on, lc-ce = on, fsdp = on), look for a log line similar to: ``` ... | mem 7.30 GiB | ... -``` \ No newline at end of file +``` diff --git a/examples/convergence/tulu3/model-verification/extract_nemo_activations.py b/examples/convergence/tulu3/model-verification/extract_nemo_activations.py index 53050832b1..889a435768 100644 --- a/examples/convergence/tulu3/model-verification/extract_nemo_activations.py +++ b/examples/convergence/tulu3/model-verification/extract_nemo_activations.py @@ -15,8 +15,8 @@ """Extract per-layer activations from NeMo AutoModel using the training code path. -Uses the same distributed setup as training (EP, FSDP, backend config) and -registers forward hooks on decoder layers to capture hidden states. Saves +Uses the same distributed environment and mesh context as training (EP, FSDP, +backend config) and registers forward hooks on decoder layers to capture hidden states. Saves activations and final logits for comparison against HF Transformers. Run via torchrun: @@ -84,10 +84,10 @@ def main(): sys.argv = ["extract_nemo_activations.py", "--config", args.config] + extra cfg = parse_args_and_load_config() - # --- Distributed setup --- + # --- Distributed environment and mesh context --- from nemo_automodel._transformers.utils import apply_cache_compatibility_patches from nemo_automodel.components.loggers.log_utils import setup_logging - from nemo_automodel.recipes._dist_setup import setup_distributed + from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config from nemo_automodel.recipes.llm.train_ft import build_distributed, build_model from nemo_automodel.shared.te_patches import apply_te_patches @@ -95,9 +95,10 @@ def main(): setup_logging() apply_cache_compatibility_patches() apply_te_patches() - dist_setup = setup_distributed(cfg, world_size=dist_env.world_size) + distributed_setup = create_distributed_setup_from_config(cfg, world_size=dist_env.world_size) + mesh_context = distributed_setup.mesh_context - if dist_setup.cp_size > 1 and cfg.get("model.backend.rope_fusion", False): + if mesh_context.cp_size > 1 and cfg.get("model.backend.rope_fusion", False): cfg.model.backend.rope_fusion = False # --- Build model --- @@ -105,12 +106,7 @@ def main(): cfg.model, cfg_peft=None, seed=cfg.get("seed", 42), - device_mesh=dist_setup.device_mesh, - moe_mesh=dist_setup.moe_mesh, - distributed_config=dist_setup.strategy_config, - pipeline_config=dist_setup.pipeline_config, - cfg_moe=dist_setup.moe_config, - activation_checkpointing=dist_setup.activation_checkpointing, + distributed_setup=distributed_setup, ) model.eval() diff --git a/nemo_automodel/_diffusers/auto_diffusion_pipeline.py b/nemo_automodel/_diffusers/auto_diffusion_pipeline.py index 3ad3a0d495..cb7b46eda4 100644 --- a/nemo_automodel/_diffusers/auto_diffusion_pipeline.py +++ b/nemo_automodel/_diffusers/auto_diffusion_pipeline.py @@ -48,11 +48,10 @@ import torch import torch.nn as nn -from nemo_automodel.components.distributed import parallelizer +from nemo_automodel.components.distributed import DistributedSetup, ParallelismSizes, parallelizer from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config from nemo_automodel.components.distributed.ddp import DDPManager from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager -from nemo_automodel.components.distributed.mesh_utils import create_device_mesh from nemo_automodel.components.distributed.parallelizer import ( HunyuanParallelizationStrategy, WanParallelizationStrategy, @@ -208,15 +207,13 @@ def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager: """ Factory function to create the appropriate parallel manager based on config. - Constructs the proper config objects (FSDP2Config / DDPConfig) and, for FSDP2, - creates the required device mesh before instantiating the manager. This mirrors + Builds a ``DistributedSetup`` via ``DistributedSetup.build(...)``, then instantiates the + requested manager from the setup's strategy config and meshes. This mirrors the pattern used by ``_instantiate_distributed`` in the transformers infrastructure. The manager type is determined by the ``_manager_type`` key in *manager_args*: - - ``'ddp'``: Creates :class:`DDPConfig` + :class:`DDPManager` - - ``'fsdp2'`` (default): Creates :class:`FSDP2Config`, builds a - :class:`DeviceMesh` via :func:`create_device_mesh`, then creates - :class:`FSDP2Manager` + - ``'ddp'``: Creates a DDP ``DistributedSetup`` + ``DDPManager`` + - ``'fsdp2'`` (default): Creates an FSDP2 ``DistributedSetup`` + ``FSDP2Manager`` Args: manager_args: Flat dictionary of arguments. Recognised keys: @@ -224,7 +221,6 @@ def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager: Common: ``_manager_type`` (str): ``'fsdp2'`` or ``'ddp'``. ``activation_checkpointing`` (bool): Enable activation checkpointing. - ``backend`` (str): Distributed backend (default ``'nccl'``). FSDP2-specific (mesh creation): ``world_size`` (int): Total number of processes. @@ -244,46 +240,62 @@ def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager: """ args = manager_args.copy() manager_type = args.pop("_manager_type", "fsdp2").lower() + if "backend" in args: + raise ValueError( + "backend is not a parallel manager option; configure the process group before parallelization." + ) + parallelism = ParallelismSizes( + dp_size=args.get("dp_size"), + dp_replicate_size=args.get("dp_replicate_size"), + tp_size=args.get("tp_size", 1), + pp_size=args.get("pp_size", 1), + cp_size=args.get("cp_size", 1), + ep_size=args.get("ep_size", 1), + ) if manager_type == "ddp": - config = DDPConfig( + distributed_setup = DistributedSetup.build( + strategy=DDPConfig( + activation_checkpointing=args.get("activation_checkpointing", False), + ), + parallelism_sizes=parallelism, activation_checkpointing=args.get("activation_checkpointing", False), - backend=args.get("backend", "nccl"), + world_size=args.get("world_size"), ) - logger.info("[Parallel] Creating DDPManager with config: %s", config) - return DDPManager(config) + logger.info("[Parallel] Creating DDPManager with config: %s", distributed_setup.strategy_config) + return DDPManager(distributed_setup.strategy_config) elif manager_type == "fsdp2": - config = FSDP2Config( + world_size = args.get("world_size") + if world_size is None: + world_size = torch.distributed.get_world_size() + + distributed_setup = DistributedSetup.build( + strategy=FSDP2Config( + mp_policy=args["mp_policy"] if "mp_policy" in args else None, + sequence_parallel=args.get("sequence_parallel", False), + tp_plan=args.get("tp_plan", None), + patch_is_packed_sequence=args.get("patch_is_packed_sequence", False), + offload_policy=args.get("offload_policy", None), + defer_fsdp_grad_sync=args.get("defer_fsdp_grad_sync", True), + enable_async_tensor_parallel=args.get("enable_async_tensor_parallel", False), + enable_compile=args.get("enable_compile", False), + enable_fsdp2_prefetch=args.get("enable_fsdp2_prefetch", False), + fsdp2_backward_prefetch_depth=args.get("fsdp2_backward_prefetch_depth", 2), + fsdp2_forward_prefetch_depth=args.get("fsdp2_forward_prefetch_depth", 1), + ), + parallelism_sizes=parallelism, activation_checkpointing=args.get("activation_checkpointing", False), - mp_policy=args.get("mp_policy", None), - backend=args.get("backend", "nccl"), - sequence_parallel=args.get("sequence_parallel", False), - tp_plan=args.get("tp_plan", None), - patch_is_packed_sequence=args.get("patch_is_packed_sequence", False), - offload_policy=args.get("offload_policy", None), - defer_fsdp_grad_sync=args.get("defer_fsdp_grad_sync", True), - enable_async_tensor_parallel=args.get("enable_async_tensor_parallel", False), - enable_compile=args.get("enable_compile", False), - enable_fsdp2_prefetch=args.get("enable_fsdp2_prefetch", False), - fsdp2_backward_prefetch_depth=args.get("fsdp2_backward_prefetch_depth", 2), - fsdp2_forward_prefetch_depth=args.get("fsdp2_forward_prefetch_depth", 1), - ) - - world_size = args.get("world_size") or torch.distributed.get_world_size() - device_mesh, moe_mesh = create_device_mesh( - config, - dp_size=args.get("dp_size"), - dp_replicate_size=args.get("dp_replicate_size"), - tp_size=args.get("tp_size", 1), - pp_size=args.get("pp_size", 1), - cp_size=args.get("cp_size", 1), - ep_size=args.get("ep_size", 1), world_size=world_size, ) - logger.info("[Parallel] Creating FSDP2Manager with config: %s", config) - return FSDP2Manager(config, device_mesh=device_mesh, moe_mesh=moe_mesh) + mesh_context = distributed_setup.mesh_context + logger.info("[Parallel] Creating FSDP2Manager with config: %s", distributed_setup.strategy_config) + return FSDP2Manager( + distributed_setup.strategy_config, + device_mesh=mesh_context.device_mesh, + moe_mesh=mesh_context.moe_mesh, + ) else: raise ValueError(f"Unknown manager type: '{manager_type}'. Expected 'ddp' or 'fsdp2'.") diff --git a/nemo_automodel/_transformers/auto_model.py b/nemo_automodel/_transformers/auto_model.py index f08f6a2145..4aaf0c8e01 100644 --- a/nemo_automodel/_transformers/auto_model.py +++ b/nemo_automodel/_transformers/auto_model.py @@ -51,15 +51,11 @@ from transformers.models.auto.auto_factory import _BaseAutoModelClass # noqa: E402 from transformers.utils import ContextManagers # noqa: E402 -from nemo_automodel.components.distributed.config import ( # noqa: E402 - DistributedConfig, -) +from nemo_automodel.components.distributed.config import DistributedSetup # noqa: E402 from nemo_automodel.components.distributed.ddp import DDPManager # noqa: E402 from nemo_automodel.components.distributed.init_utils import get_world_size_safe # noqa: E402 from nemo_automodel.components.distributed.megatron_fsdp import MegatronFSDPManager # noqa: E402 from nemo_automodel.components.distributed.pipelining.autopipeline import AutoPipeline # noqa: E402, F401 -from nemo_automodel.components.distributed.pipelining.config import PipelineConfig # noqa: E402 -from nemo_automodel.components.moe.config import MoEParallelizerConfig # noqa: E402 from nemo_automodel.components.quantization.qat import QATConfig # noqa: E402 from nemo_automodel.components.utils.model_utils import ( # noqa: E402 init_empty_weights, @@ -112,6 +108,40 @@ _MAX_BUILD_RETRIES = 5 _remote_code_compat_applied = False +_DISTRIBUTED_SETUP_ONLY_KWARGS = { + "moe_mesh", + "distributed_config", + "pipeline_config", + "moe_config", + "activation_checkpointing", + "tp_plan", +} + + +def _reject_separate_distributed_kwargs(kwargs: dict) -> None: + provided = sorted(_DISTRIBUTED_SETUP_ONLY_KWARGS & set(kwargs)) + if provided: + raise TypeError( + "Distributed settings must be passed with distributed_setup; " + f"separate distributed kwargs are not accepted: {provided}" + ) + + +def _resolve_distributed_setup( + *, + distributed_setup: Optional[DistributedSetup], + device_mesh: Optional["DeviceMesh"] = None, +) -> DistributedSetup: + """Return a setup, upcasting raw mesh inputs into topology-only setup.""" + if distributed_setup is not None: + if device_mesh is not None: + raise ValueError("Pass either distributed_setup or device_mesh, not both") + return distributed_setup + + if isinstance(device_mesh, MeshContext): + raise TypeError("device_mesh expects a DeviceMesh; pass DistributedSetup for MeshContext or MoE topology") + + return DistributedSetup(mesh_context=MeshContext.from_meshes(device_mesh)) def _patch_remote_code_compat(): @@ -572,14 +602,9 @@ def from_pretrained( attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION, quantization_config=None, force_hf: bool = False, + distributed_setup: Optional[DistributedSetup] = None, device_mesh: Optional["DeviceMesh"] = None, - moe_mesh: Optional["DeviceMesh"] = None, - tp_plan: Optional[dict] = None, - distributed_config: Optional[DistributedConfig] = None, - pipeline_config: Optional[PipelineConfig] = None, qat_config: Optional[QATConfig] = None, - moe_config: Optional[MoEParallelizerConfig] = None, - activation_checkpointing: bool = False, peft_config: Optional[dict] = None, fp8_config: Optional["FP8Config"] = None, compile_config: Optional["CompileConfig"] = None, @@ -621,23 +646,15 @@ def from_pretrained( will be applied to the model. force_hf (bool, default=False): If `True`, force the use of HF model implementation. If `False`, the model will be loaded using the custom model implementation if available. - device_mesh (DeviceMesh | None, optional): Pre-created device mesh for - distributed training. Parallelism sizes (tp, pp, cp, ep) are inferred - from this. Default: None. - moe_mesh (DeviceMesh | None, optional): FSDP2-only. Device mesh for expert - parallelism. ep_size is inferred from this. Default: None. - tp_plan (dict | None, optional): Custom tensor parallel plan. If provided, - overrides the tp_plan on distributed_config. Default: None. - distributed_config (FSDP2Config | MegatronFSDPConfig | DDPConfig | None, optional): - Strategy-specific distributed training configuration. Default: None. - pipeline_config (PipelineConfig | None, optional): Pipeline parallelism - configuration including loss_fn. Default: None. + distributed_setup (DistributedSetup | None, optional): Resolved distributed + topology and policy object. Default: None. + device_mesh (DeviceMesh | None, optional): Pre-created Hugging Face-style + device mesh. NeMo wraps it in a topology-only ``DistributedSetup`` + internally. Use ``distributed_setup`` when passing NeMo-specific + policies such as strategy, pipeline, MoE, or activation checkpointing. + Default: None. qat_config (QATConfig | None, optional): Quantization-Aware Training configuration. Default: None. - moe_config (MoEParallelizerConfig | None, optional): MoE parallelizer - configuration. Default: None. - activation_checkpointing (bool, default=False): Enable activation checkpointing - for transformer blocks to reduce memory usage. Default: False. peft_config (dict | None, optional): PEFT/LoRA configuration dictionary. If provided, LoRA adapters will be applied to the model. Default: None. fp8_config (FP8Config | None, optional): FP8 quantization configuration. @@ -652,16 +669,22 @@ def from_pretrained( transformers.PreTrainedModel: The loaded (and possibly patched) model instance with all infrastructure applied. """ - if tp_plan is not None and distributed_config is not None: - distributed_config.tp_plan = tp_plan - - mesh = MeshContext.from_meshes(device_mesh, moe_mesh) + _reject_separate_distributed_kwargs(kwargs) + setup = _resolve_distributed_setup( + distributed_setup=distributed_setup, + device_mesh=device_mesh, + ) + mesh = setup.mesh_context + distributed_config = setup.strategy_config + pipeline_config = setup.pipeline_config + moe_parallel_config = setup.moe_parallel_config + activation_checkpointing = setup.activation_checkpointing model_wrapper, autopipeline, parallelize_fn, qat_quantizer = instantiate_infrastructure( distributed_config=distributed_config, pipeline_config=pipeline_config, qat_config=qat_config, - moe_config=moe_config, + moe_parallel_config=moe_parallel_config, activation_checkpointing=activation_checkpointing, device=torch.device("cuda", torch.cuda.current_device()), mesh=mesh, @@ -679,7 +702,7 @@ def from_pretrained( raise is_hf_model = get_is_hf_model(hf_config, force_hf) - sdpa_method = resolve_sdpa_method(sdpa_method, device_mesh, activation_checkpointing) + sdpa_method = resolve_sdpa_method(sdpa_method, mesh.device_mesh, activation_checkpointing) return cls._build_model( pretrained_model_name_or_path, @@ -717,14 +740,9 @@ def from_config( attn_implementation: str = DEFAULT_ATTN_IMPLEMENTATION, quantization_config=None, force_hf: bool = False, + distributed_setup: Optional[DistributedSetup] = None, device_mesh: Optional["DeviceMesh"] = None, - moe_mesh: Optional["DeviceMesh"] = None, - tp_plan: Optional[dict] = None, - distributed_config: Optional[DistributedConfig] = None, - pipeline_config: Optional[PipelineConfig] = None, qat_config: Optional[QATConfig] = None, - moe_config: Optional[MoEParallelizerConfig] = None, - activation_checkpointing: bool = False, peft_config: Optional[dict] = None, fp8_config: Optional["FP8Config"] = None, compile_config: Optional["CompileConfig"] = None, @@ -744,10 +762,16 @@ def from_config( torch_dtype (str | torch.dtype, default="auto"): Data type for model parameters. If "auto", defaults to ``torch.bfloat16``. """ - if tp_plan is not None and distributed_config is not None: - distributed_config.tp_plan = tp_plan - - mesh = MeshContext.from_meshes(device_mesh, moe_mesh) + _reject_separate_distributed_kwargs(kwargs) + setup = _resolve_distributed_setup( + distributed_setup=distributed_setup, + device_mesh=device_mesh, + ) + mesh = setup.mesh_context + distributed_config = setup.strategy_config + pipeline_config = setup.pipeline_config + moe_parallel_config = setup.moe_parallel_config + activation_checkpointing = setup.activation_checkpointing # Only instantiate infrastructure when distributed_config is provided model_wrapper = autopipeline = parallelize_fn = qat_quantizer = None @@ -757,7 +781,7 @@ def from_config( distributed_config=distributed_config, pipeline_config=pipeline_config, qat_config=qat_config, - moe_config=moe_config, + moe_parallel_config=moe_parallel_config, activation_checkpointing=activation_checkpointing, device=torch.device("cuda", torch.cuda.current_device()), mesh=mesh, @@ -783,7 +807,7 @@ def from_config( _consume_config_overrides(config, kwargs) is_hf_model = get_is_hf_model(config, force_hf) - sdpa_method = resolve_sdpa_method(sdpa_method, device_mesh, activation_checkpointing) + sdpa_method = resolve_sdpa_method(sdpa_method, mesh.device_mesh, activation_checkpointing) return cls._build_model( config, @@ -823,8 +847,6 @@ class NeMoAutoModelForCausalLM(_BaseNeMoAutoModelClass, AutoModelForCausalLM): functional model. - TODO(@akoumpa): extend this beyond liger_kernel. - Notes: ----- - No changes are made to the model's public API; forward signatures, @@ -853,8 +875,6 @@ class NeMoAutoModelForImageTextToText(_BaseNeMoAutoModelClass, AutoModelForImage functional model. - @akoumpa: currently only supporting liger_kernel for demonstration purposes. - Notes: ----- - No changes are made to the model's public API; forward signatures, @@ -983,11 +1003,8 @@ def from_pretrained( use_sdpa_patching: bool = True, sdpa_method: Optional[List[SDPBackend]] = None, torch_dtype="auto", + distributed_setup: Optional[DistributedSetup] = None, device_mesh: Optional["DeviceMesh"] = None, - moe_mesh: Optional["DeviceMesh"] = None, - tp_plan: Optional[dict] = None, - distributed_config: Optional[DistributedConfig] = None, - moe_config: Optional[MoEParallelizerConfig] = None, compile_config: Optional["CompileConfig"] = None, peft_config: Optional[dict] = None, **kwargs, @@ -1008,11 +1025,9 @@ def from_pretrained( use_sdpa_patching: Whether to apply SDPA patching. sdpa_method: SDPA backend methods to use. torch_dtype: Data type passed to the underlying model initialization. - device_mesh: Pre-created device mesh for distributed training. - moe_mesh: Device mesh for expert parallelism (FSDP2 only). - tp_plan: Custom tensor parallel plan; overrides distributed_config.tp_plan. - distributed_config: Strategy-specific distributed training configuration. - moe_config: MoE parallelizer configuration. + distributed_setup: Resolved distributed topology and policy object. + device_mesh: Pre-created Hugging Face-style device mesh. NeMo wraps it + in a topology-only ``DistributedSetup`` internally. compile_config: Configuration for torch.compile. peft_config: PEFT/LoRA configuration dictionary. **kwargs: Additional arguments passed to the encoder's ``build()`` method. @@ -1023,6 +1038,7 @@ def from_pretrained( Notes: If kernel patching fails, the method retries with adjusted parameters. """ + _reject_separate_distributed_kwargs(kwargs) from nemo_automodel._transformers import retrieval as _enc_mod encoder_cls = getattr(_enc_mod, cls._ENCODER_CLS_NAME) @@ -1037,11 +1053,8 @@ def _retry(**override): use_sdpa_patching=override.get("use_sdpa_patching", use_sdpa_patching), sdpa_method=sdpa_method, torch_dtype=torch_dtype, + distributed_setup=distributed_setup, device_mesh=device_mesh, - moe_mesh=moe_mesh, - tp_plan=tp_plan, - distributed_config=distributed_config, - moe_config=moe_config, compile_config=compile_config, peft_config=peft_config, **kwargs, @@ -1052,16 +1065,21 @@ def _retry(**override): build_kwargs.pop("cp_size", None) build_kwargs.pop("has_packed_sequence", None) - if tp_plan is not None and distributed_config is not None: - distributed_config.tp_plan = tp_plan - - mesh = MeshContext.from_meshes(device_mesh, moe_mesh) + setup = _resolve_distributed_setup( + distributed_setup=distributed_setup, + device_mesh=device_mesh, + ) + mesh = setup.mesh_context + distributed_config = setup.strategy_config + moe_parallel_config = setup.moe_parallel_config + activation_checkpointing = setup.activation_checkpointing model_wrapper, autopipeline, parallelize_fn, qat_quantizer = instantiate_infrastructure( distributed_config=distributed_config, pipeline_config=None, qat_config=None, - moe_config=moe_config, + moe_parallel_config=moe_parallel_config, + activation_checkpointing=activation_checkpointing, device=torch.device("cuda", torch.cuda.current_device()), mesh=mesh, ) @@ -1130,7 +1148,7 @@ class NeMoAutoModelBiEncoder(_NeMoAutoModelForRetrievalBase): ... "meta-llama/Llama-3.2-1B", ... pooling="cls", ... l2_normalize=False, - ... distributed_config=FSDP2Config(), + ... distributed_setup=distributed_setup, ... ) """ @@ -1177,7 +1195,7 @@ class NeMoAutoModelCrossEncoder(_NeMoAutoModelForRetrievalBase): >>> model = NeMoAutoModelCrossEncoder.from_pretrained("meta-llama/Llama-3.2-1B") >>> model = NeMoAutoModelCrossEncoder.from_pretrained( ... "meta-llama/Llama-3.2-1B", - ... distributed_config=FSDP2Config(), + ... distributed_setup=distributed_setup, ... ) """ diff --git a/nemo_automodel/_transformers/infrastructure.py b/nemo_automodel/_transformers/infrastructure.py index 496d1fe1f3..9ba6c43075 100644 --- a/nemo_automodel/_transformers/infrastructure.py +++ b/nemo_automodel/_transformers/infrastructure.py @@ -25,6 +25,7 @@ import logging from contextlib import nullcontext +from dataclasses import is_dataclass, replace from functools import partial from typing import TYPE_CHECKING, Optional, Union @@ -40,9 +41,10 @@ ) from nemo_automodel.components.distributed.config import ( DDPConfig, - DistributedConfig, + DistributedStrategyConfig, FSDP2Config, MegatronFSDPConfig, + MoEParallelizerConfig, ) from nemo_automodel.components.distributed.ddp import DDPManager from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager @@ -52,7 +54,6 @@ from nemo_automodel.components.distributed.pipelining.autopipeline import AutoPipeline from nemo_automodel.components.distributed.pipelining.config import PipelineConfig from nemo_automodel.components.loss.masked_ce import MaskedCrossEntropy -from nemo_automodel.components.moe.config import MoEParallelizerConfig from nemo_automodel.components.quantization.fp8 import apply_fp8_to_model from nemo_automodel.components.quantization.qat import QATConfig from nemo_automodel.components.utils.compile_utils import compile_model @@ -67,7 +68,6 @@ ) if TYPE_CHECKING: - from torch.distributed.device_mesh import DeviceMesh from torchao.quantization.qat.linear import Int4WeightOnlyQATQuantizer, Int8DynActInt4WeightQATQuantizer logger = logging.getLogger(__name__) @@ -149,7 +149,7 @@ def _shard_ep_fsdp(model, model_wrapper, parallelize_fn, mesh: MeshContext): # Infrastructure instantiation (config -> runtime objects) def _instantiate_distributed( - config: DistributedConfig, + config: DistributedStrategyConfig | None, mesh: MeshContext, ) -> Union[FSDP2Manager, MegatronFSDPManager, DDPManager, None]: """Instantiate the appropriate distributed manager from config. @@ -181,6 +181,20 @@ def _instantiate_distributed( raise ValueError(f"Unknown distributed config type: {type(config)}") +def _with_activation_checkpointing( + config: Optional[DistributedStrategyConfig], + activation_checkpointing: bool, +) -> Optional[DistributedStrategyConfig]: + """Return a strategy config whose AC flag matches the resolved setup.""" + if config is None or not hasattr(config, "activation_checkpointing"): + return config + if getattr(config, "activation_checkpointing") is activation_checkpointing: + return config + if not is_dataclass(config): + return config + return replace(config, activation_checkpointing=activation_checkpointing) + + def _instantiate_pipeline( config: Optional[PipelineConfig], mesh: MeshContext, @@ -193,8 +207,8 @@ def _instantiate_pipeline( config: Pipeline config. If None or pp_size <= 1, returns None. mesh: MeshContext holding device_mesh, moe_mesh, and axis names. device: Target device for pipeline computation. - strategy_config: Strategy config fallback when ``mesh`` was rebuilt from - raw device meshes and no longer carries the recipe-level config. + strategy_config: Strategy config used to route distributed policy into + pipeline setup. Returns: AutoPipeline instance, or None if pipeline parallelism is not enabled. @@ -207,7 +221,6 @@ def _instantiate_pipeline( # Route the existing FSDP2Config.defer_fsdp_grad_sync into the pipeline so # the same knob controls grad-sync behavior under PP. - strategy_config = getattr(mesh, "strategy_config", None) or strategy_config if strategy_config is not None and hasattr(strategy_config, "defer_fsdp_grad_sync"): config_dict.setdefault("defer_fsdp_grad_sync", strategy_config.defer_fsdp_grad_sync) @@ -256,17 +269,13 @@ def parallelize_for_pp( def instantiate_infrastructure( *, - distributed_config: Optional[DistributedConfig] = None, + distributed_config: Optional[DistributedStrategyConfig] = None, pipeline_config: Optional[PipelineConfig] = None, qat_config: Optional[QATConfig] = None, - moe_config: Optional[MoEParallelizerConfig] = None, - activation_checkpointing: bool = False, + moe_parallel_config: Optional[MoEParallelizerConfig] = None, + activation_checkpointing: Optional[bool] = None, device: Optional[torch.device] = None, mesh: Optional[MeshContext] = None, - # Deprecated -- prefer passing ``mesh`` directly - device_mesh: Optional["DeviceMesh"] = None, - moe_mesh: Optional["DeviceMesh"] = None, - ep_size: int = 1, ) -> tuple: """Instantiate infrastructure objects from config classes. @@ -279,15 +288,11 @@ def instantiate_infrastructure( or DDPConfig). pipeline_config: Pipeline parallelism config. qat_config: Quantization-aware training config. - moe_config: MoE parallelizer config (for expert parallel models). + moe_parallel_config: MoE parallelizer config (for expert parallel models). activation_checkpointing: Enable activation checkpointing for transformer blocks. - Defaults to False. + If ``None``, inferred from ``distributed_config.activation_checkpointing``. device: Target device for model. mesh: MeshContext holding device meshes, sizes, and axis names. - If None, built from the legacy ``device_mesh`` / ``moe_mesh`` params. - device_mesh: (deprecated) Device mesh for distributed operations. - moe_mesh: (deprecated) Optional MOE mesh for expert parallelism. - ep_size: (deprecated) Expert parallelism size. Ignored when ``mesh`` is provided. Returns: tuple: (model_wrapper, autopipeline, parallelize_fn, qat_quantizer) @@ -298,19 +303,25 @@ def instantiate_infrastructure( - qat_quantizer: QAT quantizer instance (or None) """ if mesh is None: - mesh = MeshContext.from_meshes(device_mesh, moe_mesh) + mesh = MeshContext() - ep_size = mesh.ep_size if mesh.ep_size > 1 else ep_size + if activation_checkpointing is None: + activation_checkpointing = bool(getattr(distributed_config, "activation_checkpointing", False)) + distributed_config = _with_activation_checkpointing(distributed_config, activation_checkpointing) model_wrapper = _instantiate_distributed(distributed_config, mesh) autopipeline = _instantiate_pipeline(pipeline_config, mesh, device, distributed_config) parallelize_fn = None - if ep_size > 1: + if mesh.ep_size > 1: from nemo_automodel.components.moe.parallelizer import parallelize_model + if moe_parallel_config is None: + moe_parallel_config = MoEParallelizerConfig() parallelize_fn = partial( - parallelize_model, activation_checkpointing=activation_checkpointing, **moe_config.to_dict() + parallelize_model, + activation_checkpointing=activation_checkpointing, + **moe_parallel_config.to_dict(), ) elif autopipeline is not None and model_wrapper is not None: parallelize_fn = partial(parallelize_for_pp, model_wrapper=model_wrapper) diff --git a/nemo_automodel/components/distributed/__init__.py b/nemo_automodel/components/distributed/__init__.py index d368b8c4a1..5ad6ee2871 100644 --- a/nemo_automodel/components/distributed/__init__.py +++ b/nemo_automodel/components/distributed/__init__.py @@ -12,7 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config, MegatronFSDPConfig +from nemo_automodel.components.distributed.config import ( + DDPConfig, + DistributedSetup, + FSDP2Config, + MegatronFSDPConfig, + MoEParallelizerConfig, +) +from nemo_automodel.components.distributed.init_utils import DistInfo, initialize_distributed +from nemo_automodel.components.distributed.mesh import MeshContext, ParallelismSizes from nemo_automodel.components.distributed.pipelining.config import PipelineConfig -__all__ = ["FSDP2Config", "MegatronFSDPConfig", "DDPConfig", "PipelineConfig"] +__all__ = [ + "DDPConfig", + "DistributedSetup", + "DistInfo", + "FSDP2Config", + "MegatronFSDPConfig", + "MeshContext", + "MoEParallelizerConfig", + "ParallelismSizes", + "PipelineConfig", + "initialize_distributed", +] diff --git a/nemo_automodel/components/distributed/config.py b/nemo_automodel/components/distributed/config.py index eb3af0423d..d97989b656 100644 --- a/nemo_automodel/components/distributed/config.py +++ b/nemo_automodel/components/distributed/config.py @@ -16,8 +16,8 @@ Strategy-specific distributed training configuration classes. Design principle: -- Size params (dp_size, dp_replicate_size, tp_size, pp_size, cp_size, ep_size) go directly - on the from_pretrained/from_config method signature +- Size params (dp_size, dp_replicate_size, tp_size, pp_size, cp_size, ep_size) + are grouped in ``ParallelismSizes``. - dp_replicate_size is FSDP2-only: raises assertion if passed with non-FSDP2 config - Strategy-specific configs contain only *additional* flags unique to each strategy - Managers become normal classes that accept (config, device_mesh) @@ -35,14 +35,101 @@ config = DDPConfig(activation_checkpointing=True) """ -from dataclasses import InitVar, dataclass, fields -from typing import Any, Dict, List, Optional, Union +from __future__ import annotations + +from dataclasses import dataclass, field, fields +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union import torch from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy -# Type alias for API signature -DistributedConfig = Union["FSDP2Config", "MegatronFSDPConfig", "DDPConfig"] +if TYPE_CHECKING: + from nemo_automodel.components.distributed.mesh import MeshContext, ParallelismSizes + from nemo_automodel.components.distributed.pipelining.config import PipelineConfig + +# Type aliases for API signatures. +DistributedStrategyConfig = Union["FSDP2Config", "MegatronFSDPConfig", "DDPConfig"] + + +@dataclass(frozen=True) +class DistributedSetup: + """Resolved distributed topology and execution policies.""" + + mesh_context: "MeshContext" + strategy_config: DistributedStrategyConfig | None = None + pipeline_config: "PipelineConfig | None" = None + moe_parallel_config: "MoEParallelizerConfig | None" = None + activation_checkpointing: bool = False + + @classmethod + def build( + cls, + strategy: str | DistributedStrategyConfig = "fsdp2", + parallelism_sizes: "ParallelismSizes | None" = None, + pipeline_config: "PipelineConfig | dict | None" = None, + moe_parallel_config: "MoEParallelizerConfig | dict | None" = None, + activation_checkpointing: bool = False, + world_size: int | None = None, + ) -> "DistributedSetup": + """Create a resolved distributed setup from sizes and policy configs. + + Intentionally, this function is forgiving wrt the input types, allowing + strings for the strategy and dicts for the pipeline and MoE configs. + """ + from nemo_automodel.components.distributed.init_utils import get_world_size_safe + from nemo_automodel.components.distributed.mesh import MeshContext, ParallelismSizes + from nemo_automodel.components.distributed.pipelining.config import PipelineConfig + + if world_size is None: + world_size = get_world_size_safe() + + strategy_config = _resolve_strategy_config(strategy) + + if parallelism_sizes is None: + parallelism_sizes = ParallelismSizes() + + pp_size = parallelism_sizes.pp_size + ep_size = parallelism_sizes.ep_size + if pipeline_config is not None and pp_size <= 1: + raise ValueError("pipeline_config requires pp_size > 1") + if moe_parallel_config is not None and ep_size <= 1: + raise ValueError("moe_parallel_config requires ep_size > 1") + if pp_size > 1 and pipeline_config is None: + pipeline_config = PipelineConfig() + if isinstance(pipeline_config, dict): + pipeline_config = PipelineConfig(**pipeline_config) + if ep_size > 1 and moe_parallel_config is None: + moe_parallel_config = MoEParallelizerConfig() + if isinstance(moe_parallel_config, dict): + moe_parallel_config = MoEParallelizerConfig(**moe_parallel_config) + + mesh_context = MeshContext.build( + strategy_config, + parallelism_sizes=parallelism_sizes, + world_size=world_size, + ) + + return cls( + mesh_context=mesh_context, + strategy_config=strategy_config, + pipeline_config=pipeline_config, + moe_parallel_config=moe_parallel_config, + activation_checkpointing=activation_checkpointing, + ) + + +@dataclass +class MoEParallelizerConfig: + """Configuration for MoE model parallelization (EP + FSDP settings).""" + + ignore_router_for_ac: bool = False + reshard_after_forward: bool = False + lm_head_precision: Optional[Union[str, torch.dtype]] = None + wrap_outer_model: bool = True + mp_policy: Optional[MixedPrecisionPolicy] = None + + def to_dict(self) -> Dict[str, Any]: + return {f.name: getattr(self, f.name) for f in fields(self)} @dataclass @@ -51,7 +138,7 @@ class FSDP2Config: Additional configuration for FSDP2 distributed training. Note: Size parameters (dp_size, dp_replicate_size, tp_size, pp_size, cp_size, ep_size) - are passed separately on the from_pretrained/from_config method signature. + are grouped separately in ``ParallelismSizes``. Attributes: sequence_parallel (bool): Enable sequence parallelism in TP plan. @@ -79,7 +166,6 @@ class FSDP2Config: Can be set from YAML as a string (e.g. ``autocast_dtype: bfloat16``). activation_checkpointing (bool): Enable activation checkpointing. defer_fsdp_grad_sync (bool): Defer FSDP gradient sync to final micro-batch. - backend (str): Distributed backend. enable_async_tensor_parallel (bool): Enable async tensor parallelism via ``torch._inductor.config._micro_pipeline_tp``. Overlaps ReduceScatter with compute in row-parallel layers. Requires ``sequence_parallel=True`` (forced @@ -101,27 +187,24 @@ class FSDP2Config: sequence_parallel: bool = False tp_plan: Optional[dict] = None patch_is_packed_sequence: bool = False - mp_policy: Optional[MixedPrecisionPolicy] = None + mp_policy: Optional[MixedPrecisionPolicy] = field( + default_factory=lambda: MixedPrecisionPolicy( + param_dtype=torch.bfloat16, + reduce_dtype=torch.float32, + output_dtype=torch.bfloat16, + cast_forward_inputs=True, + ) + ) offload_policy: Optional[CPUOffloadPolicy] = None autocast_dtype: Optional[torch.dtype] = None activation_checkpointing: bool = False defer_fsdp_grad_sync: bool = True - backend: str = "nccl" enable_async_tensor_parallel: bool = False enable_compile: bool = False enable_fsdp2_prefetch: bool = False fsdp2_backward_prefetch_depth: int = 2 fsdp2_forward_prefetch_depth: int = 1 - def __post_init__(self): - if self.mp_policy is None: - self.mp_policy = MixedPrecisionPolicy( - param_dtype=torch.bfloat16, - reduce_dtype=torch.float32, - output_dtype=torch.bfloat16, - cast_forward_inputs=True, - ) - def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary (shallow, preserves policy objects).""" return {f.name: getattr(self, f.name) for f in fields(self)} @@ -132,13 +215,11 @@ class MegatronFSDPConfig: """ Additional configuration for MegatronFSDP distributed training. - Note: Size parameters (dp_size, tp_size, cp_size) are passed separately on - the from_pretrained/from_config method signature. MegatronFSDP does not + Note: Size parameters (dp_size, tp_size, cp_size) are grouped separately in + ``ParallelismSizes``. MegatronFSDP does not support pp_size, dp_replicate_size, or ep_size. Attributes: - sequence_parallel (bool): Enable sequence parallelism in TP plan. - Note: Not supported with MegatronFSDP right now. megatron_fsdp_unit_modules (Optional[List[str]]): List of unit modules to be wrapped with MegatronFSDP. zero_dp_strategy (int): Data parallel sharding strategy. @@ -156,12 +237,11 @@ class MegatronFSDPConfig: fsdp_double_buffer (bool): Use double buffer if True. activation_checkpointing (bool): Enable activation checkpointing for transformer MLP layers to save memory. - backend (str): Distributed backend, e.g. 'nccl' or 'gloo'. """ - sequence_parallel: bool = False - tp_plan: InitVar[Optional[dict]] = None - megatron_fsdp_unit_modules: Optional[List[str]] = None + megatron_fsdp_unit_modules: List[str] = field( + default_factory=lambda: ["transformers.models.llama.modeling_llama.LlamaDecoderLayer"] + ) zero_dp_strategy: int = 3 init_fsdp_with_meta_device: bool = False grad_reduce_in_fp32: bool = False @@ -176,13 +256,6 @@ class MegatronFSDPConfig: nccl_ub: bool = False fsdp_double_buffer: bool = False activation_checkpointing: bool = False - backend: str = "nccl" - - def __post_init__(self, tp_plan: Optional[dict]): - if tp_plan is not None: - raise ValueError("MegatronFSDPConfig does not support custom TP plans. Use FSDP2Config instead.") - if self.megatron_fsdp_unit_modules is None: - self.megatron_fsdp_unit_modules = ["transformers.models.llama.modeling_llama.LlamaDecoderLayer"] def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary (shallow, preserves objects).""" @@ -199,12 +272,55 @@ class DDPConfig: Attributes: activation_checkpointing (bool): Enable activation checkpointing if True. - backend (str): Distributed backend, e.g. 'nccl' or 'gloo'. """ activation_checkpointing: bool = False - backend: str = "nccl" def to_dict(self) -> Dict[str, Any]: """Convert config to dictionary.""" return {f.name: getattr(self, f.name) for f in fields(self)} + + +_StrategyConfigClass = type[FSDP2Config] | type[MegatronFSDPConfig] | type[DDPConfig] +_STRATEGY_MAP: Dict[str, _StrategyConfigClass] = { + "fsdp2": FSDP2Config, + "megatron_fsdp": MegatronFSDPConfig, + "megatron-fsdp": MegatronFSDPConfig, + "mfsdp": MegatronFSDPConfig, + "ddp": DDPConfig, +} + + +def _resolve_strategy_config( + strategy: str | DistributedStrategyConfig, + **strategy_kwargs: Any, +) -> DistributedStrategyConfig: + """Resolve a setup-level strategy name or config object.""" + if isinstance(strategy, (FSDP2Config, MegatronFSDPConfig, DDPConfig)): + if strategy_kwargs: + raise ValueError("Strategy kwargs cannot be passed with an instantiated strategy config.") + return strategy + + if not isinstance(strategy, str): + raise ValueError(f"Unknown distributed strategy type: {type(strategy)}") + + strategy_name = strategy.lower() + if strategy_name not in _STRATEGY_MAP: + valid = sorted(_STRATEGY_MAP) + raise ValueError(f"Unknown strategy: {strategy}. Valid strategies: {valid}") + strategy_cls = _STRATEGY_MAP[strategy_name] + valid_fields = {f.name for f in fields(strategy_cls)} + unknown = set(strategy_kwargs) - valid_fields + if unknown: + raise ValueError(f"Unknown options for strategy '{strategy_name}': {sorted(unknown)}") + return strategy_cls(**strategy_kwargs) + + +__all__ = [ + "DDPConfig", + "DistributedSetup", + "DistributedStrategyConfig", + "FSDP2Config", + "MegatronFSDPConfig", + "MoEParallelizerConfig", +] diff --git a/nemo_automodel/components/distributed/ddp.py b/nemo_automodel/components/distributed/ddp.py index 91d064b6d3..8b8034e9a9 100644 --- a/nemo_automodel/components/distributed/ddp.py +++ b/nemo_automodel/components/distributed/ddp.py @@ -50,7 +50,6 @@ def __init__(self, config: DDPConfig): # Extract config fields for easy access self.activation_checkpointing = config.activation_checkpointing - self.backend = config.backend # Setup distributed environment self._setup_distributed() @@ -59,7 +58,7 @@ def _setup_distributed(self): """ Initialize device configuration for DDP. - Sets the rank, world_size, and device based on the backend. + Sets the rank, world_size, and device based on the process group backend. """ if not dist.is_available(): raise RuntimeError("torch.distributed not available") @@ -70,8 +69,8 @@ def _setup_distributed(self): self.rank = dist.get_rank() self.world_size = dist.get_world_size() - # Pin GPU if using NCCL - if self.backend == "nccl": + backend = str(dist.get_backend()).lower() + if "nccl" in backend and torch.cuda.is_available(): local_gpu = self.rank % torch.cuda.device_count() torch.cuda.set_device(local_gpu) self.device = torch.device("cuda", index=local_gpu) @@ -93,7 +92,9 @@ def parallelize(self, model): """ if dist.get_world_size() == 1: logger.info("World size is 1, skipping parallelization.") - model = model.to("cuda").to(torch.bfloat16) + model = model.to(self.device) + if self.device.type == "cuda": + model = model.to(torch.bfloat16) if self.activation_checkpointing: if hasattr(model, "gradient_checkpointing_enable"): model.gradient_checkpointing_enable() diff --git a/nemo_automodel/components/distributed/device_mesh.py b/nemo_automodel/components/distributed/device_mesh.py deleted file mode 100644 index 59d2513d6e..0000000000 --- a/nemo_automodel/components/distributed/device_mesh.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Device mesh creation utilities for distributed training. - -This module provides a central function to create device meshes based on the -distributed config type (FSDP2, MegatronFSDP, or DDP). - -Usage: - from nemo_automodel.components.distributed.config import FSDP2Config - from nemo_automodel.components.distributed.device_mesh import create_device_mesh - - config = FSDP2Config(sequence_parallel=True) - device_mesh, moe_mesh = create_device_mesh( - config, - tp_size=2, - pp_size=1, - dp_replicate_size=2, - world_size=8, - ) -""" - -from typing import Optional, Tuple, Union - -from torch.distributed.device_mesh import DeviceMesh, init_device_mesh - -from nemo_automodel.components.distributed.config import ( - DDPConfig, - FSDP2Config, - MegatronFSDPConfig, -) -from nemo_automodel.components.distributed.mesh_utils import _unflatten_compat - - -def create_device_mesh( - distributed_config: Union[FSDP2Config, MegatronFSDPConfig, DDPConfig], - *, - dp_size: Optional[int] = None, - dp_replicate_size: Optional[int] = None, - tp_size: int = 1, - pp_size: int = 1, - cp_size: int = 1, - ep_size: int = 1, - world_size: int, -) -> Tuple[Optional[DeviceMesh], Optional[DeviceMesh]]: - """Create device mesh based on distributed config type. - - Routes to the appropriate mesh creation logic based on config type. - - Args: - distributed_config: The distributed config (FSDP2Config, MegatronFSDPConfig, - or DDPConfig). - dp_size: Data parallel size. If None, inferred from world_size and other - parallelism sizes. - dp_replicate_size: FSDP2-only. Size of the replication group for HSDP - (Hybrid Sharded Data Parallel). If None or <= 0, defaults to 1. - Must be a divisor of dp_size. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - cp_size: Context parallel size. - ep_size: Expert parallel size (for MoE models). - world_size: Total number of processes. - - Returns: - tuple: (device_mesh, moe_mesh) - - For FSDP2Config: Full device mesh + optional moe_mesh (if ep_size > 1) - - For MegatronFSDPConfig: Device mesh + None - - For DDPConfig: (None, None) - DDP doesn't use device mesh - - Raises: - ValueError: If dp_replicate_size is provided with non-FSDP2 config. - ValueError: If world_size is not divisible by parallelism sizes. - """ - # Validate FSDP2-only params - if dp_replicate_size is not None and dp_replicate_size > 1: - if not isinstance(distributed_config, FSDP2Config): - raise ValueError("dp_replicate_size is only supported with FSDP2Config") - - if isinstance(distributed_config, FSDP2Config): - return _create_fsdp2_device_mesh( - dp_size=dp_size, - dp_replicate_size=dp_replicate_size, - tp_size=tp_size, - pp_size=pp_size, - cp_size=cp_size, - ep_size=ep_size, - world_size=world_size, - backend=distributed_config.backend, - ) - elif isinstance(distributed_config, MegatronFSDPConfig): - mesh = _create_megatron_fsdp_device_mesh( - dp_size=dp_size, - tp_size=tp_size, - cp_size=cp_size, - world_size=world_size, - backend=distributed_config.backend, - ) - return mesh, None - elif isinstance(distributed_config, DDPConfig): - return None, None # DDP doesn't use device mesh - else: - raise ValueError(f"Unknown distributed config type: {type(distributed_config)}") - - -def _create_fsdp2_device_mesh( - dp_size: Optional[int], - dp_replicate_size: Optional[int], - tp_size: int, - pp_size: int, - cp_size: int, - ep_size: int, - world_size: int, - backend: str, -) -> Tuple[DeviceMesh, Optional[DeviceMesh]]: - """ - Create device mesh for FSDP2. - - Mesh shape: (pp_size, dp_replicate_size, dp_shard_size, cp_size, tp_size) - Mesh names: ("pp", "dp_replicate", "dp_shard", "cp", "tp") - - Also creates flattened submeshes: - - "dp": dp_replicate + dp_shard - - "dp_shard_cp": dp_shard + cp - - "dp_cp": dp_replicate + dp_shard + cp - - Args: - dp_size: Data parallel size. If None, inferred from world_size. - dp_replicate_size: Size of the replication group for HSDP. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - cp_size: Context parallel size. - ep_size: Expert parallel size (for MoE models). - world_size: Total number of processes. - backend: Distributed backend ('nccl' or 'gloo'). - - Returns: - tuple: (device_mesh, moe_mesh) - """ - # Normalize sizes - if tp_size is None or tp_size <= 0: - tp_size = 1 - if cp_size is None or cp_size <= 0: - cp_size = 1 - if pp_size is None or pp_size <= 0: - pp_size = 1 - if ep_size is None or ep_size <= 0: - ep_size = 1 - - # Infer dp_size if not provided - if dp_size is None or dp_size <= 0: - total_parallel_ranks = tp_size * cp_size * pp_size - if world_size % total_parallel_ranks != 0: - raise ValueError( - f"world_size ({world_size}) must be divisible by (tp_size * cp_size * pp_size) " - f"({tp_size} * {cp_size} * {pp_size} = {total_parallel_ranks})" - ) - dp_size = world_size // total_parallel_ranks - - if dp_replicate_size is None or dp_replicate_size <= 0: - dp_replicate_size = 1 - - # HSDP usecase: dp_size = dp_replicate_size * dp_shard_size - assert dp_size % dp_replicate_size == 0, "dp_size must be a multiple of dp_replicate_size" - assert dp_replicate_size < dp_size or dp_replicate_size == 1, ( - "dp_replicate_size must be less than dp_size since ddp usecase is not supported by FSDP2" - ) - - # Expert parallelism: EP spans all non-pp dims (dp, cp, tp) - non_pp_size = dp_size * cp_size * tp_size - assert non_pp_size % ep_size == 0, f"{non_pp_size=} must be a multiple of {ep_size=}" - if ep_size < non_pp_size: - ep_shard_size = non_pp_size // ep_size - else: - ep_shard_size = 1 - - dp_shard_size = dp_size // dp_replicate_size - - # Build main device mesh - mesh_shape = (pp_size, dp_replicate_size, dp_shard_size, cp_size, tp_size) - mesh_names = ("pp", "dp_replicate", "dp_shard", "cp", "tp") - for shape, name in zip(mesh_shape, mesh_names): - assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}" - assert shape > 0, f"Expected {name} > 0, got {shape}" - - device_mesh = init_device_mesh( - device_type="cuda" if backend == "nccl" else "cpu", - mesh_shape=mesh_shape, - mesh_dim_names=mesh_names, - ) - - # Create flattened submeshes - # Based on https://github.com/pytorch/torchtitan/blob/d282cf2ce9ca8049b4b8423c1d7578c80426576f/torchtitan/distributed/parallel_dims.py#L191 - dp_mesh_dim_names = [] # Mesh for data loading (no communication on this mesh) - dp_shard_cp_mesh_dim_names = [] # Mesh for param sharding - dp_cp_mesh_dim_names = [] # Mesh for loss all-reduce - - # for dp_replicate: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - # for dp_shard: - dp_mesh_dim_names.append("dp_shard") - dp_shard_cp_mesh_dim_names.append("dp_shard") - dp_cp_mesh_dim_names.append("dp_shard") - # for cp: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - - # Flatten submeshes - device_mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - device_mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") - device_mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - - # Derive EP mesh by flattening all non-pp dims and unflattening into (ep_shard, ep). - moe_mesh = None - if ep_size > 1: - non_pp_mesh = device_mesh[("dp_replicate", "dp_shard", "cp", "tp")]._flatten() - moe_mesh = _unflatten_compat( - non_pp_mesh, - 0, - (ep_shard_size, ep_size), - ("ep_shard", "ep"), - ) - - return device_mesh, moe_mesh - - -def _create_megatron_fsdp_device_mesh( - dp_size: Optional[int], - tp_size: int, - cp_size: int, - world_size: int, - backend: str, -) -> DeviceMesh: - """ - Create device mesh for MegatronFSDP. - - Mesh shape: (dp_size, cp_size, tp_size) - Mesh names: ("dp", "cp", "tp") - - Also creates flattened submesh "dp_cp" if cp_size > 1. - - Args: - dp_size: Data parallel size. If None, inferred from world_size. - tp_size: Tensor parallel size. - cp_size: Context parallel size. - world_size: Total number of processes. - backend: Distributed backend ('nccl' or 'gloo'). - - Returns: - DeviceMesh: The device mesh for MegatronFSDP. - """ - # Normalize sizes - tp_size = tp_size or 1 - cp_size = cp_size or 1 - - # Infer dp_size if not provided - if dp_size is None or dp_size <= 0: - total_parallel_ranks = tp_size * cp_size - if world_size % total_parallel_ranks != 0: - raise ValueError( - f"world_size ({world_size}) must be divisible by (tp_size * cp_size) " - f"({tp_size} * {cp_size} = {total_parallel_ranks})" - ) - dp_size = world_size // total_parallel_ranks - - mesh_shape = (dp_size, cp_size, tp_size) - mesh_names = ("dp", "cp", "tp") - for shape, name in zip(mesh_shape, mesh_names): - assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}" - assert shape > 0, f"Expected {name} > 0, got {shape}" - - # Build mesh [dp, cp, tp] - device_mesh = init_device_mesh( - device_type="cuda" if backend == "nccl" else "cpu", - mesh_shape=mesh_shape, - mesh_dim_names=mesh_names, - ) - - # Flatten dp+cp if cp > 1 - if cp_size > 1: - device_mesh[("dp", "cp")]._flatten(mesh_dim_name="dp_cp") - - return device_mesh diff --git a/nemo_automodel/components/distributed/fsdp2.py b/nemo_automodel/components/distributed/fsdp2.py index 91b231b251..3e44795495 100644 --- a/nemo_automodel/components/distributed/fsdp2.py +++ b/nemo_automodel/components/distributed/fsdp2.py @@ -73,7 +73,7 @@ class FSDP2Manager: from nemo_automodel.components.distributed.config import FSDP2Config config = FSDP2Config(sequence_parallel=True, activation_checkpointing=True) - # device_mesh created externally via create_device_mesh() + # device_mesh created externally via MeshContext.build() manager = FSDP2Manager(config, device_mesh=device_mesh, moe_mesh=moe_mesh) model = manager.parallelize(model) """ @@ -95,7 +95,6 @@ def __init__( self.offload_policy = config.offload_policy self.activation_checkpointing = config.activation_checkpointing self.defer_fsdp_grad_sync = config.defer_fsdp_grad_sync - self.backend = config.backend self.enable_async_tensor_parallel = config.enable_async_tensor_parallel self.enable_compile = config.enable_compile self.enable_fsdp2_prefetch = config.enable_fsdp2_prefetch diff --git a/nemo_automodel/components/distributed/megatron_fsdp.py b/nemo_automodel/components/distributed/megatron_fsdp.py index 1ba338b558..64df6c73ce 100644 --- a/nemo_automodel/components/distributed/megatron_fsdp.py +++ b/nemo_automodel/components/distributed/megatron_fsdp.py @@ -57,7 +57,7 @@ class MegatronFSDPManager: from nemo_automodel.components.distributed.config import MegatronFSDPConfig config = MegatronFSDPConfig(zero_dp_strategy=3, overlap_grad_reduce=True) - # device_mesh created externally via create_device_mesh() + # device_mesh created externally via MeshContext.build() manager = MegatronFSDPManager(config, device_mesh=device_mesh) model, optimizer = manager.parallelize(model, optimizer) """ @@ -71,7 +71,6 @@ def __init__( self.device_mesh = device_mesh # Extract config fields for easy access - self.sequence_parallel = config.sequence_parallel self.megatron_fsdp_unit_modules = config.megatron_fsdp_unit_modules self.zero_dp_strategy = config.zero_dp_strategy self.init_fsdp_with_meta_device = config.init_fsdp_with_meta_device @@ -87,7 +86,6 @@ def __init__( self.nccl_ub = config.nccl_ub self.fsdp_double_buffer = config.fsdp_double_buffer self.activation_checkpointing = config.activation_checkpointing - self.backend = config.backend def parallelize(self, model, optimizer=None): """ diff --git a/nemo_automodel/components/distributed/mesh.py b/nemo_automodel/components/distributed/mesh.py index ef6d4057d7..d90f7c3426 100644 --- a/nemo_automodel/components/distributed/mesh.py +++ b/nemo_automodel/components/distributed/mesh.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Typed MeshContext dataclass, validation, and strategy map. +"""MeshContext dataclass, construction, and validation. -``MeshContext`` is the single source of truth for everything related to -distributed training: strategy config, device meshes, and axis names. +``MeshContext`` is the single source of truth for distributed topology: +device meshes, parallelism sizes, and axis names. Parallelism sizes (``tp_size``, ``pp_size``, etc.) are derived at runtime from the attached ``DeviceMesh`` objects via ``@property``. When no mesh @@ -24,36 +24,22 @@ All inputs and outputs are typed Python objects (dataclasses, enums, etc.). YAML / dict parsing belongs in the recipe layer — see -``nemo_automodel.recipes._dist_setup``. +``nemo_automodel.recipes._dist_utils``. """ from dataclasses import dataclass, field from enum import Enum -from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union +from typing import TYPE_CHECKING, Dict, Optional, Tuple -from nemo_automodel.components.distributed.config import ( - DDPConfig, - FSDP2Config, - MegatronFSDPConfig, -) +from nemo_automodel.components.distributed.config import DistributedStrategyConfig +from nemo_automodel.components.distributed.init_utils import get_world_size_safe if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh - from nemo_automodel.components.distributed.pipelining.config import PipelineConfig - from nemo_automodel.components.moe.config import MoEParallelizerConfig - - -#: Maps strategy name (from YAML) → strategy dataclass. -STRATEGY_MAP: Dict[str, type] = { - "fsdp2": FSDP2Config, - "megatron_fsdp": MegatronFSDPConfig, - "ddp": DDPConfig, -} - class MeshAxisName(str, Enum): - """Canonical mesh-dimension names used by ``DeviceMesh`` and helpers. + """Canonical mesh axis names used by ``DeviceMesh`` and helpers. Inherits from ``str`` so each member compares equal to (and can be used wherever) a plain string — e.g. ``MeshAxisName.TP == "tp"``. @@ -71,56 +57,61 @@ class MeshAxisName(str, Enum): EP_SHARD = "ep_shard" -#: All values accepted as ``DeviceMesh`` dimension names. +#: All values accepted as ``DeviceMesh`` axis names. _VALID_AXIS_NAMES: frozenset = frozenset(MeshAxisName) +@dataclass(frozen=True, kw_only=True) +class ParallelismSizes: + """Build-time requested parallelism sizes. + + This is durable user intent, not runtime topology. ``MeshContext`` derives + its size properties from live ``DeviceMesh`` objects after build. + """ + + dp_size: int | None = None + dp_replicate_size: int | None = None + tp_size: int = 1 + pp_size: int = 1 + cp_size: int = 1 + ep_size: int = 1 + + @dataclass class MeshContext: - """Runtime distributed training context: configs + device meshes. + """Runtime distributed topology context. Parallelism sizes (``tp_size``, ``pp_size``, etc.) are **not** stored as fields; they are ``@property`` accessors that read directly from the attached ``DeviceMesh`` / ``moe_mesh``. When no mesh is present the properties return safe defaults (``1`` for sizes, ``None`` for dp / hsdp). - All ``DeviceMesh`` objects passed in must use dimension names from + All ``DeviceMesh`` objects passed in must use axis names from :class:`MeshAxisName`; a ``ValueError`` is raised on construction if any unknown name is encountered. Lifecycle --------- 1. Recipes parse YAML to obtain sizes and strategy configs. - 2. Sizes are passed to ``create_device_mesh`` to build ``DeviceMesh`` + 2. Sizes are passed to :meth:`build` to build ``DeviceMesh`` objects. - 3. ``MeshContext`` is created with those meshes; dimension names are + 3. ``MeshContext`` is created with those meshes; axis names are validated automatically in ``__post_init__``. Alternatively, :meth:`from_meshes` constructs an instance directly from ``DeviceMesh`` objects (used by ``NeMoAutoModel.from_pretrained``). Attributes: - strategy_config: Strategy-specific config (FSDP2, MegatronFSDP, or DDP). device_mesh: Device mesh for distributed training. moe_mesh: MoE-specific device mesh. - pipeline_config: Pipeline-parallel schedule/splitting config. - moe_config: MoE parallelizer settings. - activation_checkpointing: Whether activation checkpointing is enabled. """ - # config fields - strategy_config: Optional[Union["FSDP2Config", "MegatronFSDPConfig", "DDPConfig"]] = None - pipeline_config: Optional["PipelineConfig"] = None - moe_config: Optional["MoEParallelizerConfig"] = None - activation_checkpointing: bool = False - # runtime mesh references device_mesh: Optional["DeviceMesh"] = field(default=None, repr=False) moe_mesh: Optional["DeviceMesh"] = field(default=None, repr=False) def __post_init__(self) -> None: - _validate_mesh_dim_names(self) - _validate_distributed_setup(self) + _validate_mesh_axis_names(self) # Parallelism sizes — derived from the attached meshes @property @@ -192,17 +183,44 @@ def parallelize_axis_kwargs(self) -> Dict[str, object]: else None, } + @classmethod + def build( + cls, + strategy_config: DistributedStrategyConfig, + parallelism_sizes: ParallelismSizes | None = None, + *, + world_size: int | None = None, + ) -> "MeshContext": + """Build a topology-only :class:`MeshContext` from parallelism sizes. + + Args: + strategy_config: Already-instantiated distributed strategy config. + parallelism_sizes: Requested data, tensor, pipeline, context, and expert + parallelism sizes. If ``None``, defaults to no parallelism with + DP inferred from ``world_size``. + world_size: Total process count. If ``None``, inferred from the + distributed environment. + """ + if world_size is None: + world_size = get_world_size_safe() + if parallelism_sizes is None: + parallelism_sizes = ParallelismSizes() + + from nemo_automodel.components.distributed.mesh_utils import _create_device_meshes + + device_mesh, moe_mesh = _create_device_meshes( + strategy_config, + parallelism_sizes, + world_size=world_size, + ) + return cls.from_meshes(device_mesh, moe_mesh) + # Convenience constructor @classmethod def from_meshes( cls, device_mesh: Optional["DeviceMesh"], moe_mesh: Optional["DeviceMesh"] = None, - *, - strategy_config: Optional[Union["FSDP2Config", "MegatronFSDPConfig", "DDPConfig"]] = None, - pipeline_config: Optional["PipelineConfig"] = None, - moe_config: Optional["MoEParallelizerConfig"] = None, - activation_checkpointing: bool = False, ) -> "MeshContext": """Build a :class:`MeshContext` from ``DeviceMesh`` objects. @@ -211,10 +229,6 @@ def from_meshes( YAML config. """ return cls( - strategy_config=strategy_config, - pipeline_config=pipeline_config, - moe_config=moe_config, - activation_checkpointing=activation_checkpointing, device_mesh=device_mesh, moe_mesh=moe_mesh, ) @@ -225,7 +239,7 @@ def _get_axis_size(mesh: Optional["DeviceMesh"], axis: MeshAxisName, default=1) """Return the size of *axis* if present in *mesh*, else *default*.""" if mesh is None: return default - # Check mesh dims and _flatten() results on root mesh + # Check mesh axes and _flatten() results on root mesh. if axis in mesh.mesh_dim_names: return mesh[axis].size() if hasattr(mesh, "_get_root_mesh"): @@ -245,8 +259,8 @@ def _optional_axis(mesh: Optional["DeviceMesh"], axis: MeshAxisName) -> Optional # Validation utils -def _validate_mesh_dim_names(mesh_context: "MeshContext") -> None: - """Ensure every dimension name in the attached meshes is a :class:`MeshAxisName`.""" +def _validate_mesh_axis_names(mesh_context: "MeshContext") -> None: + """Ensure every axis name in the attached meshes is a :class:`MeshAxisName`.""" for label in ("device_mesh", "moe_mesh"): mesh = getattr(mesh_context, label) if mesh is None: @@ -254,45 +268,12 @@ def _validate_mesh_dim_names(mesh_context: "MeshContext") -> None: bad = {n for n in mesh.mesh_dim_names if n not in _VALID_AXIS_NAMES} if bad: raise ValueError( - f"{label} contains unknown dimension names {bad}; allowed names are {sorted(_VALID_AXIS_NAMES)}" + f"{label} contains unknown axis names {bad}; allowed names are {sorted(_VALID_AXIS_NAMES)}" ) -def _validate_distributed_setup(mesh_context: "MeshContext") -> None: - """Validate cross-field constraints on a :class:`MeshContext`. - - Called automatically by ``MeshContext.__post_init__`` when a - ``strategy_config`` is present. Can also be invoked explicitly - after mutating a context. - - Raises: - ValueError: If any constraint is violated. - """ - if mesh_context.strategy_config is None: - return - - if isinstance(mesh_context.strategy_config, MegatronFSDPConfig): - if mesh_context.pp_size > 1: - raise ValueError("megatron_fsdp does not support pipeline parallelism") - if mesh_context.ep_size > 1: - raise ValueError("megatron_fsdp does not support expert parallelism") - if mesh_context.strategy_config.sequence_parallel: - raise ValueError("megatron_fsdp does not yet support sequence_parallel") - - if isinstance(mesh_context.strategy_config, DDPConfig): - if mesh_context.tp_size > 1: - raise ValueError("ddp does not support tensor parallelism") - if mesh_context.pp_size > 1: - raise ValueError("ddp does not support pipeline parallelism") - if mesh_context.cp_size > 1: - raise ValueError("ddp does not support context parallelism") - if mesh_context.ep_size > 1: - raise ValueError("ddp does not support expert parallelism") - if mesh_context.dp_replicate_size is not None and mesh_context.dp_replicate_size > 1: - raise ValueError("ddp does not support HSDP (dp_replicate_size)") - - if mesh_context.pipeline_config is not None and mesh_context.pp_size <= 1: - raise ValueError("pipeline config requires pp_size > 1") - - if mesh_context.moe_config is not None and mesh_context.ep_size <= 1: - raise ValueError("moe config requires ep_size > 1") +__all__ = [ + "MeshAxisName", + "MeshContext", + "ParallelismSizes", +] diff --git a/nemo_automodel/components/distributed/mesh_utils.py b/nemo_automodel/components/distributed/mesh_utils.py index aefe78d86c..a9f95333d6 100644 --- a/nemo_automodel/components/distributed/mesh_utils.py +++ b/nemo_automodel/components/distributed/mesh_utils.py @@ -12,321 +12,244 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Device mesh creation utilities for distributed training. - -This module provides a central function to create device meshes based on the -distributed config type (FSDP2, MegatronFSDP, or DDP). - -Usage: - from nemo_automodel.components.distributed.config import FSDP2Config - from nemo_automodel.components.distributed.mesh_utils import create_device_mesh - - config = FSDP2Config(sequence_parallel=True) - device_mesh, moe_mesh = create_device_mesh( - config, - tp_size=2, - pp_size=1, - dp_replicate_size=2, - world_size=8, - ) -""" +"""Device mesh construction and access utilities for distributed training.""" -from typing import Optional, Tuple, Union +from dataclasses import dataclass, field +import torch +import torch.distributed as dist from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from nemo_automodel.components.distributed.config import ( DDPConfig, + DistributedStrategyConfig, FSDP2Config, MegatronFSDPConfig, ) -from nemo_automodel.components.distributed.mesh import MeshAxisName +from nemo_automodel.components.distributed.mesh import MeshAxisName, ParallelismSizes +__all__ = [ + "_create_device_meshes", + "_create_fsdp2_device_mesh", + "_create_megatron_fsdp_device_mesh", + "_unflatten_compat", + "get_flat_mesh", + "get_submesh", + "get_fsdp_dp_mesh", +] -def create_device_mesh( - distributed_config: Union[FSDP2Config, MegatronFSDPConfig, DDPConfig], - *, - dp_size: Optional[int] = None, - dp_replicate_size: Optional[int] = None, - tp_size: int = 1, - pp_size: int = 1, - cp_size: int = 1, - ep_size: int = 1, - world_size: int, -) -> Tuple[Optional[DeviceMesh], Optional[DeviceMesh]]: - """Create device mesh based on distributed config type. - Routes to the appropriate mesh creation logic based on config type. +def _degree(value: int | None) -> int: + return value if isinstance(value, int) and value > 0 else 1 + + +def _require_size_one(strategy_name: str, size: int | None, feature_name: str) -> None: + if _degree(size) > 1: + raise ValueError(f"{strategy_name} does not support {feature_name}") - Args: - distributed_config: The distributed config (FSDP2Config, MegatronFSDPConfig, - or DDPConfig). - dp_size: Data parallel size. If None, inferred from world_size and other - parallelism sizes. - dp_replicate_size: FSDP2-only. Size of the replication group for HSDP - (Hybrid Sharded Data Parallel). If None or <= 0, defaults to 1. - Must be a divisor of dp_size. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - cp_size: Context parallel size. - ep_size: Expert parallel size (for MoE models). - world_size: Total number of processes. - - Returns: - tuple: (device_mesh, moe_mesh) - - For FSDP2Config: Full device mesh + optional moe_mesh (if ep_size > 1) - - For MegatronFSDPConfig: Device mesh + None - - For DDPConfig: (None, None) - DDP doesn't use device mesh - - Raises: - ValueError: If dp_replicate_size is provided with non-FSDP2 config. - ValueError: If world_size is not divisible by parallelism sizes. - """ - # Validate FSDP2-only params - if dp_replicate_size is not None and dp_replicate_size > 1: - if not isinstance(distributed_config, FSDP2Config): - raise ValueError("dp_replicate_size is only supported with FSDP2Config") - if isinstance(distributed_config, FSDP2Config): +@dataclass(frozen=True) +class _MeshSpec: + """Named mesh shape plus derived flattened axes.""" + + shape: tuple[int, ...] + axes: tuple[MeshAxisName, ...] + flattened_axes: dict[MeshAxisName, tuple[MeshAxisName, ...]] = field(default_factory=dict) + + +def _create_device_meshes( + strategy_config: DistributedStrategyConfig, + parallelism: ParallelismSizes, + *, + world_size: int, +) -> tuple[DeviceMesh | None, DeviceMesh | None]: + """Create raw device meshes based on distributed config type.""" + if ( + parallelism.dp_replicate_size is not None + and parallelism.dp_replicate_size > 1 + and not isinstance(strategy_config, FSDP2Config) + ): + raise ValueError("dp_replicate_size is only supported with FSDP2Config") + + if isinstance(strategy_config, FSDP2Config): return _create_fsdp2_device_mesh( - dp_size=dp_size, - dp_replicate_size=dp_replicate_size, - tp_size=tp_size, - pp_size=pp_size, - cp_size=cp_size, - ep_size=ep_size, + parallelism, world_size=world_size, - backend=distributed_config.backend, ) - elif isinstance(distributed_config, MegatronFSDPConfig): + elif isinstance(strategy_config, MegatronFSDPConfig): + _require_size_one("megatron_fsdp", parallelism.pp_size, "pipeline parallelism") + _require_size_one("megatron_fsdp", parallelism.ep_size, "expert parallelism") mesh = _create_megatron_fsdp_device_mesh( - dp_size=dp_size, - tp_size=tp_size, - cp_size=cp_size, + parallelism, world_size=world_size, - backend=distributed_config.backend, ) return mesh, None - elif isinstance(distributed_config, DDPConfig): - return None, None # DDP doesn't use device mesh + elif isinstance(strategy_config, DDPConfig): + _require_size_one("ddp", parallelism.tp_size, "tensor parallelism") + _require_size_one("ddp", parallelism.pp_size, "pipeline parallelism") + _require_size_one("ddp", parallelism.cp_size, "context parallelism") + _require_size_one("ddp", parallelism.ep_size, "expert parallelism") + return None, None else: - raise ValueError(f"Unknown distributed config type: {type(distributed_config)}") + raise ValueError(f"Unknown distributed strategy config type: {type(strategy_config)}") -def _create_fsdp2_device_mesh( - dp_size: Optional[int], - dp_replicate_size: Optional[int], - tp_size: int, - pp_size: int, - cp_size: int, - ep_size: int, +def _infer_dp_size( + dp_size: int | None, + *, world_size: int, - backend: str, -) -> Tuple[DeviceMesh, Optional[DeviceMesh]]: - """ - Create device mesh for FSDP2. + non_dp_size: int, + expression: str, + factors: tuple[int, ...], +) -> int: + if dp_size is not None and dp_size > 0: + return dp_size + + if world_size % non_dp_size != 0: + factors_str = " * ".join(str(factor) for factor in factors) + raise ValueError( + f"world_size ({world_size}) must be divisible by ({expression}) ({factors_str} = {non_dp_size})" + ) + return world_size // non_dp_size - Mesh shape: (pp_size, dp_replicate_size, dp_shard_size, cp_size, tp_size) - Mesh names: ("pp", "dp_replicate", "dp_shard", "cp", "tp") - Also creates flattened submeshes: - - "dp": dp_replicate + dp_shard - - "dp_shard_cp": dp_shard + cp - - "dp_cp": dp_replicate + dp_shard + cp +def _mesh_device_type() -> str: + if dist.is_available() and dist.is_initialized(): + backend = str(dist.get_backend()).lower() + return "cuda" if "nccl" in backend and torch.cuda.is_available() else "cpu" + return "cuda" if torch.cuda.is_available() else "cpu" - Args: - dp_size: Data parallel size. If None, inferred from world_size. - dp_replicate_size: Size of the replication group for HSDP. - tp_size: Tensor parallel size. - pp_size: Pipeline parallel size. - cp_size: Context parallel size. - ep_size: Expert parallel size (for MoE models). - world_size: Total number of processes. - backend: Distributed backend ('nccl' or 'gloo'). - - Returns: - tuple: (device_mesh, moe_mesh) - """ - # Normalize sizes - if tp_size is None or tp_size <= 0: - tp_size = 1 - if cp_size is None or cp_size <= 0: - cp_size = 1 - if pp_size is None or pp_size <= 0: - pp_size = 1 - if ep_size is None or ep_size <= 0: - ep_size = 1 - - # Infer dp_size if not provided - if dp_size is None or dp_size <= 0: - total_parallel_ranks = tp_size * cp_size * pp_size - if world_size % total_parallel_ranks != 0: - raise ValueError( - f"world_size ({world_size}) must be divisible by (tp_size * cp_size * pp_size) " - f"({tp_size} * {cp_size} * {pp_size} = {total_parallel_ranks})" - ) - dp_size = world_size // total_parallel_ranks - - if dp_replicate_size is None or dp_replicate_size <= 0: - dp_replicate_size = 1 - - # HSDP usecase: dp_size = dp_replicate_size * dp_shard_size - assert dp_size % dp_replicate_size == 0, "dp_size must be a multiple of dp_replicate_size" - assert dp_replicate_size < dp_size or dp_replicate_size == 1, ( - f"dp_replicate_size={dp_replicate_size} must be less than dp_size={dp_size} " - "since DDP usecase is not supported by FSDP2" - ) - # Expert parallelism: EP spans all non-pp dims (dp, cp, tp) - non_pp_size = dp_size * cp_size * tp_size - assert non_pp_size % ep_size == 0, f"{non_pp_size=} must be a multiple of {ep_size=}" - if ep_size < non_pp_size: - ep_shard_size = non_pp_size // ep_size - else: - ep_shard_size = 1 +def _init_named_mesh(spec: _MeshSpec) -> DeviceMesh: + _validate_mesh_spec(spec) + device_mesh = init_device_mesh( + device_type=_mesh_device_type(), + mesh_shape=spec.shape, + mesh_dim_names=spec.axes, + ) + _register_flattened_axes(device_mesh, spec.flattened_axes) + return device_mesh - dp_shard_size = dp_size // dp_replicate_size - # Build main device mesh - mesh_shape = (pp_size, dp_replicate_size, dp_shard_size, cp_size, tp_size) - mesh_names = ( - MeshAxisName.PP, - MeshAxisName.DP_REPLICATE, - MeshAxisName.DP_SHARD, - MeshAxisName.CP, - MeshAxisName.TP, - ) - for shape, name in zip(mesh_shape, mesh_names): - assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}" - assert shape > 0, f"Expected {name} > 0, got {shape}" +def _validate_mesh_spec(spec: _MeshSpec) -> None: + for shape, axis in zip(spec.shape, spec.axes): + assert isinstance(shape, int), f"Expected {axis} to be an int, but got {type(shape)}" + assert shape > 0, f"Expected {axis} > 0, got {shape}" - device_mesh = init_device_mesh( - device_type="cuda" if backend == "nccl" else "cpu", - mesh_shape=mesh_shape, - mesh_dim_names=mesh_names, - ) - # Create flattened submeshes - # Based on https://github.com/pytorch/torchtitan/blob/d282cf2ce9ca8049b4b8423c1d7578c80426576f/torchtitan/distributed/parallel_dims.py#L191 - dp_mesh_dim_names = [] # Mesh for data loading (no communication on this mesh) - dp_shard_cp_mesh_dim_names = [] # Mesh for param sharding - dp_cp_mesh_dim_names = [] # Mesh for loss all-reduce - - # for dp_replicate: - dp_mesh_dim_names.append(MeshAxisName.DP_REPLICATE) - dp_cp_mesh_dim_names.append(MeshAxisName.DP_REPLICATE) - # for dp_shard: - dp_mesh_dim_names.append(MeshAxisName.DP_SHARD) - dp_shard_cp_mesh_dim_names.append(MeshAxisName.DP_SHARD) - dp_cp_mesh_dim_names.append(MeshAxisName.DP_SHARD) - # for cp: - dp_shard_cp_mesh_dim_names.append(MeshAxisName.CP) - dp_cp_mesh_dim_names.append(MeshAxisName.CP) - - # Flatten submeshes. - # PyTorch >= 2.10 stores results in root._flatten_mapping automatically. - # PyTorch 2.9.x returns the mesh but does NOT store it, so we keep our own - # mapping and attach it to the root mesh for use in get_flat_mesh(). - _dp_flat = device_mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name=MeshAxisName.DP) - _dp_shard_cp_flat = device_mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name=MeshAxisName.DP_SHARD_CP) - _dp_cp_flat = device_mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name=MeshAxisName.DP_CP) +def _register_flattened_axes( + device_mesh: DeviceMesh, flattened_axes: dict[MeshAxisName, tuple[MeshAxisName, ...]] +) -> None: + if not flattened_axes: + return if not hasattr(device_mesh, "_flatten_mapping"): device_mesh._flatten_mapping = {} - device_mesh._flatten_mapping.setdefault(MeshAxisName.DP, _dp_flat) - device_mesh._flatten_mapping.setdefault(MeshAxisName.DP_SHARD_CP, _dp_shard_cp_flat) - device_mesh._flatten_mapping.setdefault(MeshAxisName.DP_CP, _dp_cp_flat) + for flattened_axis, source_axes in flattened_axes.items(): + flattened_mesh = device_mesh[source_axes]._flatten(mesh_dim_name=flattened_axis) + device_mesh._flatten_mapping.setdefault(flattened_axis, flattened_mesh) + + +def _create_fsdp2_device_mesh( + parallelism: ParallelismSizes, + *, + world_size: int, +) -> tuple[DeviceMesh, DeviceMesh | None]: + """Create the FSDP2 root mesh and optional MoE mesh.""" + tp_size = _degree(parallelism.tp_size) + cp_size = _degree(parallelism.cp_size) + pp_size = _degree(parallelism.pp_size) + ep_size = _degree(parallelism.ep_size) + dp_replicate_size = _degree(parallelism.dp_replicate_size) + dp_size = _infer_dp_size( + parallelism.dp_size, + world_size=world_size, + non_dp_size=tp_size * cp_size * pp_size, + expression="tp_size * cp_size * pp_size", + factors=(tp_size, cp_size, pp_size), + ) + + if dp_size % dp_replicate_size != 0: + raise ValueError("dp_size must be a multiple of dp_replicate_size") + if dp_replicate_size >= dp_size and dp_replicate_size != 1: + raise ValueError( + f"dp_replicate_size={dp_replicate_size} must be less than dp_size={dp_size} " + "since DDP usecase is not supported by FSDP2" + ) + + non_pp_size = dp_size * cp_size * tp_size + if non_pp_size % ep_size != 0: + raise ValueError(f"{non_pp_size=} must be a multiple of {ep_size=}") + ep_shard_size = non_pp_size // ep_size if ep_size < non_pp_size else 1 + dp_shard_size = dp_size // dp_replicate_size + + device_mesh = _init_named_mesh( + _MeshSpec( + shape=(pp_size, dp_replicate_size, dp_shard_size, cp_size, tp_size), + axes=( + MeshAxisName.PP, + MeshAxisName.DP_REPLICATE, + MeshAxisName.DP_SHARD, + MeshAxisName.CP, + MeshAxisName.TP, + ), + flattened_axes={ + MeshAxisName.DP: (MeshAxisName.DP_REPLICATE, MeshAxisName.DP_SHARD), + MeshAxisName.DP_SHARD_CP: (MeshAxisName.DP_SHARD, MeshAxisName.CP), + MeshAxisName.DP_CP: (MeshAxisName.DP_REPLICATE, MeshAxisName.DP_SHARD, MeshAxisName.CP), + }, + ), + ) - # Derive EP mesh by flattening all non-pp dims and unflattening into (ep_shard, ep). - # EP spans dp, cp, and tp — the full non-pp rank space. moe_mesh = None if ep_size > 1: - non_pp_dims = (MeshAxisName.DP_REPLICATE, MeshAxisName.DP_SHARD, MeshAxisName.CP, MeshAxisName.TP) - non_pp_mesh = device_mesh[non_pp_dims]._flatten() - moe_mesh = _unflatten_compat( - non_pp_mesh, - 0, - (ep_shard_size, ep_size), - (MeshAxisName.EP_SHARD, MeshAxisName.EP), - ) + moe_mesh = _create_moe_mesh(device_mesh, ep_shard_size=ep_shard_size, ep_size=ep_size) return device_mesh, moe_mesh def _create_megatron_fsdp_device_mesh( - dp_size: Optional[int], - tp_size: int, - cp_size: int, + parallelism: ParallelismSizes, + *, world_size: int, - backend: str, ) -> DeviceMesh: - """ - Create device mesh for MegatronFSDP. - - Mesh shape: (dp_size, cp_size, tp_size) - Mesh names: ("dp", "cp", "tp") - - Also creates flattened submesh "dp_cp" if cp_size > 1. - - Args: - dp_size: Data parallel size. If None, inferred from world_size. - tp_size: Tensor parallel size. - cp_size: Context parallel size. - world_size: Total number of processes. - backend: Distributed backend ('nccl' or 'gloo'). - - Returns: - DeviceMesh: The device mesh for MegatronFSDP. - """ - # Normalize sizes - tp_size = tp_size or 1 - cp_size = cp_size or 1 - - # Infer dp_size if not provided - if dp_size is None or dp_size <= 0: - total_parallel_ranks = tp_size * cp_size - if world_size % total_parallel_ranks != 0: - raise ValueError( - f"world_size ({world_size}) must be divisible by (tp_size * cp_size) " - f"({tp_size} * {cp_size} = {total_parallel_ranks})" - ) - dp_size = world_size // total_parallel_ranks - - mesh_shape = (dp_size, cp_size, tp_size) - mesh_names = (MeshAxisName.DP, MeshAxisName.CP, MeshAxisName.TP) - for shape, name in zip(mesh_shape, mesh_names): - assert isinstance(shape, int), f"Expected {name} to be an int, but got {type(shape)}" - assert shape > 0, f"Expected {name} > 0, got {shape}" - - # Build mesh [dp, cp, tp] - device_mesh = init_device_mesh( - device_type="cuda" if backend == "nccl" else "cpu", - mesh_shape=mesh_shape, - mesh_dim_names=mesh_names, + """Create the Megatron FSDP mesh.""" + tp_size = _degree(parallelism.tp_size) + cp_size = _degree(parallelism.cp_size) + dp_size = _infer_dp_size( + parallelism.dp_size, + world_size=world_size, + non_dp_size=tp_size * cp_size, + expression="tp_size * cp_size", + factors=(tp_size, cp_size), ) - # Flatten dp+cp if cp > 1 - if cp_size > 1: - _dp_cp_flat = device_mesh[(MeshAxisName.DP, MeshAxisName.CP)]._flatten(mesh_dim_name=MeshAxisName.DP_CP) - if not hasattr(device_mesh, "_flatten_mapping"): - device_mesh._flatten_mapping = {} - device_mesh._flatten_mapping.setdefault(MeshAxisName.DP_CP, _dp_cp_flat) - - return device_mesh + return _init_named_mesh( + _MeshSpec( + shape=(dp_size, cp_size, tp_size), + axes=(MeshAxisName.DP, MeshAxisName.CP, MeshAxisName.TP), + flattened_axes={MeshAxisName.DP_CP: (MeshAxisName.DP, MeshAxisName.CP)} if cp_size > 1 else {}, + ), + ) -def _unflatten_compat(flat_mesh: "DeviceMesh", dim: int, sizes: tuple, names: tuple) -> "DeviceMesh": - """Compatibility shim for DeviceMesh._unflatten(), which was added in PyTorch 2.10. +def _create_moe_mesh(device_mesh: DeviceMesh, *, ep_shard_size: int, ep_size: int) -> DeviceMesh: + non_pp_axes = (MeshAxisName.DP_REPLICATE, MeshAxisName.DP_SHARD, MeshAxisName.CP, MeshAxisName.TP) + return _unflatten_compat( + device_mesh[non_pp_axes]._flatten(), + 0, + (ep_shard_size, ep_size), + (MeshAxisName.EP_SHARD, MeshAxisName.EP), + ) - Reconstructs a multi-dimensional mesh from a flat mesh by reshaping its - rank tensor. ``dim`` must be 0 (only case used in this codebase). - """ - from torch.distributed.device_mesh import DeviceMesh +def _unflatten_compat(flat_mesh: DeviceMesh, axis: int, sizes: tuple, names: tuple) -> DeviceMesh: + """Compatibility shim for DeviceMesh._unflatten(), added in PyTorch 2.10.""" if hasattr(flat_mesh, "_unflatten"): - return flat_mesh._unflatten(dim, sizes, names) - # PyTorch 2.9.x fallback: reshape the underlying rank tensor directly. + return flat_mesh._unflatten(axis, sizes, names) new_mesh_tensor = flat_mesh.mesh.reshape(sizes) - return DeviceMesh(flat_mesh.device_type, new_mesh_tensor, mesh_dim_names=names) + from torch.distributed.device_mesh import DeviceMesh as _DeviceMesh + + return _DeviceMesh(flat_mesh.device_type, new_mesh_tensor, mesh_dim_names=names) def get_flat_mesh(device_mesh: "DeviceMesh", name: str) -> "DeviceMesh": diff --git a/nemo_automodel/components/moe/config.py b/nemo_automodel/components/moe/config.py index c98b361cfc..7ddc2adc29 100644 --- a/nemo_automodel/components/moe/config.py +++ b/nemo_automodel/components/moe/config.py @@ -12,31 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""MoE parallelizer configuration.""" +"""MoE model configuration.""" -from dataclasses import dataclass, fields -from typing import Any, Dict, Literal, Optional, Union +from dataclasses import dataclass +from typing import Literal, Optional import torch -from torch.distributed.fsdp._fully_shard import MixedPrecisionPolicy from nemo_automodel.shared.utils import dtype_from_str -@dataclass -class MoEParallelizerConfig: - """Configuration for MoE model parallelization (EP + FSDP settings).""" - - ignore_router_for_ac: bool = False - reshard_after_forward: bool = False - lm_head_precision: Optional[Union[str, torch.dtype]] = None - wrap_outer_model: bool = True - mp_policy: Optional[MixedPrecisionPolicy] = None - - def to_dict(self) -> Dict[str, Any]: - return {f.name: getattr(self, f.name) for f in fields(self)} - - @dataclass(kw_only=True) class MoEConfig: """Configuration for routed and shared MoE expert modules.""" diff --git a/nemo_automodel/recipes/_dist_setup.py b/nemo_automodel/recipes/_dist_utils.py similarity index 64% rename from nemo_automodel/recipes/_dist_setup.py rename to nemo_automodel/recipes/_dist_utils.py index a5cdfe9ae9..50806841d0 100644 --- a/nemo_automodel/recipes/_dist_setup.py +++ b/nemo_automodel/recipes/_dist_utils.py @@ -15,19 +15,21 @@ """Recipe-level helpers for parsing YAML distributed configs. This module bridges the gap between raw YAML / :class:`ConfigNode` dicts -and the typed :class:`MeshContext` used by the component layer. -All dict handling lives here; the component layer (``mesh``) stays purely typed. +and the typed :class:`DistributedSetup` used by the component layer. +All dict handling lives here; the component layer stays typed. This module does +not initialize ``torch.distributed``. Recipes call ``initialize_distributed`` +first, then pass the resulting world size here. """ -import dataclasses from typing import Any, Dict, Optional -from nemo_automodel.components.distributed.mesh import ( - STRATEGY_MAP, - MeshContext, +from nemo_automodel.components.distributed.config import ( + DistributedSetup, + MoEParallelizerConfig, + _resolve_strategy_config, ) +from nemo_automodel.components.distributed.mesh import ParallelismSizes from nemo_automodel.components.distributed.pipelining.config import PipelineConfig -from nemo_automodel.components.moe.config import MoEParallelizerConfig from nemo_automodel.shared.utils import dtype_from_str _PARALLELISM_DEFAULTS: Dict[str, Any] = { @@ -40,18 +42,6 @@ } -def _validate_strategy_kwargs( - strategy_name: str, - strategy_cls: type, - strategy_kwargs: Dict[str, Any], -) -> None: - """Check that *strategy_kwargs* only contains fields recognised by *strategy_cls*.""" - valid_fields = {f.name for f in dataclasses.fields(strategy_cls)} - unknown = set(strategy_kwargs) - valid_fields - if unknown: - raise ValueError(f"Unknown options for strategy '{strategy_name}': {sorted(unknown)}") - - def parse_distributed_section(cfg_dict: dict) -> dict: """Parse a flat distributed config dict into components for mesh creation. @@ -59,22 +49,20 @@ def parse_distributed_section(cfg_dict: dict) -> dict: - ``strategy_config`` – instantiated strategy dataclass - ``pipeline_config`` – :class:`PipelineConfig` or ``None`` - - ``moe_config`` – :class:`MoEParallelizerConfig` or ``None`` + - ``moe_parallel_config`` – :class:`MoEParallelizerConfig` or ``None`` - ``activation_checkpointing`` – bool + - ``parallelism_sizes`` – :class:`ParallelismSizes` - ``tp_size``, ``pp_size``, ``cp_size``, ``ep_size``, ``dp_size``, ``dp_replicate_size`` – parallelism sizes - ``pp_enabled`` – ``True`` when ``pp_size > 1`` Device meshes are **not** created here; that is done by - :func:`setup_distributed`. + :meth:`DistributedSetup.build`. """ cfg = cfg_dict.copy() # shallow copy — never mutate the caller's dict # -- strategy ----------------------------------------------------------- - strategy_name: str = cfg.pop("strategy", "fsdp2") - if strategy_name not in STRATEGY_MAP: - raise ValueError(f"Unknown strategy: {strategy_name}. Valid strategies: {list(STRATEGY_MAP.keys())}") - strategy_cls = STRATEGY_MAP[strategy_name] + strategy_name: str = cfg.pop("strategy", "fsdp2").lower() # -- parallelism sizes -------------------------------------------------- # Use `val if val is not None` so that explicit YAML nulls (``ep_size:`` @@ -143,11 +131,6 @@ def parse_distributed_section(cfg_dict: dict) -> dict: if isinstance(val, str): strategy_kwargs["autocast_dtype"] = dtype_from_str(val) - _validate_strategy_kwargs(strategy_name, strategy_cls, strategy_kwargs) - - # Route activation_checkpointing: for non-EP configs it goes on the - # strategy config; for EP configs it stays only on MeshContext - # (the MoE infra reads it from there). ep_size: int = parallelism.get("ep_size") or 1 # YAML-level sanity: silently discard sub-configs that don't apply to the @@ -158,10 +141,8 @@ def parse_distributed_section(cfg_dict: dict) -> dict: pipeline_dict = None if moe_dict is not None and ep_size <= 1: moe_dict = None - if ep_size <= 1: - strategy_kwargs["activation_checkpointing"] = activation_checkpointing - strategy_config = strategy_cls(**strategy_kwargs) + strategy_config = _resolve_strategy_config(strategy_name, **strategy_kwargs) if pipeline_dict is not None: pipeline_config = PipelineConfig(**pipeline_dict) @@ -181,62 +162,106 @@ def parse_distributed_section(cfg_dict: dict) -> dict: mp_raw[key] = dtype_from_str(mp_raw[key]) moe_dict["mp_policy"] = target(**mp_raw) - moe_config = MoEParallelizerConfig(**(moe_dict or {})) if ep_size > 1 else None - - # Full cross-field validation is deferred to MeshContext.__post_init__ - # (called automatically when setup_distributed constructs the context). + moe_parallel_config = MoEParallelizerConfig(**(moe_dict or {})) if ep_size > 1 else None return { "strategy_config": strategy_config, "pipeline_config": pipeline_config, - "moe_config": moe_config, + "moe_parallel_config": moe_parallel_config, "activation_checkpointing": activation_checkpointing, + "parallelism_sizes": ParallelismSizes(**parallelism), "pp_enabled": parallelism["pp_size"] > 1, **parallelism, } -def setup_distributed(cfg: Any, world_size: Optional[int] = None) -> MeshContext: - """Parse ``cfg.distributed`` and create device meshes. - - This is the main entry-point called by recipes. It converts the - config section into a fully-initialised :class:`MeshContext` - (including ``device_mesh`` and ``moe_mesh``). +def _distributed_cfg_to_dict(cfg: Any | None) -> dict: + """Return a distributed config dict from ``cfg`` or an empty fallback.""" + if cfg is None: + return {} + if isinstance(cfg, dict): + return cfg.copy() + distributed_cfg = cfg.distributed + return distributed_cfg.to_dict() if hasattr(distributed_cfg, "to_dict") else dict(distributed_cfg) + + +def create_distributed_setup_from_config( + cfg: Any | None = None, + world_size: Optional[int] = None, + *, + strategy: str | None = None, + dp_size: int | None = None, + dp_replicate_size: int | None = None, + tp_size: int | None = None, + pp_size: int | None = None, + cp_size: int | None = None, + ep_size: int | None = None, + pipeline: dict | None = None, + moe: dict | None = None, + **strategy_kwargs: Any, +) -> DistributedSetup: + """Parse recipe distributed settings and create a distributed setup. + + This is the recipe-level adapter around :meth:`DistributedSetup.build`. + It converts a YAML/config section or programmatic keyword arguments into a + fully initialized :class:`DistributedSetup` (including ``device_mesh`` and + ``moe_mesh`` through ``setup.mesh_context``). It does not initialize the + process group; call ``initialize_distributed`` before this in distributed + recipes. Args: - cfg: Top-level config (must have a ``distributed`` key). + cfg: Optional distributed config dict or top-level config with a + ``distributed`` key. Used as fallback when explicit keyword + arguments are omitted. world_size: Total number of processes in the job. If ``None`` (default), the value is auto-detected from ``torch.distributed`` if initialized, or from the ``WORLD_SIZE`` environment variable, falling back to ``1``. + strategy: Distributed strategy name (``fsdp2``, ``megatron_fsdp``, + ``megatron-fsdp``, ``mfsdp``, or ``ddp``). + dp_size: Data-parallel size. If ``None``, inferred by mesh creation. + dp_replicate_size: HSDP replicate size for FSDP2. + tp_size: Tensor-parallel size. + pp_size: Pipeline-parallel size. + cp_size: Context-parallel size. + ep_size: Expert-parallel size. + pipeline: Optional pipeline sub-config. + moe: Optional MoE parallelizer sub-config. + **strategy_kwargs: Additional strategy-specific options. Returns: - A :class:`MeshContext` with device meshes attached. + A :class:`DistributedSetup` with device meshes and policy configs attached. """ - from nemo_automodel.components.distributed import mesh_utils from nemo_automodel.components.distributed.init_utils import get_world_size_safe if world_size is None: world_size = get_world_size_safe() - cfg_dict = cfg.distributed.to_dict() if not isinstance(cfg, dict) else cfg - parsed = parse_distributed_section(cfg_dict) - - device_mesh, moe_mesh = mesh_utils.create_device_mesh( - parsed["strategy_config"], - dp_size=parsed["dp_size"], - dp_replicate_size=parsed["dp_replicate_size"], - tp_size=parsed["tp_size"], - pp_size=parsed["pp_size"], - cp_size=parsed["cp_size"], - ep_size=parsed["ep_size"], - world_size=world_size, - ) + cfg_dict = _distributed_cfg_to_dict(cfg) + + explicit_overrides = { + "strategy": strategy, + "dp_size": dp_size, + "dp_replicate_size": dp_replicate_size, + "tp_size": tp_size, + "pp_size": pp_size, + "cp_size": cp_size, + "ep_size": ep_size, + "pipeline": pipeline, + "moe": moe, + } + for key, value in explicit_overrides.items(): + if value is not None: + cfg_dict[key] = value + for key, value in strategy_kwargs.items(): + if value is not None: + cfg_dict[key] = value - return MeshContext( - strategy_config=parsed["strategy_config"], + parsed = parse_distributed_section(cfg_dict) + return DistributedSetup.build( + strategy=parsed["strategy_config"], + parallelism_sizes=parsed["parallelism_sizes"], pipeline_config=parsed["pipeline_config"], - moe_config=parsed["moe_config"], + moe_parallel_config=parsed["moe_parallel_config"], activation_checkpointing=parsed["activation_checkpointing"], - device_mesh=device_mesh, - moe_mesh=moe_mesh, + world_size=world_size, ) diff --git a/nemo_automodel/recipes/base_recipe.py b/nemo_automodel/recipes/base_recipe.py index 2bf3af9350..7ed66b8bc3 100644 --- a/nemo_automodel/recipes/base_recipe.py +++ b/nemo_automodel/recipes/base_recipe.py @@ -204,6 +204,22 @@ class BaseRecipe: BaseRecipe provides checkpoint load/save functionality for recipes. """ + @staticmethod + def _distributed_setup_attributes(distributed_setup): + """Return common recipe attributes derived from a distributed setup.""" + mesh_context = distributed_setup.mesh_context + return ( + distributed_setup, + mesh_context, + distributed_setup.strategy_config, + mesh_context.device_mesh, + mesh_context.moe_mesh, + mesh_context.pp_enabled, + distributed_setup.pipeline_config, + distributed_setup.moe_parallel_config, + distributed_setup.activation_checkpointing, + ) + def __setattr__(self, key, value): """ Overriden __setattr__ to keep track of stateful classes. diff --git a/nemo_automodel/recipes/diffusion/train.py b/nemo_automodel/recipes/diffusion/train.py index 8b7c291433..bb978187a0 100644 --- a/nemo_automodel/recipes/diffusion/train.py +++ b/nemo_automodel/recipes/diffusion/train.py @@ -40,6 +40,7 @@ prepare_for_final_backward, prepare_for_grad_accumulation, ) +from nemo_automodel.recipes._dist_utils import parse_distributed_section from nemo_automodel.recipes.base_recipe import BaseRecipe from nemo_automodel.recipes.llm.train_ft import build_distributed, build_wandb @@ -85,6 +86,64 @@ def _calculate_throughput_metrics( } +def _build_diffusion_parallel_manager_args( + *, + fsdp_cfg: Optional[Dict[str, Any]], + ddp_cfg: Optional[Dict[str, Any]], + world_size: int, + dtype: torch.dtype, + lora_enabled: bool, +) -> Dict[str, Any]: + """Build diffusion transformer manager args through the shared distributed parser.""" + if fsdp_cfg is not None and ddp_cfg is not None: + raise ValueError( + "Cannot specify both 'fsdp' and 'ddp' configurations. " + "Please provide only one distributed training strategy." + ) + + if ddp_cfg is not None: + parsed = parse_distributed_section({"strategy": "ddp", **ddp_cfg}) + return { + "_manager_type": "ddp", + "world_size": world_size, + **parsed["strategy_config"].to_dict(), + "activation_checkpointing": parsed["activation_checkpointing"], + } + + fsdp_options = dict(fsdp_cfg or {}) + ignored_options = {"use_hf_tp_plan": fsdp_options.pop("use_hf_tp_plan", False)} + + param_dtype = None if lora_enabled else dtype + parsed = parse_distributed_section( + { + "strategy": "fsdp2", + "activation_checkpointing": True, + "defer_fsdp_grad_sync": True, + "enable_fsdp2_prefetch": True, + **fsdp_options, + "mp_policy": MixedPrecisionPolicy( + param_dtype=param_dtype, + reduce_dtype=torch.float32, + output_dtype=dtype, + ), + } + ) + + return { + "_manager_type": "fsdp2", + "world_size": world_size, + "dp_size": parsed["dp_size"], + "dp_replicate_size": parsed["dp_replicate_size"], + "tp_size": parsed["tp_size"], + "cp_size": parsed["cp_size"], + "pp_size": parsed["pp_size"], + "ep_size": parsed["ep_size"], + **parsed["strategy_config"].to_dict(), + "activation_checkpointing": parsed["activation_checkpointing"], + **ignored_options, + } + + def build_model_and_optimizer( *, model_id: str, @@ -130,13 +189,6 @@ def build_model_and_optimizer( ValueError: If both fsdp_cfg and ddp_cfg are provided. ValueError: If finetune_mode is False and pipeline_spec is not provided. """ - # Validate mutually exclusive configs - if fsdp_cfg is not None and ddp_cfg is not None: - raise ValueError( - "Cannot specify both 'fsdp' and 'ddp' configurations. " - "Please provide only one distributed training strategy." - ) - logging.info("[INFO] Building NeMoAutoDiffusionPipeline with transformer parallel scheme...") if not dist.is_initialized(): @@ -145,67 +197,17 @@ def build_model_and_optimizer( world_size = dist.get_world_size() if dist.is_initialized() else 1 lora_enabled = peft_cfg is not None - # param_dtype=None when LoRA: FSDP2 does not cast any parameter. - # bf16 base weights stay bf16 (loaded dtype). - # bf16 LoRA weights stay bf16 (set via peft_cfg.lora_dtype in pipeline). - # param_dtype=dtype when full fine-tune: FSDP2 casts everything to dtype (bf16). - param_dtype = None if lora_enabled else dtype - - # Build manager args based on which config is provided if ddp_cfg is not None: - # DDP configuration logging.info("[INFO] Using DDP (DistributedDataParallel) for training") - manager_args: Dict[str, Any] = { - "_manager_type": "ddp", - "backend": ddp_cfg.get("backend", "nccl"), - "world_size": world_size, - "activation_checkpointing": ddp_cfg.get("activation_checkpointing", False), - } else: - # FSDP configuration (default) - fsdp_cfg = fsdp_cfg or {} logging.info("[INFO] Using FSDP2 (Fully Sharded Data Parallel) for training") - - dp_size = fsdp_cfg.get("dp_size") - tp_size = fsdp_cfg.get("tp_size", 1) - cp_size = fsdp_cfg.get("cp_size", 1) - pp_size = fsdp_cfg.get("pp_size", 1) - - if dp_size is None: - denom = tp_size * cp_size * pp_size - if world_size % denom != 0: - raise ValueError( - f"world_size ({world_size}) must be divisible by " - f"tp_size*cp_size*pp_size ({tp_size}*{cp_size}*{pp_size}={denom})" - ) - dp_size = world_size // denom - - manager_args: Dict[str, Any] = { - "_manager_type": "fsdp2", - "dp_size": dp_size, - "dp_replicate_size": fsdp_cfg.get("dp_replicate_size", None), - "tp_size": tp_size, - "cp_size": cp_size, - "pp_size": pp_size, - "backend": "nccl", - "world_size": world_size, - "use_hf_tp_plan": fsdp_cfg.get("use_hf_tp_plan", False), - "sequence_parallel": fsdp_cfg.get("sequence_parallel", False), - "tp_plan": fsdp_cfg.get("tp_plan", None), - "patch_is_packed_sequence": fsdp_cfg.get("patch_is_packed_sequence", False), - "activation_checkpointing": fsdp_cfg.get("activation_checkpointing", True), - "defer_fsdp_grad_sync": fsdp_cfg.get("defer_fsdp_grad_sync", True), - "enable_async_tensor_parallel": fsdp_cfg.get("enable_async_tensor_parallel", False), - "enable_compile": fsdp_cfg.get("enable_compile", False), - "enable_fsdp2_prefetch": fsdp_cfg.get("enable_fsdp2_prefetch", True), - "fsdp2_backward_prefetch_depth": fsdp_cfg.get("fsdp2_backward_prefetch_depth", 2), - "fsdp2_forward_prefetch_depth": fsdp_cfg.get("fsdp2_forward_prefetch_depth", 1), - "mp_policy": MixedPrecisionPolicy( - param_dtype=param_dtype, - reduce_dtype=torch.float32, - output_dtype=dtype, - ), - } + manager_args = _build_diffusion_parallel_manager_args( + fsdp_cfg=fsdp_cfg, + ddp_cfg=ddp_cfg, + world_size=world_size, + dtype=dtype, + lora_enabled=lora_enabled, + ) parallel_scheme = {"transformer": manager_args} diff --git a/nemo_automodel/recipes/llm/kd.py b/nemo_automodel/recipes/llm/kd.py index 215601c737..a783e06c60 100644 --- a/nemo_automodel/recipes/llm/kd.py +++ b/nemo_automodel/recipes/llm/kd.py @@ -53,6 +53,7 @@ from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer from nemo_automodel.components.config._arg_parser import parse_args_and_load_config +from nemo_automodel.components.distributed.config import DistributedSetup from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx from nemo_automodel.components.distributed.pipelining.config import PipelineConfig from nemo_automodel.components.distributed.utils import get_sync_ctx @@ -91,9 +92,7 @@ def _build_teacher_model( cfg_teacher, seed, has_packed_sequence, - device_mesh=None, - moe_mesh=None, - distributed_config=None, + distributed_setup: DistributedSetup | None = None, device=None, ): """Build and initialize the teacher model for knowledge distillation. @@ -105,9 +104,7 @@ def _build_teacher_model( cfg_teacher: Configuration for teacher model instantiation. seed: Random seed for reproducibility. has_packed_sequence: Whether using packed sequences. - device_mesh: Device mesh for distributed training. - moe_mesh: MOE mesh for expert parallelism. - distributed_config: Strategy-specific distributed config. + distributed_setup: Resolved distributed topology and policy object. device: Device to place the teacher model on. Returns: @@ -125,9 +122,7 @@ def _build_teacher_model( with ScopedRNG(seed=seed, ranked=True): kwargs: Dict[str, Any] = { "has_packed_sequence": has_packed_sequence, - "device_mesh": device_mesh, - "moe_mesh": moe_mesh, - "distributed_config": distributed_config, + "distributed_setup": distributed_setup, } teacher_model = cfg_teacher.instantiate(**kwargs) @@ -147,11 +142,9 @@ def _build_teacher_model_with_pp( cfg_teacher, seed: int, has_packed_sequence: bool, - device_mesh, - moe_mesh, - distributed_config, pipeline_config: PipelineConfig, - dist_setup, + distributed_setup: DistributedSetup, + activation_checkpointing: bool, ) -> Any: """Build teacher model with same parallelization as student (TP/EP/SP/PP). @@ -166,11 +159,9 @@ def _build_teacher_model_with_pp( cfg_teacher: Configuration for teacher model instantiation. seed: Random seed for reproducibility. has_packed_sequence: Whether using packed sequences. - device_mesh: Device mesh for distributed training. - moe_mesh: MOE mesh for expert parallelism. - distributed_config: Strategy-specific distributed config. pipeline_config: PipelineConfig from the student, used as a template. - dist_setup: Distributed setup object (provides moe_config, activation_checkpointing). + distributed_setup: Student distributed setup, used as a template. + activation_checkpointing: Whether to enable activation checkpointing. Returns: The frozen teacher AutoPipeline with a ``_teacher_logits_capture`` attribute. @@ -202,6 +193,13 @@ def _teacher_capture_loss_fn(logits, target, **kwargs): scale_grads_in_schedule=pipeline_config.scale_grads_in_schedule, loss_fn=_teacher_capture_loss_fn, ) + teacher_distributed_setup = DistributedSetup( + mesh_context=distributed_setup.mesh_context, + strategy_config=distributed_setup.strategy_config, + pipeline_config=teacher_pipeline_config, + moe_parallel_config=distributed_setup.moe_parallel_config, + activation_checkpointing=activation_checkpointing, + ) with ScopedRNG(seed=seed, ranked=True): teacher_model = build_model( @@ -212,13 +210,8 @@ def _teacher_capture_loss_fn(logits, target, **kwargs): cfg_fp8=None, cfg_compile=None, cfg_quantization=None, - device_mesh=device_mesh, - moe_mesh=moe_mesh, - distributed_config=distributed_config, - pipeline_config=teacher_pipeline_config, + distributed_setup=teacher_distributed_setup, cfg_qat=None, - cfg_moe=dist_setup.moe_config, - activation_checkpointing=dist_setup.activation_checkpointing, ) # Freeze all teacher parameters. @@ -276,11 +269,9 @@ def setup(self): # noqa: C901 – same complexity as parent cfg_teacher=self.cfg.get("teacher_model", None), seed=self.cfg.get("seed", 42), has_packed_sequence=self.cfg.get("packed_sequence.packed_sequence_size", 0) > 0, - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - distributed_config=self.distributed_config, pipeline_config=self.pipeline_config, - dist_setup=self.dist_setup, + distributed_setup=self.distributed_setup, + activation_checkpointing=self.activation_checkpointing, ) self.teacher_pp = self.teacher_model if self.pipeline_config.pp_microbatch_size != self.pipeline_config.pp_batch_size: @@ -294,9 +285,7 @@ def setup(self): # noqa: C901 – same complexity as parent cfg_teacher=self.cfg.get("teacher_model", None), seed=self.cfg.get("seed", 42), has_packed_sequence=self.cfg.get("packed_sequence.packed_sequence_size", 0) > 0, - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - distributed_config=self.distributed_config, + distributed_setup=self.distributed_setup, device=teacher_device, ) self.teacher_pp = None diff --git a/nemo_automodel/recipes/llm/train_ft.py b/nemo_automodel/recipes/llm/train_ft.py index 46dea2b553..31bfb68fc8 100644 --- a/nemo_automodel/recipes/llm/train_ft.py +++ b/nemo_automodel/recipes/llm/train_ft.py @@ -60,7 +60,7 @@ from nemo_automodel.components.datasets.llm.megatron.sampler import create_megatron_sampler from nemo_automodel.components.datasets.llm.megatron_dataset import MegatronPretraining from nemo_automodel.components.datasets.llm.packed_sequence import pack_dataset -from nemo_automodel.components.distributed.config import MegatronFSDPConfig +from nemo_automodel.components.distributed.config import DistributedSetup, MegatronFSDPConfig from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx from nemo_automodel.components.distributed.init_utils import ( initialize_distributed, @@ -106,7 +106,7 @@ filter_forward_kwargs, resolve_trust_remote_code, ) -from nemo_automodel.recipes._dist_setup import setup_distributed +from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config from nemo_automodel.recipes.base_recipe import BaseRecipe from nemo_automodel.shared.te_patches import apply_te_patches from nemo_automodel.shared.utils import dtype_from_str @@ -180,15 +180,11 @@ def build_model( cfg_fp8=None, cfg_compile=None, cfg_quantization=None, - device_mesh=None, - moe_mesh=None, - distributed_config=None, - pipeline_config=None, + distributed_setup: DistributedSetup | None = None, cfg_qat=None, - cfg_moe=None, - activation_checkpointing=False, unfreeze_modules: list[str] | None = None, sdpa_method: list[str] | None = None, + device_mesh=None, ) -> tuple[nn.Module | AutoPipeline, list["Optimizer"]]: # noqa: F821 """Build and initialize a model. @@ -200,28 +196,24 @@ def build_model( cfg_fp8: Configuration for FP8. cfg_compile: Configuration for torch.compile. cfg_quantization: Configuration for BitsAndBytes quantization. - device_mesh: Device mesh for distributed training. - moe_mesh: MOE mesh for expert parallelism. - distributed_config: Strategy-specific distributed config (FSDP2Config, etc.). - pipeline_config: Pipeline parallelism config. + distributed_setup: Resolved distributed topology and policy object. cfg_qat: Configuration for QAT (will be instantiated to QATConfig). - cfg_moe: MoEParallelizerConfig instance, or ConfigNode to be converted. - activation_checkpointing: Whether to enable activation checkpointing. unfreeze_modules: List of module names/substrings to unfreeze. sdpa_method: Explicit list of SDPA backend name strings (e.g. ``["flash_attention", "efficient_attention"]``), or ``None`` to auto-select based on CP / activation checkpointing. + device_mesh: Pre-created device mesh forwarded when ``distributed_setup`` is not provided. """ with ScopedRNG(seed=seed, ranked=True): kwargs = { "has_packed_sequence": has_packed_sequence, "peft_config": cfg_peft, - "device_mesh": device_mesh, - "moe_mesh": moe_mesh, - "distributed_config": distributed_config, - "pipeline_config": pipeline_config, "sdpa_method": sdpa_method, } + if distributed_setup is not None: + kwargs["distributed_setup"] = distributed_setup + elif device_mesh is not None: + kwargs["device_mesh"] = device_mesh if cfg_qat is not None and cfg_qat.get("enabled", False): if cfg_peft is not None: @@ -235,19 +227,6 @@ def build_model( if quantizer_attr is not None: kwargs["qat_config"] = quantizer_attr.instantiate() - if cfg_moe is not None: - from nemo_automodel.components.moe.config import MoEParallelizerConfig - - if isinstance(cfg_moe, MoEParallelizerConfig): - kwargs["moe_config"] = cfg_moe - else: - moe_dict = cfg_moe.to_dict() if hasattr(cfg_moe, "to_dict") else dict(cfg_moe) - # activation_checkpointing is handled separately; strip config keys - moe_dict.pop("activation_checkpointing", None) - moe_dict.pop("_target_", None) - kwargs["moe_config"] = MoEParallelizerConfig(**moe_dict) - kwargs["activation_checkpointing"] = activation_checkpointing - if cfg_fp8 is not None: kwargs["fp8_config"] = build_fp8_config(cfg_fp8) if cfg_compile is not None: @@ -280,13 +259,15 @@ def build_model( # exactly as from_pretrained/from_config do internally. model = cfg_model.instantiate() - mesh = MeshContext.from_meshes(device_mesh, moe_mesh) + setup = distributed_setup or DistributedSetup(mesh_context=MeshContext()) + mesh = setup.mesh_context + pipeline_config = setup.pipeline_config model_wrapper, autopipeline, parallelize_fn, qat_quantizer = instantiate_infrastructure( - distributed_config=distributed_config, + distributed_config=setup.strategy_config, pipeline_config=pipeline_config, qat_config=kwargs.get("qat_config"), - moe_config=kwargs.get("moe_config"), - activation_checkpointing=kwargs.get("activation_checkpointing", False), + moe_parallel_config=setup.moe_parallel_config, + activation_checkpointing=setup.activation_checkpointing, device=torch.device("cuda", torch.cuda.current_device()), mesh=mesh, ) @@ -899,12 +880,19 @@ def setup(self): # Enable NVTX patching only when explicitly requested in config self.enable_nvtx = bool(self.cfg.get("nvtx", False)) - self.dist_setup = setup_distributed(self.cfg, world_size=self.dist_env.world_size) - self.distributed_config = self.dist_setup.strategy_config - self.device_mesh = self.dist_setup.device_mesh - self.moe_mesh = self.dist_setup.moe_mesh - self.pp_enabled = self.dist_setup.pp_enabled - self.pipeline_config = self.dist_setup.pipeline_config + ( + self.distributed_setup, + self.mesh_context, + self.distributed_config, + self.device_mesh, + self.moe_mesh, + self.pp_enabled, + self.pipeline_config, + self.moe_parallel_config, + self.activation_checkpointing, + ) = self._distributed_setup_attributes( + create_distributed_setup_from_config(self.cfg, world_size=self.dist_env.world_size) + ) if self.dist_env.is_main and hasattr(self.cfg, "wandb"): suppress_wandb_log_messages() @@ -933,13 +921,13 @@ def setup(self): pp_batch_size = self.cfg.step_scheduler.local_batch_size pp_microbatch_size = self.cfg.get("distributed.pipeline.pp_microbatch_size", 1) - assert pp_batch_size // pp_microbatch_size >= self.dist_setup.pp_size, ( - f"pp_batch_size {pp_batch_size} // pp_microbatch_size {pp_microbatch_size} must be >= pp_size {self.dist_setup.pp_size}" + assert pp_batch_size // pp_microbatch_size >= self.mesh_context.pp_size, ( + f"pp_batch_size {pp_batch_size} // pp_microbatch_size {pp_microbatch_size} must be >= pp_size {self.mesh_context.pp_size}" ) # THD override logic if ( - self.dist_setup.cp_size > 1 + self.mesh_context.cp_size > 1 and _uses_te_dot_product_attention(self.cfg.model) and _uses_thd_collater(self.cfg.dataloader) ): @@ -998,8 +986,8 @@ def setup(self): ) # Disable fused RoPE when context parallelism is enabled (cp > 1) - if self.dist_setup.cp_size > 1 and self.cfg.get("model.backend.rope_fusion", False): - logging.info("Disabling rope_fusion because cp_size=%d > 1", self.dist_setup.cp_size) + if self.mesh_context.cp_size > 1 and self.cfg.get("model.backend.rope_fusion", False): + logging.info("Disabling rope_fusion because cp_size=%d > 1", self.mesh_context.cp_size) self.cfg.model.backend.rope_fusion = False model = build_model( @@ -1010,13 +998,8 @@ def setup(self): cfg_fp8=self.cfg.get("fp8", None), cfg_compile=self.cfg.get("compile", None), cfg_quantization=self.cfg.get("quantization", None), - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - distributed_config=self.distributed_config, - pipeline_config=self.pipeline_config, + distributed_setup=self.distributed_setup, cfg_qat=self.cfg.get("qat", None), - cfg_moe=self.dist_setup.moe_config, - activation_checkpointing=self.dist_setup.activation_checkpointing, sdpa_method=self.cfg.get("sdpa_method", None), ) self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh) @@ -1050,11 +1033,11 @@ def setup(self): self._configure_pipeline_loss_fn() _packed_seq_size = self.cfg.get("packed_sequence.packed_sequence_size", 0) - if self.dist_setup.cp_size > 1 and _packed_seq_size > 0: + if self.mesh_context.cp_size > 1 and _packed_seq_size > 0: _m = self.model_parts[0] if hasattr(_m, "supports") and not _m.supports_cp_with_sequence_packing: raise ValueError( - f"Context parallelism (cp_size={self.dist_setup.cp_size}) with packed sequences " + f"Context parallelism (cp_size={self.mesh_context.cp_size}) with packed sequences " f"is not supported for {type(_m).__name__}.\n" f"Either disable sequence packing:\n" f" packed_sequence:\n" diff --git a/nemo_automodel/recipes/llm/train_seq_cls.py b/nemo_automodel/recipes/llm/train_seq_cls.py index aafef3da5e..fca441840d 100644 --- a/nemo_automodel/recipes/llm/train_seq_cls.py +++ b/nemo_automodel/recipes/llm/train_seq_cls.py @@ -31,7 +31,7 @@ from nemo_automodel.components.training.utils import clip_grad_norm from nemo_automodel.components.utils.flops_utils import calculate_mfu from nemo_automodel.components.utils.model_utils import filter_forward_kwargs -from nemo_automodel.recipes._dist_setup import setup_distributed +from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config from nemo_automodel.recipes.base_recipe import BaseRecipe from nemo_automodel.recipes.llm.train_ft import ( _get_model_name, @@ -60,12 +60,19 @@ def setup(self): apply_cache_compatibility_patches() self.rng = StatefulRNG(seed=self.cfg.get("seed", 42), ranked=True) - self.dist_setup = setup_distributed(self.cfg, world_size=self.dist_env.world_size) - self.distributed_config = self.dist_setup.strategy_config - self.device_mesh = self.dist_setup.device_mesh - self.moe_mesh = self.dist_setup.moe_mesh - self.pp_enabled = self.dist_setup.pp_enabled - self.pipeline_config = self.dist_setup.pipeline_config + ( + self.distributed_setup, + self.mesh_context, + self.distributed_config, + self.device_mesh, + self.moe_mesh, + self.pp_enabled, + self.pipeline_config, + self.moe_parallel_config, + self.activation_checkpointing, + ) = self._distributed_setup_attributes( + create_distributed_setup_from_config(self.cfg, world_size=self.dist_env.world_size) + ) if self.dist_env.is_main and hasattr(self.cfg, "wandb"): suppress_wandb_log_messages() @@ -115,9 +122,7 @@ def setup(self): has_packed_sequence=use_hf_fa2, cfg_compile=self.cfg.get("compile", None), cfg_quantization=self.cfg.get("quantization", None), - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - distributed_config=self.distributed_config, + distributed_setup=self.distributed_setup, unfreeze_modules=["classifier"] if self.peft_config is not None else None, ) self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh) diff --git a/nemo_automodel/recipes/retrieval/train_bi_encoder.py b/nemo_automodel/recipes/retrieval/train_bi_encoder.py index 5d5181e6a5..00ef3d54a2 100644 --- a/nemo_automodel/recipes/retrieval/train_bi_encoder.py +++ b/nemo_automodel/recipes/retrieval/train_bi_encoder.py @@ -34,7 +34,7 @@ from nemo_automodel.components.loggers.wandb_utils import suppress_wandb_log_messages from nemo_automodel.components.training.rng import ScopedRNG, StatefulRNG from nemo_automodel.components.training.utils import scale_grads_and_clip_grad_norm -from nemo_automodel.recipes._dist_setup import setup_distributed +from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config from nemo_automodel.recipes.base_recipe import BaseRecipe from nemo_automodel.recipes.llm.train_ft import ( build_checkpoint_config, @@ -153,12 +153,19 @@ def setup(self): apply_te_patches() self.rng = StatefulRNG(seed=self.cfg.get("seed", 42), ranked=True) - self.dist_setup = setup_distributed(self.cfg, world_size=self.dist_env.world_size) - self.distributed_config = self.dist_setup.strategy_config - self.device_mesh = self.dist_setup.device_mesh - self.moe_mesh = self.dist_setup.moe_mesh - self.pp_enabled = self.dist_setup.pp_enabled - self.pipeline_config = self.dist_setup.pipeline_config + ( + self.distributed_setup, + self.mesh_context, + self.distributed_config, + self.device_mesh, + self.moe_mesh, + self.pp_enabled, + self.pipeline_config, + self.moe_parallel_config, + self.activation_checkpointing, + ) = self._distributed_setup_attributes( + create_distributed_setup_from_config(self.cfg, world_size=self.dist_env.world_size) + ) if self.pp_enabled: raise NotImplementedError("Encoder does not support pipeline parallelism") @@ -198,9 +205,7 @@ def setup(self): with ScopedRNG(seed=self.cfg.get("seed", 42), ranked=True): model = self.cfg.model.instantiate( - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - distributed_config=self.distributed_config, + distributed_setup=self.distributed_setup, peft_config=self.peft_config, ) diff --git a/nemo_automodel/recipes/vlm/finetune.py b/nemo_automodel/recipes/vlm/finetune.py index afab19148a..473babb081 100644 --- a/nemo_automodel/recipes/vlm/finetune.py +++ b/nemo_automodel/recipes/vlm/finetune.py @@ -55,7 +55,7 @@ from nemo_automodel.components.datasets.llm.formatting_utils import _resolve_chat_template from nemo_automodel.components.datasets.vlm.collate_fns import COLLATE_FNS from nemo_automodel.components.datasets.vlm.pp_media import stage_vlm_media_for_pp, wrap_vlm_collate_for_pp -from nemo_automodel.components.distributed.config import MegatronFSDPConfig +from nemo_automodel.components.distributed.config import DistributedSetup, MegatronFSDPConfig from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx from nemo_automodel.components.distributed.init_utils import initialize_distributed from nemo_automodel.components.distributed.pipelining import AutoPipeline @@ -84,7 +84,7 @@ ) from nemo_automodel.components.utils.compile_utils import build_compile_config from nemo_automodel.components.utils.model_utils import VLM_INPUT_KEYS, _supports_logits_to_keep, filter_forward_kwargs -from nemo_automodel.recipes._dist_setup import setup_distributed +from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config from nemo_automodel.recipes.base_recipe import BaseRecipe if TYPE_CHECKING: @@ -117,12 +117,7 @@ def build_model( seed, cfg_fp8=None, cfg_compile=None, - device_mesh=None, - moe_mesh=None, - distributed_config=None, - pipeline_config=None, - cfg_moe=None, - activation_checkpointing=False, + distributed_setup: DistributedSetup | None = None, ) -> tuple[nn.Module | AutoPipeline, list["Optimizer"]]: # noqa: F821 """Build and initialize a model for VLM. @@ -133,25 +128,10 @@ def build_model( # Build infrastructure kwargs kwargs = { "peft_config": cfg_peft, - "device_mesh": device_mesh, - "moe_mesh": moe_mesh, - "distributed_config": distributed_config, - "pipeline_config": pipeline_config, "freeze_config": cfg_freeze.to_dict() if cfg_freeze is not None else None, } - - if cfg_moe is not None: - from nemo_automodel.components.moe.config import MoEParallelizerConfig - - if isinstance(cfg_moe, MoEParallelizerConfig): - kwargs["moe_config"] = cfg_moe - else: - moe_dict = cfg_moe.to_dict() if hasattr(cfg_moe, "to_dict") else dict(cfg_moe) - # activation_checkpointing is handled separately; strip config keys - moe_dict.pop("activation_checkpointing", None) - moe_dict.pop("_target_", None) - kwargs["moe_config"] = MoEParallelizerConfig(**moe_dict) - kwargs["activation_checkpointing"] = activation_checkpointing + if distributed_setup is not None: + kwargs["distributed_setup"] = distributed_setup if cfg_fp8 is not None: fp8_config = build_fp8_config(cfg_fp8) @@ -680,12 +660,19 @@ def setup(self): # Set up the stateful random number generator self.rng = StatefulRNG(seed=self.cfg.get("seed", 42), ranked=True) - self.dist_setup = setup_distributed(self.cfg, world_size=self.dist_env.world_size) - self.distributed_config = self.dist_setup.strategy_config - self.device_mesh = self.dist_setup.device_mesh - self.moe_mesh = self.dist_setup.moe_mesh - self.pp_enabled = self.dist_setup.pp_enabled - self.pipeline_config = self.dist_setup.pipeline_config + ( + self.distributed_setup, + self.mesh_context, + self.distributed_config, + self.device_mesh, + self.moe_mesh, + self.pp_enabled, + self.pipeline_config, + self.moe_parallel_config, + self.activation_checkpointing, + ) = self._distributed_setup_attributes( + create_distributed_setup_from_config(self.cfg, world_size=self.dist_env.world_size) + ) if self.dist_env.is_main and hasattr(self.cfg, "wandb"): suppress_wandb_log_messages() @@ -708,8 +695,8 @@ def setup(self): pp_batch_size = self.cfg.step_scheduler.local_batch_size pp_microbatch_size = self.cfg.get("distributed.pipeline.pp_microbatch_size", 1) - assert pp_batch_size // pp_microbatch_size >= self.dist_setup.pp_size, ( - f"pp_batch_size {pp_batch_size} // pp_microbatch_size {pp_microbatch_size} must be >= pp_size {self.dist_setup.pp_size}" + assert pp_batch_size // pp_microbatch_size >= self.mesh_context.pp_size, ( + f"pp_batch_size {pp_batch_size} // pp_microbatch_size {pp_microbatch_size} must be >= pp_size {self.mesh_context.pp_size}" ) assert not isinstance(self.distributed_config, MegatronFSDPConfig), ( @@ -753,8 +740,8 @@ def setup(self): ) # Disable fused RoPE when context parallelism is enabled (cp > 1) - if self.dist_setup.cp_size > 1 and self.cfg.get("model.backend.rope_fusion", False): - logging.info("Disabling rope_fusion because cp_size=%d > 1", self.dist_setup.cp_size) + if self.mesh_context.cp_size > 1 and self.cfg.get("model.backend.rope_fusion", False): + logging.info("Disabling rope_fusion because cp_size=%d > 1", self.mesh_context.cp_size) self.cfg.model.backend.rope_fusion = False model = build_model( @@ -764,12 +751,7 @@ def setup(self): seed=self.cfg.get("seed", 42), cfg_fp8=self.cfg.get("fp8", None), cfg_compile=self.cfg.get("compile", None), - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - distributed_config=self.distributed_config, - pipeline_config=self.pipeline_config, - cfg_moe=self.dist_setup.moe_config, - activation_checkpointing=self.dist_setup.activation_checkpointing, + distributed_setup=self.distributed_setup, ) self.optimizer = build_optimizer(model, self.cfg.optimizer, self.distributed_config, self.device_mesh) @@ -791,7 +773,7 @@ def setup(self): pp_n_microbatches = None pp_cp_preembed = ( self.pp_enabled - and self.dist_setup.cp_size > 1 + and self.mesh_context.cp_size > 1 and hasattr(self.model_parts[0], "prepare_model_inputs_for_cp") ) if self.pp_enabled and not pp_cp_preembed: diff --git a/nemo_automodel/recipes/vlm/kd.py b/nemo_automodel/recipes/vlm/kd.py index e483facad6..439ffdc86a 100644 --- a/nemo_automodel/recipes/vlm/kd.py +++ b/nemo_automodel/recipes/vlm/kd.py @@ -51,6 +51,7 @@ from nemo_automodel._transformers.auto_tokenizer import NeMoAutoTokenizer from nemo_automodel.components.config._arg_parser import parse_args_and_load_config +from nemo_automodel.components.distributed.config import DistributedSetup from nemo_automodel.components.distributed.cp_utils import make_cp_batch_and_ctx from nemo_automodel.components.distributed.utils import get_sync_ctx from nemo_automodel.components.loggers.metric_logger import MetricsSample @@ -82,9 +83,7 @@ def _build_teacher_model( cfg_teacher, cfg_freeze, seed: int, - device_mesh=None, - moe_mesh=None, - distributed_config=None, + distributed_setup: DistributedSetup | None = None, device=None, ) -> torch.nn.Module: """Build and initialize the teacher VLM for knowledge distillation. @@ -96,9 +95,7 @@ def _build_teacher_model( cfg_teacher: Configuration for teacher model instantiation. cfg_freeze: Freeze configuration for the teacher model. seed: Random seed for reproducibility. - device_mesh: Device mesh for distributed training. - moe_mesh: MOE mesh for expert parallelism. - distributed_config: Strategy-specific distributed config. + distributed_setup: Resolved distributed topology and policy object. device: Device to place the teacher model on. Returns: @@ -114,12 +111,7 @@ def _build_teacher_model( seed=seed, cfg_fp8=None, cfg_compile=None, - device_mesh=device_mesh, - moe_mesh=moe_mesh, - distributed_config=distributed_config, - pipeline_config=None, - cfg_moe=None, - activation_checkpointing=False, + distributed_setup=distributed_setup, ) if device is not None: @@ -207,9 +199,7 @@ def setup(self): cfg_teacher=self.cfg.get("teacher_model", None), cfg_freeze=self.cfg.get("teacher_freeze_config", None), seed=self.cfg.get("seed", 42), - device_mesh=self.device_mesh, - moe_mesh=self.moe_mesh, - distributed_config=self.distributed_config, + distributed_setup=getattr(self, "distributed_setup", None), device=teacher_device, ) diff --git a/skills/distributed-training/SKILL.md b/skills/distributed-training/SKILL.md index 23b529a0fc..da9bc7013b 100644 --- a/skills/distributed-training/SKILL.md +++ b/skills/distributed-training/SKILL.md @@ -32,7 +32,7 @@ Decision tree: ## YAML Config Structure The `distributed` section in the recipe YAML maps directly to -`parse_distributed_section()` in `recipes/_dist_setup.py`: +`parse_distributed_section()` in `recipes/_dist_utils.py`: ```yaml distributed: @@ -69,11 +69,12 @@ dp_size = world_size / (tp_size * pp_size * cp_size) ## Infrastructure Flow ``` -YAML distributed section - -> parse_distributed_section() [recipes/_dist_setup.py] - -> setup_distributed() [recipes/_dist_setup.py] - -> create_device_mesh() [components/distributed/device_mesh.py] - -> MeshContext(...) [components/distributed/mesh.py] +initialize_distributed() [components/distributed/init_utils.py] + -> initializes torch.distributed process group and returns DistInfo +YAML distributed section + DistInfo.world_size + -> parse_distributed_section() [recipes/_dist_utils.py] + -> create_distributed_setup_from_config() [recipes/_dist_utils.py] + -> DistributedSetup.build() [components/distributed/config.py] -> instantiate_infrastructure() [_transformers/infrastructure.py] -> _instantiate_distributed() -> FSDP2Manager / MegatronFSDPManager / DDPManager -> _instantiate_pipeline() -> AutoPipeline (if pp_size > 1) @@ -152,8 +153,9 @@ distributed: activation_checkpointing: true ``` -This is forwarded to the strategy config for non-EP models, or read from -`MeshContext.activation_checkpointing` for EP models. +This is a model-build/training behavior flag, not mesh topology. Dense +strategies read it from the strategy config; EP/MoE paths pass the recipe-level +flag directly into model infrastructure. ### Gradient Sync Deferral @@ -372,25 +374,17 @@ scaling dimension: When not using YAML recipes, configure distributed training via Python: ```python -from nemo_automodel.components.distributed.config import FSDP2Config -from nemo_automodel.components.distributed.device_mesh import create_device_mesh -from nemo_automodel.components.distributed.mesh import MeshContext +from nemo_automodel.components.distributed import FSDP2Config, create_mesh_context, initialize_distributed from nemo_automodel._transformers.infrastructure import instantiate_infrastructure -# 1. Create strategy config +dist_env = initialize_distributed("nccl") config = FSDP2Config(sequence_parallel=True, activation_checkpointing=True) -# 2. Create device mesh -device_mesh, moe_mesh = create_device_mesh( - config, tp_size=2, pp_size=1, cp_size=1, ep_size=1, world_size=8, +mesh = create_mesh_context( + config, tp_size=2, pp_size=1, cp_size=1, ep_size=1, world_size=dist_env.world_size, ) -# 3. Build MeshContext -mesh = MeshContext.from_meshes( - device_mesh, moe_mesh, strategy_config=config, activation_checkpointing=True, -) - -# 4. Instantiate infrastructure +# 3. Instantiate infrastructure model_wrapper, autopipeline, parallelize_fn, qat_quantizer = instantiate_infrastructure( distributed_config=config, mesh=mesh, ) @@ -428,13 +422,13 @@ components/distributed/mesh.py MeshAxisName -- PP, DP, DP_REPLICATE, DP_SHARD, DP_SHARD_CP, DP_CP, CP, TP, EP, EP_SHARD ``` -Device mesh creation: +Mesh context and raw mesh creation: ``` components/distributed/device_mesh.py - create_device_mesh() -- routes to FSDP2/MegatronFSDP/DDP mesh creation + create_mesh_context() -- builds MeshContext from strategy + parallelism + _create_device_meshes() -- routes to FSDP2/MegatronFSDP/DDP raw mesh creation _create_fsdp2_device_mesh() -- shape (pp, dp_replicate, dp_shard, cp, tp) + flattened submeshes - _create_moe_mesh() -- shape (pp, ep_shard, ep) ``` Distributed managers: @@ -477,9 +471,9 @@ _transformers/infrastructure.py YAML parsing: ``` -recipes/_dist_setup.py +recipes/_dist_utils.py parse_distributed_section() -- YAML dict -> typed configs + sizes - setup_distributed() -- full entry-point: parse + create meshes + MeshContext + create_distributed_setup_from_config() -- recipe adapter: parse + create DistributedSetup; does not init process group ``` MoE config: diff --git a/tests/functional_tests/hf_peft/deepseek_moe_lora_small_for_test.yaml b/tests/functional_tests/hf_peft/deepseek_moe_lora_small_for_test.yaml index 9ca57b7225..a7131f32a1 100644 --- a/tests/functional_tests/hf_peft/deepseek_moe_lora_small_for_test.yaml +++ b/tests/functional_tests/hf_peft/deepseek_moe_lora_small_for_test.yaml @@ -50,8 +50,7 @@ model: qk_rope_head_dim: 64 v_head_dim: 64 qk_nope_head_dim: 64 - moe_config: - _target_: nemo_automodel.components.moe.layers.MoEConfig + moe_overrides: n_routed_experts: 8 n_shared_experts: 1 n_activated_experts: 2 diff --git a/tests/functional_tests/training/test_megatron_data_sharding.py b/tests/functional_tests/training/test_megatron_data_sharding.py index 2fa69942ce..af7cf1d0af 100644 --- a/tests/functional_tests/training/test_megatron_data_sharding.py +++ b/tests/functional_tests/training/test_megatron_data_sharding.py @@ -18,24 +18,26 @@ import torch.distributed as dist from nemo_automodel.components.config._arg_parser import parse_args_and_load_config -from nemo_automodel.recipes._dist_setup import setup_distributed +from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config from nemo_automodel.recipes.llm.train_ft import build_dataloader, build_distributed """ This test is to make sure that JSONL dataset can be checkpointed and loaded correctly. """ + def gather_helper(input_tensor): tensor_list = [torch.zeros_like(input_tensor) for _ in range(2)] dist.all_gather(tensor_list, input_tensor) return tensor_list + def test_megatron_data_sharding(): cfg_path = Path(__file__).parents[4] / "examples" / "llm_pretrain" / "megatron_pretrain_gpt2.yaml" cfg = parse_args_and_load_config(cfg_path) dist_env = build_distributed(cfg.get("dist_env", {})) - dist_setup = setup_distributed(cfg, world_size=dist_env.world_size) - device_mesh = dist_setup.device_mesh + mesh_context = create_distributed_setup_from_config(cfg, world_size=dist_env.world_size).mesh_context + device_mesh = mesh_context.device_mesh dp_rank = device_mesh["dp"].get_local_rank() dp_world_size = device_mesh["dp"].size() tp_world_size = device_mesh["tp"].size() @@ -64,7 +66,9 @@ def test_megatron_data_sharding(): batch_to_test = {k: v.to(dist.get_rank()) for k, v in batch_to_test.items()} # ensure that labels are inputs left shifted by 1 - assert torch.all(batch_to_test["labels"][:, :-1] == batch_to_test["input_ids"][:, 1:]), "Labels are not inputs left shifted by 1" + assert torch.all(batch_to_test["labels"][:, :-1] == batch_to_test["input_ids"][:, 1:]), ( + "Labels are not inputs left shifted by 1" + ) dist.barrier() del dataset diff --git a/tests/functional_tests/training/test_megatron_dataset_checkpoint.py b/tests/functional_tests/training/test_megatron_dataset_checkpoint.py index 27a6a70b9a..a0fc445226 100644 --- a/tests/functional_tests/training/test_megatron_dataset_checkpoint.py +++ b/tests/functional_tests/training/test_megatron_dataset_checkpoint.py @@ -20,19 +20,20 @@ from nemo_automodel.components.checkpoint.checkpointing import Checkpointer, CheckpointingConfig from nemo_automodel.components.config._arg_parser import parse_args_and_load_config -from nemo_automodel.recipes._dist_setup import setup_distributed +from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config from nemo_automodel.recipes.llm.train_ft import build_dataloader, build_distributed """ This test is to make sure that JSONL dataset can be checkpointed and loaded correctly. """ + def test_megatron_dataset_checkpointing(): cfg_path = Path(__file__).parents[4] / "examples" / "llm_pretrain" / "megatron_pretrain_gpt2.yaml" cfg = parse_args_and_load_config(cfg_path) dist_env = build_distributed(cfg.get("dist_env", {})) - dist_setup = setup_distributed(cfg, world_size=dist_env.world_size) - device_mesh = dist_setup.device_mesh + mesh_context = create_distributed_setup_from_config(cfg, world_size=dist_env.world_size).mesh_context + device_mesh = mesh_context.device_mesh dp_rank = device_mesh["dp"].get_local_rank() dp_world_size = device_mesh["dp"].size() tp_rank = device_mesh["tp"].get_local_rank() @@ -113,14 +114,18 @@ def test_megatron_dataset_checkpointing(): initial_batch = next(iter(dataset)) for k in ["input_ids", "labels"]: - assert torch.any(initial_batch[k] != expected_batch[k]), f"Initial batch key {k, initial_batch[k]} should not be equal to expected batch key {k, expected_batch[k]}" + assert torch.any(initial_batch[k] != expected_batch[k]), ( + f"Initial batch key {k, initial_batch[k]} should not be equal to expected batch key {k, expected_batch[k]}" + ) # load checkpoint checkpointer.load_on_dp_ranks(dataset, "dataloader", cfg.checkpoint.checkpoint_dir) for i, batch in enumerate(dataset): for k in batch.keys(): - assert torch.all(batch[k] == expected_batch[k]), f"Batch key {k, batch[k]} is not equal to expected batch key {k, expected_batch[k]}" + assert torch.all(batch[k] == expected_batch[k]), ( + f"Batch key {k, batch[k]} is not equal to expected batch key {k, expected_batch[k]}" + ) break torch.distributed.barrier(device_mesh["dp"].get_group()) diff --git a/tests/unit_tests/_diffusers/test_auto_diffusion_pipeline.py b/tests/unit_tests/_diffusers/test_auto_diffusion_pipeline.py index 565ea419b1..6deb006004 100644 --- a/tests/unit_tests/_diffusers/test_auto_diffusion_pipeline.py +++ b/tests/unit_tests/_diffusers/test_auto_diffusion_pipeline.py @@ -13,11 +13,15 @@ # limitations under the License. import logging +from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest import torch +from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config +from nemo_automodel.components.distributed.mesh import ParallelismSizes + # Check if diffusers can be imported properly (may fail due to peft/transformers incompatibility) try: DIFFUSERS_AVAILABLE = True @@ -223,61 +227,82 @@ def test_pipeline_spec_validate_for_from_config_passes_with_cls(): def test_create_parallel_manager_fsdp2_default(): from nemo_automodel._diffusers.auto_diffusion_pipeline import _create_parallel_manager + mock_config = Mock() mock_mesh = Mock() mock_moe_mesh = Mock() + mock_setup = SimpleNamespace( + strategy_config=mock_config, + mesh_context=SimpleNamespace(device_mesh=mock_mesh, moe_mesh=mock_moe_mesh), + ) with ( patch(f"{MODULE_PATH}.FSDP2Manager") as MockFSDP2, - patch(f"{MODULE_PATH}.FSDP2Config") as MockConfig, - patch(f"{MODULE_PATH}.create_device_mesh", return_value=(mock_mesh, mock_moe_mesh)), + patch(f"{MODULE_PATH}.DistributedSetup.build", return_value=mock_setup) as MockBuildSetup, ): MockFSDP2.return_value = Mock() manager = _create_parallel_manager({"world_size": 1}) - MockConfig.assert_called_once() - MockFSDP2.assert_called_once_with(MockConfig.return_value, device_mesh=mock_mesh, moe_mesh=mock_moe_mesh) + MockBuildSetup.assert_called_once() + assert MockBuildSetup.call_args.kwargs["world_size"] == 1 + MockFSDP2.assert_called_once_with(mock_config, device_mesh=mock_mesh, moe_mesh=mock_moe_mesh) assert manager is MockFSDP2.return_value def test_create_parallel_manager_ddp(): from nemo_automodel._diffusers.auto_diffusion_pipeline import _create_parallel_manager + mock_config = Mock() + mock_setup = SimpleNamespace(strategy_config=mock_config) with ( patch(f"{MODULE_PATH}.DDPManager") as MockDDP, - patch(f"{MODULE_PATH}.DDPConfig") as MockConfig, + patch(f"{MODULE_PATH}.DistributedSetup.build", return_value=mock_setup) as MockBuildSetup, ): MockDDP.return_value = Mock() manager = _create_parallel_manager({"_manager_type": "ddp", "some_arg": "value"}) - MockConfig.assert_called_once_with(activation_checkpointing=False, backend="nccl") - MockDDP.assert_called_once_with(MockConfig.return_value) + MockBuildSetup.assert_called_once() + build_kwargs = MockBuildSetup.call_args.kwargs + assert isinstance(build_kwargs["strategy"], DDPConfig) + assert build_kwargs["parallelism_sizes"] == ParallelismSizes() + assert build_kwargs["world_size"] is None + assert build_kwargs["activation_checkpointing"] is False + assert not hasattr(build_kwargs["strategy"], "backend") + MockDDP.assert_called_once_with(mock_config) assert manager is MockDDP.return_value def test_create_parallel_manager_explicit_fsdp2(): from nemo_automodel._diffusers.auto_diffusion_pipeline import _create_parallel_manager + mock_config = Mock() mock_mesh = Mock() mock_moe_mesh = Mock() + mock_setup = SimpleNamespace( + strategy_config=mock_config, + mesh_context=SimpleNamespace(device_mesh=mock_mesh, moe_mesh=mock_moe_mesh), + ) with ( patch(f"{MODULE_PATH}.FSDP2Manager") as MockFSDP2, - patch(f"{MODULE_PATH}.FSDP2Config") as MockConfig, - patch(f"{MODULE_PATH}.create_device_mesh", return_value=(mock_mesh, mock_moe_mesh)), + patch(f"{MODULE_PATH}.DistributedSetup.build", return_value=mock_setup), ): MockFSDP2.return_value = Mock() _create_parallel_manager({"_manager_type": "fsdp2", "world_size": 1}) - MockFSDP2.assert_called_once_with(MockConfig.return_value, device_mesh=mock_mesh, moe_mesh=mock_moe_mesh) + MockFSDP2.assert_called_once_with(mock_config, device_mesh=mock_mesh, moe_mesh=mock_moe_mesh) def test_create_parallel_manager_fsdp2_passes_perf_options(): from nemo_automodel._diffusers.auto_diffusion_pipeline import _create_parallel_manager + mock_config = Mock() mock_mesh = Mock() mock_moe_mesh = Mock() + mock_setup = SimpleNamespace( + strategy_config=mock_config, + mesh_context=SimpleNamespace(device_mesh=mock_mesh, moe_mesh=mock_moe_mesh), + ) with ( patch(f"{MODULE_PATH}.FSDP2Manager") as MockFSDP2, - patch(f"{MODULE_PATH}.FSDP2Config") as MockConfig, - patch(f"{MODULE_PATH}.create_device_mesh", return_value=(mock_mesh, mock_moe_mesh)), + patch(f"{MODULE_PATH}.DistributedSetup.build", return_value=mock_setup) as MockBuildSetup, ): MockFSDP2.return_value = Mock() _create_parallel_manager( @@ -296,16 +321,17 @@ def test_create_parallel_manager_fsdp2_passes_perf_options(): } ) - config_kwargs = MockConfig.call_args.kwargs - assert config_kwargs["sequence_parallel"] is True - assert config_kwargs["tp_plan"] == {"layer": "colwise"} - assert config_kwargs["patch_is_packed_sequence"] is True - assert config_kwargs["defer_fsdp_grad_sync"] is False - assert config_kwargs["enable_async_tensor_parallel"] is True - assert config_kwargs["enable_compile"] is True - assert config_kwargs["enable_fsdp2_prefetch"] is True - assert config_kwargs["fsdp2_backward_prefetch_depth"] == 4 - assert config_kwargs["fsdp2_forward_prefetch_depth"] == 3 + strategy_config = MockBuildSetup.call_args.kwargs["strategy"] + assert isinstance(strategy_config, FSDP2Config) + assert strategy_config.sequence_parallel is True + assert strategy_config.tp_plan == {"layer": "colwise"} + assert strategy_config.patch_is_packed_sequence is True + assert strategy_config.defer_fsdp_grad_sync is False + assert strategy_config.enable_async_tensor_parallel is True + assert strategy_config.enable_compile is True + assert strategy_config.enable_fsdp2_prefetch is True + assert strategy_config.fsdp2_backward_prefetch_depth == 4 + assert strategy_config.fsdp2_forward_prefetch_depth == 3 def test_create_parallel_manager_unknown_type_raises(): @@ -315,6 +341,13 @@ def test_create_parallel_manager_unknown_type_raises(): _create_parallel_manager({"_manager_type": "unknown"}) +def test_create_parallel_manager_rejects_backend_option(): + from nemo_automodel._diffusers.auto_diffusion_pipeline import _create_parallel_manager + + with pytest.raises(ValueError, match="backend is not a parallel manager option"): + _create_parallel_manager({"_manager_type": "ddp", "backend": "gloo"}) + + def test_create_parallel_manager_does_not_mutate_input(): from nemo_automodel._diffusers.auto_diffusion_pipeline import _create_parallel_manager @@ -664,7 +697,7 @@ def import_class(name): def test_import_diffusers_class_success(): from nemo_automodel._diffusers.auto_diffusion_pipeline import _import_diffusers_class - with patch("diffusers.SomeClass", create=True, new="sentinel"): + with patch.dict("sys.modules", {"diffusers": SimpleNamespace(SomeClass="sentinel")}): result = _import_diffusers_class("SomeClass") assert result == "sentinel" @@ -672,7 +705,10 @@ def test_import_diffusers_class_success(): def test_import_diffusers_class_missing_raises(): from nemo_automodel._diffusers.auto_diffusion_pipeline import _import_diffusers_class - with pytest.raises(ImportError, match="not found in diffusers"): + with ( + patch.dict("sys.modules", {"diffusers": SimpleNamespace()}), + pytest.raises(ImportError, match="not found in diffusers"), + ): _import_diffusers_class("NonExistentClassName12345") diff --git a/tests/unit_tests/_transformers/test_auto_model.py b/tests/unit_tests/_transformers/test_auto_model.py index 4be75101a4..75b41e5efb 100644 --- a/tests/unit_tests/_transformers/test_auto_model.py +++ b/tests/unit_tests/_transformers/test_auto_model.py @@ -30,6 +30,7 @@ _init_model, _patch_attention, _patch_remote_code_compat, + _resolve_distributed_setup, ) from nemo_automodel._transformers.infrastructure import _apply_peft_and_lower_precision from nemo_automodel._transformers.model_init import ( @@ -41,9 +42,170 @@ no_hf_meta_device, ) from nemo_automodel.components.checkpoint.utils import _get_checkpoint_tensor_dtypes +from nemo_automodel.components.distributed.config import DistributedSetup, FSDP2Config, MoEParallelizerConfig +from nemo_automodel.components.distributed.mesh import MeshAxisName, MeshContext from nemo_automodel.components.models.common.hf_checkpointing_mixin import HFCheckpointingMixin +class _FakeMesh: + def __init__(self, sizes): + self._sizes = sizes + self.mesh_dim_names = tuple(sizes) + + def __getitem__(self, axis): + return types.SimpleNamespace(size=lambda: self._sizes[axis]) + + +class TestResolveMeshContext: + def test_device_mesh_input_builds_topology_only_setup(self): + device_mesh = _FakeMesh({MeshAxisName.DP_SHARD: 1, MeshAxisName.CP: 1, MeshAxisName.TP: 1}) + + setup = _resolve_distributed_setup( + distributed_setup=None, + device_mesh=device_mesh, + ) + + assert setup.mesh_context.device_mesh is device_mesh + assert setup.mesh_context.moe_mesh is None + assert setup.strategy_config is None + assert setup.pipeline_config is None + assert setup.moe_parallel_config is None + assert setup.activation_checkpointing is False + + def test_device_mesh_rejects_mesh_context(self): + mesh_context = MeshContext() + + with pytest.raises(TypeError, match="DeviceMesh"): + _resolve_distributed_setup( + distributed_setup=None, + device_mesh=mesh_context, + ) + + def test_distributed_setup_and_device_mesh_are_mutually_exclusive(self): + device_mesh = _FakeMesh({MeshAxisName.DP_SHARD: 1, MeshAxisName.CP: 1, MeshAxisName.TP: 1}) + distributed_setup = DistributedSetup(mesh_context=MeshContext()) + + with pytest.raises(ValueError, match="either distributed_setup or device_mesh"): + _resolve_distributed_setup( + distributed_setup=distributed_setup, + device_mesh=device_mesh, + ) + + def test_distributed_setup_input_supplies_configs(self): + device_mesh = _FakeMesh({MeshAxisName.DP_SHARD: 1, MeshAxisName.CP: 1, MeshAxisName.TP: 1}) + moe_mesh = _FakeMesh({MeshAxisName.EP_SHARD: 1, MeshAxisName.EP: 8}) + distributed_config = FSDP2Config(activation_checkpointing=True) + moe_config = MoEParallelizerConfig() + source_setup = DistributedSetup( + mesh_context=MeshContext.from_meshes(device_mesh, moe_mesh), + strategy_config=distributed_config, + moe_parallel_config=moe_config, + activation_checkpointing=True, + ) + + setup = _resolve_distributed_setup( + distributed_setup=source_setup, + ) + + assert setup.mesh_context.device_mesh is device_mesh + assert setup.mesh_context.moe_mesh is moe_mesh + assert setup.strategy_config is distributed_config + assert setup.moe_parallel_config is moe_config + assert setup.activation_checkpointing is True + + def test_missing_distributed_setup_returns_empty_setup(self): + setup = _resolve_distributed_setup(distributed_setup=None) + + assert isinstance(setup.mesh_context, MeshContext) + assert setup.strategy_config is None + assert setup.activation_checkpointing is False + + +class TestFromPretrainedDeviceMesh: + def test_from_pretrained_accepts_device_mesh_as_topology_shortcut(self): + device_mesh = _FakeMesh({MeshAxisName.DP_SHARD: 1, MeshAxisName.CP: 1, MeshAxisName.TP: 1}) + sentinel_model = object() + + with ( + patch("torch.cuda.current_device", return_value=0), + patch("nemo_automodel._transformers.auto_model.instantiate_infrastructure") as mock_infra, + patch("nemo_automodel._transformers.auto_model.get_hf_config", return_value=MagicMock()), + patch("nemo_automodel._transformers.auto_model.get_is_hf_model", return_value=True), + patch("nemo_automodel._transformers.auto_model.resolve_sdpa_method", return_value=None) as mock_sdpa, + patch.object(NeMoAutoModelForCausalLM, "_build_model", return_value=sentinel_model) as mock_build, + ): + mock_infra.return_value = (None, None, None, None) + + result = NeMoAutoModelForCausalLM.from_pretrained("test-model", device_mesh=device_mesh) + + assert result is sentinel_model + assert mock_infra.call_args.kwargs["distributed_config"] is None + assert mock_infra.call_args.kwargs["moe_parallel_config"] is None + assert mock_infra.call_args.kwargs["activation_checkpointing"] is False + assert mock_infra.call_args.kwargs["mesh"].device_mesh is device_mesh + assert mock_infra.call_args.kwargs["mesh"].moe_mesh is None + mock_sdpa.assert_called_once_with(None, device_mesh, False) + assert mock_build.call_args.kwargs["mesh"].device_mesh is device_mesh + + def test_from_pretrained_accepts_distributed_setup(self): + device_mesh = _FakeMesh({MeshAxisName.DP_SHARD: 1, MeshAxisName.CP: 1, MeshAxisName.TP: 1}) + moe_mesh = _FakeMesh({MeshAxisName.EP_SHARD: 1, MeshAxisName.EP: 8}) + distributed_config = FSDP2Config(activation_checkpointing=True) + moe_config = MoEParallelizerConfig() + distributed_setup = DistributedSetup( + mesh_context=MeshContext.from_meshes(device_mesh, moe_mesh), + strategy_config=distributed_config, + moe_parallel_config=moe_config, + activation_checkpointing=True, + ) + sentinel_model = object() + + with ( + patch("torch.cuda.current_device", return_value=0), + patch("nemo_automodel._transformers.auto_model.instantiate_infrastructure") as mock_infra, + patch("nemo_automodel._transformers.auto_model.get_hf_config", return_value=MagicMock()), + patch("nemo_automodel._transformers.auto_model.get_is_hf_model", return_value=True), + patch("nemo_automodel._transformers.auto_model.resolve_sdpa_method", return_value=None) as mock_sdpa, + patch.object(NeMoAutoModelForCausalLM, "_build_model", return_value=sentinel_model) as mock_build, + ): + mock_infra.return_value = (None, None, None, None) + + result = NeMoAutoModelForCausalLM.from_pretrained("test-model", distributed_setup=distributed_setup) + + assert result is sentinel_model + assert mock_infra.call_args.kwargs["distributed_config"] is distributed_config + assert mock_infra.call_args.kwargs["moe_parallel_config"] is moe_config + assert mock_infra.call_args.kwargs["activation_checkpointing"] is True + assert mock_infra.call_args.kwargs["mesh"].device_mesh is device_mesh + mock_sdpa.assert_called_once_with(None, device_mesh, True) + assert mock_build.call_args.kwargs["mesh"].moe_mesh is moe_mesh + + def test_from_pretrained_rejects_distributed_setup_with_device_mesh(self): + device_mesh = _FakeMesh({MeshAxisName.DP_SHARD: 1, MeshAxisName.CP: 1, MeshAxisName.TP: 1}) + distributed_setup = DistributedSetup(mesh_context=MeshContext()) + + with pytest.raises(ValueError, match="either distributed_setup or device_mesh"): + NeMoAutoModelForCausalLM.from_pretrained( + "test-model", + distributed_setup=distributed_setup, + device_mesh=device_mesh, + ) + + def test_from_pretrained_rejects_separate_distributed_kwargs(self): + with pytest.raises(TypeError, match="distributed_setup"): + NeMoAutoModelForCausalLM.from_pretrained( + "test-model", + distributed_config=FSDP2Config(), + ) + + def test_from_pretrained_rejects_moe_mesh_kwarg(self): + with pytest.raises(TypeError, match="distributed_setup"): + NeMoAutoModelForCausalLM.from_pretrained( + "test-model", + moe_mesh=object(), + ) + + class TestPatchAttention: """Test cases for _patch_attention function.""" diff --git a/tests/unit_tests/distributed/test_device_mesh.py b/tests/unit_tests/distributed/test_device_mesh.py new file mode 100644 index 0000000000..d99e21899c --- /dev/null +++ b/tests/unit_tests/distributed/test_device_mesh.py @@ -0,0 +1,237 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the high-level mesh context builder.""" + +import pytest + +from nemo_automodel.components.distributed.config import ( + DDPConfig, + DistributedSetup, + FSDP2Config, + MegatronFSDPConfig, + MoEParallelizerConfig, +) +from nemo_automodel.components.distributed.mesh import MeshAxisName, MeshContext, ParallelismSizes +from nemo_automodel.components.distributed.pipelining.config import PipelineConfig + + +class _FakeAxis: + def __init__(self, size: int) -> None: + self._size = size + + def size(self) -> int: + return self._size + + +class _FakeMesh: + def __init__(self, sizes: dict[MeshAxisName, int]) -> None: + self.mesh_dim_names = tuple(sizes) + self._sizes = sizes + + def __getitem__(self, axis: MeshAxisName) -> _FakeAxis: + return _FakeAxis(self._sizes[axis]) + + +@pytest.fixture +def captured_raw_mesh_call(monkeypatch): + captured: dict = {} + + def fake_create_device_meshes(strategy_config, parallelism, **kwargs): + captured["strategy_config"] = strategy_config + captured["parallelism"] = parallelism + captured.update(kwargs) + return None, None + + monkeypatch.setattr( + "nemo_automodel.components.distributed.mesh_utils._create_device_meshes", + fake_create_device_meshes, + ) + return captured + + +def test_mesh_context_build_accepts_ddp_config(captured_raw_mesh_call): + config = DDPConfig() + + ctx = MeshContext.build(config, world_size=4) + + assert isinstance(ctx, MeshContext) + assert not hasattr(ctx, "strategy_config") + assert captured_raw_mesh_call["strategy_config"] is config + assert not hasattr(ctx, "activation_checkpointing") + assert captured_raw_mesh_call["world_size"] == 4 + + +@pytest.mark.parametrize("strategy", ["megatron_fsdp", "megatron-fsdp", "mfsdp"]) +def test_distributed_setup_config_accepts_megatron_fsdp_names(strategy, captured_raw_mesh_call): + setup = DistributedSetup.build(strategy=strategy, world_size=4) + + assert isinstance(setup.mesh_context, MeshContext) + assert isinstance(captured_raw_mesh_call["strategy_config"], MegatronFSDPConfig) + assert captured_raw_mesh_call["world_size"] == 4 + + +def test_mesh_context_build_accepts_existing_config(captured_raw_mesh_call): + config = FSDP2Config() + + ctx = MeshContext.build(config, world_size=8) + + assert isinstance(ctx, MeshContext) + assert captured_raw_mesh_call["strategy_config"] is config + assert captured_raw_mesh_call["world_size"] == 8 + + +def test_mesh_context_build_passes_parallelism_to_raw_mesh_builder(captured_raw_mesh_call): + MeshContext.build( + FSDP2Config(), + parallelism_sizes=ParallelismSizes(dp_size=4, dp_replicate_size=2, tp_size=2, cp_size=2), + world_size=16, + ) + + parallelism = captured_raw_mesh_call["parallelism"] + assert parallelism.dp_size == 4 + assert parallelism.dp_replicate_size == 2 + assert parallelism.tp_size == 2 + assert parallelism.pp_size == 1 + assert parallelism.cp_size == 2 + assert parallelism.ep_size == 1 + + +def test_mesh_context_build_requires_strategy_config(): + with pytest.raises(ValueError, match="Unknown distributed strategy config type"): + MeshContext.build("ddp", world_size=1) # type: ignore[arg-type] + + +def test_distributed_setup_config_rejects_unknown_strategy(): + with pytest.raises(ValueError, match="Unknown strategy"): + DistributedSetup.build(strategy="unknown", world_size=1) + + +def test_distributed_setup_config_defaults_parallel_subconfigs(monkeypatch): + device_mesh = _FakeMesh( + { + MeshAxisName.PP: 2, + MeshAxisName.DP_REPLICATE: 1, + MeshAxisName.DP_SHARD: 1, + MeshAxisName.CP: 1, + MeshAxisName.TP: 1, + } + ) + moe_mesh = _FakeMesh({MeshAxisName.EP_SHARD: 1, MeshAxisName.EP: 2}) + + def fake_create_device_meshes(strategy_config, parallelism, **kwargs): + return device_mesh, moe_mesh + + monkeypatch.setattr( + "nemo_automodel.components.distributed.mesh_utils._create_device_meshes", + fake_create_device_meshes, + ) + + setup = DistributedSetup.build( + strategy="fsdp2", + parallelism_sizes=ParallelismSizes(pp_size=2, ep_size=2), + world_size=4, + ) + + assert isinstance(setup, DistributedSetup) + assert isinstance(setup.pipeline_config, PipelineConfig) + assert isinstance(setup.moe_parallel_config, MoEParallelizerConfig) + + +def test_distributed_setup_config_keeps_activation_checkpointing_separate(monkeypatch): + device_mesh = _FakeMesh( + { + MeshAxisName.PP: 1, + MeshAxisName.DP_REPLICATE: 1, + MeshAxisName.DP_SHARD: 2, + MeshAxisName.CP: 1, + MeshAxisName.TP: 1, + } + ) + moe_mesh = _FakeMesh({MeshAxisName.EP_SHARD: 1, MeshAxisName.EP: 2}) + + def fake_create_device_meshes(strategy_config, parallelism, **kwargs): + return device_mesh, moe_mesh + + monkeypatch.setattr( + "nemo_automodel.components.distributed.mesh_utils._create_device_meshes", + fake_create_device_meshes, + ) + + setup = DistributedSetup.build( + strategy="fsdp2", + parallelism_sizes=ParallelismSizes(ep_size=2), + activation_checkpointing=True, + world_size=2, + ) + + assert not hasattr(setup.mesh_context, "activation_checkpointing") + assert setup.activation_checkpointing is True + assert setup.strategy_config.activation_checkpointing is False + + +def test_distributed_setup_config_does_not_infer_activation_checkpointing_from_strategy_config(captured_raw_mesh_call): + setup = DistributedSetup.build( + strategy=FSDP2Config(activation_checkpointing=True), + world_size=1, + ) + + assert setup.activation_checkpointing is False + assert captured_raw_mesh_call["strategy_config"].activation_checkpointing is True + + +def test_distributed_setup_config_activation_checkpointing_override(captured_raw_mesh_call): + setup = DistributedSetup.build( + strategy=FSDP2Config(activation_checkpointing=True), + activation_checkpointing=False, + world_size=1, + ) + + assert setup.activation_checkpointing is False + assert setup.strategy_config.activation_checkpointing is True + + +def test_distributed_setup_config_rejects_pipeline_config_without_pipeline_parallelism(captured_raw_mesh_call): + with pytest.raises(ValueError, match="pipeline_config requires pp_size > 1"): + DistributedSetup.build( + strategy=FSDP2Config(), + pipeline_config=PipelineConfig(), + parallelism_sizes=ParallelismSizes(pp_size=1), + world_size=1, + ) + + +def test_distributed_setup_config_rejects_moe_config_without_expert_parallelism(captured_raw_mesh_call): + with pytest.raises(ValueError, match="moe_parallel_config requires ep_size > 1"): + DistributedSetup.build( + strategy=FSDP2Config(), + moe_parallel_config=MoEParallelizerConfig(), + parallelism_sizes=ParallelismSizes(ep_size=1), + world_size=1, + ) + + +def test_distributed_setup_config_builds_runtime_setup(captured_raw_mesh_call): + setup = DistributedSetup.build( + strategy=FSDP2Config(sequence_parallel=True), + parallelism_sizes=ParallelismSizes(tp_size=2), + activation_checkpointing=True, + world_size=4, + ) + + assert isinstance(setup, DistributedSetup) + assert setup.strategy_config.sequence_parallel is True + assert setup.activation_checkpointing is True + assert captured_raw_mesh_call["parallelism"].tp_size == 2 + assert captured_raw_mesh_call["world_size"] == 4 diff --git a/tests/unit_tests/distributed/test_mesh.py b/tests/unit_tests/distributed/test_mesh.py index 98e2f07e87..66a6eefb0a 100644 --- a/tests/unit_tests/distributed/test_mesh.py +++ b/tests/unit_tests/distributed/test_mesh.py @@ -12,25 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the **component-layer** mesh module (MeshContext, validation, STRATEGY_MAP). +"""Tests for the **component-layer** mesh module (MeshContext and validation). -Dict-parsing tests live in ``tests/unit_tests/recipes/test_dist_setup.py``. +Dict-parsing tests live in ``tests/unit_tests/recipes/test_dist_utils.py``. """ from unittest.mock import Mock import pytest -from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config, MegatronFSDPConfig +from nemo_automodel.components.distributed.config import ( + DDPConfig, + DistributedSetup, + FSDP2Config, + MegatronFSDPConfig, +) from nemo_automodel.components.distributed.mesh import ( - STRATEGY_MAP, MeshAxisName, MeshContext, _get_axis_size, - _validate_distributed_setup, ) -from nemo_automodel.components.distributed.pipelining.config import PipelineConfig -from nemo_automodel.components.moe.config import MoEParallelizerConfig # --------------------------------------------------------------------------- # MeshContext – defaults (no mesh attached) @@ -54,32 +55,12 @@ def test_pp_enabled_false_by_default(self): def test_default_config_fields(self): ctx = MeshContext() - assert ctx.strategy_config is None - assert ctx.pipeline_config is None - assert ctx.moe_config is None - assert ctx.activation_checkpointing is False + assert not hasattr(ctx, "strategy_config") + assert not hasattr(ctx, "pipeline_config") + assert not hasattr(ctx, "moe_config") assert ctx.device_mesh is None assert ctx.moe_mesh is None - def test_with_strategy_config(self): - cfg = FSDP2Config() - ctx = MeshContext(strategy_config=cfg) - assert ctx.strategy_config is cfg - - def test_with_pipeline_config(self): - pc = PipelineConfig(pp_schedule="1f1b", pp_microbatch_size=4) - ctx = MeshContext(pipeline_config=pc) - assert ctx.pipeline_config is pc - - def test_with_moe_config(self): - mc = MoEParallelizerConfig(ignore_router_for_ac=True) - ctx = MeshContext(moe_config=mc) - assert ctx.moe_config is mc - - def test_activation_checkpointing_flag(self): - ctx = MeshContext(activation_checkpointing=True) - assert ctx.activation_checkpointing is True - # --------------------------------------------------------------------------- # MeshContext.from_meshes (no real mesh — smoke test) @@ -93,11 +74,6 @@ def test_from_none_meshes(self): assert ctx.moe_mesh is None assert ctx.tp_size == 1 - def test_from_meshes_with_strategy(self): - cfg = FSDP2Config() - ctx = MeshContext.from_meshes(None, strategy_config=cfg) - assert ctx.strategy_config is cfg - # --------------------------------------------------------------------------- # MeshContext – helper methods @@ -115,8 +91,16 @@ def test_enum_is_str(self): def test_all_expected_members(self): names = {m.value for m in MeshAxisName} assert names == { - "pp", "dp", "dp_replicate", "dp_shard", "dp_shard_cp", - "dp_cp", "cp", "tp", "ep", "ep_shard", + "pp", + "dp", + "dp_replicate", + "dp_shard", + "dp_shard_cp", + "dp_cp", + "cp", + "tp", + "ep", + "ep_shard", } @@ -186,65 +170,16 @@ def test_parallelize_axis_kwargs(self): # --------------------------------------------------------------------------- -# validate_distributed_setup – happy paths -# --------------------------------------------------------------------------- - - -class TestValidateHappyPaths: - def test_minimal_fsdp2(self): - _validate_distributed_setup(MeshContext(strategy_config=FSDP2Config())) - - def test_minimal_megatron_fsdp(self): - _validate_distributed_setup(MeshContext(strategy_config=MegatronFSDPConfig())) - - def test_minimal_ddp(self): - _validate_distributed_setup(MeshContext(strategy_config=DDPConfig())) - - def test_activation_checkpointing_on_strategy(self): - _validate_distributed_setup( - MeshContext(strategy_config=FSDP2Config(activation_checkpointing=True)), - ) - - -# --------------------------------------------------------------------------- -# validate_distributed_setup – constraint violations +# DistributedSetup – simple runtime bundle # --------------------------------------------------------------------------- -class TestValidation: - def test_megatron_fsdp_rejects_sequence_parallel(self): - with pytest.raises(ValueError, match="sequence_parallel"): - _validate_distributed_setup( - MeshContext(strategy_config=MegatronFSDPConfig(sequence_parallel=True)), - ) - - def test_pipeline_requires_pp_gt_1(self): - pc = PipelineConfig(pp_schedule="1f1b") - with pytest.raises(ValueError, match="pp_size > 1"): - _validate_distributed_setup( - MeshContext(strategy_config=FSDP2Config(), pipeline_config=pc), - ) +class TestDistributedSetup: + def test_minimal_setup_holds_mesh_and_policy(self): + setup = DistributedSetup(mesh_context=MeshContext(), strategy_config=FSDP2Config(activation_checkpointing=True)) - def test_moe_requires_ep_gt_1(self): - mc = MoEParallelizerConfig() - with pytest.raises(ValueError, match="ep_size > 1"): - _validate_distributed_setup( - MeshContext(strategy_config=FSDP2Config(), moe_config=mc), - ) - - -# --------------------------------------------------------------------------- -# STRATEGY_MAP -# --------------------------------------------------------------------------- - - -class TestStrategyMap: - def test_strategy_map_entries(self): - assert STRATEGY_MAP == { - "fsdp2": FSDP2Config, - "megatron_fsdp": MegatronFSDPConfig, - "ddp": DDPConfig, - } + assert isinstance(setup.mesh_context, MeshContext) + assert setup.strategy_config.activation_checkpointing is True # --------------------------------------------------------------------------- @@ -254,19 +189,21 @@ def test_strategy_map_entries(self): class TestIntegration: def test_megatron_fsdp_with_valid_options(self): - _validate_distributed_setup( - MeshContext( - strategy_config=MegatronFSDPConfig( - zero_dp_strategy=2, - overlap_grad_reduce=False, - activation_checkpointing=True, - ), + setup = DistributedSetup( + mesh_context=MeshContext(), + strategy_config=MegatronFSDPConfig( + zero_dp_strategy=2, + overlap_grad_reduce=False, + activation_checkpointing=True, ), ) + assert setup.strategy_config.zero_dp_strategy == 2 + assert setup.strategy_config.overlap_grad_reduce is False - def test_fsdp2_validates_at_construction(self): - """MeshContext.__post_init__ runs validation automatically.""" - ctx = MeshContext( + def test_fsdp2_validates_on_distributed_setup(self): + """DistributedSetup validates strategy policy against mesh topology.""" + setup = DistributedSetup( + mesh_context=MeshContext(), strategy_config=FSDP2Config( sequence_parallel=True, activation_checkpointing=True, @@ -274,12 +211,13 @@ def test_fsdp2_validates_at_construction(self): ), ) # No meshes → sizes default to 1 / None, which is valid for FSDP2. - assert ctx.tp_size == 1 + assert setup.mesh_context.tp_size == 1 @pytest.mark.parametrize( "strategy_config", - [FSDP2Config(backend="gloo"), MegatronFSDPConfig(backend="gloo"), DDPConfig(backend="gloo")], + [FSDP2Config(), MegatronFSDPConfig(), DDPConfig()], ids=["fsdp2", "megatron_fsdp", "ddp"], ) - def test_backend_configuration(self, strategy_config): - _validate_distributed_setup(MeshContext(strategy_config=strategy_config)) + def test_strategy_configs_do_not_carry_process_group_backend(self, strategy_config): + setup = DistributedSetup(mesh_context=MeshContext(), strategy_config=strategy_config) + assert not hasattr(setup.strategy_config, "backend") diff --git a/tests/unit_tests/distributed/test_mesh_utils.py b/tests/unit_tests/distributed/test_mesh_utils.py index dfc6abf79a..c0b19512a6 100644 --- a/tests/unit_tests/distributed/test_mesh_utils.py +++ b/tests/unit_tests/distributed/test_mesh_utils.py @@ -26,6 +26,30 @@ get_submesh, ) + +def test_mesh_utils_reexports_mesh_creation_helpers(): + """Raw device mesh constructors live in mesh_utils, not MeshContext.""" + from nemo_automodel.components.distributed import mesh, mesh_utils + from nemo_automodel.components.distributed.mesh import MeshContext + + assert callable(mesh_utils._create_device_meshes) + assert callable(mesh_utils._create_fsdp2_device_mesh) + assert callable(mesh_utils._create_megatron_fsdp_device_mesh) + assert not hasattr(MeshContext, "_create_device_meshes") + assert not hasattr(mesh, "_create_device_meshes") + + +def test_distributed_package_exports_user_entrypoints(): + """Users can import programmatic distributed entry points from one namespace.""" + from nemo_automodel.components import distributed + from nemo_automodel.components.distributed.init_utils import initialize_distributed + from nemo_automodel.components.distributed.mesh import MeshContext, ParallelismSizes + + assert distributed.MeshContext is MeshContext + assert distributed.ParallelismSizes is ParallelismSizes + assert distributed.initialize_distributed is initialize_distributed + + # --------------------------------------------------------------------------- # get_flat_mesh # --------------------------------------------------------------------------- diff --git a/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py b/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py index 29f6c0bf6e..a3c395f87e 100644 --- a/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py +++ b/tests/unit_tests/models/bi_encoder/test_bi_encoder_model.py @@ -30,7 +30,9 @@ class DummyMesh: def _apply_common_mocks(monkeypatch): """Mock CUDA-dependent infrastructure so tests run without a GPU.""" monkeypatch.setattr(am, "instantiate_infrastructure", lambda **kwargs: (None, None, None, None)) - monkeypatch.setattr(am, "MeshContext", type("MeshContext", (), {"from_meshes": staticmethod(lambda *a, **k: DummyMesh())})) + monkeypatch.setattr( + am, "MeshContext", type("MeshContext", (), {"from_meshes": staticmethod(lambda *a, **k: DummyMesh())}) + ) monkeypatch.setattr(am.torch.cuda, "current_device", lambda: 0) @@ -76,7 +78,7 @@ def fake_apply_infrastructure(model, **kwargs): # Patches applied assert "liger" in model.marker and "sdpa" in model.marker # Ensure HF kwargs injected + passthrough of parameters to build - assert last_kwargs["attn_implementation"] == "flash_attention_2" + assert last_kwargs["attn_implementation"] == am.DEFAULT_ATTN_IMPLEMENTATION assert last_kwargs["some_other_kwarg"] == "x" diff --git a/tests/unit_tests/recipes/test_diffusion_train_metrics.py b/tests/unit_tests/recipes/test_diffusion_train_metrics.py index b5188dd580..44648e6da1 100644 --- a/tests/unit_tests/recipes/test_diffusion_train_metrics.py +++ b/tests/unit_tests/recipes/test_diffusion_train_metrics.py @@ -23,6 +23,7 @@ from nemo_automodel.recipes.diffusion import train as diffusion_train from nemo_automodel.recipes.diffusion.train import ( TrainDiffusionRecipe, + _build_diffusion_parallel_manager_args, _calculate_throughput_metrics, _count_local_batch_group_samples, _get_diffusion_microbatch_size, @@ -103,6 +104,60 @@ def set_attention_backend(self, attention_backend): self.attention_backend = attention_backend +def test_build_diffusion_parallel_manager_args_uses_shared_fsdp_defaults(): + manager_args = _build_diffusion_parallel_manager_args( + fsdp_cfg=None, + ddp_cfg=None, + world_size=8, + dtype=torch.float16, + lora_enabled=False, + ) + + assert manager_args["_manager_type"] == "fsdp2" + assert manager_args["world_size"] == 8 + assert manager_args["dp_size"] is None + assert manager_args["tp_size"] == 1 + assert manager_args["pp_size"] == 1 + assert manager_args["cp_size"] == 1 + assert manager_args["ep_size"] == 1 + assert manager_args["activation_checkpointing"] is True + assert manager_args["defer_fsdp_grad_sync"] is True + assert manager_args["enable_fsdp2_prefetch"] is True + assert manager_args["use_hf_tp_plan"] is False + assert manager_args["mp_policy"].param_dtype == torch.float16 + assert manager_args["mp_policy"].reduce_dtype == torch.float32 + assert manager_args["mp_policy"].output_dtype == torch.float16 + + +def test_build_diffusion_parallel_manager_args_keeps_lora_param_dtype_uncast(): + manager_args = _build_diffusion_parallel_manager_args( + fsdp_cfg={}, + ddp_cfg=None, + world_size=1, + dtype=torch.bfloat16, + lora_enabled=True, + ) + + assert manager_args["mp_policy"].param_dtype is None + assert manager_args["mp_policy"].output_dtype == torch.bfloat16 + + +def test_build_diffusion_parallel_manager_args_parses_ddp_config(): + manager_args = _build_diffusion_parallel_manager_args( + fsdp_cfg=None, + ddp_cfg={"activation_checkpointing": True}, + world_size=4, + dtype=torch.bfloat16, + lora_enabled=False, + ) + + assert manager_args == { + "_manager_type": "ddp", + "world_size": 4, + "activation_checkpointing": True, + } + + def test_build_model_and_optimizer_forwards_perf_options_and_optimizer_kwargs(monkeypatch): pipe = SimpleNamespace(transformer=_TinyTransformer()) manager = SimpleNamespace(device_mesh="mesh") diff --git a/tests/unit_tests/recipes/test_dist_setup.py b/tests/unit_tests/recipes/test_dist_utils.py similarity index 72% rename from tests/unit_tests/recipes/test_dist_setup.py rename to tests/unit_tests/recipes/test_dist_utils.py index 1f6a5845a2..33566b1e8b 100644 --- a/tests/unit_tests/recipes/test_dist_setup.py +++ b/tests/unit_tests/recipes/test_dist_utils.py @@ -12,17 +12,44 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for the **recipe-layer** YAML / dict parsing (``_dist_setup``). +"""Tests for the **recipe-layer** YAML / dict parsing (``_dist_utils``). Typed validation tests live in ``tests/unit_tests/distributed/test_mesh.py``. """ import pytest -from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config, MegatronFSDPConfig +from nemo_automodel.components.distributed.config import ( + DDPConfig, + DistributedSetup, + FSDP2Config, + MegatronFSDPConfig, + MoEParallelizerConfig, +) +from nemo_automodel.components.distributed.mesh import MeshAxisName, MeshContext, ParallelismSizes from nemo_automodel.components.distributed.pipelining.config import PipelineConfig -from nemo_automodel.components.moe.config import MoEParallelizerConfig -from nemo_automodel.recipes._dist_setup import parse_distributed_section, setup_distributed +from nemo_automodel.recipes._dist_utils import ( + create_distributed_setup_from_config, + parse_distributed_section, +) + + +class _FakeAxis: + def __init__(self, size: int): + self._size = size + + def size(self) -> int: + return self._size + + +class _FakeMesh: + def __init__(self, sizes: dict[MeshAxisName, int]): + self.mesh_dim_names = tuple(sizes) + self._sizes = sizes + + def __getitem__(self, axis: MeshAxisName): + return _FakeAxis(self._sizes[axis]) + # --------------------------------------------------------------------------- # Basic dict parsing @@ -45,21 +72,23 @@ def test_minimal_fsdp2(self): assert result["dp_replicate_size"] is None assert result["pp_enabled"] is False assert result["pipeline_config"] is None - assert result["moe_config"] is None + assert result["moe_parallel_config"] is None + assert result["activation_checkpointing"] is False def test_default_strategy_is_fsdp2(self): result = parse_distributed_section({}) assert isinstance(result["strategy_config"], FSDP2Config) - def test_megatron_fsdp(self): - result = parse_distributed_section({"strategy": "megatron_fsdp"}) + @pytest.mark.parametrize("strategy", ["megatron_fsdp", "megatron-fsdp", "mfsdp"]) + def test_megatron_fsdp_names(self, strategy): + result = parse_distributed_section({"strategy": strategy}) assert isinstance(result["strategy_config"], MegatronFSDPConfig) assert result["strategy_config"].zero_dp_strategy == 3 def test_ddp(self): result = parse_distributed_section({"strategy": "ddp"}) assert isinstance(result["strategy_config"], DDPConfig) - assert result["strategy_config"].backend == "nccl" + assert result["strategy_config"].activation_checkpointing is False def test_all_parallelism_keys(self): cfg = { @@ -148,8 +177,8 @@ class TestMoE: def test_moe_config_created(self): cfg = {"ep_size": 2, "moe": {"ignore_router_for_ac": True}} result = parse_distributed_section(cfg) - assert isinstance(result["moe_config"], MoEParallelizerConfig) - assert result["moe_config"].ignore_router_for_ac is True + assert isinstance(result["moe_parallel_config"], MoEParallelizerConfig) + assert result["moe_parallel_config"].ignore_router_for_ac is True def test_moe_fields_pass_through(self): cfg = { @@ -157,17 +186,17 @@ def test_moe_fields_pass_through(self): "moe": {"ignore_router_for_ac": True, "reshard_after_forward": True, "wrap_outer_model": False}, } result = parse_distributed_section(cfg) - assert result["moe_config"].reshard_after_forward is True - assert result["moe_config"].wrap_outer_model is False + assert result["moe_parallel_config"].reshard_after_forward is True + assert result["moe_parallel_config"].wrap_outer_model is False def test_empty_moe_dict_uses_defaults(self): result = parse_distributed_section({"ep_size": 2, "moe": {}}) - assert isinstance(result["moe_config"], MoEParallelizerConfig) - assert result["moe_config"].ignore_router_for_ac is False + assert isinstance(result["moe_parallel_config"], MoEParallelizerConfig) + assert result["moe_parallel_config"].ignore_router_for_ac is False def test_mp_policy_none_when_omitted(self): result = parse_distributed_section({"ep_size": 2, "moe": {}}) - assert result["moe_config"].mp_policy is None + assert result["moe_parallel_config"].mp_policy is None def test_mp_policy_target_instantiated(self): """mp_policy with resolved _target_ callable is instantiated to MixedPrecisionPolicy.""" @@ -187,7 +216,7 @@ def test_mp_policy_target_instantiated(self): }, } result = parse_distributed_section(cfg) - mp = result["moe_config"].mp_policy + mp = result["moe_parallel_config"].mp_policy assert isinstance(mp, MixedPrecisionPolicy) assert mp.param_dtype == torch.bfloat16 assert mp.reduce_dtype == torch.float32 @@ -208,7 +237,7 @@ def test_mp_policy_in_to_dict(self): }, } result = parse_distributed_section(cfg) - d = result["moe_config"].to_dict() + d = result["moe_parallel_config"].to_dict() assert "mp_policy" in d assert isinstance(d["mp_policy"], MixedPrecisionPolicy) @@ -220,21 +249,21 @@ def test_mp_policy_passthrough_when_already_instantiated(self): policy = MixedPrecisionPolicy(param_dtype=torch.float16, reduce_dtype=torch.float32) cfg = {"ep_size": 2, "moe": {"mp_policy": policy}} result = parse_distributed_section(cfg) - assert result["moe_config"].mp_policy is policy + assert result["moe_parallel_config"].mp_policy is policy # --------------------------------------------------------------------------- -# activation_checkpointing routing (EP-aware) +# activation_checkpointing is parsed separately from topology # --------------------------------------------------------------------------- -class TestActivationCheckpointingRouting: - def test_routes_to_strategy_when_no_ep(self): +class TestActivationCheckpointingParsing: + def test_parses_when_no_ep(self): result = parse_distributed_section({"strategy": "fsdp2", "activation_checkpointing": True, "ep_size": 1}) - assert result["strategy_config"].activation_checkpointing is True + assert result["strategy_config"].activation_checkpointing is False assert result["activation_checkpointing"] is True - def test_not_on_strategy_when_ep_gt_1(self): + def test_parses_when_ep_gt_1(self): result = parse_distributed_section( {"strategy": "fsdp2", "activation_checkpointing": True, "ep_size": 2, "moe": {}} ) @@ -248,7 +277,8 @@ def test_defaults_to_false(self): def test_works_with_ddp(self): result = parse_distributed_section({"strategy": "ddp", "activation_checkpointing": True}) - assert result["strategy_config"].activation_checkpointing is True + assert result["strategy_config"].activation_checkpointing is False + assert result["activation_checkpointing"] is True # --------------------------------------------------------------------------- @@ -280,12 +310,12 @@ def test_moe_requires_ep_gt_1(self): """MoE section is silently discarded when ep_size <= 1 (common when a YAML template is overridden via CLI).""" result = parse_distributed_section({"ep_size": 1, "moe": {"ignore_router_for_ac": True}}) - assert result["moe_config"] is None + assert result["moe_parallel_config"] is None def test_moe_rejects_default_ep_size(self): """MoE section is silently discarded when ep_size defaults to 1.""" result = parse_distributed_section({"moe": {"ignore_router_for_ac": True}}) - assert result["moe_config"] is None + assert result["moe_parallel_config"] is None def test_unknown_field_for_strategy(self): with pytest.raises(ValueError, match="Unknown options"): @@ -322,7 +352,8 @@ def test_megatron_fsdp_with_valid_options(self): result = parse_distributed_section(cfg) assert result["strategy_config"].zero_dp_strategy == 2 assert result["strategy_config"].overlap_grad_reduce is False - assert result["strategy_config"].activation_checkpointing is True + assert result["strategy_config"].activation_checkpointing is False + assert result["activation_checkpointing"] is True assert result["tp_size"] == 2 def test_fsdp2_full_config(self): @@ -339,7 +370,8 @@ def test_fsdp2_full_config(self): } result = parse_distributed_section(cfg) assert result["strategy_config"].sequence_parallel is True - assert result["strategy_config"].activation_checkpointing is True + assert result["strategy_config"].activation_checkpointing is False + assert result["activation_checkpointing"] is True assert result["pp_enabled"] is True assert isinstance(result["pipeline_config"], PipelineConfig) @@ -354,13 +386,13 @@ def test_combined_pipeline_and_moe(self): result = parse_distributed_section(cfg) assert result["pp_enabled"] is True assert isinstance(result["pipeline_config"], PipelineConfig) - assert isinstance(result["moe_config"], MoEParallelizerConfig) - assert result["moe_config"].ignore_router_for_ac is True + assert isinstance(result["moe_parallel_config"], MoEParallelizerConfig) + assert result["moe_parallel_config"].ignore_router_for_ac is True - @pytest.mark.parametrize("strategy", ["fsdp2", "megatron_fsdp", "ddp"]) - def test_backend_configuration(self, strategy): - result = parse_distributed_section({"strategy": strategy, "backend": "gloo"}) - assert result["strategy_config"].backend == "gloo" + @pytest.mark.parametrize("strategy", ["fsdp2", "megatron_fsdp", "megatron-fsdp", "mfsdp", "ddp"]) + def test_process_group_backend_is_not_a_strategy_option(self, strategy): + with pytest.raises(ValueError, match="Unknown options"): + parse_distributed_section({"strategy": strategy, "backend": "gloo"}) # --------------------------------------------------------------------------- @@ -375,7 +407,7 @@ class TestNoneParallelismValues: def test_ep_size_none_defaults_to_1(self): result = parse_distributed_section({"strategy": "fsdp2", "ep_size": None}) assert result["ep_size"] == 1 - assert result["moe_config"] is None + assert result["moe_parallel_config"] is None def test_pp_size_none_defaults_to_1(self): result = parse_distributed_section({"strategy": "fsdp2", "pp_size": None}) @@ -385,7 +417,8 @@ def test_pp_size_none_defaults_to_1(self): def test_ep_size_none_routes_ac_to_strategy(self): result = parse_distributed_section({"strategy": "fsdp2", "activation_checkpointing": True, "ep_size": None}) - assert result["strategy_config"].activation_checkpointing is True + assert result["strategy_config"].activation_checkpointing is False + assert result["activation_checkpointing"] is True def test_pp_size_none_discards_pipeline_dict(self): result = parse_distributed_section({"pp_size": None, "pipeline": {"pp_schedule": "1f1b"}}) @@ -393,40 +426,47 @@ def test_pp_size_none_discards_pipeline_dict(self): def test_ep_size_none_discards_moe_dict(self): result = parse_distributed_section({"ep_size": None, "moe": {"ignore_router_for_ac": True}}) - assert result["moe_config"] is None + assert result["moe_parallel_config"] is None # --------------------------------------------------------------------------- -# setup_distributed: world_size auto-detection +# create_distributed_setup_from_config: world_size auto-detection # --------------------------------------------------------------------------- -class TestSetupDistributedWorldSizeAutoDetect: - """``setup_distributed`` accepts an optional ``world_size`` and auto-detects +class TestCreateDistributedSetupFromConfigWorldSizeAutoDetect: + """``create_distributed_setup_from_config`` accepts an optional ``world_size`` and auto-detects it from ``torch.distributed`` / ``WORLD_SIZE`` when not provided.""" @pytest.fixture def patched_mesh(self, monkeypatch): - """Stub create_device_mesh to capture the world_size it receives.""" + """Stub mesh context creation to capture the world_size it receives.""" captured: dict = {} - def fake_create_device_mesh(strategy_config, **kwargs): + def fake_build(cls, strategy_config, parallelism_sizes=None, **kwargs): + parallelism = parallelism_sizes or ParallelismSizes() + captured["strategy_config"] = strategy_config + captured["parallelism"] = parallelism captured.update(kwargs) - return ("device_mesh_sentinel", None) + device_mesh = _FakeMesh( + { + MeshAxisName.PP: parallelism.pp_size or 1, + MeshAxisName.DP_REPLICATE: parallelism.dp_replicate_size or 1, + MeshAxisName.DP_SHARD: parallelism.dp_size or 1, + MeshAxisName.CP: parallelism.cp_size or 1, + MeshAxisName.TP: parallelism.tp_size or 1, + } + ) + moe_mesh = None + if (parallelism.ep_size or 1) > 1: + moe_mesh = _FakeMesh({MeshAxisName.EP_SHARD: 1, MeshAxisName.EP: parallelism.ep_size}) + return cls.from_meshes(device_mesh, moe_mesh) - monkeypatch.setattr( - "nemo_automodel.components.distributed.mesh_utils.create_device_mesh", - fake_create_device_mesh, - ) - # MeshContext.__post_init__ runs full validation against real meshes; bypass it. - monkeypatch.setattr( - "nemo_automodel.recipes._dist_setup.MeshContext", - lambda **kw: kw, - ) + monkeypatch.setattr(MeshContext, "build", classmethod(fake_build)) return captured def test_explicit_world_size_used(self, patched_mesh): - setup_distributed({"strategy": "fsdp2"}, world_size=4) + create_distributed_setup_from_config({"strategy": "fsdp2"}, world_size=4) assert patched_mesh["world_size"] == 4 def test_auto_detect_from_env(self, monkeypatch, patched_mesh): @@ -435,7 +475,7 @@ def test_auto_detect_from_env(self, monkeypatch, patched_mesh): monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) monkeypatch.setenv("WORLD_SIZE", "8") - setup_distributed({"strategy": "fsdp2"}) + create_distributed_setup_from_config({"strategy": "fsdp2"}) assert patched_mesh["world_size"] == 8 def test_auto_detect_defaults_to_one(self, monkeypatch, patched_mesh): @@ -444,7 +484,7 @@ def test_auto_detect_defaults_to_one(self, monkeypatch, patched_mesh): monkeypatch.setattr(torch.distributed, "is_initialized", lambda: False) monkeypatch.delenv("WORLD_SIZE", raising=False) - setup_distributed({"strategy": "fsdp2"}) + create_distributed_setup_from_config({"strategy": "fsdp2"}) assert patched_mesh["world_size"] == 1 def test_auto_detect_from_torch_distributed(self, monkeypatch, patched_mesh): @@ -453,5 +493,48 @@ def test_auto_detect_from_torch_distributed(self, monkeypatch, patched_mesh): monkeypatch.setattr(torch.distributed, "is_initialized", lambda: True) monkeypatch.setattr(torch.distributed, "get_world_size", lambda: 16) - setup_distributed({"strategy": "fsdp2"}) + create_distributed_setup_from_config({"strategy": "fsdp2"}) assert patched_mesh["world_size"] == 16 + + def test_programmatic_args_without_cfg(self, patched_mesh): + result = create_distributed_setup_from_config( + strategy="fsdp2", + tp_size=2, + ep_size=2, + activation_checkpointing=True, + world_size=4, + ) + + assert patched_mesh["parallelism"] == ParallelismSizes(tp_size=2, ep_size=2) + assert isinstance(result, DistributedSetup) + assert isinstance(result.strategy_config, FSDP2Config) + assert isinstance(result.moe_parallel_config, MoEParallelizerConfig) + assert result.strategy_config.activation_checkpointing is False + assert result.activation_checkpointing is True + + @pytest.mark.parametrize("strategy", ["megatron_fsdp", "megatron-fsdp", "mfsdp"]) + def test_programmatic_megatron_fsdp_names(self, strategy, patched_mesh): + result = create_distributed_setup_from_config(strategy=strategy, world_size=1) + + assert isinstance(result.strategy_config, MegatronFSDPConfig) + + def test_programmatic_args_override_cfg_fallback(self, patched_mesh): + create_distributed_setup_from_config( + {"strategy": "fsdp2", "tp_size": 1, "ep_size": 1}, + tp_size=2, + ep_size=2, + world_size=4, + ) + + assert patched_mesh["parallelism"] == ParallelismSizes(tp_size=2, ep_size=2) + + def test_strategy_kwargs_are_forwarded_to_strategy_config(self, patched_mesh): + result = create_distributed_setup_from_config( + strategy="fsdp2", + sequence_parallel=True, + defer_fsdp_grad_sync=False, + world_size=1, + ) + + assert result.strategy_config.sequence_parallel is True + assert result.strategy_config.defer_fsdp_grad_sync is False diff --git a/tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py b/tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py index 637fd3d7b4..4815267868 100644 --- a/tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py +++ b/tests/unit_tests/recipes/test_finetune_vlm_cp_wiring.py @@ -395,17 +395,19 @@ def _patch_pp_setup_minimals(monkeypatch, *, cp_size, stage0, dataloader_calls): monkeypatch.setattr(vlm_finetune, "_supports_logits_to_keep", lambda model: True) monkeypatch.setattr( vlm_finetune, - "setup_distributed", + "create_distributed_setup_from_config", lambda cfg, world_size: SimpleNamespace( + mesh_context=SimpleNamespace( + pp_enabled=True, + device_mesh=None, + moe_mesh=None, + cp_size=cp_size, + pp_size=2, + ), strategy_config=SimpleNamespace(), pipeline_config=SimpleNamespace(), - moe_config=None, + moe_parallel_config=None, activation_checkpointing=False, - pp_enabled=True, - device_mesh=None, - moe_mesh=None, - cp_size=cp_size, - pp_size=2, ), ) monkeypatch.setattr( diff --git a/tests/unit_tests/recipes/test_finetune_vlm_helpers.py b/tests/unit_tests/recipes/test_finetune_vlm_helpers.py index 26c53f0085..c8bd06f2f5 100644 --- a/tests/unit_tests/recipes/test_finetune_vlm_helpers.py +++ b/tests/unit_tests/recipes/test_finetune_vlm_helpers.py @@ -174,10 +174,11 @@ def to_dict(self): assert captured_kwargs["freeze_config"] == {"freeze_language_model": False, "freeze_vision_tower": True} -def test_build_model_passes_moe_config_from_parallelizer_config(): - """Test that cfg_moe as MoEParallelizerConfig is forwarded directly.""" +def test_build_model_passes_distributed_setup(): + """Distributed policy is passed through the single setup object.""" from nemo_automodel._transformers import NeMoAutoModelForImageTextToText - from nemo_automodel.components.moe.config import MoEParallelizerConfig + from nemo_automodel.components.distributed.config import DistributedSetup + from nemo_automodel.components.distributed.mesh import MeshContext captured_kwargs = {} @@ -193,7 +194,7 @@ def get(self, key, default=None): return getattr(self, key, default) cfg_model = CapturingModelConfig() - moe_cfg = MoEParallelizerConfig() + distributed_setup = DistributedSetup(mesh_context=MeshContext()) with patch("nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep", return_value=True): build_model( @@ -201,55 +202,12 @@ def get(self, key, default=None): cfg_freeze=None, cfg_peft=None, seed=123, - cfg_moe=moe_cfg, - activation_checkpointing=True, + distributed_setup=distributed_setup, ) - assert "moe_config" in captured_kwargs - assert captured_kwargs["moe_config"] is moe_cfg - assert captured_kwargs["activation_checkpointing"] is True - - -def test_build_model_passes_moe_config_from_dict_like(): - """Test that cfg_moe with to_dict() is converted to MoEParallelizerConfig.""" - from nemo_automodel._transformers import NeMoAutoModelForImageTextToText - from nemo_automodel.components.moe.config import MoEParallelizerConfig - - captured_kwargs = {} - - class CapturingModelConfig: - def __init__(self): - self._target_ = NeMoAutoModelForImageTextToText.from_pretrained - - def instantiate(self, **kwargs): - captured_kwargs.update(kwargs) - return DummyModel() - - def get(self, key, default=None): - return getattr(self, key, default) - - class DictLikeMoeConfig: - def to_dict(self): - return { - "activation_checkpointing": True, # should be stripped - "_target_": "some.target", # should be stripped - } - - cfg_model = CapturingModelConfig() - - with patch("nemo_automodel.recipes.vlm.finetune._supports_logits_to_keep", return_value=True): - build_model( - cfg_model=cfg_model, - cfg_freeze=None, - cfg_peft=None, - seed=123, - cfg_moe=DictLikeMoeConfig(), - activation_checkpointing=False, - ) - - assert "moe_config" in captured_kwargs - assert isinstance(captured_kwargs["moe_config"], MoEParallelizerConfig) - assert captured_kwargs["activation_checkpointing"] is False + assert captured_kwargs["distributed_setup"] is distributed_setup + assert "moe_config" not in captured_kwargs + assert "activation_checkpointing" not in captured_kwargs def test_build_model_no_moe_config_when_cfg_moe_is_none(): @@ -277,7 +235,6 @@ def get(self, key, default=None): cfg_freeze=None, cfg_peft=None, seed=123, - cfg_moe=None, ) assert "moe_config" not in captured_kwargs @@ -1023,7 +980,6 @@ def get(self, key, default=None): cfg_freeze=None, cfg_peft=None, seed=42, - device_mesh=mock_device_mesh, ) build_optimizer(model, cfg_opt, None, mock_device_mesh) @@ -2390,7 +2346,6 @@ def test_build_optimizer_disables_foreach_with_tp(): cfg_freeze=None, cfg_peft=None, seed=42, - device_mesh=mock_device_mesh, ) optimizer = build_optimizer(model, cfg_opt, None, mock_device_mesh) @@ -2512,16 +2467,19 @@ def _patch_vlm_setup_minimals(monkeypatch, cp_size): lambda *a, **k: SimpleNamespace(checkpoint_dir="ckpts", model_state_dict_keys=None), ) monkeypatch.setattr( - "nemo_automodel.recipes.vlm.finetune.setup_distributed", + "nemo_automodel.recipes.vlm.finetune.create_distributed_setup_from_config", lambda cfg, world_size: SimpleNamespace( + mesh_context=SimpleNamespace( + pp_enabled=False, + device_mesh=None, + moe_mesh=None, + cp_size=cp_size, + pp_size=1, + ), strategy_config=None, pipeline_config=None, - moe_config=None, + moe_parallel_config=None, activation_checkpointing=False, - pp_enabled=False, - device_mesh=None, - moe_mesh=None, - cp_size=cp_size, ), ) monkeypatch.setattr( @@ -2654,7 +2612,10 @@ def test_n_images_per_sample_packed(self): n_images_per_sample = torch.tensor([3, 1]) pv_chunks, ig_chunks = chunk_vlm_media( - pixel_values, image_grid, batch_size=2, n_microbatches=2, + pixel_values, + image_grid, + batch_size=2, + n_microbatches=2, n_images_per_sample=n_images_per_sample, ) assert len(pv_chunks) == 2 @@ -2670,7 +2631,10 @@ def test_legacy_one_image_per_sample(self): pixel_values = torch.randn(int(patch_counts.sum()), 64) pv_chunks, ig_chunks = chunk_vlm_media( - pixel_values, image_grid, batch_size=4, n_microbatches=2, + pixel_values, + image_grid, + batch_size=4, + n_microbatches=2, ) assert len(pv_chunks) == 2 assert ig_chunks[0].shape[0] == 2 @@ -2708,7 +2672,10 @@ def test_fallback_mismatched_images_raises(self): with pytest.raises(ValueError, match="VLM PP chunking cannot align"): chunk_vlm_media( - pixel_values, image_grid, batch_size=2, n_microbatches=2, + pixel_values, + image_grid, + batch_size=2, + n_microbatches=2, ) def test_n_videos_per_sample_packed(self): diff --git a/tests/unit_tests/recipes/test_train_ft.py b/tests/unit_tests/recipes/test_train_ft.py index 125fc3a812..28d0520f55 100644 --- a/tests/unit_tests/recipes/test_train_ft.py +++ b/tests/unit_tests/recipes/test_train_ft.py @@ -487,18 +487,21 @@ def _patch_setup_minimals(monkeypatch, patch_fn): "nemo_automodel.recipes.llm.train_ft.build_checkpoint_config", lambda *a, **k: SimpleNamespace(checkpoint_dir="ckpts", model_state_dict_keys=None), ) - # Stub setup_distributed to avoid requiring torch.distributed init + # Stub create_distributed_setup_from_config to avoid requiring torch.distributed init monkeypatch.setattr( - "nemo_automodel.recipes.llm.train_ft.setup_distributed", + "nemo_automodel.recipes.llm.train_ft.create_distributed_setup_from_config", lambda cfg, world_size: SimpleNamespace( + mesh_context=SimpleNamespace( + pp_enabled=False, + device_mesh=None, + moe_mesh=None, + cp_size=1, + pp_size=1, + ), strategy_config=None, pipeline_config=None, - moe_config=None, + moe_parallel_config=None, activation_checkpointing=False, - pp_enabled=False, - device_mesh=None, - moe_mesh=None, - cp_size=1, ), ) @@ -1263,7 +1266,6 @@ def test_build_optimizer_disables_foreach_with_tp(): cfg_model=cfg_model, cfg_peft=None, seed=42, - device_mesh=mock_mesh, ) _ = build_optimizer(model, cfg_opt, None, mock_mesh) @@ -1867,18 +1869,21 @@ def _minimal_cfg_with_rope_fusion(cp_size: int, rope_fusion: bool): def _patch_setup_minimals_with_cp(monkeypatch, cp_size): """Variant of _patch_setup_minimals that lets us control cp_size.""" _patch_setup_minimals(monkeypatch, lambda *a, **k: None) - # Override setup_distributed to expose the desired cp_size + # Override create_distributed_setup_from_config to expose the desired cp_size monkeypatch.setattr( - "nemo_automodel.recipes.llm.train_ft.setup_distributed", + "nemo_automodel.recipes.llm.train_ft.create_distributed_setup_from_config", lambda cfg, world_size: SimpleNamespace( + mesh_context=SimpleNamespace( + pp_enabled=False, + device_mesh=None, + moe_mesh=None, + cp_size=cp_size, + pp_size=1, + ), strategy_config=None, pipeline_config=None, - moe_config=None, + moe_parallel_config=None, activation_checkpointing=False, - pp_enabled=False, - device_mesh=None, - moe_mesh=None, - cp_size=cp_size, ), )