feat: opt-in perf — selective GC, CFG-Zero*, torch.compile#1297
Open
medyas wants to merge 1 commit into
Open
Conversation
Three orthogonal perf wins as opt-in flags. All defaults preserved.
1) Selective gradient checkpointing (dit.py, mmdit.py)
- New kwarg `gc_checkpoint_interval: int = 1` on DiT and MMDiT.
- When `checkpoint_activations=True`, only ckpt every Nth transformer
block. interval=1 (default) = prior behavior (every block);
interval=2 = ~50% activation mem at ~10% throughput cost; often net
positive once batch size is doubled.
- finetune_cli flag `--gc_checkpoint_interval N` (default 0 = off):
when >0, sets checkpoint_activations=True AND the interval. Guarded
via inspect.signature so passing it with UNetT (E2TTS_Base) warns
and ignores rather than crashing.
2) CFG-Zero* (cfm.py, arxiv 2503.18886)
- Drop-in, inference-time only, no retrain.
- `cfg_zero_init_steps: int = 0`: skip the first N solver steps —
fn() returns zeros, no DiT forward. Composes with EPSS to drop the
cheapest early-noise calls. Smoke shows forwards 4 → 2 at
zero_init_steps=2 / steps=4.
- `cfg_zero_star_velocity: bool = False`: per-step projection scalar
α* = <pred, null_pred> / ||null_pred||² used in place of the fixed
cfg_strength scalar in the CFG formula. Authors report fidelity
uptick + freedom to skip early CFG entirely.
3) torch.compile (trainer.py)
- New kwarg `torch_compile_mode: str | None = None` ('default',
'reduce-overhead', or 'max-autotune'). Applied to self.model BEFORE
accelerator.prepare so DDP wraps the compiled module.
- Bumps `torch._dynamo.config.cache_size_limit` to at least 64 to
absorb variable mel-frame seq-lens triggering recompiles.
- CLI flag `--torch_compile <mode>` (default None).
Note: when both this and PR SWivid#1296 (PEFT) land, callers should pair
torch.compile with peft.utils.hotswap.prepare_model_for_compiled_hotswap
to avoid recompile on adapter swap. That wiring is deferred to the
PEFT PR.
Stacked impact (no retrain): selective GC + torch.compile on training,
CFG-Zero* on inference. With existing EPSS already merged the realistic
end-to-end ceiling vs original F5-TTS is ~10x free inference + 1.5-1.8x
free train wall.
Validation: AST + ruff (0.11.2) + format checks clean. CPU smoke tests
(5/5 pass): forward at gc_intervals {0,1,2,4}, grad flow under selective
GC, cfg_zero_init forward-call reduction, cfg_zero_star_velocity output
finite, CLI flag parsing.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Three orthogonal, opt-in perf improvements. All defaults preserved — existing users see zero behavior change.
1. Selective gradient checkpointing —
DiT/MMDiTNew kwarg
gc_checkpoint_interval: int = 1on both backbones. Whencheckpoint_activations=True, only checkpoint every Nth transformer block.interval=1(default) preserves the current full-GC behavior.interval=2cuts ~50% of the activation memory at ~10% throughput cost — usually a net win once batch size is doubled.CLI:
--gc_checkpoint_interval N(default0= disabled). When>0, setscheckpoint_activations=Trueand the interval. Guarded viainspect.signatureso passing it withE2TTS_Base(UNetT — nocheckpoint_activationskwarg) warns and ignores rather than crashing.2. CFG-Zero* —
cfm.sample()arXiv 2503.18886. Drop-in inference-time, no retrain.
cfg_zero_init_steps: int = 0: skip the first N solver steps —fn()returns zeros, no DiT forward. Composes with EPSS to drop the cheapest early-noise calls. Smoke shows forwards4 → 2atzero_init_steps=2/steps=4.cfg_zero_star_velocity: bool = False: per-step projection scalarα* = <pred, null_pred> / ||null_pred||²used in place of the fixedcfg_strengthscalar in the CFG formula. Authors report fidelity uptick + freedom to skip early CFG entirely.3.
torch.compileintegration —TrainerNew Trainer kwarg
torch_compile_mode: str | None = None('default','reduce-overhead', or'max-autotune'). Applied toself.modelBEFOREaccelerator.prepareso DDP wraps the compiled module. Bumpstorch._dynamo.config.cache_size_limitto ≥64 to absorb variable mel-frame seq-lens triggering recompiles.CLI:
--torch_compile <mode>(defaultNone).Note: when paired with PR #1296 (PEFT), callers should add
peft.utils.hotswap.prepare_model_for_compiled_hotswap(model, target_rank=R)betweenget_peft_modelandtorch.compileto avoid recompile on adapter swap. That wiring is deferred to the PEFT PR.Stacked impact
All three opt-in, no retrain. With EPSS (already merged), the realistic end-to-end ceiling vs original F5-TTS is approximately:
use_epss=True)torch.compile reduce-overheadinterval=2+ 2× batch on trainingInference free ceiling: ~7-9× vs original. Train wall: ~1.5-1.8×.
Validation
0.11.2lint + format checks clean.tests_perf_smoke.py, 5/5):gc_checkpoint_interval ∈ {0, 1, 2, 4}.cfg_zero_init_stepsreduces DiT forwards as expected.cfg_zero_star_velocityoutput finite.Source pointers
These are intentionally three small, orthogonal additions. Happy to split into separate PRs if reviewers prefer.