Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added .DS_Store
Binary file not shown.
11 changes: 11 additions & 0 deletions src/llamafactory/hparams/finetuning_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ class LoraArguments:
)
},
)
lora_parameters: Optional[str] = field(
default=None,
metadata={
"help": (
"Name(s) of nn.Parameters to apply LoRA directly. "
"Use commas to separate multiple parameters. "
"Useful for MoE models with expert parameters."
)
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
Expand Down Expand Up @@ -524,6 +534,7 @@ def split_arg(arg):
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
self.lora_target: list[str] = split_arg(self.lora_target)
self.lora_parameters: Optional[list[str]] = split_arg(self.lora_parameters)
self.oft_target: list[str] = split_arg(self.oft_target)
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
self.galore_target: list[str] = split_arg(self.galore_target)
Expand Down
9 changes: 8 additions & 1 deletion src/llamafactory/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,14 @@ def _setup_lora_tuning(
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))

if is_trainable and adapter_to_resume is None: # create new lora weights while training
target_modules = []
target_parameters = []
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
if finetuning_args.lora_parameters: # if specified the parameters to be adapted, use them
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we specify the target parameters, the target modules should not be affected

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if-else is strange

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I didn't understand the idea very clearly, because I noticed that Lora target has a default value of "all". I was thinking that if this default value is not changed and Lora parameters are used for injection without the user specifying a target, then this if-else should be judged here.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation will prevent target_parameters (lora_parameters) from being passed to peftconfig when target_modules (lora_target) is specified. There is no need to modify this peice of code to pass lora_parameters. Instead I would pass it immediatly to peft_kwargs:

            peft_kwargs = {
                "r": finetuning_args.lora_rank,
                "target_modules": target_modules,
                "target_parameters": finetuning_args.lora_parameters,
                "lora_alpha": finetuning_args.lora_alpha,
                "lora_dropout": finetuning_args.lora_dropout,
                "use_rslora": finetuning_args.use_rslora,
                "use_dora": finetuning_args.use_dora,
                "modules_to_save": finetuning_args.additional_target,
            }

target_parameters is an optional argument with None as default.

logger.info_rank0("Using specified LoRA parameters: {}", finetuning_args.lora_parameters)
target_parameters = finetuning_args.lora_parameters
else:
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower)
else:
target_modules = finetuning_args.lora_target

Expand Down Expand Up @@ -235,6 +241,7 @@ def _setup_lora_tuning(
"use_rslora": finetuning_args.use_rslora,
"use_dora": finetuning_args.use_dora,
"modules_to_save": finetuning_args.additional_target,
"target_parameters": target_parameters,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is target parameters always be defined?

}
elif finetuning_args.finetuning_type == "oft":
peft_kwargs = {
Expand Down
11 changes: 8 additions & 3 deletions src/llamafactory/train/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,16 @@ def compare_model(model_a: "torch.nn.Module", model_b: "torch.nn.Module", diff_k


def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
linear_modules, extra_modules = set(), set()
linear_modules, linear_parameters, extra_modules = set(), set(), set()
for name, param in model.named_parameters():
if any(module in name for module in ["lora_A", "lora_B"]):
linear_modules.add(name.split(".lora_", maxsplit=1)[0].split(".")[-1])
parts = name.split(".")
for i, part in enumerate(parts):
if "lora_" in part:
short_name = parts[i - 1] + "." + parts[-1]
linear_parameters.add(short_name)
break
assert param.requires_grad is True
assert param.dtype == torch.float32
elif "modules_to_save" in name:
Expand All @@ -58,8 +64,7 @@ def check_lora_model(model: "LoraModel") -> tuple[set[str], set[str]]:
assert param.requires_grad is False
assert param.dtype == torch.float16

return linear_modules, extra_modules

return linear_modules, linear_parameters, extra_modules

def load_train_model(add_valuehead: bool = False, **kwargs) -> "PreTrainedModel":
model_args, _, _, finetuning_args, _ = get_train_args(kwargs)
Expand Down
3 changes: 3 additions & 0 deletions src/llamafactory/webui/components/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_pissa = gr.Checkbox()
lora_target = gr.Textbox(scale=2)
additional_target = gr.Textbox(scale=2)
lora_parameters = gr.Textbox(scale=2)

input_elems.update(
{
Expand All @@ -192,6 +193,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_pissa,
lora_target,
additional_target,
lora_parameters,
}
)
elem_dict.update(
Expand All @@ -207,6 +209,7 @@ def create_train_tab(engine: "Engine") -> dict[str, "Component"]:
use_pissa=use_pissa,
lora_target=lora_target,
additional_target=additional_target,
lora_parameters=lora_parameters,
)
)

Expand Down
22 changes: 22 additions & 0 deletions src/llamafactory/webui/locales.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,6 +1323,28 @@
"info": "LoRA 層以外の学習可能なモジュールの名前。複数のモジュールを区切るにはカンマを使用します。",
},
},
"lora_parameters": {
"en": {
"label": "LoRA parameters (optional)",
"info": "Name(s) of parameters to apply LoRA. Use commas to separate multiple parameters.",
},
"ru": {
"label": "Параметры LoRA (необязательно)",
"info": "Имя(ена) параметров для применения LoRA. Используйте запятые для разделения нескольких параметров.",
},
"zh": {
"label": "LoRA 参数(可选)",
"info": "要应用 LoRA 的参数名称。使用逗号分隔多个参数。",
},
"ko": {
"label": "LoRA 매개변수 (선택 사항)",
"info": "LoRA를 적용할 매개변수의 이름입니다. 여러 매개변수를 구분하려면 쉼표를 사용하십시오.",
},
"ja": {
"label": "LoRA パラメータ (オプション)",
"info": "LoRA を適用するパラメータの名前。複数のパラメータを区切るにはカンマを使用します。",
},
},
"rlhf_tab": {
"en": {
"label": "RLHF configurations",
Expand Down
1 change: 1 addition & 0 deletions src/llamafactory/webui/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def _parse_train_args(self, data: dict["Component", Any]) -> dict[str, Any]:
args["pissa_convert"] = get("train.use_pissa")
args["lora_target"] = get("train.lora_target") or "all"
args["additional_target"] = get("train.additional_target") or None
args["lora_parameters"] = get("train.lora_parameters") or None

if args["use_llama_pro"]:
args["freeze_trainable_layers"] = get("train.freeze_trainable_layers")
Expand Down
17 changes: 14 additions & 3 deletions tests/model/test_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,19 +63,19 @@ def fix_valuehead_cpu_loading():

def test_lora_train_qv_modules():
model = load_train_model(lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
linear_modules, _, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "v_proj"}


def test_lora_train_all_modules():
model = load_train_model(lora_target="all", **TRAIN_ARGS)
linear_modules, _ = check_lora_model(model)
linear_modules, _, _ = check_lora_model(model)
assert linear_modules == {"q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"}


def test_lora_train_extra_modules():
model = load_train_model(additional_target="embed_tokens,lm_head", **TRAIN_ARGS)
_, extra_modules = check_lora_model(model)
_, _, extra_modules = check_lora_model(model)
assert extra_modules == {"embed_tokens", "lm_head"}


Expand All @@ -91,6 +91,17 @@ def test_lora_train_new_adapters():
compare_model(
model, ref_model, diff_keys=["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "gate_proj", "down_proj"]
)

def test_lora_parameters():
model = load_train_model(lora_parameters="q_proj.weight, k_proj.weight", **TRAIN_ARGS)
_, injected_parameters, _ = check_lora_model(model)
assert injected_parameters == {"q_proj.weight", "k_proj.weight"}

def test_lora_target_and_parameters_conflicts():
model = load_train_model(lora_parameters="q_proj.weight",lora_target="q_proj,v_proj", **TRAIN_ARGS)
linear_modules, injected_parameters, _ = check_lora_model(model)
assert injected_parameters == {"q_proj.weight", "v_proj.weight"}
assert linear_modules == {"q_proj", "v_proj"}


@pytest.mark.usefixtures("fix_valuehead_cpu_loading")
Expand Down