Skip to content

OOM when training AR wan-2.1-14b model #20

@xilongzhou

Description

@xilongzhou

Hi,

Thanks for this great project summarizing so many fast gen method! I am trying train an ar 14b model, and now I am using this modified config file:

def create_config():
    config = config_sft_causal_default.create_config()
    config.model_class = L(CausalSFTModel)(config=None)
    config.model.fsdp_meta_init = True

    config.trainer.logging_iter = 100
    config.model.net_optimizer.lr = 5e-5
    config.model.guidance_scale = 5.0
    config.model.student_sample_steps = 50

    config.model.sample_t_cfg.time_dist_type = "uniform"
    config.model.sample_t_cfg.min_t = 0.001
    config.model.sample_t_cfg.max_t = 0.999

    config.model.precision = "bfloat16"
    config.model.precision_fsdp = "bfloat16"

    # VAE compress ratio for WAN: (1+T/4) * H / 8 * W / 8
    config.model.input_shape = [16, 21, 60, 104]  # cthw
    config.model.net = CausalWan_14B_Config
    # config.model.enable_preprocessors = False

    config.dataloader_train = MyLoaderConfig
    config.dataloader_train.batch_size = 1

    # 480p (832x480) resolution
    config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
    config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1

    config.trainer.max_iter = 5000
    config.trainer.save_ckpt_iter = 500
    config.trainer.validation_iter = 500

    config.log_config.group = "my_wan_14b_sft_ar_df"
    return config

MyLoaderConfig = L(VideoWDSLoader)(
    datatags=["WDS:mydata/db1"],
    batch_size=1,
    key_map={"real": "mp4", "condition": "txt"},
    presets_map={"neg_condition": "neg_prompt_wan"},
    sequence_length=81,
    img_size=(832,480),
    num_workers=1,
)

MyLoaderConfig contains mp4 and txt, I disable preprocess since I would like load raw video instead of video latent as input. My running scripts is as simple as this:

torchrun --nproc_per_node=2 train.py --config fastgen/configs/experiments/WanT2V/my_config_sft_causal_14b.py - trainer.fsdp=True log_config.name=test_fsdp

I try both trainer.ddp=True and trainer.ddp=True, but both have OOM error. I am using 95GB H100, so memory should not be an issue. I also tried supervised finetuning config for 14b model, still have OOM issue.

Could you please help me with this?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions