1- import json
2- from dataclasses import asdict , dataclass , field
1+ from dataclasses import dataclass , field
32from typing import Literal , Optional
43
54
@@ -9,22 +8,40 @@ class FreezeArguments:
98 Arguments pertaining to the freeze (partial-parameter) training.
109 """
1110
12- name_module_trainable : str = field (
11+ freeze_trainable_layers : int = field (
12+ default = 2 ,
13+ metadata = {
14+ "help" : (
15+ "The number of trainable layers for freeze (partial-parameter) fine-tuning. "
16+ "Positive numbers mean the last n layers are set as trainable, "
17+ "negative numbers mean the first n layers are set as trainable."
18+ )
19+ },
20+ )
21+ freeze_trainable_modules : str = field (
1322 default = "all" ,
1423 metadata = {
15- "help" : """Name of trainable modules for partial-parameter (freeze) fine-tuning. \
16- Use commas to separate multiple modules. \
17- Use "all" to specify all the available modules. \
18- LLaMA choices: ["mlp", "self_attn"], \
19- BLOOM & Falcon & ChatGLM choices: ["mlp", "self_attention"], \
20- Qwen choices: ["mlp", "attn"], \
21- InternLM2 choices: ["feed_forward", "attention"], \
22- Others choices: the same as LLaMA."""
24+ "help" : (
25+ "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
26+ "Use commas to separate multiple modules. "
27+ "Use `all` to specify all the available modules. "
28+ "LLaMA choices: [`mlp`, `self_attn`], "
29+ "BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], "
30+ "Qwen choices: [`mlp`, `attn`], "
31+ "InternLM2 choices: [`feed_forward`, `attention`], "
32+ "Others choices: the same as LLaMA."
33+ )
2334 },
2435 )
25- num_layer_trainable : int = field (
26- default = 2 ,
27- metadata = {"help" : "The number of trainable layers for partial-parameter (freeze) fine-tuning." },
36+ freeze_extra_modules : Optional [str ] = field (
37+ default = None ,
38+ metadata = {
39+ "help" : (
40+ "Name(s) of modules apart from hidden layers to be set as trainable "
41+ "for freeze (partial-parameter) fine-tuning. "
42+ "Use commas to separate multiple modules."
43+ )
44+ },
2845 )
2946
3047
@@ -37,7 +54,11 @@ class LoraArguments:
3754 additional_target : Optional [str ] = field (
3855 default = None ,
3956 metadata = {
40- "help" : "Name(s) of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint."
57+ "help" : (
58+ "Name(s) of modules apart from LoRA layers to be set as trainable "
59+ "and saved in the final checkpoint. "
60+ "Use commas to separate multiple modules."
61+ )
4162 },
4263 )
4364 lora_alpha : Optional [int ] = field (
@@ -55,15 +76,17 @@ class LoraArguments:
5576 lora_target : str = field (
5677 default = "all" ,
5778 metadata = {
58- "help" : """Name(s) of target modules to apply LoRA. \
59- Use commas to separate multiple modules. \
60- Use "all" to specify all the linear modules. \
61- LLaMA choices: ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], \
62- BLOOM & Falcon & ChatGLM choices: ["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], \
63- Baichuan choices: ["W_pack", "o_proj", "gate_proj", "up_proj", "down_proj"], \
64- Qwen choices: ["c_attn", "attn.c_proj", "w1", "w2", "mlp.c_proj"], \
65- InternLM2 choices: ["wqkv", "wo", "w1", "w2", "w3"], \
66- Others choices: the same as LLaMA."""
79+ "help" : (
80+ "Name(s) of target modules to apply LoRA. "
81+ "Use commas to separate multiple modules. "
82+ "Use `all` to specify all the linear modules. "
83+ "LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
84+ "BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], "
85+ "Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
86+ "Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], "
87+ "InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], "
88+ "Others choices: the same as LLaMA."
89+ )
6790 },
6891 )
6992 loraplus_lr_ratio : Optional [float ] = field (
@@ -177,8 +200,10 @@ class GaloreArguments:
177200 galore_target : str = field (
178201 default = "all" ,
179202 metadata = {
180- "help" : """Name(s) of modules to apply GaLore. Use commas to separate multiple modules. \
181- Use "all" to specify all the linear modules."""
203+ "help" : (
204+ "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
205+ "Use `all` to specify all the linear modules."
206+ )
182207 },
183208 )
184209 galore_rank : int = field (
@@ -238,16 +263,20 @@ class BAdamArgument:
238263 badam_mask_mode : Literal ["adjacent" , "scatter" ] = field (
239264 default = "adjacent" ,
240265 metadata = {
241- "help" : """The mode of the mask for BAdam optimizer. \
242- `adjacent` means that the trainable parameters are adjacent to each other, \
243- `scatter` means that trainable parameters are randomly choosed from the weight."""
266+ "help" : (
267+ "The mode of the mask for BAdam optimizer. "
268+ "`adjacent` means that the trainable parameters are adjacent to each other, "
269+ "`scatter` means that trainable parameters are randomly choosed from the weight."
270+ )
244271 },
245272 )
246273 badam_verbose : int = field (
247274 default = 0 ,
248275 metadata = {
249- "help" : """The verbosity level of BAdam optimizer. \
250- 0 for no print, 1 for print the block prefix, 2 for print trainable parameters"""
276+ "help" : (
277+ "The verbosity level of BAdam optimizer. "
278+ "0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
279+ )
251280 },
252281 )
253282
@@ -285,7 +314,8 @@ def split_arg(arg):
285314 return [item .strip () for item in arg .split ("," )]
286315 return arg
287316
288- self .name_module_trainable = split_arg (self .name_module_trainable )
317+ self .freeze_trainable_modules = split_arg (self .freeze_trainable_modules )
318+ self .freeze_extra_modules = split_arg (self .freeze_extra_modules )
289319 self .lora_alpha = self .lora_alpha or self .lora_rank * 2
290320 self .lora_target = split_arg (self .lora_target )
291321 self .additional_target = split_arg (self .additional_target )
@@ -315,17 +345,3 @@ def split_arg(arg):
315345
316346 if self .loraplus_lr_ratio is not None and self .finetuning_type != "lora" :
317347 raise ValueError ("`loraplus_lr_ratio` is only valid for the LoRA training." )
318-
319- def save_to_json (self , json_path : str ):
320- r"""Saves the content of this instance in JSON format inside `json_path`."""
321- json_string = json .dumps (asdict (self ), indent = 2 , sort_keys = True ) + "\n "
322- with open (json_path , "w" , encoding = "utf-8" ) as f :
323- f .write (json_string )
324-
325- @classmethod
326- def load_from_json (cls , json_path : str ):
327- r"""Creates an instance from the content of `json_path`."""
328- with open (json_path , "r" , encoding = "utf-8" ) as f :
329- text = f .read ()
330-
331- return cls (** json .loads (text ))
0 commit comments