-
Notifications
You must be signed in to change notification settings - Fork 8
Sourcery Starbot ⭐ refactored RocketFlash/EmbeddingNet #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -26,12 +26,12 @@ def __init__(self, dataset_path, | |
| self.dataset_path = dataset_path | ||
| self.class_files_paths = {} | ||
| self.class_names = [] | ||
|
|
||
| if train_csv_file is not None: | ||
| self.class_files_paths = self._load_from_dataframe(train_csv_file, image_id_column, label_column, is_google) | ||
| else: | ||
| self.class_files_paths = self._load_from_directory() | ||
|
|
||
|
Comment on lines
-29
to
+34
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Found the following improvement in Function |
||
| self.n_classes = len(self.class_names) | ||
| self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()} | ||
|
|
||
|
|
@@ -94,7 +94,7 @@ def _load_from_directory(self): | |
| for class_name, class_dir_path in tqdm.tqdm(zip(self.class_names, class_dir_paths)): | ||
| subdirs = [f.path for f in os.scandir(class_dir_path) if f.is_dir()] | ||
| temp_list = [] | ||
| if len(subdirs)>0: | ||
| if subdirs: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| for subdir in subdirs: | ||
| class_image_paths = [f.path for f in os.scandir(subdir) if f.is_file() and | ||
| (f.name.endswith('.jpg') or | ||
|
|
@@ -134,10 +134,7 @@ def __init__(self, class_files_paths, | |
| self.n_samples = {k: len(v) for k, v in self.class_files_paths.items()} | ||
|
|
||
| def __len__(self): | ||
| if self.val_gen: | ||
| return self.n_batches_val | ||
| else: | ||
| return self.n_batches | ||
| return self.n_batches_val if self.val_gen else self.n_batches | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| def __getitem__(self, index): | ||
| pass | ||
|
|
@@ -207,7 +204,7 @@ def get_batch_triplets_mining(self): | |
| all_embeddings_list = [] | ||
| all_images_list = [] | ||
|
|
||
|
|
||
|
Comment on lines
-210
to
+207
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| for idx, cl_img_idxs in enumerate(selected_images): | ||
| images = self._get_images_set(selected_classes[idx], cl_img_idxs, with_aug=self.augmentations) | ||
| all_images_list.append(images) | ||
|
|
@@ -243,7 +240,7 @@ def get_batch_triplets_mining(self): | |
| triplet_negatives.append(all_images[hard_negative]) | ||
| targets.append(1) | ||
|
|
||
| if len(triplet_anchors) == 0: | ||
| if not triplet_anchors: | ||
| triplet_anchors.append(all_images[anchor_positive[0]]) | ||
| triplet_positives.append(all_images[anchor_positive[1]]) | ||
| triplet_negatives.append(all_images[negative_indices[0]]) | ||
|
|
@@ -282,9 +279,7 @@ def get_batch_triplets(self): | |
| np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], 3))] | ||
| targets = np.zeros((self.batch_size,)) | ||
|
|
||
| count = 0 | ||
|
|
||
| for i in range(self.batch_size): | ||
| for count, i in enumerate(range(self.batch_size)): | ||
|
Comment on lines
-285
to
+282
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| selected_class_idx = random.randrange(0, self.n_classes) | ||
| selected_class = self.class_names[selected_class_idx] | ||
| selected_class_n_elements = self.n_samples[selected_class] | ||
|
|
@@ -306,8 +301,6 @@ def get_batch_triplets(self): | |
| triplets[1][count, :, :, :] = imgs[1] | ||
| triplets[2][count, :, :, :] = imgs[2] | ||
| targets[i] = 1 | ||
| count += 1 | ||
|
|
||
| return triplets, targets | ||
|
|
||
| def __getitem__(self, index): | ||
|
|
@@ -398,9 +391,8 @@ def get_batch(self): | |
| np.zeros((self.batch_size, self.input_shape[0], self.input_shape[1], 3))] | ||
| targets = np.zeros((self.batch_size, self.n_classes)) | ||
|
|
||
| count = 0 | ||
| with_aug = self.augmentations | ||
| for i in range(self.batch_size): | ||
| for count, i in enumerate(range(self.batch_size)): | ||
|
Comment on lines
-401
to
+395
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| selected_class_idx = random.randrange(0, self.n_classes) | ||
| selected_class = self.class_names[selected_class_idx] | ||
| selected_class_n_elements = len(self.class_files_paths[selected_class]) | ||
|
|
@@ -410,8 +402,6 @@ def get_batch(self): | |
| img = self._get_images_set([selected_class], [indx], with_aug=with_aug) | ||
| images[0][count, :, :, :] = img[0] | ||
| targets[i][selected_class_idx] = 1 | ||
| count += 1 | ||
|
|
||
| return images, targets | ||
|
|
||
| def __getitem__(self, index): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,8 +45,7 @@ def _create_base_model(self): | |
| self.classification_model = Model(inputs=[self.base_model.layers[0].input],outputs=[output]) | ||
|
|
||
| def _generate_encodings(self, imgs): | ||
| encodings = self.base_model.predict(imgs) | ||
| return encodings | ||
| return self.base_model.predict(imgs) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| def train_embeddings_classifier(self, data_loader, | ||
|
|
@@ -61,27 +60,25 @@ def train_embeddings_classifier(self, data_loader, | |
| def generate_encodings(self, data_loader, max_n_samples=10, | ||
| shuffle=True): | ||
| data_paths, data_labels, data_encodings = [], [], [] | ||
| encoded_training_data = {} | ||
|
|
||
| for class_name in data_loader.class_names: | ||
| data_list = data_loader.train_data[class_name] | ||
| if len(data_list)>max_n_samples: | ||
| if shuffle: | ||
| random.shuffle(data_list) | ||
| data_list = data_list[:max_n_samples] | ||
|
|
||
|
Comment on lines
-64
to
+69
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| data_paths += data_list | ||
| imgs = get_images(data_list, self.params_model['input_shape']) | ||
| encods = self._generate_encodings(imgs) | ||
| for encod in encods: | ||
| data_encodings.append(encod) | ||
| data_labels.append(class_name) | ||
|
|
||
| encoded_training_data['paths'] = data_paths | ||
| encoded_training_data['labels'] = data_labels | ||
| encoded_training_data['encodings'] = np.squeeze(np.array(data_encodings)) | ||
|
|
||
| return encoded_training_data | ||
| return { | ||
| 'paths': data_paths, | ||
| 'labels': data_labels, | ||
| 'encodings': np.squeeze(np.array(data_encodings)), | ||
| } | ||
|
|
||
| def save_encodings(self, encoded_training_data, | ||
| save_folder='./', | ||
|
|
@@ -113,23 +110,16 @@ def save_onnx(self, save_folder, save_name='base_model.onnx'): | |
| keras2onnx.save_model(onnx_model, os.path.join(save_folder, save_name)) | ||
|
|
||
| def predict(self, image): | ||
| if type(image) is str: | ||
| img = cv2.imread(image) | ||
| else: | ||
| img = image | ||
| img = cv2.imread(image) if type(image) is str else image | ||
| img = cv2.resize(img, (self.params_model['input_shape'][0], | ||
| self.params_model['input_shape'][1])) | ||
| encoding = self.base_model.predict(np.expand_dims(img, axis=0)) | ||
| distances = self.calculate_distances(encoding) | ||
| max_element = np.argmin(distances) | ||
| predicted_label = self.encoded_training_data['labels'][max_element] | ||
| return predicted_label | ||
| return self.encoded_training_data['labels'][max_element] | ||
|
Comment on lines
-116
to
+119
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| def predict_knn(self, image, with_top5=False): | ||
| if type(image) is str: | ||
| img = cv2.imread(image) | ||
| else: | ||
| img = image | ||
| img = cv2.imread(image) if type(image) is str else image | ||
|
Comment on lines
-129
to
+122
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| img = cv2.resize(img, (self.input_shape[0], self.input_shape[1])) | ||
|
|
||
| encoding = self.base_model.predict(np.expand_dims(img, axis=0)) | ||
|
|
@@ -145,8 +135,6 @@ def calculate_prediction_accuracy(self, data_loader): | |
| correct_top1 = 0 | ||
| correct_top5 = 0 | ||
|
|
||
| accuracies = {'top1':0, | ||
| 'top5':0 } | ||
|
Comment on lines
-148
to
-149
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| total_n_of_images = len(data_loader.images_paths['val']) | ||
| for img_path, img_label in zip(data_loader.images_paths['val'], | ||
| data_loader.images_labels['val']): | ||
|
|
@@ -155,10 +143,10 @@ def calculate_prediction_accuracy(self, data_loader): | |
| correct_top1 += 1 | ||
| if img_label in prediction_top5: | ||
| correct_top5 += 1 | ||
| accuracies['top1'] = correct_top1/total_n_of_images | ||
| accuracies['top5'] = correct_top5/total_n_of_images | ||
|
|
||
| return accuracies | ||
| return { | ||
| 'top1': correct_top1 / total_n_of_images, | ||
| 'top5': correct_top5 / total_n_of_images, | ||
| } | ||
|
|
||
|
|
||
| class TripletNet(EmbeddingNet): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
get_augrefactored with the following changes:lift-return-into-if)This removes the following comments ( why? ):