Skip to content

float8 + FSDP2: crash on uneven sharding (tensor_to_amax + storage mismatch) #3982

@platers

Description

@platers

Two crashes when using convert_to_float8_training with enable_fsdp_float8_all_gather=True and FSDP2 when param.numel() % world_size != 0:

  1. tensor_to_amax on empty shards: RuntimeError: max(): Expected reduction dim for input.numel() == 0
  2. Storage mismatch after all-gather: RuntimeError: setStorage: sizes [4096, 1024] requiring 4194304 out of bounds for storage of size 4190208

Related to #1938.

Minimal Repro

import torch, torch.nn as nn, torch.distributed as dist
from torch.distributed._composable.fsdp import fully_shard
from torchao.float8 import convert_to_float8_training
from torchao.float8.config import Float8LinearConfig

dist.init_process_group("nccl")
rank, ws = dist.get_rank(), dist.get_world_size()
torch.cuda.set_device(rank)

model = nn.Sequential(
    nn.Linear(1024, 4096, bias=False),
    nn.Linear(4096, 1024, bias=False),
).cuda().bfloat16()

convert_to_float8_training(model, config=Float8LinearConfig(enable_fsdp_float8_all_gather=True))
fully_shard(model)

x = torch.randn(2, 128, 1024, device="cuda", dtype=torch.bfloat16)
model(x).sum().backward()  # crashes when 1024*4096 % world_size != 0

Passes with world_size=4 (4194304 % 4 = 0), crashes with world_size=3 (4194304 % 3 != 0).

Root Cause

WeightWithDynamicFloat8CastTensor.fsdp_pre_all_gather uses the old 1-param signature. When sharding is uneven, FSDP2 pads the bf16 shards, but the fp8 hook returns data with the unpadded shard size. After all-gather, the reconstructed fp8 storage has floor(numel/ws) * ws elements instead of the original numel, causing as_strided to fail.

Additionally, tensor_to_amax calls torch.max(torch.abs(x)) which crashes on empty tensors from small-output layers where some ranks get zero-element shards.

Suggested Fix

  1. Guard tensor_to_amax for empty tensors: return torch.tensor(0.0) when x.numel() == 0
  2. Upgrade fsdp_pre_all_gather to the 5-param signature so FSDP2 handles padding correctly

Environment

  • torchao 0.17.0.dev20260302+cu128
  • torch 2.12.0.dev20260221+cu128
  • H100 80GB

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions