-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathparameters.py
More file actions
120 lines (108 loc) · 7.03 KB
/
parameters.py
File metadata and controls
120 lines (108 loc) · 7.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import argparse
import json
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--seed', type=int, default=None, help="Random seed")
# Sequence arguments
parser.add_argument('--snr', type=float, default=None, help="Transmission SNR")
parser.add_argument('--snr_all', type=bool, default=None, help="Transmission SNR")
parser.add_argument('--snr_high', type=float, default=None, help="Transmission SNR")
parser.add_argument('--snr_low', type=float, default=None, help="Transmission SNR")
parser.add_argument('--K_in', type=int, default=None, help="Length of the initial information bits before outer code")
parser.add_argument('--K', type=int, default=None, help="Length of input sequence to the inner code")
parser.add_argument('--M', type=int, default=None, help="Block size")
parser.add_argument('--block_rate', type=int, default=None, help="Number of parity bits produced for each block")
parser.add_argument('--core', type=int, default=None)
parser.add_argument('--NS_model', type=int, default=None) # use 1 for relu, use 2 for gelu
parser.add_argument('--num_iters', type=int, default=None)
parser.add_argument('--num_iters_test', type=int, default=None)
parser.add_argument('--constraint', type=str, default=None, help="Type of power normalization")
# Transformer arguments
parser.add_argument('--heads_trx', type=int, default=None, help="number of heads for the multi-head attention")
parser.add_argument('--heads_trx_rf', type=int, default=None, help="number of heads for the multi-head attention")
parser.add_argument('--d_k_trx', type=int, default=None, help="number of features for each head")
parser.add_argument('--N_trx', type=int, default=None, help="number of layers in the encoder")
parser.add_argument('--N_rec', type=int, default=None, help="number of layers in the decoder")
parser.add_argument('--dropout', type=float, default=None, help="prob of dropout")
parser.add_argument('--custom_attn', type=bool, default=None, help="use custom attention")
parser.add_argument('--model_average', type=bool, default=None, help="use custom attention")
parser.add_argument('--vv', type=int, default=None)
parser.add_argument('--pe_type', type=str, default=None)
parser.add_argument('--save_first', type=int, default=None, help="number of batches to save the statistics")
# Learning arguments
parser.add_argument('--train', type=int, default=None)
parser.add_argument('--total_batches', type=int, default=None, help="number of total batches to train")
parser.add_argument('--load_batches', type=int, default=None, help="number of batches to load")
parser.add_argument('--batch_size', type=int, default=None, help="batch size")
parser.add_argument('--opt_method', type=str, default='adamW', help="Optimization method adamW,lamb,adam")
parser.add_argument('--clip_th', type=float, default=0.5, help="clipping threshold")
parser.add_argument('--use_lr_schedule', type=str, default='linear', help="lr scheduling")
parser.add_argument('--lr', type=float, default=None, help="learning rate")
parser.add_argument('--wd', type=float, default=0.01, help="weight decay")
parser.add_argument('--device', type=str, default=None, help="GPU")
# Outer code type
parser.add_argument('--padd_symbol_len', type=int, default=None,
help="The number of learnable symbols generated"
"to obtain a wanted length of outer"
"coded sequence")
parser.add_argument('--outer_code_type', type=str, default=None)
parser.add_argument('--dec_type', type=str, default=None)
parser.add_argument('--loss_type', type=str, default=None)
parser.add_argument('--usual_koef', type=float, default=None)
parser.add_argument('--first_koef', type=float, default=None)
parser.add_argument('--reloc1', type=bool, default=None)
parser.add_argument("--llrs_norm_loss", type=bool, default=None)
parser.add_argument('--bp_iters', type=int, default=None)
parser.add_argument('--bp_iters_test', type=int, default=None)
parser.add_argument('--add_last', type=int, default=None)
parser.add_argument('--ldpc_type', type=str, default=None)
parser.add_argument('--multilabel', type=str, default=None)
parser.add_argument('--logging_write', type=bool, default=None)
parser.add_argument('--bp_iters_arr', type=list, default=None)
parser.add_argument('--bp_iters_arr_test', type=list, default=None)
parser.add_argument('--backprop_outer', type=bool, default=None)
parser.add_argument('--no_inner', type=bool, default=None)
parser.add_argument('--power_norm', type=bool, default=None)
parser.add_argument('--snr_noisy', type=bool, default=None)
parser.add_argument('--af_module', type=bool, default=None)
parser.add_argument('--random_bpiters', type=bool, default=None)
parser.add_argument('--random_bpiters_iter', type=int, default=None)
parser.add_argument('--weights', type=bool, default=None)
parser.add_argument('--high_bp', type=int, default=None)
parser.add_argument('--low_bp', type=int, default=None)
parser.add_argument('--w_n', type=int, default=None)
parser.add_argument('--loss_koef', type=float, default=None)
parser.add_argument('--no_Tmodel', type=bool, default=None)
parser.add_argument('--no_Rmodel', type=bool, default=None)
parser.add_argument('--test_batch_num', type=int, default=None)
parser.add_argument('--ebno_train', type=bool, default=None)
parser.add_argument('--not_best_model', type=bool, default=None)
parser.add_argument('--test_bestloss', type=bool, default=None)
parser.add_argument('--no_inner_test', type=bool, default=None)
parser.add_argument('--rayleigh', type=bool, default=None)
parser.add_argument('--K_in_polar', type=int, default=None)
parser.add_argument('--check_vanilla', type=bool, default=None)
parser.add_argument('--check_scl', type=bool, default=None)
parser.add_argument('--check_crc', type=bool, default=None)
parser.add_argument('--linear_layers', type=bool, default=None)
parser.add_argument('--pairwise_distance', type=bool, default=None)
parser.add_argument('--protograph', type=bool, default=None)
parser.add_argument('--label_smoothing', type=bool, default=None)
parser.add_argument('--method_bp', type=str, default=None)
args = parser.parse_args()
return args
def read_configs(args=None, code_type=None):
configs_path = f"./configs/config_{code_type}.json"
with open(configs_path) as json_file:
config_dict = json.load(json_file)
if args is not None:
for key, val in vars(args).items():
if val is not None:
if key in config_dict.keys():
config_dict[key] = val
if config_dict["bp_iters_arr"] is not None:
config_dict["bp_iters_arr"] = eval(config_dict["bp_iters_arr"])
if config_dict["bp_iters_arr_test"] is not None:
config_dict["bp_iters_arr_test"] = eval(config_dict["bp_iters_arr_test"])
return config_dict