Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions docs/guides/gradient-checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,18 @@ distributed:

### Configure Programmatically
```python
from nemo_automodel.components.distributed.config import FSDP2Config
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager

config = FSDP2Config(activation_checkpointing=True)
# device_mesh is created elsewhere (e.g. by the recipe via setup_distributed)
manager = FSDP2Manager(config, device_mesh=device_mesh, moe_mesh=moe_mesh)
model = manager.parallelize(model)
from nemo_automodel import NeMoAutoModelForCausalLM
from nemo_automodel.components.distributed.config import DistributedSetup

distributed_setup = DistributedSetup.build(
strategy="fsdp2",
activation_checkpointing=True,
)

model = NeMoAutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.2-1B",
distributed_setup=distributed_setup,
)
```

## Combine with Linear-Cut Cross-Entropy (LC-CE)
Expand Down Expand Up @@ -82,4 +87,4 @@ automodel examples/llm_finetune/llama3_2/llama_3_2_1b_my_finetune.yaml
If we run with the above settings (activation ckpt = on, lc-ce = on, fsdp = on), look for a log line similar to:
```
... | mem 7.30 GiB | ...
```
```
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

"""Extract per-layer activations from NeMo AutoModel using the training code path.

Uses the same distributed setup as training (EP, FSDP, backend config) and
registers forward hooks on decoder layers to capture hidden states. Saves
Uses the same distributed environment and mesh context as training (EP, FSDP,
backend config) and registers forward hooks on decoder layers to capture hidden states. Saves
activations and final logits for comparison against HF Transformers.

Run via torchrun:
Expand Down Expand Up @@ -84,33 +84,29 @@ def main():
sys.argv = ["extract_nemo_activations.py", "--config", args.config] + extra
cfg = parse_args_and_load_config()

# --- Distributed setup ---
# --- Distributed environment and mesh context ---
from nemo_automodel._transformers.utils import apply_cache_compatibility_patches
from nemo_automodel.components.loggers.log_utils import setup_logging
from nemo_automodel.recipes._dist_setup import setup_distributed
from nemo_automodel.recipes._dist_utils import create_distributed_setup_from_config
from nemo_automodel.recipes.llm.train_ft import build_distributed, build_model
from nemo_automodel.shared.te_patches import apply_te_patches

dist_env = build_distributed(cfg.get("dist_env", {}))
setup_logging()
apply_cache_compatibility_patches()
apply_te_patches()
dist_setup = setup_distributed(cfg, world_size=dist_env.world_size)
distributed_setup = create_distributed_setup_from_config(cfg, world_size=dist_env.world_size)
mesh_context = distributed_setup.mesh_context

if dist_setup.cp_size > 1 and cfg.get("model.backend.rope_fusion", False):
if mesh_context.cp_size > 1 and cfg.get("model.backend.rope_fusion", False):
cfg.model.backend.rope_fusion = False

# --- Build model ---
model = build_model(
cfg.model,
cfg_peft=None,
seed=cfg.get("seed", 42),
device_mesh=dist_setup.device_mesh,
moe_mesh=dist_setup.moe_mesh,
distributed_config=dist_setup.strategy_config,
pipeline_config=dist_setup.pipeline_config,
cfg_moe=dist_setup.moe_config,
activation_checkpointing=dist_setup.activation_checkpointing,
distributed_setup=distributed_setup,
)
model.eval()

Expand Down
90 changes: 51 additions & 39 deletions nemo_automodel/_diffusers/auto_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@
import torch
import torch.nn as nn

from nemo_automodel.components.distributed import parallelizer
from nemo_automodel.components.distributed import DistributedSetup, ParallelismSizes, parallelizer
from nemo_automodel.components.distributed.config import DDPConfig, FSDP2Config
from nemo_automodel.components.distributed.ddp import DDPManager
from nemo_automodel.components.distributed.fsdp2 import FSDP2Manager
from nemo_automodel.components.distributed.mesh_utils import create_device_mesh
from nemo_automodel.components.distributed.parallelizer import (
HunyuanParallelizationStrategy,
WanParallelizationStrategy,
Expand Down Expand Up @@ -208,23 +207,20 @@ def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager:
"""
Factory function to create the appropriate parallel manager based on config.

Constructs the proper config objects (FSDP2Config / DDPConfig) and, for FSDP2,
creates the required device mesh before instantiating the manager. This mirrors
Builds a ``DistributedSetup`` via ``DistributedSetup.build(...)``, then instantiates the
requested manager from the setup's strategy config and meshes. This mirrors
the pattern used by ``_instantiate_distributed`` in the transformers infrastructure.

The manager type is determined by the ``_manager_type`` key in *manager_args*:
- ``'ddp'``: Creates :class:`DDPConfig` + :class:`DDPManager`
- ``'fsdp2'`` (default): Creates :class:`FSDP2Config`, builds a
:class:`DeviceMesh` via :func:`create_device_mesh`, then creates
:class:`FSDP2Manager`
- ``'ddp'``: Creates a DDP ``DistributedSetup`` + ``DDPManager``
- ``'fsdp2'`` (default): Creates an FSDP2 ``DistributedSetup`` + ``FSDP2Manager``

Args:
manager_args: Flat dictionary of arguments. Recognised keys:

Common:
``_manager_type`` (str): ``'fsdp2'`` or ``'ddp'``.
``activation_checkpointing`` (bool): Enable activation checkpointing.
``backend`` (str): Distributed backend (default ``'nccl'``).

FSDP2-specific (mesh creation):
``world_size`` (int): Total number of processes.
Expand All @@ -244,46 +240,62 @@ def _create_parallel_manager(manager_args: Dict[str, Any]) -> ParallelManager:
"""
args = manager_args.copy()
manager_type = args.pop("_manager_type", "fsdp2").lower()
if "backend" in args:
raise ValueError(
"backend is not a parallel manager option; configure the process group before parallelization."
)
parallelism = ParallelismSizes(
dp_size=args.get("dp_size"),
dp_replicate_size=args.get("dp_replicate_size"),
tp_size=args.get("tp_size", 1),
pp_size=args.get("pp_size", 1),
cp_size=args.get("cp_size", 1),
ep_size=args.get("ep_size", 1),
)

if manager_type == "ddp":
config = DDPConfig(
distributed_setup = DistributedSetup.build(
strategy=DDPConfig(
activation_checkpointing=args.get("activation_checkpointing", False),
),
parallelism_sizes=parallelism,
activation_checkpointing=args.get("activation_checkpointing", False),
backend=args.get("backend", "nccl"),
world_size=args.get("world_size"),
)
logger.info("[Parallel] Creating DDPManager with config: %s", config)
return DDPManager(config)
logger.info("[Parallel] Creating DDPManager with config: %s", distributed_setup.strategy_config)
return DDPManager(distributed_setup.strategy_config)

elif manager_type == "fsdp2":
config = FSDP2Config(
world_size = args.get("world_size")
if world_size is None:
world_size = torch.distributed.get_world_size()

distributed_setup = DistributedSetup.build(
strategy=FSDP2Config(
mp_policy=args["mp_policy"] if "mp_policy" in args else None,
sequence_parallel=args.get("sequence_parallel", False),
tp_plan=args.get("tp_plan", None),
patch_is_packed_sequence=args.get("patch_is_packed_sequence", False),
offload_policy=args.get("offload_policy", None),
defer_fsdp_grad_sync=args.get("defer_fsdp_grad_sync", True),
enable_async_tensor_parallel=args.get("enable_async_tensor_parallel", False),
enable_compile=args.get("enable_compile", False),
enable_fsdp2_prefetch=args.get("enable_fsdp2_prefetch", False),
fsdp2_backward_prefetch_depth=args.get("fsdp2_backward_prefetch_depth", 2),
fsdp2_forward_prefetch_depth=args.get("fsdp2_forward_prefetch_depth", 1),
),
parallelism_sizes=parallelism,
activation_checkpointing=args.get("activation_checkpointing", False),
mp_policy=args.get("mp_policy", None),
backend=args.get("backend", "nccl"),
sequence_parallel=args.get("sequence_parallel", False),
tp_plan=args.get("tp_plan", None),
patch_is_packed_sequence=args.get("patch_is_packed_sequence", False),
offload_policy=args.get("offload_policy", None),
defer_fsdp_grad_sync=args.get("defer_fsdp_grad_sync", True),
enable_async_tensor_parallel=args.get("enable_async_tensor_parallel", False),
enable_compile=args.get("enable_compile", False),
enable_fsdp2_prefetch=args.get("enable_fsdp2_prefetch", False),
fsdp2_backward_prefetch_depth=args.get("fsdp2_backward_prefetch_depth", 2),
fsdp2_forward_prefetch_depth=args.get("fsdp2_forward_prefetch_depth", 1),
)

world_size = args.get("world_size") or torch.distributed.get_world_size()
device_mesh, moe_mesh = create_device_mesh(
config,
dp_size=args.get("dp_size"),
dp_replicate_size=args.get("dp_replicate_size"),
tp_size=args.get("tp_size", 1),
pp_size=args.get("pp_size", 1),
cp_size=args.get("cp_size", 1),
ep_size=args.get("ep_size", 1),
world_size=world_size,
)

logger.info("[Parallel] Creating FSDP2Manager with config: %s", config)
return FSDP2Manager(config, device_mesh=device_mesh, moe_mesh=moe_mesh)
mesh_context = distributed_setup.mesh_context
logger.info("[Parallel] Creating FSDP2Manager with config: %s", distributed_setup.strategy_config)
return FSDP2Manager(
distributed_setup.strategy_config,
device_mesh=mesh_context.device_mesh,
moe_mesh=mesh_context.moe_mesh,
)

else:
raise ValueError(f"Unknown manager type: '{manager_type}'. Expected 'ddp' or 'fsdp2'.")
Expand Down
Loading
Loading