-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathfed_sweep_part.py
More file actions
119 lines (105 loc) · 3.91 KB
/
fed_sweep_part.py
File metadata and controls
119 lines (105 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
"""Grid Search for FedETuning"""
import os
import sys
import itertools as it
from loguru import logger
from multiprocessing import Pool
from configs.tuning import hyperparameter_grid
def run_process(proc):
os.system(proc)
machine_env = sys.argv[1]
task_name = sys.argv[2]
fl_algorithm = sys.argv[3]
config_path = sys.argv[4]
# partition_method = sys.argv[4]
tuning_type = sys.argv[5]
port_start = int(sys.argv[6])
device = sys.argv[7]
if machine_env == "ali-dsw":
run_dirs = "/mnt/workspace"
elif machine_env == "ali-dlc":
run_dirs = "/root/data"
elif machine_env == "hit":
run_dirs = "/data/xiangjing"
else:
run_dirs = machine_env
device_idx_list = [idx for idx in device.split(",")]
n_gpu = len(device_idx_list)
world_size = n_gpu
logger.info(f"world_size is {world_size}")
max_seq = 512
dataset_name = "legal"
metric_name = "legal"
model_name = 'roberta-wwm-ext'
data_file = 'legal/silo'
if task_name == "LCP":
model_output_mode = "seq_classification"
elif task_name == 'LJP':
model_output_mode = 'seq_regression'
elif task_name == 'LER':
model_output_mode = "token_classification_crf"
elif task_name == 'LRE':
model_output_mode = "seq_classification"
elif task_name == 'LAM':
model_output_mode = "multi_seq_classification"
elif task_name == 'LDG':
model_output_mode = 'seq_generation'
model_name = 'gpt2-chinese-cluecorpussmall'
max_seq = 1024
else:
logger.info(f"not support {task_name}")
logger.info(f"{task_name}'s max_seq is {max_seq}")
cmds = []
hyper_parameter = hyperparameter_grid[tuning_type]
gpu_index = 0
for parameter in it.product(*list(hyper_parameter.values())):
specific_parameter_dict = {key: parameter[list(hyper_parameter.keys()).index(key)]
for key in list(hyper_parameter.keys())}
if "lora_rank" in specific_parameter_dict:
specific_parameter_dict["lora_alpha"] = specific_parameter_dict["lora_rank"]
port = port_start + gpu_index
device_index = gpu_index % n_gpu
cmd = f'CUDA_VISIBLE_DEVICES={device_idx_list[device_index]} python main.py '
options = [
"--model_name_or_path", f"{run_dirs}/pretrain/nlp/{model_name}/",
"--output_dir", f"{run_dirs}/output/{data_file}",
"--task_name", f"{task_name}",
"--fl_algorithm", f"{fl_algorithm}",
"--raw_dataset_path", f"{run_dirs}/datasets/{data_file}",
"--partition_dataset_path", f"{run_dirs}/datasets/{data_file}",
"--max_seq_length", f"{max_seq}",
"--world_size", f"{world_size}",
"--port", f"{port}",
"--dataset_name", dataset_name,
"--metric_name", metric_name,
"--model_output_mode", model_output_mode,
"--tuning_type", f"{tuning_type}_{model_name}",
"--raw_tuning_type", tuning_type,
"--config_path", config_path,
"--do_grid", "True",
]
for key, value in specific_parameter_dict.items():
options.extend(["--" + key, str(value)])
server_options = options + ["--rank", "0"]
server_cmd = cmd + " ".join(server_options)
print(f"server_cmd cmd: {server_cmd}")
one_cmd_list = [server_cmd]
for i in range(1, world_size//2):
# debug for one fine_tuning
cmd = cmd.replace(f"CUDA_VISIBLE_DEVICES={device_idx_list[gpu_index % n_gpu]}",
f"CUDA_VISIBLE_DEVICES={device_idx_list[(gpu_index + 1) % n_gpu]}")
gpu_index += 1
client_options = options + ["--rank", str(i)]
client_cmd = cmd + " ".join(client_options)
# client_cmd = "sleep 2s " + client_cmd
one_cmd_list.append(client_cmd)
one_cmd = " & ".join(one_cmd_list)
one_cmd += " & wait"
gpu_index += 1
cmds.append(one_cmd)
print(one_cmd)
# run_process("sleep 3s")
logger.warning(f"run {len(cmds)} grid-search tasks for {model_name}_{task_name}_{tuning_type}")
# run_process(cmds[0]) # debug
pool = Pool(processes=n_gpu//2) # # Each task uses half of provided gpus
pool.map(run_process, cmds)