Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 107 additions & 20 deletions nemo_automodel/components/_peft/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class PeftConfig:
dropout_position: Literal["pre", "post"] = "post"
lora_A_init: str = "xavier"
lora_dtype: Optional[torch.dtype] = None
use_memory_efficient_lora: bool = True
use_triton: bool = False
moe_rank_scaling: bool = False

Expand All @@ -72,6 +73,7 @@ def from_dict(cls, d: dict[str, Any]):
dropout_position=d.get("dropout_position", "post"),
lora_A_init=d.get("lora_A_init", "xavier"),
lora_dtype=d.get("lora_dtype", None),
use_memory_efficient_lora=d.get("use_memory_efficient_lora", True),
use_triton=d.get("use_triton", False),
moe_rank_scaling=d.get("moe_rank_scaling", False),
)
Expand Down Expand Up @@ -102,6 +104,7 @@ def __init__(
dropout_position="post",
lora_A_init_method="xavier",
lora_dtype=None,
use_memory_efficient_lora=True,
):
"""
LinearLora constructor.
Expand Down Expand Up @@ -138,6 +141,7 @@ def __init__(
dropout_position=dropout_position,
lora_A_init_method=lora_A_init_method,
lora_dtype=lora_dtype,
use_memory_efficient_lora=use_memory_efficient_lora,
)

@torch.no_grad
Expand Down Expand Up @@ -165,6 +169,7 @@ def _init_adapter(
dropout_position="post",
lora_A_init_method="xavier",
lora_dtype=None,
use_memory_efficient_lora=True,
):
"""
Adds LoRA weights to obj. Obj is either a LinearLoRA or an nn.Module (when monkey-patching).
Expand All @@ -182,6 +187,7 @@ def _init_adapter(
obj.dim = dim
obj.scale = alpha / dim
obj.use_dora = bool(use_dora)
obj.use_memory_efficient_lora = bool(use_memory_efficient_lora)

# Freezer
device = obj.weight.device
Expand Down Expand Up @@ -227,6 +233,22 @@ def _dora_weight_norm(self) -> torch.Tensor:
weight_norm = torch.linalg.norm(weight + self.scale * delta_w, dim=1).to(weight.dtype)
return weight_norm.detach()

def _should_use_memory_efficient_lora(self, x: torch.Tensor) -> bool:
"""Return whether this LoRA branch can use the custom autograd path."""
if not getattr(self, "use_memory_efficient_lora", False):
return False
if isinstance(x, DTensor):
return False
if isinstance(getattr(self.lora_A, "weight", None), DTensor):
return False
if isinstance(getattr(self.lora_B, "weight", None), DTensor):
return False
if torch.compiler.is_compiling():
return False
if HAS_TE and isinstance(getattr(self, "lora_A", None), transformer_engine.pytorch.Linear):
return False
return True

def forward(self, x):
"""
Forward pass through the original linear layer augmented with the LoRA pathway.
Expand Down Expand Up @@ -275,9 +297,21 @@ def forward(self, x):
# Apply scale before lora_B to keep lora_res as a Partial tensor.
# This allows both res and lora_res to remain Partial, so only one reduce-scatter is needed after addition.
# Multiplying after lora_B would convert Partial to Replicate, causing an extra reduce-scatter operation.
lora_res = self.lora_B(self.lora_A(x) * self.scale)
use_memory_efficient_lora = self._should_use_memory_efficient_lora(x)
if use_memory_efficient_lora:
if self.dropout_position == "pre" or not self.training or self.dropout_p == 0.0:
return LoRATritonFunction.apply(
x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, False, res
)
lora_res = LoRATritonFunction.apply(
x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, False
)
else:
lora_res = self.lora_B(self.lora_A(x) * self.scale)
if self.dropout_position == "post":
lora_res = F.dropout(lora_res, p=self.dropout_p, training=self.training)
if use_memory_efficient_lora:
return lora_res.add_(res)
return res + lora_res

if getattr(self, "lora_magnitude", None) is None:
Expand Down Expand Up @@ -357,9 +391,18 @@ def forward(self, x):

if self.dropout_position == "pre":
x = F.dropout(x, p=self.dropout_p, training=self.training)
lora_res = LoRATritonFunction.apply(x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype)
if self.use_memory_efficient_lora:
if self.dropout_position == "pre" or not self.training or self.dropout_p == 0.0:
return LoRATritonFunction.apply(
x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, True, res
)
lora_res = LoRATritonFunction.apply(x, self.lora_A.weight, self.lora_B.weight, self.scale, x.dtype, True)
else:
lora_res = self.lora_B(self.lora_A(x) * self.scale)
if self.dropout_position == "post":
lora_res = F.dropout(lora_res, p=self.dropout_p, training=self.training)
if self.use_memory_efficient_lora:
return lora_res.add_(res)

return res + lora_res

Expand All @@ -373,6 +416,7 @@ def patch_linear_module(
dropout_position="post",
lora_A_init_method="xavier",
lora_dtype=None,
use_memory_efficient_lora=True,
use_triton=True,
layer_name=None,
):
Expand All @@ -396,8 +440,10 @@ def patch_linear_module(
Defaults to 'post' (choices: 'pre', 'post').
lora_A_init_method (str, optional): lora_a init method. Defaults to 'xavier'.
lora_dtype (_type_, optional): Lora weights' dtype. By default will use orig_linear's dtype
but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must
specify the dtype manually. Defaults to None.
but orig_linear might use non-trainable dtype (e.g., 4bit), in which case the user must
specify the dtype manually. Defaults to None.
use_memory_efficient_lora (bool, optional): Use the custom autograd implementation for standard LoRA.
When Triton is enabled this uses Triton kernels; otherwise it uses PyTorch matmuls. Defaults to True.
use_triton (bool, optional): By default we use the triton kernel LoRA implementation.

Returns:
Expand Down Expand Up @@ -428,6 +474,7 @@ def patch_linear_module(
dropout_position=dropout_position,
lora_A_init_method=lora_A_init_method,
lora_dtype=lora_dtype,
use_memory_efficient_lora=use_memory_efficient_lora,
)
cls = orig_linear.__class__
new_cls = type("PatchedLinearLoRA", (linear_lora_cls, cls), {})
Expand Down Expand Up @@ -605,6 +652,7 @@ def apply_lora_to_linear_modules(
dropout_position=peft_config.dropout_position,
lora_A_init_method=peft_config.lora_A_init,
lora_dtype=lora_dtype,
use_memory_efficient_lora=getattr(peft_config, "use_memory_efficient_lora", True),
use_triton=peft_config.use_triton,
layer_name=name,
)
Expand All @@ -614,57 +662,96 @@ def apply_lora_to_linear_modules(

class LoRATritonFunction(torch.autograd.Function):
"""
Autograd function that calls the triton kernel wrappers for the LoRA forward and backward passes.
Autograd function that avoids saving the LoRA A activation.

The default path calls Triton kernel wrappers for forward and backward. Callers can pass
``use_triton_kernel=False`` to use PyTorch matmuls while keeping the same memory-efficient
saved tensor behavior.
"""

@staticmethod
def setup_context(ctx, inputs, output):
"""
Stores context for LoRA backward pass.
"""
x, lora_A, lora_B, scale, _ = inputs
x, lora_A, lora_B, scale, dtype, *rest = inputs
ctx.save_for_backward(x, lora_A, lora_B)
ctx.scale = scale
ctx.dtype = dtype
ctx.use_triton_kernel = bool(rest[0]) if rest else True
ctx.has_residual = len(rest) > 1 and rest[1] is not None
ctx.num_inputs = len(inputs)

@staticmethod
def forward(x, lora_A, lora_B, scale, dtype):
def forward(x, lora_A, lora_B, scale, dtype, use_triton_kernel=True, res=None):
"""
Forward method for LoRATriton.
Forward method for memory-efficient LoRA.

Reshapes 3D tensors into 2D and then calls the triton kernel.
Reshapes 3D tensors into 2D and then calls either Triton kernels or PyTorch matmuls. When ``res`` is
provided, the residual is added in-place into the LoRA output to avoid allocating a separate add result.
"""
reshape = x.dim() == 3
if reshape:
bs, seq_len, d = x.shape
x = x.reshape(-1, d)
if res is not None:
res = res.reshape(-1, res.shape[-1])

lora_res = lora_forward_wrapper(x, lora_A.t(), lora_B.t(), res=None, scale=scale, dtype=dtype)
if use_triton_kernel:
lora_res = lora_forward_wrapper(x, lora_A.t(), lora_B.t(), res=None, scale=scale, dtype=dtype)
else:
lora_res = F.linear(F.linear(x, lora_A) * scale, lora_B)

if res is not None:
lora_res.add_(res)

if reshape:
return lora_res.view(bs, seq_len, -1)
else:
return lora_res
return lora_res

@staticmethod
def backward(ctx, d_y):
"""
Backward method for LoRATriton.
Backward method for memory-efficient LoRA.

Reshapes 3D tensors into 2D and then calls the kernels to update d_lora_a, d_lora_b, and dx.
Reshapes 3D tensors into 2D and then updates d_lora_a, d_lora_b, and dx. The PyTorch matmul
path recomputes ``x @ lora_A.T`` here instead of saving it from forward.
"""
x, lora_A, lora_B = ctx.saved_tensors
scale = ctx.scale
dtype = x.dtype
d_res = d_y if ctx.has_residual and ctx.needs_input_grad[6] else None

reshape = x.dim() == 3
if reshape:
bs, seq_len, d = x.shape
d_y = d_y.reshape(-1, d_y.shape[-1])
x = x.reshape(-1, d)

d_lora_A, d_x = lora_da_dx_update_wrapper(x.t(), d_y, lora_B, lora_A, scale, dtype=dtype)
d_lora_B = lora_db_update_wrapper(lora_A, x.t(), d_y, scale, dtype)

if reshape:
if ctx.use_triton_kernel:
d_lora_A, d_x = lora_da_dx_update_wrapper(x.t(), d_y, lora_B, lora_A, scale, dtype=ctx.dtype)
d_lora_B = lora_db_update_wrapper(lora_A, x.t(), d_y, scale, ctx.dtype)
d_lora_A = d_lora_A.t()
else:
d_x = d_lora_A = d_lora_B = None
needs_x, needs_lora_A, needs_lora_B = ctx.needs_input_grad[:3]
if needs_x or needs_lora_A:
d_y_lora_B = torch.matmul(d_y, lora_B)
if needs_x:
d_x = torch.empty_like(x)
d_x.addmm_(d_y_lora_B, lora_A, beta=0, alpha=scale)
if needs_lora_A:
d_lora_A = torch.matmul(d_y_lora_B.t(), x) * scale

if needs_lora_B:
d_lora_B = torch.empty_like(lora_B)
d_lora_B.addmm_(d_y.t(), F.linear(x, lora_A), beta=0, alpha=scale)

if reshape and d_x is not None:
d_x = d_x.view(bs, seq_len, d)
return d_x, d_lora_A.t(), d_lora_B, None, None

gradients = (d_x, d_lora_A, d_lora_B, None, None)
if ctx.num_inputs == 7:
return gradients + (None, d_res)
if ctx.num_inputs == 6:
return gradients + (None,)
return gradients
1 change: 1 addition & 0 deletions tests/functional_tests/checkpoint/test_peft.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def test_hf_peft_checkpoint(force_hf, use_triton):
"moe_rank_scaling": False,
"target_modules": [],
"use_dora": False,
"use_memory_efficient_lora": True,
"use_triton": False,
}

Expand Down
19 changes: 10 additions & 9 deletions tests/functional_tests/checkpoint/test_peft_vlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ def test_hf_peft_checkpoint():
"moe_rank_scaling": False,
"target_modules": [],
"use_dora": False,
"use_memory_efficient_lora": True,
"use_triton": True,
}

Expand Down Expand Up @@ -772,6 +773,7 @@ def test_hf_peft_checkpoint():
source_model = trainer.model_parts[0]

from nemo_automodel.components.checkpoint.checkpointing import _load_hf_checkpoint_preserving_dtype

hf_model_path = cfg.get("model.pretrained_model_name_or_path")
hf_state_dict = _load_hf_checkpoint_preserving_dtype(hf_model_path) or {}
print(f"HF checkpoint loaded: {len(hf_state_dict)} keys from {hf_model_path}", flush=True)
Expand All @@ -782,9 +784,7 @@ def test_hf_peft_checkpoint():
print(f"Model param keys (first 10, no lora): {model_keys_sorted[:10]}", flush=True)
param_mismatches = []
buffer_mismatches = []
for (sn, sp), (rn, rp) in zip(
source_model.named_parameters(), restored_model.named_parameters()
):
for (sn, sp), (rn, rp) in zip(source_model.named_parameters(), restored_model.named_parameters()):
assert sn == rn, f"Parameter name mismatch: {sn} vs {rn}"
sp_full = sp.full_tensor() if hasattr(sp, "full_tensor") else sp
rp_full = rp.full_tensor() if hasattr(rp, "full_tensor") else rp
Expand Down Expand Up @@ -818,9 +818,7 @@ def test_hf_peft_checkpoint():
f"src_norm={sp_full.float().norm().item():.4f} rst_norm={rp_full.float().norm().item():.4f} "
f"| {src_vs_hf} | {rst_vs_hf}"
)
for (sn, sb), (rn, rb) in zip(
source_model.named_buffers(), restored_model.named_buffers()
):
for (sn, sb), (rn, rb) in zip(source_model.named_buffers(), restored_model.named_buffers()):
assert sn == rn, f"Buffer name mismatch: {sn} vs {rn}"
if sb.is_meta or rb.is_meta:
buffer_mismatches.append(f" BUFFER {sn}: src_meta={sb.is_meta} rst_meta={rb.is_meta}")
Expand All @@ -836,13 +834,16 @@ def test_hf_peft_checkpoint():
f"max_diff={diff.max().item():.6e} mean_diff={diff.mean().item():.6e}"
)
if param_mismatches or buffer_mismatches:
print(f"\n{'='*80}", flush=True)
print(f"WEIGHT COMPARISON: {len(param_mismatches)} param mismatches, {len(buffer_mismatches)} buffer mismatches", flush=True)
print(f"\n{'=' * 80}", flush=True)
print(
f"WEIGHT COMPARISON: {len(param_mismatches)} param mismatches, {len(buffer_mismatches)} buffer mismatches",
flush=True,
)
for m in param_mismatches:
print(m, flush=True)
for m in buffer_mismatches:
print(m, flush=True)
print(f"{'='*80}\n", flush=True)
print(f"{'=' * 80}\n", flush=True)
else:
print("WEIGHT COMPARISON: All parameters and buffers match exactly.", flush=True)

Expand Down
Loading
Loading