Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions benchmarks/prototype/moe_training/bench_moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
from torchao.prototype.moe_training.config import (
FP8GroupedMMRecipe,
MXFP8GroupedMMConfig,
MXFP8GroupedMMRecipe,
MXFP8TrainingConfig,
MXFP8TrainingRecipe,
)
from torchao.quantization.quant_api import quantize_

Expand Down Expand Up @@ -60,11 +60,13 @@ def bench_moe_training_fsdp(args: argparse.Namespace):
if recipe_name == "fp8_rowwise":
recipe = FP8GroupedMMRecipe.FP8_ROWWISE
elif recipe_name == "mxfp8_rceil":
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL
elif recipe_name == "mxfp8_rceil_wgrad_with_hp":
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
else:
raise ValueError(f"Unknown recipe: {recipe_name}")

# Check hardware requirements
if (
recipe == FP8GroupedMMRecipe.FP8_ROWWISE
and torch.cuda.get_device_capability()
Expand All @@ -78,8 +80,8 @@ def bench_moe_training_fsdp(args: argparse.Namespace):
)
return

elif (
recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL
if (
recipe == MXFP8TrainingRecipe.MXFP8_RCEIL
and torch.cuda.get_device_capability()
!= (
10,
Expand Down Expand Up @@ -110,7 +112,7 @@ def bench_moe_training_fsdp(args: argparse.Namespace):
model = copy.deepcopy(ref_model)

# Token group alignment size must be 16 for fp8 rowwise training
alignment_size = 32 if recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL else 16
alignment_size = 32 if recipe == MXFP8TrainingRecipe.MXFP8_RCEIL else 16
set_token_group_alignment_size_m(alignment_size)

# assert starting params are identical for both models
Expand All @@ -125,7 +127,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MXFP8GroupedMMConfig.from_recipe(recipe)
config = MXFP8TrainingConfig.from_recipe(recipe)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# inputs
Expand Down
14 changes: 7 additions & 7 deletions benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
from torchao.prototype.moe_training.config import (
FP8GroupedMMRecipe,
MXFP8GroupedMMConfig,
MXFP8GroupedMMRecipe,
MXFP8TrainingConfig,
MXFP8TrainingRecipe,
)
from torchao.quantization.quant_api import quantize_

Expand All @@ -50,9 +50,9 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile:
if recipe_name.upper() == "fp8_rowwise":
recipe = FP8GroupedMMRecipe.FP8_ROWWISE
elif recipe_name.upper() == "mxfp8_rceil":
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL
elif recipe_name.upper() == "mxfp8_rceil_wgrad_with_hp":
recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
recipe = MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP
else:
raise ValueError(f"Unknown recipe: {recipe_name}")
if (
Expand All @@ -69,7 +69,7 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile:
return

elif (
recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL
recipe == MXFP8TrainingRecipe.MXFP8_RCEIL
and torch.cuda.get_device_capability()
!= (
10,
Expand Down Expand Up @@ -104,7 +104,7 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile:
model = copy.deepcopy(ref_model)

# Token group alignment size must be 16 for fp8 rowwise training
alignment_size = 32 if recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL else 16
alignment_size = 32 if recipe == MXFP8TrainingRecipe.MXFP8_RCEIL else 16
set_token_group_alignment_size_m(alignment_size)

# assert starting params are identical for both models
Expand All @@ -119,7 +119,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model
config = MXFP8GroupedMMConfig.from_recipe(recipe)
config = MXFP8TrainingConfig.from_recipe(recipe)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# FSDP2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from torchao.prototype.moe_training.config import (
FP8GroupedMMConfig,
FP8GroupedMMRecipe,
MXFP8GroupedMMConfig,
MXFP8GroupedMMRecipe,
MXFP8TrainingConfig,
MXFP8TrainingRecipe,
)
from torchao.prototype.moe_training.utils import generate_jagged_offs
from torchao.utils import is_MI300, is_MI350, is_ROCM
Expand All @@ -42,7 +42,7 @@
class ExperimentConfig:
high_precision_dtype: torch.dtype
MNKG: tuple[int]
recipe: Union[FP8GroupedMMRecipe, MXFP8GroupedMMRecipe]
recipe: Union[FP8GroupedMMRecipe, MXFP8TrainingRecipe]


@dataclass(frozen=True)
Expand Down Expand Up @@ -92,9 +92,8 @@ def get_configs() -> List[ExperimentConfig]:
(128000, 2048, 7168, 8),
]
recipes = [
FP8GroupedMMRecipe.FP8_ROWWISE,
MXFP8GroupedMMRecipe.MXFP8_RCEIL,
MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
MXFP8TrainingRecipe.MXFP8_RCEIL,
MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
]
high_precision_dtypes = [torch.bfloat16]
configs = []
Expand Down Expand Up @@ -173,7 +172,7 @@ def run_experiment(
if isinstance(config.recipe, FP8GroupedMMRecipe):
quant_config = FP8GroupedMMConfig.from_recipe(config.recipe)
else:
quant_config = MXFP8GroupedMMConfig.from_recipe(config.recipe)
quant_config = MXFP8TrainingConfig.from_recipe(config.recipe)

# fwd_bwd scaled benchmark + profiling
scaled_fwd_bwd_us = bench_fwd_bwd_microseconds(
Expand Down Expand Up @@ -276,8 +275,8 @@ def main(args: argparse.Namespace):
continue

elif config.recipe in (
MXFP8GroupedMMRecipe.MXFP8_RCEIL,
MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
MXFP8TrainingRecipe.MXFP8_RCEIL,
MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
) and torch.cuda.get_device_capability() != (10, 0):
logging.warning(
f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
Expand Down
30 changes: 9 additions & 21 deletions test/prototype/moe_training/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@

from torchao.float8.float8_utils import compute_error
from torchao.prototype.moe_training.config import (
FP8GroupedMMRecipe,
MXFP8GroupedMMConfig,
MXFP8GroupedMMRecipe,
MXFP8TrainingConfig,
MXFP8TrainingRecipe,
)
from torchao.quantization.quant_api import quantize_

Expand Down Expand Up @@ -133,28 +132,21 @@ def distributed_env():
"recipe_config",
[
{
"recipe": FP8GroupedMMRecipe.FP8_ROWWISE,
"group_alignment_size": 16,
"min_out_sqnr": 29.0,
"min_input_grad_sqnr": 29.0,
"min_param_grad_sqnr": 23.0,
},
{
"recipe": MXFP8GroupedMMRecipe.MXFP8_RCEIL,
"recipe": MXFP8TrainingRecipe.MXFP8_RCEIL,
"group_alignment_size": 32,
"min_out_sqnr": 27.0,
"min_input_grad_sqnr": 29.0,
"min_param_grad_sqnr": 21.0,
},
{
"recipe": MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
"recipe": MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
"group_alignment_size": 32,
"min_out_sqnr": 27.0,
"min_input_grad_sqnr": 29.0,
"min_param_grad_sqnr": 25.0,
},
{
"recipe": MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL,
"recipe": MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL,
"group_alignment_size": 32,
"min_out_sqnr": 27.0,
"min_input_grad_sqnr": 29.0,
Expand Down Expand Up @@ -183,19 +175,15 @@ def test_moe_training_parallel(
)
assert torch.cuda.is_available()

# Skip FP8 tests - FP8GroupedMMConfig not yet implemented
if isinstance(recipe, FP8GroupedMMRecipe):
pytest.skip("FP8GroupedMMConfig not yet implemented, will be added separately")

if recipe in (
MXFP8GroupedMMRecipe.MXFP8_RCEIL,
MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
MXFP8TrainingRecipe.MXFP8_RCEIL,
MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP,
):
if torch.cuda.get_device_capability() != (10, 0):
pytest.skip(
f"Non-emulated mode only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}"
)
elif recipe == MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL:
elif recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL:
if compile:
pytest.skip("MXFP8 emulated mode does not support torch.compile")

Expand Down Expand Up @@ -238,7 +226,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
return False

# quantize test model using MXFP8 config
config = MXFP8GroupedMMConfig.from_recipe(recipe)
config = MXFP8TrainingConfig.from_recipe(recipe)
quantize_(model, config=config, filter_fn=moe_module_filter_fn)

# validate that only the experts were converted
Expand Down
Loading
Loading