diff --git a/benchmarks/prototype/moe_training/bench_moe_layer.py b/benchmarks/prototype/moe_training/bench_moe_layer.py index b435131478..c8a8859dd6 100644 --- a/benchmarks/prototype/moe_training/bench_moe_layer.py +++ b/benchmarks/prototype/moe_training/bench_moe_layer.py @@ -17,8 +17,8 @@ from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd from torchao.prototype.moe_training.config import ( FP8GroupedMMRecipe, - MXFP8GroupedMMConfig, - MXFP8GroupedMMRecipe, + MXFP8TrainingConfig, + MXFP8TrainingRecipe, ) from torchao.quantization.quant_api import quantize_ @@ -60,11 +60,13 @@ def bench_moe_training_fsdp(args: argparse.Namespace): if recipe_name == "fp8_rowwise": recipe = FP8GroupedMMRecipe.FP8_ROWWISE elif recipe_name == "mxfp8_rceil": - recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL + recipe = MXFP8TrainingRecipe.MXFP8_RCEIL elif recipe_name == "mxfp8_rceil_wgrad_with_hp": - recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP + recipe = MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP else: raise ValueError(f"Unknown recipe: {recipe_name}") + + # Check hardware requirements if ( recipe == FP8GroupedMMRecipe.FP8_ROWWISE and torch.cuda.get_device_capability() @@ -78,8 +80,8 @@ def bench_moe_training_fsdp(args: argparse.Namespace): ) return - elif ( - recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL + if ( + recipe == MXFP8TrainingRecipe.MXFP8_RCEIL and torch.cuda.get_device_capability() != ( 10, @@ -110,7 +112,7 @@ def bench_moe_training_fsdp(args: argparse.Namespace): model = copy.deepcopy(ref_model) # Token group alignment size must be 16 for fp8 rowwise training - alignment_size = 32 if recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL else 16 + alignment_size = 32 if recipe == MXFP8TrainingRecipe.MXFP8_RCEIL else 16 set_token_group_alignment_size_m(alignment_size) # assert starting params are identical for both models @@ -125,7 +127,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MXFP8GroupedMMConfig.from_recipe(recipe) + config = MXFP8TrainingConfig.from_recipe(recipe) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # inputs diff --git a/benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py b/benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py index d8908e6fdb..fc6f59b74a 100644 --- a/benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py +++ b/benchmarks/prototype/moe_training/benchmark_moe_layer_fsdp.py @@ -26,8 +26,8 @@ from benchmarks.utils import bench_fwd_bwd_microseconds, profile_fwd_bwd from torchao.prototype.moe_training.config import ( FP8GroupedMMRecipe, - MXFP8GroupedMMConfig, - MXFP8GroupedMMRecipe, + MXFP8TrainingConfig, + MXFP8TrainingRecipe, ) from torchao.quantization.quant_api import quantize_ @@ -50,9 +50,9 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile: if recipe_name.upper() == "fp8_rowwise": recipe = FP8GroupedMMRecipe.FP8_ROWWISE elif recipe_name.upper() == "mxfp8_rceil": - recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL + recipe = MXFP8TrainingRecipe.MXFP8_RCEIL elif recipe_name.upper() == "mxfp8_rceil_wgrad_with_hp": - recipe = MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP + recipe = MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP else: raise ValueError(f"Unknown recipe: {recipe_name}") if ( @@ -69,7 +69,7 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile: return elif ( - recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL + recipe == MXFP8TrainingRecipe.MXFP8_RCEIL and torch.cuda.get_device_capability() != ( 10, @@ -104,7 +104,7 @@ def bench_moe_training_fsdp(recipe_name: str, enable_profile: bool, use_compile: model = copy.deepcopy(ref_model) # Token group alignment size must be 16 for fp8 rowwise training - alignment_size = 32 if recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL else 16 + alignment_size = 32 if recipe == MXFP8TrainingRecipe.MXFP8_RCEIL else 16 set_token_group_alignment_size_m(alignment_size) # assert starting params are identical for both models @@ -119,7 +119,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model - config = MXFP8GroupedMMConfig.from_recipe(recipe) + config = MXFP8TrainingConfig.from_recipe(recipe) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # FSDP2 diff --git a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py index 82755a2003..6dc8058d79 100644 --- a/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py +++ b/benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py @@ -23,8 +23,8 @@ from torchao.prototype.moe_training.config import ( FP8GroupedMMConfig, FP8GroupedMMRecipe, - MXFP8GroupedMMConfig, - MXFP8GroupedMMRecipe, + MXFP8TrainingConfig, + MXFP8TrainingRecipe, ) from torchao.prototype.moe_training.utils import generate_jagged_offs from torchao.utils import is_MI300, is_MI350, is_ROCM @@ -42,7 +42,7 @@ class ExperimentConfig: high_precision_dtype: torch.dtype MNKG: tuple[int] - recipe: Union[FP8GroupedMMRecipe, MXFP8GroupedMMRecipe] + recipe: Union[FP8GroupedMMRecipe, MXFP8TrainingRecipe] @dataclass(frozen=True) @@ -92,9 +92,8 @@ def get_configs() -> List[ExperimentConfig]: (128000, 2048, 7168, 8), ] recipes = [ - FP8GroupedMMRecipe.FP8_ROWWISE, - MXFP8GroupedMMRecipe.MXFP8_RCEIL, - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, + MXFP8TrainingRecipe.MXFP8_RCEIL, + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, ] high_precision_dtypes = [torch.bfloat16] configs = [] @@ -173,7 +172,7 @@ def run_experiment( if isinstance(config.recipe, FP8GroupedMMRecipe): quant_config = FP8GroupedMMConfig.from_recipe(config.recipe) else: - quant_config = MXFP8GroupedMMConfig.from_recipe(config.recipe) + quant_config = MXFP8TrainingConfig.from_recipe(config.recipe) # fwd_bwd scaled benchmark + profiling scaled_fwd_bwd_us = bench_fwd_bwd_microseconds( @@ -276,8 +275,8 @@ def main(args: argparse.Namespace): continue elif config.recipe in ( - MXFP8GroupedMMRecipe.MXFP8_RCEIL, - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, + MXFP8TrainingRecipe.MXFP8_RCEIL, + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, ) and torch.cuda.get_device_capability() != (10, 0): logging.warning( f"Skipping MXFP8 benchmarks, only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" diff --git a/test/prototype/moe_training/test_distributed.py b/test/prototype/moe_training/test_distributed.py index 3bdad3e85a..e5c5be7e58 100644 --- a/test/prototype/moe_training/test_distributed.py +++ b/test/prototype/moe_training/test_distributed.py @@ -52,9 +52,8 @@ from torchao.float8.float8_utils import compute_error from torchao.prototype.moe_training.config import ( - FP8GroupedMMRecipe, - MXFP8GroupedMMConfig, - MXFP8GroupedMMRecipe, + MXFP8TrainingConfig, + MXFP8TrainingRecipe, ) from torchao.quantization.quant_api import quantize_ @@ -133,28 +132,21 @@ def distributed_env(): "recipe_config", [ { - "recipe": FP8GroupedMMRecipe.FP8_ROWWISE, - "group_alignment_size": 16, - "min_out_sqnr": 29.0, - "min_input_grad_sqnr": 29.0, - "min_param_grad_sqnr": 23.0, - }, - { - "recipe": MXFP8GroupedMMRecipe.MXFP8_RCEIL, + "recipe": MXFP8TrainingRecipe.MXFP8_RCEIL, "group_alignment_size": 32, "min_out_sqnr": 27.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 21.0, }, { - "recipe": MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, + "recipe": MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, "group_alignment_size": 32, "min_out_sqnr": 27.0, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 25.0, }, { - "recipe": MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL, + "recipe": MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL, "group_alignment_size": 32, "min_out_sqnr": 27.0, "min_input_grad_sqnr": 29.0, @@ -183,19 +175,15 @@ def test_moe_training_parallel( ) assert torch.cuda.is_available() - # Skip FP8 tests - FP8GroupedMMConfig not yet implemented - if isinstance(recipe, FP8GroupedMMRecipe): - pytest.skip("FP8GroupedMMConfig not yet implemented, will be added separately") - if recipe in ( - MXFP8GroupedMMRecipe.MXFP8_RCEIL, - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, + MXFP8TrainingRecipe.MXFP8_RCEIL, + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, ): if torch.cuda.get_device_capability() != (10, 0): pytest.skip( f"Non-emulated mode only supported on compute capability 10.0 and found {torch.cuda.get_device_capability()}" ) - elif recipe == MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL: + elif recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL: if compile: pytest.skip("MXFP8 emulated mode does not support torch.compile") @@ -238,7 +226,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: return False # quantize test model using MXFP8 config - config = MXFP8GroupedMMConfig.from_recipe(recipe) + config = MXFP8TrainingConfig.from_recipe(recipe) quantize_(model, config=config, filter_fn=moe_module_filter_fn) # validate that only the experts were converted diff --git a/test/prototype/moe_training/test_fqn_to_config.py b/test/prototype/moe_training/test_fqn_to_config.py index 5b3ed33b3f..ec95c7d742 100644 --- a/test/prototype/moe_training/test_fqn_to_config.py +++ b/test/prototype/moe_training/test_fqn_to_config.py @@ -14,10 +14,10 @@ from torch import nn from torchao.prototype.moe_training.config import ( - MXFP8GroupedMMConfig, - MXFP8GroupedMMRecipe, + MXFP8TrainingConfig, + MXFP8TrainingRecipe, ) -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.prototype.moe_training.tensor import MXFP8TrainingTensor from torchao.prototype.mx_formats.config import MXLinearConfig, MXLinearRecipeName from torchao.prototype.mx_formats.mx_linear import MXLinear from torchao.quantization import FqnToConfig @@ -61,10 +61,10 @@ def test_fqn_to_config_simple(): config = FqnToConfig( fqn_to_config=OrderedDict( [ - # Apply MXFP8GroupedMMConfig to expert parameters + # Apply MXFP8TrainingConfig to expert parameters ( "experts", - MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_RCEIL), + MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL), ), # Apply MXLinearConfig to dense layers ( @@ -87,14 +87,14 @@ def test_fqn_to_config_simple(): quantize_(model, config, filter_fn=None) # Verify transformations - assert isinstance(model.experts.w1.data, ScaledGroupedMMTensor), ( - "w1 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w1.data, MXFP8TrainingTensor), ( + "w1 should be MXFP8TrainingTensor" ) - assert isinstance(model.experts.w2.data, ScaledGroupedMMTensor), ( - "w2 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w2.data, MXFP8TrainingTensor), ( + "w2 should be MXFP8TrainingTensor" ) - assert model.experts.w1.data.config == MXFP8GroupedMMConfig.from_recipe( - MXFP8GroupedMMRecipe.MXFP8_RCEIL + assert model.experts.w1.data.config == MXFP8TrainingConfig.from_recipe( + MXFP8TrainingRecipe.MXFP8_RCEIL ) assert isinstance(model.pre_moe, MXLinear), "pre_moe should be MXLinear" assert isinstance(model.post_moe, MXLinear), "post_moe should be MXLinear" @@ -110,7 +110,7 @@ def test_fqn_to_config_with_regex(): [ ( "re:.*experts.*", - MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_RCEIL), + MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL), ), ( "re:^(pre_moe|post_moe)$", @@ -125,14 +125,14 @@ def test_fqn_to_config_with_regex(): quantize_(model, config, filter_fn=None) # Verify transformations - assert isinstance(model.experts.w1.data, ScaledGroupedMMTensor), ( - "w1 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w1.data, MXFP8TrainingTensor), ( + "w1 should be MXFP8TrainingTensor" ) - assert model.experts.w1.data.config == MXFP8GroupedMMConfig.from_recipe( - MXFP8GroupedMMRecipe.MXFP8_RCEIL + assert model.experts.w1.data.config == MXFP8TrainingConfig.from_recipe( + MXFP8TrainingRecipe.MXFP8_RCEIL ) - assert isinstance(model.experts.w2.data, ScaledGroupedMMTensor), ( - "w2 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w2.data, MXFP8TrainingTensor), ( + "w2 should be MXFP8TrainingTensor" ) assert isinstance(model.pre_moe, MXLinear), "pre_moe should be MXLinear" assert isinstance(model.post_moe, MXLinear), "post_moe should be MXLinear" @@ -148,7 +148,7 @@ def test_fqn_to_config_experts_only(): [ ( "re:.*experts.*", - MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_RCEIL), + MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL), ), ] ) @@ -157,11 +157,11 @@ def test_fqn_to_config_experts_only(): quantize_(model, config, filter_fn=None) # Verify transformations - assert isinstance(model.experts.w1.data, ScaledGroupedMMTensor), ( - "w1 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w1.data, MXFP8TrainingTensor), ( + "w1 should be MXFP8TrainingTensor" ) - assert isinstance(model.experts.w2.data, ScaledGroupedMMTensor), ( - "w2 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w2.data, MXFP8TrainingTensor), ( + "w2 should be MXFP8TrainingTensor" ) # Dense layers should remain unchanged assert isinstance(model.pre_moe, nn.Linear) and not isinstance( @@ -182,7 +182,7 @@ def test_fqn_to_config_selective_layers(): [ ( "re:.*experts.*", - MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_RCEIL), + MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL), ), ( "pre_moe", @@ -197,11 +197,11 @@ def test_fqn_to_config_selective_layers(): quantize_(model, config, filter_fn=None) # Verify transformations - assert isinstance(model.experts.w1.data, ScaledGroupedMMTensor), ( - "w1 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w1.data, MXFP8TrainingTensor), ( + "w1 should be MXFP8TrainingTensor" ) - assert isinstance(model.experts.w2.data, ScaledGroupedMMTensor), ( - "w2 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w2.data, MXFP8TrainingTensor), ( + "w2 should be MXFP8TrainingTensor" ) assert isinstance(model.pre_moe, MXLinear), "pre_moe should be MXLinear" # post_moe should remain unchanged @@ -219,8 +219,8 @@ def test_fqn_to_config_mxfp8_wgrad_with_hp(): [ ( "re:.*experts.*", - MXFP8GroupedMMConfig.from_recipe( - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP + MXFP8TrainingConfig.from_recipe( + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP ), ), ( @@ -236,14 +236,14 @@ def test_fqn_to_config_mxfp8_wgrad_with_hp(): quantize_(model, config, filter_fn=None) # Verify transformations - assert isinstance(model.experts.w1.data, ScaledGroupedMMTensor), ( - "w1 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w1.data, MXFP8TrainingTensor), ( + "w1 should be MXFP8TrainingTensor" ) - assert model.experts.w1.data.config == MXFP8GroupedMMConfig.from_recipe( - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP + assert model.experts.w1.data.config == MXFP8TrainingConfig.from_recipe( + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP ), "w1 should use RCEIL_WGRAD_WITH_HP recipe" - assert isinstance(model.experts.w2.data, ScaledGroupedMMTensor), ( - "w2 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w2.data, MXFP8TrainingTensor), ( + "w2 should be MXFP8TrainingTensor" ) assert isinstance(model.pre_moe, MXLinear), "pre_moe should be MXLinear" assert isinstance(model.post_moe, MXLinear), "post_moe should be MXLinear" @@ -270,10 +270,10 @@ def test_fqn_to_config_dense_only(): quantize_(model, config, filter_fn=None) # Verify only Linear layers were transformed - assert not isinstance(model.experts.w1.data, ScaledGroupedMMTensor), ( + assert not isinstance(model.experts.w1.data, MXFP8TrainingTensor), ( "w1 should remain regular tensor" ) - assert not isinstance(model.experts.w2.data, ScaledGroupedMMTensor), ( + assert not isinstance(model.experts.w2.data, MXFP8TrainingTensor), ( "w2 should remain regular tensor" ) assert isinstance(model.pre_moe, MXLinear), "pre_moe should be MXLinear" @@ -291,12 +291,12 @@ def test_fqn_to_config_specific_expert_params(): # Apply different MXFP8 recipes to test granular fqn selection ( "experts.w1", - MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_RCEIL), + MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_RCEIL), ), ( "experts.w2", - MXFP8GroupedMMConfig.from_recipe( - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP + MXFP8TrainingConfig.from_recipe( + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP ), ), ( @@ -311,17 +311,17 @@ def test_fqn_to_config_specific_expert_params(): quantize_(model, config, filter_fn=None) # Verify different recipes were applied - assert isinstance(model.experts.w1.data, ScaledGroupedMMTensor), ( - "w1 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w1.data, MXFP8TrainingTensor), ( + "w1 should be MXFP8TrainingTensor" ) - assert model.experts.w1.data.config == MXFP8GroupedMMConfig.from_recipe( - MXFP8GroupedMMRecipe.MXFP8_RCEIL + assert model.experts.w1.data.config == MXFP8TrainingConfig.from_recipe( + MXFP8TrainingRecipe.MXFP8_RCEIL ), "w1 should use MXFP8 RCEIL" - assert isinstance(model.experts.w2.data, ScaledGroupedMMTensor), ( - "w2 should be ScaledGroupedMMTensor" + assert isinstance(model.experts.w2.data, MXFP8TrainingTensor), ( + "w2 should be MXFP8TrainingTensor" ) - assert model.experts.w2.data.config == MXFP8GroupedMMConfig.from_recipe( - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP + assert model.experts.w2.data.config == MXFP8TrainingConfig.from_recipe( + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP ), "w2 should use MXFP8 RCEIL_WGRAD_WITH_HP" assert isinstance(model.pre_moe, MXLinear), "pre_moe should be MXLinear" assert isinstance(model.post_moe, MXLinear), "post_moe should be MXLinear" diff --git a/test/prototype/moe_training/test_mxfp8_training_tensor.py b/test/prototype/moe_training/test_mxfp8_training_tensor.py new file mode 100644 index 0000000000..e75531177e --- /dev/null +++ b/test/prototype/moe_training/test_mxfp8_training_tensor.py @@ -0,0 +1,143 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +import torch.nn.functional as F + +from torchao.utils import torch_version_at_least + +# Skip module if basic requirements aren't met +if not (torch_version_at_least("2.7.0") and torch.cuda.is_available()): + pytest.skip("CUDA and PyTorch 2.7.0+ required", allow_module_level=True) + +from torchao.prototype.moe_training.config import ( + MXFP8TrainingConfig, + MXFP8TrainingRecipe, +) +from torchao.prototype.moe_training.tensor import MXFP8TrainingTensor +from torchao.quantization.utils import compute_error + + +@pytest.mark.parametrize("op_name", ["mm", "matmul", "linear"]) +@pytest.mark.parametrize("batch_size", [None, 2, 4]) +def test_mxfp8_training_tensor_ops_fwd_bwd(op_name, batch_size): + # mm doesn't support batching + if op_name == "mm" and batch_size is not None: + pytest.skip("mm doesn't support batching") + + config = MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL) + + # Create input tensors - dimensions must be divisible by 32 + # Use larger sizes for better SQNR, especially with bias in linear ops + M, K, N = 1024, 1024, 2048 + if batch_size is None: + A_shape = (M, K) + else: + A_shape = (batch_size, M, K) + + A = torch.randn(*A_shape, dtype=torch.bfloat16, device="cuda", requires_grad=True) + B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda", requires_grad=True) + bias = ( + torch.randn(N, dtype=torch.bfloat16, device="cuda") + if op_name == "linear" + else None + ) + + # Reference computation with bf16 + A_ref = A.clone().detach().requires_grad_(True) + B_ref = B.clone().detach().requires_grad_(True) + + if op_name == "mm": + result_ref = torch.mm(A_ref, B_ref.t()) + elif op_name == "matmul": + result_ref = torch.matmul(A_ref, B_ref.t()) + elif op_name == "linear": + result_ref = F.linear(A_ref, B_ref, bias) + + # MXFP8 computation + B_mxfp8 = MXFP8TrainingTensor(B, config) + + if op_name == "mm": + result_mxfp8 = torch.mm(A, B_mxfp8) + elif op_name == "matmul": + result_mxfp8 = torch.matmul(A, B_mxfp8) + elif op_name == "linear": + result_mxfp8 = F.linear(A, B_mxfp8, bias) + + # Validate forward pass + assert result_mxfp8.shape == result_ref.shape, "Shape mismatch" + assert result_mxfp8.dtype == torch.bfloat16, "Dtype should be bfloat16" + assert not isinstance(result_mxfp8, MXFP8TrainingTensor), ( + "Result should be unwrapped" + ) + + # Check forward SQNR + # Linear with bias has slightly lower SQNR due to bias addition + sqnr_fwd = compute_error(result_ref, result_mxfp8) + min_sqnr_fwd = 26.0 if op_name == "linear" else 27.0 + assert sqnr_fwd >= min_sqnr_fwd, ( + f"Forward SQNR {sqnr_fwd} is too low, must be >= {min_sqnr_fwd}" + ) + + # Backward pass with MSE loss to avoid contiguity issues + labels_ref = torch.ones_like(result_ref) + labels_mxfp8 = torch.ones_like(result_mxfp8) + loss_ref = F.mse_loss(result_ref, labels_ref) + loss_mxfp8 = F.mse_loss(result_mxfp8, labels_mxfp8) + loss_ref.backward() + loss_mxfp8.backward() + + # Verify gradients exist + assert A.grad is not None, "A.grad should be computed" + assert A_ref.grad is not None, "A_ref.grad should be computed" + assert B_mxfp8.grad is not None, "B_mxfp8.grad should be computed" + assert B_ref.grad is not None, "B_ref.grad should be computed" + + # Check input gradient SQNR + sqnr_input_grad = compute_error(A_ref.grad, A.grad) + min_sqnr_input_grad = 25.0 + assert sqnr_input_grad >= min_sqnr_input_grad, ( + f"Input grad SQNR {sqnr_input_grad} is too low, must be >= {min_sqnr_input_grad}" + ) + + # Check weight gradient SQNR + sqnr_weight_grad = compute_error(B_ref.grad, B_mxfp8.grad) + min_sqnr_weight_grad = 24.0 + assert sqnr_weight_grad >= min_sqnr_weight_grad, ( + f"Weight grad SQNR {sqnr_weight_grad} is too low, must be >= {min_sqnr_weight_grad}" + ) + + +def test_mxfp8_training_tensor_ops_preserve_subclass(): + config = MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL) + + B = torch.randn(64, 32, dtype=torch.bfloat16, device="cuda") + B_mxfp8 = MXFP8TrainingTensor(B, config) + + # view + result = B_mxfp8.view(32, 64) + assert isinstance(result, MXFP8TrainingTensor), "view should preserve subclass" + + # transpose.int + result = B_mxfp8.transpose(0, 1) + assert isinstance(result, MXFP8TrainingTensor), ( + "transpose.int should preserve subclass" + ) + + # transpose.default + result = B_mxfp8.t() + assert isinstance(result, MXFP8TrainingTensor), ( + "transpose.default should preserve subclass" + ) + + # clone + result = B_mxfp8.clone() + assert isinstance(result, MXFP8TrainingTensor), "clone should preserve subclass" + + # slice + result = B_mxfp8[:32, :] + assert isinstance(result, MXFP8TrainingTensor), "slice should preserve subclass" diff --git a/test/prototype/moe_training/test_scaled_grouped_mm.py b/test/prototype/moe_training/test_scaled_grouped_mm.py index 1dd5970bdc..5860822376 100644 --- a/test/prototype/moe_training/test_scaled_grouped_mm.py +++ b/test/prototype/moe_training/test_scaled_grouped_mm.py @@ -32,8 +32,8 @@ from torchao.prototype.moe_training.config import ( FP8GroupedMMConfig, FP8GroupedMMRecipe, - MXFP8GroupedMMConfig, - MXFP8GroupedMMRecipe, + MXFP8TrainingConfig, + MXFP8TrainingRecipe, ) from torchao.prototype.moe_training.mxfp8_grouped_mm import ( _emulated_mxfp8_scaled_grouped_mm_2d_2d, @@ -63,7 +63,7 @@ @pytest.mark.parametrize("n", [8192]) @pytest.mark.parametrize("k", [5120]) @pytest.mark.parametrize("n_groups", [1, 2, 4, 8]) -def test_valid_scaled_grouped_mm_2d_3d(m, n, k, n_groups): +def test_fp8_rowwise_scaled_grouped_mm(m, n, k, n_groups): if is_ROCM(): if not (is_MI300() or is_MI350()): pytest.skip("FP8 rowwise test requires MI300 or MI350 on ROCm") @@ -168,7 +168,7 @@ def test_K_or_N_dim_not_multiple_of_16(m, n, k): b_t = b.transpose(-2, -1) b_t = b_t.transpose(-2, -1).contiguous().transpose(-2, -1) - config = MXFP8GroupedMMConfig.from_recipe(MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL) + config = MXFP8TrainingConfig.from_recipe(MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL) offs = torch.arange(m, n_groups * m + 1, m, device="cuda", dtype=torch.int32) # Compute output. @@ -344,9 +344,9 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts): @skip_if_rocm("ROCm not supported") @pytest.mark.parametrize("M,K,N", [(32768, 5120, 8192), (16640, 7168, 2048)]) -@pytest.mark.parametrize("num_experts", (2, 4, 8, 16)) +@pytest.mark.parametrize("num_experts", (1, 8)) @pytest.mark.parametrize("wgrad_with_hp", (True, False)) -@pytest.mark.parametrize("use_compile", (True, False)) +@pytest.mark.parametrize("use_compile", (False, True)) @pytest.mark.parametrize( "kernel_preference", (KernelPreference.AUTO, KernelPreference.EMULATED) ) diff --git a/test/prototype/moe_training/test_training.py b/test/prototype/moe_training/test_training.py index 7d9f4bec4b..01d7c25abf 100644 --- a/test/prototype/moe_training/test_training.py +++ b/test/prototype/moe_training/test_training.py @@ -15,8 +15,8 @@ from torchao.prototype.moe_training.config import ( FP8GroupedMMConfig, FP8GroupedMMRecipe, - MXFP8GroupedMMConfig, - MXFP8GroupedMMRecipe, + MXFP8TrainingConfig, + MXFP8TrainingRecipe, ) from torchao.quantization.quant_api import quantize_ from torchao.quantization.quantize_.common import KernelPreference @@ -31,8 +31,7 @@ @pytest.mark.parametrize( - "target_fqns", - [["experts"]], + "target_fqns", [["experts"], ["shared_experts"], ["experts", "shared_experts"]] ) @pytest.mark.parametrize("compile", [False, True]) @pytest.mark.parametrize( @@ -42,30 +41,23 @@ "recipe_config", [ { - "recipe": FP8GroupedMMRecipe.FP8_ROWWISE, - "group_alignment_size": 16, - "min_out_sqnr": 29.0, - "min_input_grad_sqnr": 29.0, - "min_param_grad_sqnr": 23.0, - }, - { - "recipe": MXFP8GroupedMMRecipe.MXFP8_RCEIL, + "recipe": MXFP8TrainingRecipe.MXFP8_RCEIL, "group_alignment_size": 32, - "min_out_sqnr": 28.0, + "min_out_sqnr": 26.5, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 21.0, }, { - "recipe": MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, + "recipe": MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, "group_alignment_size": 32, - "min_out_sqnr": 28.0, + "min_out_sqnr": 26.5, "min_input_grad_sqnr": 29.0, - "min_param_grad_sqnr": 25.0, + "min_param_grad_sqnr": 23.0, }, { - "recipe": MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL, + "recipe": MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL, "group_alignment_size": 32, - "min_out_sqnr": 27.0, + "min_out_sqnr": 26.5, "min_input_grad_sqnr": 29.0, "min_param_grad_sqnr": 21.0, }, @@ -93,7 +85,7 @@ def test_moe_training( assert torch.cuda.is_available() # Emulated mode with compile is not supported - if recipe == MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL and compile: + if recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL and compile: pytest.skip( "Skipping compile=True with kernel_preference=EMULATED, not currently supported" ) @@ -111,8 +103,8 @@ def test_moe_training( # MXFP8 hardware path requires SM100 if recipe in ( - MXFP8GroupedMMRecipe.MXFP8_RCEIL, - MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, + MXFP8TrainingRecipe.MXFP8_RCEIL, + MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP, ) and torch.cuda.get_device_capability() != ( 10, 0, @@ -128,6 +120,7 @@ def test_moe_training( set_token_group_alignment_size_m(group_alignment_size) model_args = MoEArgs( num_experts=8, + num_shared_experts=1, ) init_std = 0.02 device = torch.device("cuda") @@ -154,8 +147,8 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool: # quantize test model config_cls = ( - MXFP8GroupedMMConfig - if isinstance(recipe, MXFP8GroupedMMRecipe) + MXFP8TrainingConfig + if isinstance(recipe, MXFP8TrainingRecipe) else FP8GroupedMMConfig ) config = config_cls.from_recipe(recipe) diff --git a/test/prototype/moe_training/testing_utils.py b/test/prototype/moe_training/testing_utils.py index 1d062b5b8b..863f2bcdb3 100644 --- a/test/prototype/moe_training/testing_utils.py +++ b/test/prototype/moe_training/testing_utils.py @@ -1,7 +1,7 @@ import torch from torch import nn -from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor +from torchao.prototype.moe_training.tensor import MXFP8TrainingTensor def _validate_model_conversion( @@ -16,7 +16,7 @@ def _recursive_validate( # check current module params for param_name, param in module.named_parameters(recurse=False): - is_converted_type = isinstance(param, ScaledGroupedMMTensor) + is_converted_type = isinstance(param, MXFP8TrainingTensor) if is_converted_type: assert is_allowed_module, ( f"Module {cur_fqn} is not in target_fqns, but has converted param {param_name}." diff --git a/torchao/prototype/moe_training/README.md b/torchao/prototype/moe_training/README.md index d984c8cc41..c185e4ccb4 100644 --- a/torchao/prototype/moe_training/README.md +++ b/torchao/prototype/moe_training/README.md @@ -273,16 +273,16 @@ This prototype is specifically designed to be used on MoE models using where expert weights are implemented as 3D nn.Parameters with `num_experts` as the leading dim. -The `MXFP8GroupedMMConfig` has a module handler registered to it which will +The `MXFP8TrainingConfig` has a module handler registered to it which will find all nn.Parameters whose parent module matches the module filter function, -and swap their data tensor with a ScaledGroupedMMTensor. +and swap their data tensor with a MXFP8TrainingTensor. -The ScaledGroupedMMTensor is a tensor subclass which overrides the +The MXFP8TrainingTensor is a tensor subclass which overrides the `torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm, which performs dynamic quantization on scaled grouped GEMM operands in both the forward and backward pass, based on the quantization config (FP8/MXFP8/etc). -For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. +For all other ops, MXFP8TrainingTensor behaves like a regular torch.Tensor. ## Limitations - The new CUDA kernel for MXFP8 quantization of the non-transposed expert weights in the backwards pass does not support TP yet. diff --git a/torchao/prototype/moe_training/config.py b/torchao/prototype/moe_training/config.py index 6bb999f6d9..8259dce0ac 100644 --- a/torchao/prototype/moe_training/config.py +++ b/torchao/prototype/moe_training/config.py @@ -24,7 +24,7 @@ class FP8GroupedMMRecipe(Enum): FP8_ROWWISE = "fp8_rowwise" -class MXFP8GroupedMMRecipe(Enum): +class MXFP8TrainingRecipe(Enum): """MXFP8 recipes for grouped matrix multiplication.""" # TODO: add floor variants @@ -33,14 +33,19 @@ class MXFP8GroupedMMRecipe(Enum): MXFP8_EMULATED_RCEIL = "mxfp8_emulated_rceil" -class GroupedMMConfig(AOBaseConfig): - """Base configuration for grouped matrix multiplication. Not intended to be used directly.""" +class TrainingBaseConfig(AOBaseConfig): + """ + Base configuration for low precision training. Not intended to be used directly. + + Purpose is to support generic model conversion function for linear and grouped gemm + low precision training. + """ pass @dataclass -class FP8GroupedMMConfig(GroupedMMConfig): +class FP8GroupedMMConfig(TrainingBaseConfig): """ Configuration for FP8 grouped matrix multiplication. """ @@ -67,23 +72,19 @@ def from_recipe( # register as pytree constant so we can use dynamo nonstrict trace in torchao.prototype.moe_training.ep @register_as_pytree_constant @dataclass -class MXFP8GroupedMMConfig(GroupedMMConfig): +class MXFP8TrainingConfig(TrainingBaseConfig): """ - The MXFP8GroupedMMConfig is specifically designed to be used on MoE models using - `torch._grouped_mm` to implement expert computation in token-choice routing, - where expert weights are implemented as 3D nn.Parameters wit `num_experts` as - the leading dim. + The MXFP8TrainingConfig defines the MXFP8 training config for nn.Linear layers + and grouped GEMM ops. - MXFP8GroupedMMConfig has a module handler registered to it which will + MXFP8TrainingConfig has a module handler registered to it which will find all nn.Parameters whose parent module matches the module filter function, - and swap their data tensor with a ScaledGroupedMMTensor. + and swap their data tensor with a MXFP8TrainingTensor. - The ScaledGroupedMMTensor is a tensor subclass which overrides the - `torch._grouped_mm` op by dispatching to a differentiable scaled grouped mm, - which performs dynamic quantization on scaled grouped GEMM operands in both - the forward and backward pass, based on the quantization config (FP8/MXFP8/etc). + The MXFP8TrainingTensor dispatches matmul and grouped gemm ops to custom + autograd functions which dynamically quantize inputs to MXFP8. - For all other ops, ScaledGroupedMMTensor behaves like a regular torch.Tensor. + For all other ops, MXFP8TrainingTensor behaves like a regular torch.Tensor. """ # AUTO = Use best supported kernel for quantization ops and GEMMs (CUDA and Triton for quantizatoin, CUTLASS for MXFP8 grouped GEM @@ -104,24 +105,24 @@ class MXFP8GroupedMMConfig(GroupedMMConfig): @classmethod def from_recipe( cls, - recipe: MXFP8GroupedMMRecipe, - ) -> "MXFP8GroupedMMConfig": - """Factory method to create a MXFP8GroupedMMConfig from a MXFP8GroupedMMRecipe.""" - if recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL: + recipe: MXFP8TrainingRecipe, + ) -> "MXFP8TrainingConfig": + """Factory method to create a MXFP8TrainingConfig from a MXFP8TrainingRecipe.""" + if recipe == MXFP8TrainingRecipe.MXFP8_RCEIL: return cls( kernel_preference=KernelPreference.AUTO, out_dtype=torch.bfloat16, wgrad_with_hp=False, scale_calculation_mode=ScaleCalculationMode.RCEIL, ) - elif recipe == MXFP8GroupedMMRecipe.MXFP8_RCEIL_WGRAD_WITH_HP: + elif recipe == MXFP8TrainingRecipe.MXFP8_RCEIL_WGRAD_WITH_HP: return cls( kernel_preference=KernelPreference.AUTO, out_dtype=torch.bfloat16, wgrad_with_hp=True, scale_calculation_mode=ScaleCalculationMode.RCEIL, ) - elif recipe == MXFP8GroupedMMRecipe.MXFP8_EMULATED_RCEIL: + elif recipe == MXFP8TrainingRecipe.MXFP8_EMULATED_RCEIL: return cls( kernel_preference=KernelPreference.EMULATED, out_dtype=torch.bfloat16, @@ -132,7 +133,7 @@ def from_recipe( raise ValueError(f"Unsupported MXFP8 recipe: {recipe}") def __eq__(self, other): - if isinstance(other, MXFP8GroupedMMConfig): + if isinstance(other, MXFP8TrainingConfig): return ( self.kernel_preference == other.kernel_preference and self.out_dtype == other.out_dtype @@ -152,19 +153,18 @@ def __hash__(self): ) -@register_quantize_module_handler(FP8GroupedMMConfig) -@register_quantize_module_handler(MXFP8GroupedMMConfig) +@register_quantize_module_handler(MXFP8TrainingConfig) def _moe_training_transform( module: nn.Module, - config: GroupedMMConfig, + config: TrainingBaseConfig, parameter_name: Optional[str] = None, ) -> nn.Module: """ - Swaps `torch.nn.Parameter` data tensor with a ScaledGroupedMMTensor. + Swaps `torch.nn.Parameter` data tensor with a MXFP8TrainingTensor. Args: module: Module to modify. - config: GroupedMMConfig which defines how to perform the MoE training transform. + config: TrainingBaseConfig which defines how to perform the training transform (i.e., convert linears and grouped GEMMs) parameter_name: If specified, only transform this specific parameter. Otherwise transform all parameters. Returns: diff --git a/torchao/prototype/moe_training/conversion_utils.py b/torchao/prototype/moe_training/conversion_utils.py index 0ce11b7402..c5641f5e98 100644 --- a/torchao/prototype/moe_training/conversion_utils.py +++ b/torchao/prototype/moe_training/conversion_utils.py @@ -9,7 +9,7 @@ from torch import nn from torchao.prototype.moe_training.config import ( - GroupedMMConfig, + TrainingBaseConfig, ) logger: logging.Logger = logging.getLogger(__name__) @@ -19,12 +19,12 @@ def _swap_params( module: nn.Module, *, module_filter_fn: Optional[Callable[[nn.Module, str], bool]] = None, - config: Optional[GroupedMMConfig] = None, + config: Optional[TrainingBaseConfig] = None, target_parameter_name: Optional[str] = None, ) -> nn.Module: """ Recurses through the nn.Module, recursively swapping the data tensor of - each nn.Parameter with a ScaledGroupedMMTensor. Only applies if the module + each nn.Parameter with a MXFP8TrainingTensor. Only applies if the module passed the module_filter_fn, if specified. Args: @@ -36,7 +36,7 @@ def _swap_params( Returns: nn.Module: The modified module with swapped linear layers. """ - from torchao.prototype.moe_training.tensor import ScaledGroupedMMTensor + from torchao.prototype.moe_training.tensor import MXFP8TrainingTensor if isinstance(module, nn.Parameter) and ( module_filter_fn is None or module_filter_fn(module, "") @@ -45,8 +45,8 @@ def _swap_params( raise AssertionError( f"Does not support a root nn.Parameter with children: {module}" ) - if not isinstance(module.data, ScaledGroupedMMTensor): - new_data = ScaledGroupedMMTensor(module.data, config) + if not isinstance(module.data, MXFP8TrainingTensor): + new_data = MXFP8TrainingTensor(module.data, config) return nn.Parameter(new_data, requires_grad=module.requires_grad) return module @@ -67,7 +67,6 @@ def post_order_traversal( new_fqn = f"{cur_fqn}.{child_module_name}" post_order_traversal(child_module, new_fqn, module) - if module_filter_fn is None or module_filter_fn(module, cur_fqn): for param_name, param in module.named_parameters(recurse=False): if ( @@ -75,14 +74,14 @@ def post_order_traversal( and param_name != target_parameter_name ): continue - if not isinstance(param.data, ScaledGroupedMMTensor): + if not isinstance(param.data, MXFP8TrainingTensor): new_param = nn.Parameter( - ScaledGroupedMMTensor(param.data, config), + MXFP8TrainingTensor(param.data, config), requires_grad=param.requires_grad, ) setattr(module, param_name, new_param) logger.info( - f"Swapped {cur_fqn}.{param_name} to ScaledGroupedMMTensor" + f"Swapped {cur_fqn}.{param_name} to MXFP8TrainingTensor" ) post_order_traversal(root_module) diff --git a/torchao/prototype/moe_training/tensor.py b/torchao/prototype/moe_training/tensor.py index 14d3c72ce7..53652f1a6d 100644 --- a/torchao/prototype/moe_training/tensor.py +++ b/torchao/prototype/moe_training/tensor.py @@ -17,8 +17,8 @@ from torchao.prototype.moe_training.config import ( FP8GroupedMMConfig, - GroupedMMConfig, - MXFP8GroupedMMConfig, + MXFP8TrainingConfig, + TrainingBaseConfig, ) from torchao.prototype.moe_training.fp8_grouped_mm import ( _to_fp8_rowwise_then_scaled_grouped_mm, @@ -26,6 +26,10 @@ from torchao.prototype.moe_training.mxfp8_grouped_mm import ( _to_mxfp8_then_scaled_grouped_mm, ) +from torchao.prototype.mx_formats.mx_linear import _to_mxfp8_then_scaled_mm +from torchao.utils import TorchAOBaseTensor + +aten = torch.ops.aten logger: logging.Logger = logging.getLogger(__name__) @@ -41,25 +45,28 @@ torch.ops.aten.split.Tensor, torch.ops.aten.clone.default, torch.ops.aten.transpose.int, + torch.ops.aten.t.default, } -class ScaledGroupedMMTensor(torch.Tensor): +class MXFP8TrainingTensor(TorchAOBaseTensor): """ - ScaledGroupedMMTensor is a simple tensor subclass that wraps a regular tensor - and overrides the torch._grouped_mm op by dispatching to the + MXFP8TrainingTensor is a simple tensor subclass that wraps a regular tensor + and overrides mm and grouped_mm ops, dispatching to autograd functions that + dynamically quantize the op inputs to MXFP8: differentiable _quantize_then_scaled_grouped_mm autograd function. """ - config: GroupedMMConfig = None + config: MXFP8TrainingConfig = None grouped_mm_func_name = "_grouped_mm" + mm_func_names = ("mm", "matmul", "linear") offs_arg_name = "offs" @staticmethod def __new__( cls, tensor: torch.Tensor, - config: GroupedMMConfig, + config: MXFP8TrainingConfig, ): self = torch.Tensor._make_wrapper_subclass( cls, @@ -79,14 +86,14 @@ def __new__( def __init__( self, tensor: torch.Tensor, - config: GroupedMMConfig, + config: MXFP8TrainingConfig, ): self._data = tensor self.config = config @classmethod def __torch_function__(cls, func, types, args, kwargs={}): - # override the grouped mm op to use the differentiable _quantize_then_scaled_grouped_mm + # grouped_mm op override if func.__name__ == cls.grouped_mm_func_name: # Use torchao scaled grouped mm with dynamic quant for # "2d x 3d with offsets" case (used for routed experts). @@ -96,26 +103,45 @@ def __torch_function__(cls, func, types, args, kwargs={}): # used for shared experts. This is basically the grouped_mm # kernel handling a bmm. A, B = args[0], args[1] - assert not isinstance(A, ScaledGroupedMMTensor), ( - "A should not be a ScaledGroupedMMTensor" + assert not isinstance(A, MXFP8TrainingTensor), ( + "A should not be a MXFP8TrainingTensor" ) - assert isinstance(B, ScaledGroupedMMTensor), ( - "B should be a ScaledGroupedMMTensor" + assert isinstance(B, MXFP8TrainingTensor), ( + "B should be a MXFP8TrainingTensor" ) config = B.config A_is_2d = A.ndim == 2 B_is_2d_or_3d = B.ndim == 2 or B.ndim == 3 - has_offs = kwargs.get(cls.offs_arg_name) is not None - other_args = args[2:] - if A_is_2d and B_is_2d_or_3d and has_offs: - return _quantize_then_scaled_grouped_mm( + offs = kwargs.get(cls.offs_arg_name, None) + if A_is_2d and B_is_2d_or_3d and offs is not None: + return _to_mxfp8_then_scaled_grouped_mm( A, B, - config, - *other_args, - **kwargs, + offs, + out_dtype=config.out_dtype, + kernel_preference=config.kernel_preference, + wgrad_with_hp=config.wgrad_with_hp, + scale_calculation_mode=config.scale_calculation_mode, ) + # linear op override + elif func.__name__ in cls.mm_func_names: + A, B = args[0], args[1] + assert not isinstance(A, MXFP8TrainingTensor), ( + "A should not be a MXFP8TrainingTensor" + ) + assert isinstance(B, MXFP8TrainingTensor), ( + "B should be a MXFP8TrainingTensor" + ) + config = B.config + return _to_mxfp8_then_scaled_mm( + A, + B, + kernel_preference=config.kernel_preference, + scale_calculation_mode=config.scale_calculation_mode, + wgrad_with_hp=config.wgrad_with_hp, + ) + # Disable torch_function by hand because we don't want # the wrapping behavior of the super() impl, go directly to dispatch with torch._C.DisableTorchFunctionSubclass(): @@ -132,20 +158,20 @@ def unwrap(t): config = t.config else: assert t.config == config, ( - "All ScaledGroupedMMTensor instances must have the same config" + "All MXFP8TrainingTensor instances must have the same config" ) return t._data args_unwrapped, kwargs_unwrapped = pytree.tree_map_only( - ScaledGroupedMMTensor, unwrap, (args, kwargs or {}) + MXFP8TrainingTensor, unwrap, (args, kwargs or {}) ) assert config is not None, ( - f"__torch_dispatch__ called on {func.__name__} without any ScaledGroupedMMTensor arguments" + f"__torch_dispatch__ called on {func.__name__} without any MXFP8TrainingTensor arguments" ) # detach is special case if func == torch.ops.aten.detach.default: - return ScaledGroupedMMTensor(args_unwrapped[0], config) + return MXFP8TrainingTensor(args_unwrapped[0], config) # perform op out = func(*args_unwrapped, **kwargs_unwrapped) @@ -154,15 +180,15 @@ def unwrap(t): if func not in _ops_to_preserve_subclass: return out - # wrap outputs back into ScaledGroupedMMTensor for ops that do preserve subclass + # wrap outputs back into MXFP8TrainingTensor for ops that do preserve subclass return pytree.tree_map_only( torch.Tensor, - lambda x: ScaledGroupedMMTensor(x, config), + lambda x: MXFP8TrainingTensor(x, config), out, ) def __repr__(self): - return f"ScaledGroupedMMTensor(data={self._data}, config={self.config})" + return f"MXFP8TrainingTensor(data={self._data}, config={self.config})" def __tensor_flatten__(self): metadata = { @@ -172,7 +198,7 @@ def __tensor_flatten__(self): @staticmethod def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): - return ScaledGroupedMMTensor( + return MXFP8TrainingTensor( inner_tensors["_data"], flatten_spec["config"], ) @@ -203,17 +229,17 @@ def fsdp_post_all_gather( # For training step 1+, out=unsharded param. if out is not None: - if isinstance(out, ScaledGroupedMMTensor): + if isinstance(out, MXFP8TrainingTensor): out_data = out._data out.config = self.config elif isinstance(out, DTensor) and isinstance( - out._local_tensor, ScaledGroupedMMTensor + out._local_tensor, MXFP8TrainingTensor ): out_data = out._local_tensor._data out._local_tensor.config = self.config else: raise RuntimeError( - f"expect out to be ScaledGroupedMMTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}" + f"expect out to be MXFP8TrainingTensor or DTensor with local_tensor=ScaledGroupedMM, but got {type(out)}" ) # If `data` (all gather outputs) is already in the mixed precision policy param_dtype, @@ -232,17 +258,17 @@ def fsdp_post_all_gather( return - # For training step 0, out=None, so we need to return a new ScaledGroupedMMTensor. - output = ScaledGroupedMMTensor(data, self.config) + # For training step 0, out=None, so we need to return a new MXFP8TrainingTensor. + output = MXFP8TrainingTensor(data, self.config) inner_tensors = (data,) return output, inner_tensors -# dispatching helper for ScaledGroupedMMTensor +# dispatching helper for MXFP8TrainingTensor def _quantize_then_scaled_grouped_mm( A: torch.Tensor, B_t: torch.Tensor, - config: GroupedMMConfig, + config: TrainingBaseConfig, offs: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ @@ -255,7 +281,7 @@ def _quantize_then_scaled_grouped_mm( B_t (bf16/float32 torch.Tensor): The second high-precision input tensor which must be 3D, which must be shape (E, K, N) and in column-major memory layout. offs (int32 torch.Tensor): The offsets to use to mark the starting index of each group along dim0 of the A tensor. - config (MXFP8GroupedMMConfig): Configuration for grouped matmul quantization. + config (MXFP8TrainingConfig): Configuration for grouped matmul quantization. """ # Dispatch based on derived dtype if isinstance(config, FP8GroupedMMConfig): @@ -266,7 +292,7 @@ def _quantize_then_scaled_grouped_mm( config.out_dtype, config.float8_dtype, ) - elif isinstance(config, MXFP8GroupedMMConfig): + elif isinstance(config, MXFP8TrainingConfig): return _to_mxfp8_then_scaled_grouped_mm( A, B_t, diff --git a/torchao/prototype/mx_formats/mx_linear.py b/torchao/prototype/mx_formats/mx_linear.py index e694d2fb7f..b30d4a5128 100644 --- a/torchao/prototype/mx_formats/mx_linear.py +++ b/torchao/prototype/mx_formats/mx_linear.py @@ -26,6 +26,63 @@ ) +# convenience wrapper +def _to_mxfp8_then_scaled_mm( + input_hp: torch.Tensor, + weight_hp: torch.Tensor, + kernel_preference: KernelPreference, + scale_calculation_mode: ScaleCalculationMode, + wgrad_with_hp: bool = False, +) -> torch.Tensor: + """ + Performs a matrix multiplication with MXFP8 quantization on both forward and backward passes. + + This function wraps the `mx_mm` autograd function to provide differentiable MXFP8 + matrix multiplication. It dynamically quantizes activations, weights, and gradients + to MXFP8 format for each matmul operation: + + - Forward: input @ weight_t = output (both quantized to MXFP8) + - Backward: grad_output @ weight = grad_input (both quantized to MXFP8) + - Backward: input_t @ grad_output = grad_weight (quantized to MXFP8 unless wgrad_with_hp=True) + + Args: + input_hp: High precision input tensor of shape [..., in_features] + weight_hp: High precision weight tensor of shape [out_features, in_features] + kernel_preference: Whether to use AUTO (best kernel for each operation) or EMULATED mode + scale_calculation_mode: Scale calculation method (RCEIL or FLOOR) for MXFP8 quantization + wgrad_with_hp: If True, compute grad_weight in high precision instead of MXFP8. Default: False + + Returns: + Output tensor of shape [..., out_features] in high precision + + Note: + Forward and backward grad_input are always computed using MXFP8 with block_size=32 + and element_dtype=float8_e4m3fn. Backward grad_weight uses MXFP8 by default, but can + optionally use high precision when wgrad_with_hp=True for improved accuracy. + The Triton kernel is used for dim0 quantization and CUDA kernel for dim1 quantization. + """ + in_elem_dtype = torch.float8_e4m3fn + w_elem_dtype = torch.float8_e4m3fn + grad_elem_dtype = torch.float8_e4m3fn + block_size = 32 + mxfp8_dim0_cast_kernel_choice = MXFP8Dim0CastKernelChoice.TRITON + mxfp8_dim1_cast_kernel_choice = MXFP8Dim1CastKernelChoice.CUDA + + return mx_mm.apply( + input_hp, + weight_hp, + in_elem_dtype, + w_elem_dtype, + grad_elem_dtype, + block_size, + kernel_preference, + mxfp8_dim0_cast_kernel_choice, + mxfp8_dim1_cast_kernel_choice, + scale_calculation_mode, + wgrad_with_hp, + ) + + @torch._dynamo.allow_in_graph class mx_mm(torch.autograd.Function): # There are three gemms in a forward + backward of a Linear layer: @@ -49,6 +106,7 @@ def forward( mxfp8_dim0_cast_kernel_choice: MXFP8Dim0CastKernelChoice, mxfp8_dim1_cast_kernel_choice: MXFP8Dim1CastKernelChoice, scale_calculation_mode: ScaleCalculationMode, + wgrad_with_hp: bool, ): ctx.save_for_backward(input_hp, weight_hp) ctx.in_elem_dtype = in_elem_dtype @@ -56,6 +114,7 @@ def forward( ctx.grad_elem_dtype = grad_elem_dtype ctx.block_size = block_size ctx.kernel_preference = kernel_preference + ctx.wgrad_with_hp = wgrad_with_hp ctx.mxfp8_dim0_cast_kernel_choice = mxfp8_dim0_cast_kernel_choice ctx.mxfp8_dim1_cast_kernel_choice = mxfp8_dim1_cast_kernel_choice ctx.scale_calculation_mode = scale_calculation_mode @@ -96,6 +155,7 @@ def backward(ctx, grad_output_hp: torch.Tensor): mxfp8_dim0_cast_kernel_choice = ctx.mxfp8_dim0_cast_kernel_choice mxfp8_dim1_cast_kernel_choice = ctx.mxfp8_dim1_cast_kernel_choice scale_calculation_mode = ctx.scale_calculation_mode + wgrad_with_hp = ctx.wgrad_with_hp grad_output_orig_shape = grad_output_hp.shape grad_output_hp_r = grad_output_hp.reshape(-1, grad_output_orig_shape[-1]) @@ -139,50 +199,67 @@ def backward(ctx, grad_output_hp: torch.Tensor): ) # input_t @ grad_output = grad_weight - if mxfp8_dim1_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: - grad_output_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper( - grad_output_hp_r, - block_size, - grad_elem_dtype, - grad_output_hp_r.dtype, - kernel_preference, - mxfp8_dim1_cast_kernel_choice, - scale_calculation_mode, - ) + if wgrad_with_hp: + # Compute grad_weight in high precision if wgrad_with_hp is True + grad_weight = torch.mm(grad_output_hp_r.t(), input_hp_r) else: - grad_output_mx_dim1 = MXTensor.to_mx( - grad_output_hp_r.t().contiguous(), - grad_elem_dtype, - block_size, - kernel_preference=kernel_preference, - scaling_mode=scale_calculation_mode, - mxfp8_dim0_cast_kernel_choice=mxfp8_dim0_cast_kernel_choice, - ) + # Compute grad_weight with MXFP8 quantization + if mxfp8_dim1_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: + grad_output_mx_dim1 = _to_mxfp8_dim1_kernel_wrapper( + grad_output_hp_r, + block_size, + grad_elem_dtype, + grad_output_hp_r.dtype, + kernel_preference, + mxfp8_dim1_cast_kernel_choice, + scale_calculation_mode, + ) + else: + grad_output_mx_dim1 = MXTensor.to_mx( + grad_output_hp_r.t().contiguous(), + grad_elem_dtype, + block_size, + kernel_preference=kernel_preference, + scaling_mode=scale_calculation_mode, + mxfp8_dim0_cast_kernel_choice=mxfp8_dim0_cast_kernel_choice, + ) - if mxfp8_dim1_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: - input_t_mx_dim0_tmp = _to_mxfp8_dim1_kernel_wrapper( - input_hp_r, - block_size, - in_elem_dtype, - input_hp_r.dtype, - kernel_preference, - mxfp8_dim1_cast_kernel_choice, - scale_calculation_mode, - ) - input_t_mx_dim0 = input_t_mx_dim0_tmp.t() - else: - input_t_mx_dim0_tmp = MXTensor.to_mx( - input_hp_r.t().contiguous(), - in_elem_dtype, - block_size, - kernel_preference=kernel_preference, - scaling_mode=scale_calculation_mode, - mxfp8_dim0_cast_kernel_choice=mxfp8_dim0_cast_kernel_choice, - ) - input_t_mx_dim0 = input_t_mx_dim0_tmp.t() - grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) + if mxfp8_dim1_cast_kernel_choice != MXFP8Dim1CastKernelChoice.TORCH: + input_t_mx_dim0_tmp = _to_mxfp8_dim1_kernel_wrapper( + input_hp_r, + block_size, + in_elem_dtype, + input_hp_r.dtype, + kernel_preference, + mxfp8_dim1_cast_kernel_choice, + scale_calculation_mode, + ) + input_t_mx_dim0 = input_t_mx_dim0_tmp.t() + else: + input_t_mx_dim0_tmp = MXTensor.to_mx( + input_hp_r.t().contiguous(), + in_elem_dtype, + block_size, + kernel_preference=kernel_preference, + scaling_mode=scale_calculation_mode, + mxfp8_dim0_cast_kernel_choice=mxfp8_dim0_cast_kernel_choice, + ) + input_t_mx_dim0 = input_t_mx_dim0_tmp.t() + grad_weight = torch.mm(grad_output_mx_dim1, input_t_mx_dim0) - return grad_input, grad_weight, None, None, None, None, None, None, None, None + return ( + grad_input, + grad_weight, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) class MXLinear(torch.nn.Linear): @@ -218,6 +295,7 @@ def forward(self, x): w = self.weight config = self.config + wgrad_with_hp = False y = mx_mm.apply( x, w, @@ -229,6 +307,7 @@ def forward(self, x): config.mxfp8_dim0_cast_kernel_choice, config.mxfp8_dim1_cast_kernel_choice, config.scale_calculation_mode, + wgrad_with_hp, # temporary, for tests to pass pending deletion of MXLinear ) if self.bias is not None: y = y + self.bias