feat(train): PEFT (LoRA / LoHa) adapter support for finetune CLI#1296
Open
medyas wants to merge 1 commit into
Open
feat(train): PEFT (LoRA / LoHa) adapter support for finetune CLI#1296medyas wants to merge 1 commit into
medyas wants to merge 1 commit into
Conversation
Issue SWivid#621 requested LoRA support. This adds a generic PEFT path via the hugging-face peft library — defaulting to none, so existing full-finetune behavior is unchanged. Trainer (src/f5_tts/model/trainer.py) - New kwarg peft_config (peft.PeftConfig | None). When set, wrap self.model with get_peft_model() before accelerator.prepare(). - EMA is skipped under PEFT (tracking ema across base+adapter bloats state and degrades adapter learning). All ema_model accesses guarded. - Optimizer changed to use self.model.parameters() (was model.parameters() on the unwrapped variable, never reflected PEFT's frozen base). - save_checkpoint also writes adapter-only safetensors via save_pretrained to checkpoint_path/adapter_<step>/ for portable, small (~MB) artifacts. - load_checkpoint tolerates missing ema_model_state_dict and strict=False loads when peft is enabled (base prefix differs). CLI (src/f5_tts/train/finetune_cli.py) - --peft_method [none|lora|loha] (default: none) - --peft_rank (default: 8) - --peft_alpha (default: 8) - --peft_target_modules (default: DiT/UNetT attn+FFN linears) - When --peft_method != none AND --finetune, the pretrain checkpoint is pre-loaded into the bare CFM before adapter wrap (strict=False) so the Trainer's load_checkpoint path doesn't see the un-prefixed pretrained .safetensors and mismatch against the wrapped (base_model.model.…) keys. - Excludes AdaLN-Zero modulation + final zero-init linears (norm/proj_out /time/text/input_embed/long_skip) — adapting them breaks F5-TTS init contract (NaN within first steps). pyproject.toml - New optional extra: pip install -e ".[peft]" Closes SWivid#621
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Closes #621.
Adds optional PEFT (parameter-efficient finetuning) support via the
peftlibrary. Defaults preserve current behavior (--peft_method none), so existing full-finetune users are not affected.What
CLI (
src/f5_tts/train/finetune_cli.py) — 4 new flags:Default targets:
to_q,to_k,to_v,to_out.0,ff.ff.0.0,ff.ff.2. Excludes (NaN-safe): AdaLN-Zero modulation (attn_norm,ff_norm), final zero-init linears (norm_out,proj_out), and conditioning paths (time_embed,text_embed,input_embed,long_skip_connection).Trainer (
src/f5_tts/model/trainer.py)peft_config: PeftConfig | None. When set:self.model = get_peft_model(self.model, peft_config)beforeaccelerator.prepare().self.ema_model is not None.model.parameters()toself.model.parameters()so it sees the wrapped, frozen-base view (this was also a latent bug pre-PEFT — same fix).save_checkpointadditionally writes adapter-only safetensors viasave_pretrained()tocheckpoint_path/adapter_<step>/for portable, small (~MB) artifacts.load_checkpointtolerates missingema_model_state_dictand usesstrict=Falsewhen PEFT is enabled (base key prefix differs after wrap).CLI pre-load path
When
--peft_method != noneAND--finetune, the pretrain checkpoint is pre-loaded into the bare CFM before adapter wrap (strict=False). This avoids the Trainer's resume path trying to load un-prefixed pretrained.safetensorsinto a wrapped (base_model.model.…) model.Dependency
New optional extra:
pip install -e \".[peft]\".Why this design
--peft_method lora|lohasee new code paths..pt. Users can ship just the adapter (~MB) once trained.Validation
.pre-commit-config.yamlv0.11.2).GPU smoke (full pipeline at scale) is queued as a follow-up. Happy to gate the merge on it landing if reviewers prefer.
Recommended target_modules / rank for F5-TTS DiT
r=16-32, alpha=r.r=4-8, alpha=r(LyCORIS ruler <= sqrt(hidden_dim)— F5TTS_Base has dim=1024 so r=8 is a sane sweet spot; effective rank ≈ r² via the Hadamard product with ~½ the LoRA params).Both honor the same target/exclude conventions to avoid touching AdaLN-Zero modulation or final zero-init linears (which break F5-TTS's init contract and produce NaN within the first updates).