Two crashes when using convert_to_float8_training with enable_fsdp_float8_all_gather=True and FSDP2 when param.numel() % world_size != 0:
tensor_to_amax on empty shards: RuntimeError: max(): Expected reduction dim for input.numel() == 0
- Storage mismatch after all-gather:
RuntimeError: setStorage: sizes [4096, 1024] requiring 4194304 out of bounds for storage of size 4190208
Related to #1938.
Minimal Repro
import torch, torch.nn as nn, torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard
from torchao.float8 import convert_to_float8_training
from torchao.float8.config import Float8LinearConfig
dist.init_process_group("nccl")
rank, ws = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)
model = nn.Sequential(
nn.Linear(1024, 4096, bias=False),
nn.Linear(4096, 1024, bias=False),
).cuda().bfloat16()
convert_to_float8_training(model, config=Float8LinearConfig(enable_fsdp_float8_all_gather=True))
fully_shard(model)
x = torch.randn(2, 128, 1024, device="cuda", dtype=torch.bfloat16)
model(x).sum().backward() # crashes when 1024*4096 % world_size != 0
Passes with world_size=4 (4194304 % 4 = 0), crashes with world_size=3 (4194304 % 3 != 0).
Root Cause
WeightWithDynamicFloat8CastTensor.fsdp_pre_all_gather uses the old 1-param signature. When sharding is uneven, FSDP2 pads the bf16 shards, but the fp8 hook returns data with the unpadded shard size. After all-gather, the reconstructed fp8 storage has floor(numel/ws) * ws elements instead of the original numel, causing as_strided to fail.
Additionally, tensor_to_amax calls torch.max(torch.abs(x)) which crashes on empty tensors from small-output layers where some ranks get zero-element shards.
Suggested Fix
- Guard
tensor_to_amax for empty tensors: return torch.tensor(0.0) when x.numel() == 0
- Upgrade
fsdp_pre_all_gather to the 5-param signature so FSDP2 handles padding correctly
Environment
- torchao 0.17.0.dev20260302+cu128
- torch 2.12.0.dev20260221+cu128
- H100 80GB
Two crashes when using
convert_to_float8_trainingwithenable_fsdp_float8_all_gather=Trueand FSDP2 whenparam.numel() % world_size != 0:tensor_to_amaxon empty shards:RuntimeError: max(): Expected reduction dim for input.numel() == 0RuntimeError: setStorage: sizes [4096, 1024] requiring 4194304 out of bounds for storage of size 4190208Related to #1938.
Minimal Repro
Passes with world_size=4 (4194304 % 4 = 0), crashes with world_size=3 (4194304 % 3 != 0).
Root Cause
WeightWithDynamicFloat8CastTensor.fsdp_pre_all_gatheruses the old 1-param signature. When sharding is uneven, FSDP2 pads the bf16 shards, but the fp8 hook returns data with the unpadded shard size. After all-gather, the reconstructed fp8 storage hasfloor(numel/ws) * wselements instead of the originalnumel, causingas_stridedto fail.Additionally,
tensor_to_amaxcallstorch.max(torch.abs(x))which crashes on empty tensors from small-output layers where some ranks get zero-element shards.Suggested Fix
tensor_to_amaxfor empty tensors: returntorch.tensor(0.0)whenx.numel() == 0fsdp_pre_all_gatherto the 5-param signature so FSDP2 handles padding correctlyEnvironment