Skip to content

feat(diffusion): add Wan2.2 T2V-A14B two-stage finetuning support#2284

Open
linnanwang wants to merge 4 commits into
mainfrom
wan22
Open

feat(diffusion): add Wan2.2 T2V-A14B two-stage finetuning support#2284
linnanwang wants to merge 4 commits into
mainfrom
wan22

Conversation

@linnanwang
Copy link
Copy Markdown
Contributor

@linnanwang linnanwang commented May 21, 2026

What does this PR do?

Adds end-to-end finetuning and inference support for Wan2.2-T2V-A14B, a two-stage text-to-video diffusion model whose denoising pipeline routes between a high-noise transformer and a low-noise transformer_2 across a configurable timestep boundary.

Changelog

  • NeMoAutoDiffusionPipeline.from_pretrained: new active_transformer kwarg ("transformer" | "transformer_2"); when set on a two-transformer pipeline the unused transformer is dropped before device placement / FSDP2 wrapping so only one ~14B model occupies GPU memory.
  • TrainDiffusionRecipe (recipes/diffusion/train.py): reads model.stage (high_noise | low_noise) and model.boundary_ratio (falls back to pipe.config.boundary_ratio); derives flow_matching.sigma_min / sigma_max from the stage + boundary so each stage only trains on its own noise range; threads active_transformer into the pipeline loader; suffixes the wandb run name with the stage.
  • examples/diffusion/finetune/wan2_2_t2v_flow.yaml: new finetune config — A14B hub path, stage knob, boundary_ratio: 0.875, bumped dp_size, explicit activation checkpointing.
  • examples/diffusion/generate/configs/generate_wan22.yaml: new inference config — A14B hub path, two optional checkpoint paths, guidance_scale_2, VAE cpu offload defaulted on.
  • examples/diffusion/generate/generate.py: load_checkpoint_into_pipeline accepts model.checkpoint_high_noise / model.checkpoint_low_noise (both optional, mutually exclusive with the legacy single model.checkpoint) and loads each into the matching pipe.transformer / pipe.transformer_2 attribute.
  • tools/diffusion/processors/wan.py: new Wan22Processor subclass registered as wan2.2; defaults to Wan-AI/Wan2.2-T2V-A14B-Diffusers, marks cache files with model_version: "wan2.2" so Wan2.1 and Wan2.2 caches can coexist.
  • tools/diffusion/preprocessing_multiprocess.py: wan2.2 added to the --processor choices.
  • tools/diffusion/processors/__init__.py: export Wan22Processor.

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?

If you haven't finished some of the above items you can still open "Draft" PR.

Additional Information

  • Related to # (issue)
  • Workflow: preprocess data once with --processor wan2.2, run finetuning twice (model.stage: high_noise and model.stage: low_noise) with distinct checkpoint.checkpoint_dir per stage, then point generate_wan22.yaml at the two resulting consolidated checkpoint dirs (either or both optional — missing stages fall back to hub-pretrained weights).

Signed-off-by: linnan wang <linnanw@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 21, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@pthombre
Copy link
Copy Markdown
Contributor

/claude review

Comment on lines +372 to +393
@ProcessorRegistry.register("wan2.2")
class Wan22Processor(WanProcessor):
"""
Processor for Wan2.2-T2V-A14B (two-stage) video model.

Wan2.2 reuses the same ``AutoencoderKLWan`` VAE class and UMT5 text encoder
as Wan2.1, but pulls VAE / text-encoder weights from the A14B hub. Cache
files emitted by this processor record ``model_version: "wan2.2"`` so
Wan2.1 and Wan2.2 caches remain unambiguous side-by-side.
"""

@property
def model_type(self) -> str:
return "wan22"

@property
def model_version(self) -> str:
return "wan2.2"

@property
def default_model_name(self) -> str:
return "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No tests for Wan22Processor. The existing tests/unit_tests/diffusion_processors/test_wan.py covers WanProcessor properties and wan/wan2.1 registry entries. Please add analogous tests for the new class:

  • Wan22Processor().model_type == "wan22"
  • Wan22Processor().model_version == "wan2.2"
  • Wan22Processor().default_model_name == "Wan-AI/Wan2.2-T2V-A14B-Diffusers"
  • ProcessorRegistry.is_registered("wan2.2") returns True
  • ProcessorRegistry.get("wan2.2") returns a Wan22Processor instance
  • get_cache_data emits model_version: "wan2.2" (the whole point of the subclass)

Comment on lines +191 to +216
def _select_active_transformer(pipe, active_transformer: str) -> None:
"""Keep only the chosen transformer on a two-transformer pipeline.

Two-stage diffusion pipelines (Wan2.2 T2V-A14B) register both
``transformer`` (high-noise) and ``transformer_2`` (low-noise). Finetuning
only needs one at a time. This helper swaps the chosen one into
``pipe.transformer`` and nulls the other so subsequent device placement,
LoRA injection, and FSDP2 wrapping only touch the active model.

Args:
pipe: A diffusers pipeline that may expose ``transformer_2``.
active_transformer: Either ``"transformer"`` or ``"transformer_2"``.

Raises:
ValueError: If ``active_transformer`` is unrecognized.
AttributeError: If ``active_transformer="transformer_2"`` but the pipeline
has no ``transformer_2`` attribute (model is not a two-stage variant).
"""
if active_transformer not in ("transformer", "transformer_2"):
raise ValueError(f"active_transformer must be 'transformer' or 'transformer_2', got {active_transformer!r}")

has_t2 = getattr(pipe, "transformer_2", None) is not None
if active_transformer == "transformer_2":
if not has_t2:
raise AttributeError(
"active_transformer='transformer_2' requested but the loaded pipeline "
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No tests for _select_active_transformer. This function has several branches (invalid input → ValueError, missing transformer_2AttributeError, swapping transformer_2 into the transformer slot, dropping transformer_2 when transformer is selected). The existing test file (tests/unit_tests/_diffusers/test_auto_diffusion_pipeline.py) already has DummyPipeline infrastructure that would make these easy to cover.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants