Skip to content

feat: opt-in perf — selective GC, CFG-Zero*, torch.compile#1297

Open
medyas wants to merge 1 commit into
SWivid:mainfrom
medyas:feat/perf-improvements
Open

feat: opt-in perf — selective GC, CFG-Zero*, torch.compile#1297
medyas wants to merge 1 commit into
SWivid:mainfrom
medyas:feat/perf-improvements

Conversation

@medyas
Copy link
Copy Markdown

@medyas medyas commented May 17, 2026

Three orthogonal, opt-in perf improvements. All defaults preserved — existing users see zero behavior change.

1. Selective gradient checkpointing — DiT / MMDiT

New kwarg gc_checkpoint_interval: int = 1 on both backbones. When checkpoint_activations=True, only checkpoint every Nth transformer block. interval=1 (default) preserves the current full-GC behavior. interval=2 cuts ~50% of the activation memory at ~10% throughput cost — usually a net win once batch size is doubled.

CLI: --gc_checkpoint_interval N (default 0 = disabled). When >0, sets checkpoint_activations=True and the interval. Guarded via inspect.signature so passing it with E2TTS_Base (UNetT — no checkpoint_activations kwarg) warns and ignores rather than crashing.

2. CFG-Zero* — cfm.sample()

arXiv 2503.18886. Drop-in inference-time, no retrain.

  • cfg_zero_init_steps: int = 0: skip the first N solver steps — fn() returns zeros, no DiT forward. Composes with EPSS to drop the cheapest early-noise calls. Smoke shows forwards 4 → 2 at zero_init_steps=2 / steps=4.
  • cfg_zero_star_velocity: bool = False: per-step projection scalar α* = <pred, null_pred> / ||null_pred||² used in place of the fixed cfg_strength scalar in the CFG formula. Authors report fidelity uptick + freedom to skip early CFG entirely.

3. torch.compile integration — Trainer

New Trainer kwarg torch_compile_mode: str | None = None ('default', 'reduce-overhead', or 'max-autotune'). Applied to self.model BEFORE accelerator.prepare so DDP wraps the compiled module. Bumps torch._dynamo.config.cache_size_limit to ≥64 to absorb variable mel-frame seq-lens triggering recompiles.

CLI: --torch_compile <mode> (default None).

Note: when paired with PR #1296 (PEFT), callers should add peft.utils.hotswap.prepare_model_for_compiled_hotswap(model, target_rank=R) between get_peft_model and torch.compile to avoid recompile on adapter swap. That wiring is deferred to the PEFT PR.

Stacked impact

All three opt-in, no retrain. With EPSS (already merged), the realistic end-to-end ceiling vs original F5-TTS is approximately:

Layer Factor
EPSS (already in main, use_epss=True)
CFG-Zero* (zero-init 1-2 steps of 7) 1.15-1.25×
torch.compile reduce-overhead 1.5×
Selective GC interval=2 + 2× batch on training 1.5-1.8× train wall

Inference free ceiling: ~7-9× vs original. Train wall: ~1.5-1.8×.

Validation

  • AST parse + ruff 0.11.2 lint + format checks clean.
  • CPU smoke tests (tests_perf_smoke.py, 5/5):
    • DiT forward at gc_checkpoint_interval ∈ {0, 1, 2, 4}.
    • Grad flow under selective GC (81 grad tensors, finite, loss ≈ 2.01).
    • cfg_zero_init_steps reduces DiT forwards as expected.
    • cfg_zero_star_velocity output finite.
    • CLI flag parsing.
  • GPU smoke (Kaggle) queued as a follow-up; happy to gate merge on it.

Source pointers

These are intentionally three small, orthogonal additions. Happy to split into separate PRs if reviewers prefer.

Three orthogonal perf wins as opt-in flags. All defaults preserved.

1) Selective gradient checkpointing (dit.py, mmdit.py)
   - New kwarg `gc_checkpoint_interval: int = 1` on DiT and MMDiT.
   - When `checkpoint_activations=True`, only ckpt every Nth transformer
     block. interval=1 (default) = prior behavior (every block);
     interval=2 = ~50% activation mem at ~10% throughput cost; often net
     positive once batch size is doubled.
   - finetune_cli flag `--gc_checkpoint_interval N` (default 0 = off):
     when >0, sets checkpoint_activations=True AND the interval. Guarded
     via inspect.signature so passing it with UNetT (E2TTS_Base) warns
     and ignores rather than crashing.

2) CFG-Zero* (cfm.py, arxiv 2503.18886)
   - Drop-in, inference-time only, no retrain.
   - `cfg_zero_init_steps: int = 0`: skip the first N solver steps —
     fn() returns zeros, no DiT forward. Composes with EPSS to drop the
     cheapest early-noise calls. Smoke shows forwards 4 → 2 at
     zero_init_steps=2 / steps=4.
   - `cfg_zero_star_velocity: bool = False`: per-step projection scalar
     α* = <pred, null_pred> / ||null_pred||² used in place of the fixed
     cfg_strength scalar in the CFG formula. Authors report fidelity
     uptick + freedom to skip early CFG entirely.

3) torch.compile (trainer.py)
   - New kwarg `torch_compile_mode: str | None = None` ('default',
     'reduce-overhead', or 'max-autotune'). Applied to self.model BEFORE
     accelerator.prepare so DDP wraps the compiled module.
   - Bumps `torch._dynamo.config.cache_size_limit` to at least 64 to
     absorb variable mel-frame seq-lens triggering recompiles.
   - CLI flag `--torch_compile <mode>` (default None).

   Note: when both this and PR SWivid#1296 (PEFT) land, callers should pair
   torch.compile with peft.utils.hotswap.prepare_model_for_compiled_hotswap
   to avoid recompile on adapter swap. That wiring is deferred to the
   PEFT PR.

Stacked impact (no retrain): selective GC + torch.compile on training,
CFG-Zero* on inference. With existing EPSS already merged the realistic
end-to-end ceiling vs original F5-TTS is ~10x free inference + 1.5-1.8x
free train wall.

Validation: AST + ruff (0.11.2) + format checks clean. CPU smoke tests
(5/5 pass): forward at gc_intervals {0,1,2,4}, grad flow under selective
GC, cfg_zero_init forward-call reduction, cfg_zero_star_velocity output
finite, CLI flag parsing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants