Hi,
Thanks for this great project summarizing so many fast gen method! I am trying train an ar 14b model, and now I am using this modified config file:
def create_config():
config = config_sft_causal_default.create_config()
config.model_class = L(CausalSFTModel)(config=None)
config.model.fsdp_meta_init = True
config.trainer.logging_iter = 100
config.model.net_optimizer.lr = 5e-5
config.model.guidance_scale = 5.0
config.model.student_sample_steps = 50
config.model.sample_t_cfg.time_dist_type = "uniform"
config.model.sample_t_cfg.min_t = 0.001
config.model.sample_t_cfg.max_t = 0.999
config.model.precision = "bfloat16"
config.model.precision_fsdp = "bfloat16"
# VAE compress ratio for WAN: (1+T/4) * H / 8 * W / 8
config.model.input_shape = [16, 21, 60, 104] # cthw
config.model.net = CausalWan_14B_Config
# config.model.enable_preprocessors = False
config.dataloader_train = MyLoaderConfig
config.dataloader_train.batch_size = 1
# 480p (832x480) resolution
config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
config.trainer.max_iter = 5000
config.trainer.save_ckpt_iter = 500
config.trainer.validation_iter = 500
config.log_config.group = "my_wan_14b_sft_ar_df"
return config
MyLoaderConfig = L(VideoWDSLoader)(
datatags=["WDS:mydata/db1"],
batch_size=1,
key_map={"real": "mp4", "condition": "txt"},
presets_map={"neg_condition": "neg_prompt_wan"},
sequence_length=81,
img_size=(832,480),
num_workers=1,
)
MyLoaderConfig contains mp4 and txt, I disable preprocess since I would like load raw video instead of video latent as input. My running scripts is as simple as this:
torchrun --nproc_per_node=2 train.py --config fastgen/configs/experiments/WanT2V/my_config_sft_causal_14b.py - trainer.fsdp=True log_config.name=test_fsdp
I try both trainer.ddp=True and trainer.ddp=True, but both have OOM error. I am using 95GB H100, so memory should not be an issue. I also tried supervised finetuning config for 14b model, still have OOM issue.
Could you please help me with this?
Thanks!
Hi,
Thanks for this great project summarizing so many fast gen method! I am trying train an ar 14b model, and now I am using this modified config file:
MyLoaderConfig contains mp4 and txt, I disable preprocess since I would like load raw video instead of video latent as input. My running scripts is as simple as this:
torchrun --nproc_per_node=2 train.py --config fastgen/configs/experiments/WanT2V/my_config_sft_causal_14b.py - trainer.fsdp=True log_config.name=test_fsdpI try both
trainer.ddp=Trueandtrainer.ddp=True, but both have OOM error. I am using 95GB H100, so memory should not be an issue. I also tried supervised finetuning config for 14b model, still have OOM issue.Could you please help me with this?
Thanks!