-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCodes.py
More file actions
92 lines (86 loc) · 3.78 KB
/
Codes.py
File metadata and controls
92 lines (86 loc) · 3.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
"""
@author: Yoni Choukroun, choukroun.yoni@gmail.com
Error Correction Code Transformer
https://arxiv.org/abs/2203.14966
"""
import numpy as np
import torch
import os
def Read_pc_matrixrix_alist(fileName):
with open(fileName, 'r') as file:
lines = file.readlines()
columnNum, rowNum = np.fromstring(
lines[0].rstrip('\n'), dtype=int, sep=' ')
H = np.zeros((rowNum, columnNum)).astype(int)
for column in range(4, 4 + columnNum):
nonZeroEntries = np.fromstring(
lines[column].rstrip('\n'), dtype=int, sep=' ')
for row in nonZeroEntries:
if row > 0:
H[row - 1, column - 4] = 1
return H
#############################################
def row_reduce(mat, ncols=None):
assert mat.ndim == 2
ncols = mat.shape[1] if ncols is None else ncols
mat_row_reduced = mat.copy()
p = 0
# print(mat.shape)
for j in range(ncols):
idxs = p + np.nonzero(mat_row_reduced[p:,j])[0]
# print(idxs, p)
if idxs.size == 0:
continue
mat_row_reduced[[p,idxs[0]],:] = mat_row_reduced[[idxs[0],p],:]
idxs = np.nonzero(mat_row_reduced[:,j])[0].tolist()
# print('two', idxs, p)
idxs.remove(p)
mat_row_reduced[idxs,:] = mat_row_reduced[idxs,:] ^ mat_row_reduced[p,:]
p += 1
if p == mat_row_reduced.shape[0]:
break
return mat_row_reduced, p
def get_generator(pc_matrix_):
print('get_generator', pc_matrix_.shape)
assert pc_matrix_.ndim == 2
pc_matrix = pc_matrix_.copy().astype(bool).transpose()
pc_matrix_I = np.concatenate((pc_matrix, np.eye(pc_matrix.shape[0], dtype=bool)), axis=-1)
pc_matrix_I, p = row_reduce(pc_matrix_I, ncols=pc_matrix.shape[1])
return row_reduce(pc_matrix_I[p:, pc_matrix.shape[1]:])[0]
def get_standard_form(pc_matrix_):
pc_matrix = pc_matrix_.copy().astype(bool)
next_col = min(pc_matrix.shape)
for ii in range(min(pc_matrix.shape)):
while True:
rows_ones = ii + np.where(pc_matrix[ii:, ii])[0]
if len(rows_ones) == 0:
new_shift = np.arange(ii, min(pc_matrix.shape) - 1).tolist()+[min(pc_matrix.shape) - 1,next_col]
old_shift = np.arange(ii + 1, min(pc_matrix.shape)).tolist()+[next_col, ii]
pc_matrix[:, new_shift] = pc_matrix[:, old_shift]
next_col += 1
else:
break
pc_matrix[[ii, rows_ones[0]], :] = pc_matrix[[rows_ones[0], ii], :]
other_rows = pc_matrix[:, ii].copy()
other_rows[ii] = False
pc_matrix[other_rows] = pc_matrix[other_rows] ^ pc_matrix[ii]
return pc_matrix.astype(int)
#############################################
def Get_Generator_and_Parity(code, standard_form = False):
n, k = code.n, code.k
path_pc_mat = os.path.join('Codes_DB', f'{code.code_type}_N{str(n)}_K{str(k)}')
if code.code_type in ['POLAR', 'BCH']:
ParityMatrix = np.loadtxt(path_pc_mat+'.txt')
elif code.code_type in ['CCSDS', 'LDPC', 'MACKAY']:
ParityMatrix = Read_pc_matrixrix_alist(path_pc_mat+'.alist')
else:
raise Exception(f'Wrong code {code.code_type}')
if standard_form and code.code_type not in ['CCSDS', 'LDPC', 'MACKAY']:
ParityMatrix = get_standard_form(ParityMatrix).astype(int)
GeneratorMatrix = np.concatenate([np.mod(-ParityMatrix[:, min(ParityMatrix.shape):].transpose(),2),np.eye(k)],1).astype(int)
else:
GeneratorMatrix = get_generator(ParityMatrix)
print('GeneratorMatrix', GeneratorMatrix.shape)
assert np.all(np.mod((np.matmul(GeneratorMatrix, ParityMatrix.transpose())), 2) == 0) and np.sum(GeneratorMatrix) > 0
return GeneratorMatrix.astype(float), ParityMatrix.astype(float)
#############################################