Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 76 additions & 46 deletions embedding_net/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,83 @@

def get_aug(name='default', input_shape=[48, 48, 3]):
if name == 'default':
augmentations = A.Compose([
A.RandomBrightnessContrast(p=0.4),
A.RandomGamma(p=0.4),
A.HueSaturationValue(hue_shift_limit=20,
sat_shift_limit=30, val_shift_limit=30, p=0.4),
A.CLAHE(p=0.4),
A.Blur(blur_limit=1, p=0.3),
A.GaussNoise(var_limit=(50, 80), p=0.3)
], p=1)
return A.Compose(
[
A.RandomBrightnessContrast(p=0.4),
A.RandomGamma(p=0.4),
A.HueSaturationValue(
hue_shift_limit=20,
sat_shift_limit=30,
val_shift_limit=30,
p=0.4,
),
A.CLAHE(p=0.4),
A.Blur(blur_limit=1, p=0.3),
A.GaussNoise(var_limit=(50, 80), p=0.3),
],
p=1,
)

elif name == 'plates':
augmentations = A.Compose([
A.RandomBrightnessContrast(p=0.4),
A.RandomGamma(p=0.4),
A.HueSaturationValue(hue_shift_limit=20,
sat_shift_limit=30,
val_shift_limit=30,
p=0.4),
A.CLAHE(p=0.4),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Blur(blur_limit=1, p=0.3),
A.GaussNoise(var_limit=(50, 80), p=0.3),
A.RandomCrop(p=0.8, height=2*input_shape[1]/3, width=2*input_shape[0]/3)
], p=1)
return A.Compose(
[
A.RandomBrightnessContrast(p=0.4),
A.RandomGamma(p=0.4),
A.HueSaturationValue(
hue_shift_limit=20,
sat_shift_limit=30,
val_shift_limit=30,
p=0.4,
),
A.CLAHE(p=0.4),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.Blur(blur_limit=1, p=0.3),
A.GaussNoise(var_limit=(50, 80), p=0.3),
A.RandomCrop(
p=0.8,
height=2 * input_shape[1] / 3,
width=2 * input_shape[0] / 3,
),
],
p=1,
)

elif name == 'deepfake':
augmentations = A.Compose([
A.HorizontalFlip(p=0.5),
], p=1)
return A.Compose(
[
A.HorizontalFlip(p=0.5),
],
p=1,
)

elif name == 'plates2':
augmentations = A.Compose([
A.CLAHE(clip_limit=(1,4),p=0.3),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomBrightness(limit=0.2, p=0.3),
A.RandomContrast(limit=0.2, p=0.3),
# A.Rotate(limit=360, p=0.9),
A.RandomRotate90(p=0.3),
A.HueSaturationValue(hue_shift_limit=(-50,50),
sat_shift_limit=(-15,15),
val_shift_limit=(-15,15),
p=0.5),
# A.Blur(blur_limit=(5,7), p=0.3),
A.GaussNoise(var_limit=(10, 50), p=0.3),
A.CenterCrop(p=1, height=2*input_shape[1]//3, width=2*input_shape[0]//3),
A.Resize(p=1, height=input_shape[1], width=input_shape[0])
], p=1)
else:
augmentations = None
return A.Compose(
[
A.CLAHE(clip_limit=(1, 4), p=0.3),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.5),
A.RandomBrightness(limit=0.2, p=0.3),
A.RandomContrast(limit=0.2, p=0.3),
# A.Rotate(limit=360, p=0.9),
A.RandomRotate90(p=0.3),
A.HueSaturationValue(
hue_shift_limit=(-50, 50),
sat_shift_limit=(-15, 15),
val_shift_limit=(-15, 15),
p=0.5,
),
# # A.Blur(blur_limit=(5,7), p=0.3),
A.GaussNoise(var_limit=(10, 50), p=0.3),
A.CenterCrop(
p=1,
height=2 * input_shape[1] // 3,
width=2 * input_shape[0] // 3,
),
A.Resize(p=1, height=input_shape[1], width=input_shape[0]),
],
p=1,
)

return augmentations
else:
return None
Comment on lines -6 to +85
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function get_aug refactored with the following changes:

This removes the following comments ( why? ):

#             A.Blur(blur_limit=(5,7), p=0.3),

26 changes: 8 additions & 18 deletions embedding_net/datagenerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ def __init__(self, dataset_path,
self.dataset_path = dataset_path
self.class_files_paths = {}
self.class_names = []

if train_csv_file is not None:
self.class_files_paths = self._load_from_dataframe(train_csv_file, image_id_column, label_column, is_google)
else:
self.class_files_paths = self._load_from_directory()

Comment on lines -29 to +34
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Found the following improvement in Function ENDataLoader.__init__:

self.n_classes = len(self.class_names)
self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()}

Expand Down Expand Up @@ -94,7 +94,7 @@ def _load_from_directory(self):
for class_name, class_dir_path in tqdm.tqdm(zip(self.class_names, class_dir_paths)):
subdirs = [f.path for f in os.scandir(class_dir_path) if f.is_dir()]
temp_list = []
if len(subdirs)>0:
if subdirs:
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENDataLoader._load_from_directory refactored with the following changes:

for subdir in subdirs:
class_image_paths = [f.path for f in os.scandir(subdir) if f.is_file() and
(f.name.endswith('.jpg') or
Expand Down Expand Up @@ -134,10 +134,7 @@ def __init__(self, class_files_paths,
self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()}

def __len__(self):
if self.val_gen:
return self.n_batches_val
else:
return self.n_batches
return self.n_batches_val if self.val_gen else self.n_batches
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function ENDataGenerator.__len__ refactored with the following changes:


def __getitem__(self, index):
pass
Expand Down Expand Up @@ -207,7 +204,7 @@ def get_batch_triplets_mining(self):
all_embeddings_list = []
all_images_list = []


Comment on lines -210 to +207
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function TripletsDataGenerator.get_batch_triplets_mining refactored with the following changes:

for idx, cl_img_idxs in enumerate(selected_images):
images = self._get_images_set(selected_classes[idx], cl_img_idxs, with_aug=self.augmentations)
all_images_list.append(images)
Expand Down Expand Up @@ -243,7 +240,7 @@ def get_batch_triplets_mining(self):
triplet_negatives.append(all_images[hard_negative])
targets.append(1)

if len(triplet_anchors) == 0:
if not triplet_anchors:
triplet_anchors.append(all_images[anchor_positive[0]])
triplet_positives.append(all_images[anchor_positive[1]])
triplet_negatives.append(all_images[negative_indices[0]])
Expand Down Expand Up @@ -282,9 +279,7 @@ def get_batch_triplets(self):
np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], 3))]
targets = np.zeros((self.batch_size,))

count = 0

for i in range(self.batch_size):
for count, i in enumerate(range(self.batch_size)):
Comment on lines -285 to +282
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SimpleTripletsDataGenerator.get_batch_triplets refactored with the following changes:

selected_class_idx = random.randrange(0, self.n_classes)
selected_class = self.class_names[selected_class_idx]
selected_class_n_elements = self.n_samples[selected_class]
Expand All @@ -306,8 +301,6 @@ def get_batch_triplets(self):
triplets[1][count, :, :, :] = imgs[1]
triplets[2][count, :, :, :] = imgs[2]
targets[i] = 1
count += 1

return triplets, targets

def __getitem__(self, index):
Expand Down Expand Up @@ -398,9 +391,8 @@ def get_batch(self):
np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], 3))]
targets = np.zeros((self.batch_size, self.n_classes))

count = 0
with_aug = self.augmentations
for i in range(self.batch_size):
for count, i in enumerate(range(self.batch_size)):
Comment on lines -401 to +395
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function SimpleDataGenerator.get_batch refactored with the following changes:

selected_class_idx = random.randrange(0, self.n_classes)
selected_class = self.class_names[selected_class_idx]
selected_class_n_elements = len(self.class_files_paths[selected_class])
Expand All @@ -410,8 +402,6 @@ def get_batch(self):
img = self._get_images_set([selected_class], [indx], with_aug=with_aug)
images[0][count, :, :, :] = img[0]
targets[i][selected_class_idx] = 1
count += 1

return images, targets

def __getitem__(self, index):
Expand Down
40 changes: 14 additions & 26 deletions embedding_net/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def _create_base_model(self):
self.classification_model = Model(inputs=[self.base_model.layers[0].input],outputs=[output])

def _generate_encodings(self, imgs):
encodings = self.base_model.predict(imgs)
return encodings
return self.base_model.predict(imgs)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function EmbeddingNet._generate_encodings refactored with the following changes:



def train_embeddings_classifier(self, data_loader,
Expand All @@ -61,27 +60,25 @@ def train_embeddings_classifier(self, data_loader,
def generate_encodings(self, data_loader, max_n_samples=10,
shuffle=True):
data_paths, data_labels, data_encodings = [], [], []
encoded_training_data = {}

for class_name in data_loader.class_names:
data_list = data_loader.train_data[class_name]
if len(data_list)>max_n_samples:
if shuffle:
random.shuffle(data_list)
data_list = data_list[:max_n_samples]

Comment on lines -64 to +69
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function EmbeddingNet.generate_encodings refactored with the following changes:

data_paths += data_list
imgs = get_images(data_list, self.params_model['input_shape'])
encods = self._generate_encodings(imgs)
for encod in encods:
data_encodings.append(encod)
data_labels.append(class_name)

encoded_training_data['paths'] = data_paths
encoded_training_data['labels'] = data_labels
encoded_training_data['encodings'] = np.squeeze(np.array(data_encodings))

return encoded_training_data
return {
'paths': data_paths,
'labels': data_labels,
'encodings': np.squeeze(np.array(data_encodings)),
}

def save_encodings(self, encoded_training_data,
save_folder='./',
Expand Down Expand Up @@ -113,23 +110,16 @@ def save_onnx(self, save_folder, save_name='base_model.onnx'):
keras2onnx.save_model(onnx_model, os.path.join(save_folder, save_name))

def predict(self, image):
if type(image) is str:
img = cv2.imread(image)
else:
img = image
img = cv2.imread(image) if type(image) is str else image
img = cv2.resize(img, (self.params_model['input_shape'][0],
self.params_model['input_shape'][1]))
encoding = self.base_model.predict(np.expand_dims(img, axis=0))
distances = self.calculate_distances(encoding)
max_element = np.argmin(distances)
predicted_label = self.encoded_training_data['labels'][max_element]
return predicted_label
return self.encoded_training_data['labels'][max_element]
Comment on lines -116 to +119
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function EmbeddingNet.predict refactored with the following changes:


def predict_knn(self, image, with_top5=False):
if type(image) is str:
img = cv2.imread(image)
else:
img = image
img = cv2.imread(image) if type(image) is str else image
Comment on lines -129 to +122
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function EmbeddingNet.predict_knn refactored with the following changes:

img = cv2.resize(img, (self.input_shape[0], self.input_shape[1]))

encoding = self.base_model.predict(np.expand_dims(img, axis=0))
Expand All @@ -145,8 +135,6 @@ def calculate_prediction_accuracy(self, data_loader):
correct_top1 = 0
correct_top5 = 0

accuracies = {'top1':0,
'top5':0 }
Comment on lines -148 to -149
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Function EmbeddingNet.calculate_prediction_accuracy refactored with the following changes:

total_n_of_images = len(data_loader.images_paths['val'])
for img_path, img_label in zip(data_loader.images_paths['val'],
data_loader.images_labels['val']):
Expand All @@ -155,10 +143,10 @@ def calculate_prediction_accuracy(self, data_loader):
correct_top1 += 1
if img_label in prediction_top5:
correct_top5 += 1
accuracies['top1'] = correct_top1/total_n_of_images
accuracies['top5'] = correct_top5/total_n_of_images

return accuracies
return {
'top1': correct_top1 / total_n_of_images,
'top5': correct_top5 / total_n_of_images,
}


class TripletNet(EmbeddingNet):
Expand Down
Loading