Skip to content

feat(train): PEFT (LoRA / LoHa) adapter support for finetune CLI#1296

Open
medyas wants to merge 1 commit into
SWivid:mainfrom
medyas:feat/peft-adapters
Open

feat(train): PEFT (LoRA / LoHa) adapter support for finetune CLI#1296
medyas wants to merge 1 commit into
SWivid:mainfrom
medyas:feat/peft-adapters

Conversation

@medyas
Copy link
Copy Markdown

@medyas medyas commented May 17, 2026

Closes #621.

Adds optional PEFT (parameter-efficient finetuning) support via the peft library. Defaults preserve current behavior (--peft_method none), so existing full-finetune users are not affected.

What

CLI (src/f5_tts/train/finetune_cli.py) — 4 new flags:

--peft_method [none|lora|loha]     # default: none
--peft_rank N                       # default: 8
--peft_alpha N                      # default: 8
--peft_target_modules \"a,b,c\"     # default: DiT/UNetT attn+FFN linears

Default targets: to_q,to_k,to_v,to_out.0,ff.ff.0.0,ff.ff.2. Excludes (NaN-safe): AdaLN-Zero modulation (attn_norm,ff_norm), final zero-init linears (norm_out,proj_out), and conditioning paths (time_embed,text_embed,input_embed,long_skip_connection).

Trainer (src/f5_tts/model/trainer.py)

  • New kwarg peft_config: PeftConfig | None. When set: self.model = get_peft_model(self.model, peft_config) before accelerator.prepare().
  • EMA disabled under PEFT (tracking EMA across base + adapter bloats state and degrades adapter learning). All EMA accesses guarded by self.ema_model is not None.
  • Optimizer changed from model.parameters() to self.model.parameters() so it sees the wrapped, frozen-base view (this was also a latent bug pre-PEFT — same fix).
  • save_checkpoint additionally writes adapter-only safetensors via save_pretrained() to checkpoint_path/adapter_<step>/ for portable, small (~MB) artifacts.
  • load_checkpoint tolerates missing ema_model_state_dict and uses strict=False when PEFT is enabled (base key prefix differs after wrap).

CLI pre-load path

When --peft_method != none AND --finetune, the pretrain checkpoint is pre-loaded into the bare CFM before adapter wrap (strict=False). This avoids the Trainer's resume path trying to load un-prefixed pretrained .safetensors into a wrapped (base_model.model.…) model.

Dependency

New optional extra: pip install -e \".[peft]\".

Why this design

  • Defaults preserved: every existing finetune command behaves identically. Only users who pass --peft_method lora|loha see new code paths.
  • EMA skip over EMA-wrap: EMA on adapters is poorly defined and the marginal benefit doesn't justify the state-management surface area in a first cut. Easy to add as a follow-up.
  • Adapter-only artifact: keeps Hub-friendly small files alongside the existing full .pt. Users can ship just the adapter (~MB) once trained.

Validation

  • Ruff lint + format clean (matching .pre-commit-config.yaml v0.11.2).
  • AST parse of both modified files.
  • CPU smoke tests (LoRA + LoHa forward, trainable-param count band, LoHa zero-init identity contract, adapter save/reload roundtrip): 5/5 pass.

GPU smoke (full pipeline at scale) is queued as a follow-up. Happy to gate the merge on it landing if reviewers prefer.

Recommended target_modules / rank for F5-TTS DiT

  • LoRA: r=16-32, alpha=r.
  • LoHa: r=4-8, alpha=r (LyCORIS rule r <= sqrt(hidden_dim) — F5TTS_Base has dim=1024 so r=8 is a sane sweet spot; effective rank ≈ r² via the Hadamard product with ~½ the LoRA params).

Both honor the same target/exclude conventions to avoid touching AdaLN-Zero modulation or final zero-init linears (which break F5-TTS's init contract and produce NaN within the first updates).

Issue SWivid#621 requested LoRA support. This adds a generic PEFT path via the
hugging-face peft library — defaulting to none, so existing full-finetune
behavior is unchanged.

Trainer (src/f5_tts/model/trainer.py)
  - New kwarg peft_config (peft.PeftConfig | None). When set, wrap self.model
    with get_peft_model() before accelerator.prepare().
  - EMA is skipped under PEFT (tracking ema across base+adapter bloats state
    and degrades adapter learning). All ema_model accesses guarded.
  - Optimizer changed to use self.model.parameters() (was model.parameters()
    on the unwrapped variable, never reflected PEFT's frozen base).
  - save_checkpoint also writes adapter-only safetensors via save_pretrained
    to checkpoint_path/adapter_<step>/ for portable, small (~MB) artifacts.
  - load_checkpoint tolerates missing ema_model_state_dict and strict=False
    loads when peft is enabled (base prefix differs).

CLI (src/f5_tts/train/finetune_cli.py)
  - --peft_method [none|lora|loha]  (default: none)
  - --peft_rank        (default: 8)
  - --peft_alpha       (default: 8)
  - --peft_target_modules  (default: DiT/UNetT attn+FFN linears)
  - When --peft_method != none AND --finetune, the pretrain checkpoint is
    pre-loaded into the bare CFM before adapter wrap (strict=False) so the
    Trainer's load_checkpoint path doesn't see the un-prefixed pretrained
    .safetensors and mismatch against the wrapped (base_model.model.…) keys.
  - Excludes AdaLN-Zero modulation + final zero-init linears (norm/proj_out
    /time/text/input_embed/long_skip) — adapting them breaks F5-TTS init
    contract (NaN within first steps).

pyproject.toml
  - New optional extra: pip install -e ".[peft]"

Closes SWivid#621
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

LoRA support

2 participants