diff --git a/micro_sam/sam_annotator/_annotator.py b/micro_sam/sam_annotator/_annotator.py index 5cd0b196..39613867 100644 --- a/micro_sam/sam_annotator/_annotator.py +++ b/micro_sam/sam_annotator/_annotator.py @@ -362,6 +362,7 @@ class _ClassifierBase(QtWidgets.QScrollArea): label_widget_title = "Label names:" max_components = 256 # PCA upper bound (256 pixel channels, 257 object features) tool_key = None # "pixel" | "object", selects the tool-specific tooltips + supports_apply_to_volume = True # if False the tool always runs over the full image/volume # # Hooks the subclasses implement. @@ -494,8 +495,9 @@ def _update_image(self, segmentation_result=None): self._ndim = len(state.image_shape) if state.ndim is None else state.ndim self._shape = tuple(state.image_shape)[:self._ndim] - # The 'Apply to Volume' checkbox only makes sense for 3d data. - self._apply_to_volume.visible = self._ndim == 3 + # The 'Apply to Volume' checkbox only makes sense for 3d data (and only for tools that have it). + if self._apply_to_volume is not None: + self._apply_to_volume.visible = self._ndim == 3 # The features depend on the image, so they have to be recomputed for a new image. self._invalidate_features() @@ -583,23 +585,31 @@ def _create_train_widget(self): # The 'Train and predict' button is kept at the top level, outside the settings dropdown. # A single 'Apply to Volume' checkbox governs both 'Train and predict' and 'Clear Annotations' # (shown only for 3d data, see '_update_image'): when checked they act on the whole volume, - # when unchecked (the default) only on the current slice. + # when unchecked (the default) only on the current slice. Tools that do not support it + # ('supports_apply_to_volume' False) omit the checkbox and always run over the full image/volume. train_button = PushButton(text="Train and predict [Shift + T]") train_button.native.setToolTip(get_tooltip("classification", "train_button")) clear_button = PushButton(text="Clear Annotations [C]") clear_button.native.setToolTip(get_tooltip("classification", "clear_button")) - apply_to_volume = CheckBox(value=False, text="Apply to Volume") - apply_to_volume.native.setToolTip(get_tooltip("classification", "apply_to_volume")) - train_button.clicked.connect(lambda: self._run_train_and_predict(apply_to_volume.value)) - clear_button.clicked.connect(lambda: self._clear_annotations(apply_to_volume.value)) + + apply_to_volume = None + if self.supports_apply_to_volume: + apply_to_volume = CheckBox(value=False, text="Apply to Volume") + apply_to_volume.native.setToolTip(get_tooltip("classification", "apply_to_volume")) + + def _volume_value(): + return True if apply_to_volume is None else apply_to_volume.value + + train_button.clicked.connect(lambda: self._run_train_and_predict(_volume_value())) + clear_button.clicked.connect(lambda: self._clear_annotations(_volume_value())) @self._viewer.bind_key("Shift-T", overwrite=True) def _train_and_predict(event=None): - self._run_train_and_predict(apply_to_volume.value) + self._run_train_and_predict(_volume_value()) @self._viewer.bind_key("c", overwrite=True) def _clear(event=None): - self._clear_annotations(apply_to_volume.value) + self._clear_annotations(_volume_value()) # The two buttons sit side-by-side and expand to share the row width equally. QSizePolicy.Policy # is nested in Qt6 and top-level in Qt5. @@ -608,7 +618,8 @@ def _clear(event=None): size_policy = getattr(QtWidgets.QSizePolicy, "Policy", QtWidgets.QSizePolicy) for button in (train_button, clear_button): button.native.setSizePolicy(size_policy.Expanding, size_policy.Fixed) - container = Container(widgets=[apply_to_volume, button_row], labels=False) + widgets_ = ([apply_to_volume] if apply_to_volume is not None else []) + [button_row] + container = Container(widgets=widgets_, labels=False) return container, apply_to_volume def _create_classifier_io_widget(self): @@ -706,7 +717,8 @@ def _create_widgets(self): self._embedding_widget.run_button.clicked.connect(self._update_image) self._train_and_predict_widget, self._apply_to_volume = self._create_train_widget() - self._apply_to_volume.visible = False + if self._apply_to_volume is not None: + self._apply_to_volume.visible = False self._classifier_io_widget = self._create_classifier_io_widget() settings = QtWidgets.QWidget() @@ -938,9 +950,15 @@ def _restore_from_spec(self, spec): # the size options are rebuilt when it changes. family, size = spec.get("model_family"), spec.get("model_size") if ew is not None and family is not None: - ew.model_family_dropdown.setCurrentText(family) - if size is not None: - ew.model_size_dropdown.setCurrentText(size) + # The classification widget routes the family to the primary or advanced selector; other + # widgets fall back to setting the family dropdown directly. + setter = getattr(ew, "set_model_family_size", None) + if setter is not None: + setter(family, size) + else: + ew.model_family_dropdown.setCurrentText(family) + if size is not None: + ew.model_size_dropdown.setCurrentText(size) # Tiling, tile/halo params and custom weights via the shared sync helper (these field names match). # 'ew.model_type' may be unset until embeddings are computed, so fall back via getattr. diff --git a/micro_sam/sam_annotator/_tooltips.py b/micro_sam/sam_annotator/_tooltips.py index 93b46275..08c01b11 100644 --- a/micro_sam/sam_annotator/_tooltips.py +++ b/micro_sam/sam_annotator/_tooltips.py @@ -7,8 +7,10 @@ "embeddings_save_path": "Select path to save or load the computed image embeddings.", "halo": "Enter overlap values for computing tiled embeddings. Enter only x-value for quadratic size.\n Only active when tiling is used.", # noqa "image": "Select the napari image layer.", - "model_family": "Select the segment anything model family.", - "model_size": "Select the image encoder size of the segment anything model.", + "model_family": "Select the segment anything 2 model family.", + "model_family_advanced": "Select the advanced (non-SAM2) model family, e.g. a SAM1 family. Switched on via 'Advanced Models' in the embedding settings.", # noqa + "model_size": "Select the image encoder size of the segment anything 2 model.", + "advanced_model": "Switch the model list above to advanced models beyond the default SAM2 models (currently SAM1). Only available for the classification tools.", # noqa "automatic_segmentation_mode": "Select the automatic segmentation mode.", "run_button": "Compute embeddings or load embeddings if embedding_save_path is specified.", "tiling": "Enter tile size for computing tiled embeddings. Enter only x-value for quadratic size or both for non-quadratic.", # noqa @@ -28,7 +30,7 @@ "unified_segment": { "apply_to_volume": "Choose if segmentation is run for the current slice/frame only or for the full volume/all frames.", # noqa "batched": "Enable to segment multiple objects at once: each positive point and each box defines a separate object. Only available for SAM2 models.", # noqa - "segment_button": "Run Segment Anything on the current point/box prompts to segment the object. Shortcut: S.", # noqa + "segment_button": "Run Segment Anything 2 on the current point/box prompts to segment the object. Shortcut: S.", # noqa "clear_button": "Clear the current prompts and the current-object segmentation (whole volume or current slice per 'Apply to Volume' for 3d data). Shortcut: Shift + C.", # noqa "settings": "Settings for interactive segmentation across slices (projection mode and propagation parameters).", # noqa }, diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 4601d9f6..4aea61e9 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1254,6 +1254,27 @@ def _generate_message(message_type: str, message: str) -> bool: raise ValueError(f"Invalid message type {message_type}") +def _ask_load_or_recompute(message: str) -> str: + """Ask the user whether to load existing embeddings or recompute them. + + Returns 'load', 'recompute' or 'cancel'. + """ + box = QtWidgets.QMessageBox() + box.setWindowTitle("Existing embeddings found") + box.setText(message) + load_button = box.addButton("Load", QtWidgets.QMessageBox.AcceptRole) + recompute_button = box.addButton("Recompute", QtWidgets.QMessageBox.DestructiveRole) + box.addButton("Cancel", QtWidgets.QMessageBox.RejectRole) + box.setDefaultButton(load_button) + box.exec_() + clicked = box.clickedButton() + if clicked is load_button: + return "load" + if clicked is recompute_button: + return "recompute" + return "cancel" + + def _validate_embeddings(viewer: "napari.viewer.Viewer"): state = AnnotatorState() if state.image_embeddings is None: @@ -1719,11 +1740,25 @@ def _create_settings_widget(self): ) setting_values.layout().addLayout(layout) + # Hook for subclasses to add extra model controls at the end of the settings (no-op by default). + self._add_extra_model_settings(setting_values.layout()) + settings = _make_collapsible( setting_values, title="Embedding Settings", tooltip=get_tooltip("embedding", "settings"), ) return settings + def _add_extra_model_settings(self, layout): + """Hook to add extra model controls to the embedding settings. No-op by default; the + classification embedding widget uses it to add the optional advanced-model selector.""" + pass + + def _apply_loaded_model_selection(self, model_name): + """Reflect a loaded model in the model family / size dropdowns. No-op by default: the + post-compute '_sync_embedding_widget' already syncs the SAM2-only widget, whose family + names match. Subclasses with custom family handling (classification) override this.""" + pass + def _selected_image_ndim(self): # Spatial dimensionality of the currently selected image layer (2 or 3), or None if no image. image = self.image_selection.get_value() @@ -1852,16 +1887,40 @@ def _validate_inputs(self): msg = f"The embeddings don't match with the image: {img_signature} {f.attrs['data_signature']}" return _generate_message("error", msg) - # Load existing parameters. - self.model_type = f.attrs.get( - "model_name", f.attrs["model_type"] - ) + # The model the saved embeddings were computed with. + saved_model = f.attrs.get("model_name", f.attrs["model_type"]) + + # Ask the user whether to load the saved embeddings or recompute them. The default + # depends on whether the current model selection still matches the saved embeddings: + # a model/config change makes recompute the natural choice, an exact match makes + # loading the natural choice. + config_changed = (not self.custom_weights) and saved_model != self.model_type + if config_changed: + message = ( + f"Saved embeddings use '{saved_model}', but '{self.model_type}' is selected. " + "Load the saved embeddings or recompute?" + ) + else: + message = f"Embeddings for '{saved_model}' already exist. Load or recompute?" + choice = _ask_load_or_recompute(message) + + if choice == "cancel": + return True + + if choice == "recompute": + # Recompute with the user's current selection: clear the saved file so the backend + # recomputes from scratch (works for any model and even when the model is unchanged). + # Tiling and model stay as the user set them in the widget. + zarr.open(self.embeddings_save_path, mode="w") + return False + + # 'load': adopt the saved model and tiling, then load the existing embeddings. + self.model_type = saved_model if self._validate_model_support(): return True - if ( - "tile_shape" in f.attrs - and f.attrs["tile_shape"] is not None - ): + # Reflect the loaded model in the model family / size dropdowns. + self._apply_loaded_model_selection(saved_model) + if "tile_shape" in f.attrs and f.attrs["tile_shape"] is not None: self.tile_x, self.tile_y = f.attrs["tile_shape"] self.halo_x, self.halo_y = f.attrs["halo"] # Reflect the loaded tiling parameters in the UI. @@ -1870,23 +1929,9 @@ def _validate_inputs(self): self.halo_x_param.setValue(self.halo_x) self.halo_y_param.setValue(self.halo_y) self.tiling_dropdown.setCurrentText("yes") - val_results = { - "message_type": "info", - "message": ( - f"Load embeddings for model: {self.model_type} with tile shape: " - f"{self.tile_x}, {self.tile_y} and halo: {self.halo_x}, {self.halo_y}." - ), - } else: self.tiling_dropdown.setCurrentText("no") - val_results = { - "message_type": "info", - "message": f"Load embeddings for model: {self.model_type}.", - } - - return _generate_message( - val_results["message_type"], val_results["message"] - ) + return False except RuntimeError as e: val_results = { @@ -1901,16 +1946,15 @@ def _validate_inputs(self): return False def _validate_existing_embeddings(self, state): + # When an embeddings save path is set, '_validate_inputs' already offered load-vs-recompute, + # so don't prompt again here. This only handles the in-memory case (no save path). + if self.embeddings_save_path: + return False if state.image_embeddings is None: return False - else: - val_results = { - "message_type": "info", - "message": "Embeddings have already been precomputed. Press OK to recompute the embeddings.", - } - return _generate_message( - val_results["message_type"], val_results["message"] - ) + return _generate_message( + "info", "Embeddings have already been precomputed. Press OK to recompute the embeddings." + ) def __call__(self, skip_validate=False): self._validate_model_type_and_custom_weights() @@ -2016,100 +2060,172 @@ def pbar_update(update): class ClassificationEmbeddingWidget(EmbeddingWidget): """Embedding widget for the classification tools (pixel and object classification). - Unlike the SAM2-only `EmbeddingWidget` used by the segmentation and tracking annotators, - this exposes all SAM1, micro-sam (finetuned SAM1) and SAM2 model families. Classification - operates directly on the image encoder embeddings, so any of these models can be used. + The model selection mirrors the segmentation/tracking `EmbeddingWidget` exactly: the same single + 'Model:' dropdown with the SAM2 'Natural Images' and 'Microscopy' families (same names and config). + Since classification operates directly on the image-encoder embeddings, it can additionally use + models beyond SAM2. An opt-in 'Advanced Models' checkbox in the embedding settings swaps that one + dropdown to the advanced families instead of adding a second dropdown (currently the SAM1 families; + future backends such as DINO can be added to `_advanced_family_suffixes`). """ size_order = ["tiny", "small", "base", "large", "huge"] - def _create_model_section(self, default_model="vit_t_cells", create_layout=True): - # The model family mapped to the model-name suffix. SAM2 families use the 'hvit_' prefix - # ('_sam2' for the natural-image backbones, '_cells' for the finetuned microscopy model); - # the SAM1 families use the 'vit_' prefix. - self.supported_dropdown_maps = { - "Natural Images (SAM1)": "", - "Natural Images (SAM2)": "_sam2", - "Microscopy (SAM2)": "_cells", - "Light Microscopy (SAM1)": "_lm", - "Electron Microscopy (SAM1)": "_em_organelles", - "Medical Imaging (SAM1)": "_medical_imaging", - "Histopathology (SAM1)": "_histopathology", - } - self._model_size_map = {"t": "tiny", "s": "small", "b": "base", "l": "large", "h": "huge"} + # Advanced (non-SAM2) families: UI label -> model-name suffix on the SAM1 'vit_' prefix. Future + # backends (e.g. DINO) get their own entries here, resolved in '_get_model_size_options'. + _advanced_family_suffixes = { + "Natural Images (SAM1)": "", + "Light Microscopy (SAM1)": "_lm", + "Electron Microscopy (SAM1)": "_em_organelles", + "Medical Imaging (SAM1)": "_medical_imaging", + "Histopathology (SAM1)": "_histopathology", + } + _advanced_size_map = {"t": "tiny", "b": "base", "l": "large", "h": "huge"} + # Older saved classifiers stored the primary SAM2 families under '(SAM2)' labels; map them to the + # current names so loading such a classifier still restores the right family. + _primary_family_aliases = {"Natural Images (SAM2)": "Natural Images", "Microscopy (SAM2)": "Microscopy"} + + def _advanced_active(self): + # The single 'Model:' dropdown holds either the primary or the advanced families, so the + # current family's membership is the source of truth (robust to the dropdown being blanked). + return getattr(self, "model_family", None) in self._advanced_family_suffixes + + def _add_extra_model_settings(self, layout): + # 'Advanced Models' swaps the single 'Model:' dropdown above between the primary (SAM2) and the + # advanced (SAM1) families - one dropdown only, to avoid confusion. Added last in the settings. + self._primary_families = list(self.supported_dropdown_maps.keys()) + # The default primary family ('Microscopy'), restored when advanced is switched back off. + self._default_primary_family = self.model_family + self.advanced = False + self.advanced_checkbox = self._add_boolean_param( + "advanced", self.advanced, title="Advanced Models", + tooltip=get_tooltip("embedding", "advanced_model"), + ) + self.advanced_checkbox.stateChanged.connect(self._on_advanced_toggled) + layout.addWidget(self.advanced_checkbox) + + # The inherited dropdowns auto-bind their attribute by indexing the option list captured at + # creation; we swap those lists (the families, and the per-family sizes), which makes the + # captured index stale (wrong value, or out of range). Drop that auto-bind and let + # '_update_model_type' (wired to 'currentTextChanged') sync the attribute from the text. + for dropdown in (self.model_family_dropdown, self.model_size_dropdown): + try: + dropdown.currentIndexChanged.disconnect() + except TypeError: + pass + + def _set_family_choices(self, families, select=None): + # Replace the 'Model:' dropdown items, select 'select' (default first), then resolve the model. + if select not in families: + select = families[0] + self.model_family_dropdown.blockSignals(True) + self.model_family_dropdown.clear() + self.model_family_dropdown.addItems(families) + self.model_family_dropdown.setCurrentText(select) + self.model_family_dropdown.blockSignals(False) + self.model_family = select + self._update_model_type() + + def _on_advanced_toggled(self, state): + advanced = self.advanced_checkbox.isChecked() + if advanced: + self._set_family_choices(list(self._advanced_family_suffixes)) + else: # Back to the SAM2 families, defaulting to 'Microscopy' rather than the first entry. + self._set_family_choices(self._primary_families, select=self._default_primary_family) + # Reflect the active tier in the 'Model:' dropdown tooltip. + self.model_family_dropdown.setToolTip( + get_tooltip("embedding", "model_family_advanced" if advanced else "model_family") + ) - self._default_model_choice = default_model - self.model_family = {v: k for k, v in self.supported_dropdown_maps.items()}[default_model[5:]] + def _reset_inputs_to_defaults(self): + # Switching off advanced restores the primary families; the base reset then selects the default. + if getattr(self, "advanced_checkbox", None) is not None and self.advanced_checkbox.isChecked(): + self.advanced_checkbox.setChecked(False) + super()._reset_inputs_to_defaults() + + def set_model_family_size(self, family, size): + """Restore a saved (family, size): swap the dropdown to the matching tier, then select it.""" + family = self._primary_family_aliases.get(family, family) + self.advanced_checkbox.setChecked(family in self._advanced_family_suffixes) + self.model_family_dropdown.setCurrentText(family) + if size: + self.model_size_dropdown.setCurrentText(size) - kwargs = {} - if create_layout: - layout = QtWidgets.QVBoxLayout() - kwargs["layout"] = layout + def _validate_model_support(self): + if super()._validate_model_support(): + return True + # The vit-tiny backbone needs MobileSAM; warn (instead of crashing later) if it is selected + # without MobileSAM installed. The model stays selectable - we just block this compute. + from ..util import VIT_T_SUPPORT + if not VIT_T_SUPPORT and (self.model_type or "").startswith("vit_t"): + return _generate_message( + "error", + f"'{self.model_type}' (vit-tiny) requires MobileSAM. Install MobileSAM or pick another size.", + ) + return False - self.model_family_dropdown, layout = self._add_choice_param( - "model_family", self.model_family, list(self.supported_dropdown_maps.keys()), - title="Model:", tooltip=get_tooltip("embedding", "model_family"), **kwargs, - ) - self.model_family_dropdown.currentTextChanged.connect(self._update_model_type) - return layout + def _family_and_size_for_model(self, model_name): + """Map a stored model name to its (family label, size label) for this widget's dropdowns.""" + full_size_map = {"t": "tiny", "s": "small", "b": "base", "l": "large", "h": "huge"} + if model_name.startswith("hvit_"): # SAM2 (primary families). + size = full_size_map.get(model_name[5]) + family = "Microscopy" if model_name.endswith("_cells") else "Natural Images" + else: # SAM1 (advanced families): 'vit_'. + size = full_size_map.get(model_name[4]) + suffix = model_name[5:] + family = {v: k for k, v in self._advanced_family_suffixes.items()}.get(suffix, "Natural Images (SAM1)") + return family, size + + def _apply_loaded_model_selection(self, model_name): + # Set the family (primary or advanced) and size dropdowns to match the loaded embeddings. + family, size = self._family_and_size_for_model(model_name) + self.set_model_family_size(family, size) def _get_model_size_options(self): - # Build the available sizes for the selected family, mapping each UI label to the model name. + # Primary (SAM2) sizes come from the inherited logic; advanced families resolve to SAM1 names. + if not self._advanced_active(): + return super()._get_model_size_options() + from ..v1.util import get_model_names + suffix = self._advanced_family_suffixes[self.model_family] + available = {m for m in get_model_names() if not m.endswith("decoder")} self.model_size_mapping = {} - if self.model_family == "Natural Images (SAM2)": - for key in ("t", "s", "b", "l"): - self.model_size_mapping[self._model_size_map[key]] = f"hvit_{key}" - elif self.model_family == "Microscopy (SAM2)": - from micro_sam.v2.util import FINETUNED_MODELS - for key in ("t", "s", "b", "l"): - name = f"hvit_{key}_cells" - if name in FINETUNED_MODELS: - self.model_size_mapping[self._model_size_map[key]] = name - else: - from ..v1.util import get_model_names - suffix = self.supported_dropdown_maps[self.model_family] - available = {m for m in get_model_names() if not m.endswith("decoder")} - for key, label in self._model_size_map.items(): - name = f"vit_{key}{suffix}" - if name in available: - self.model_size_mapping[label] = name - + for key, label in self._advanced_size_map.items(): + name = f"vit_{key}{suffix}" + if name in available: + self.model_size_mapping[label] = name self.model_size_options = sorted(self.model_size_mapping.keys(), key=self.size_order.index) def _update_model_type(self): + # Sync the family from the dropdown first: the inherited auto-bind closure captures the original + # (primary) option list, so after the dropdown is swapped to the advanced families it can set a + # stale value; re-reading the current text here is authoritative and decides the branch below. + self.model_family = self.model_family_dropdown.currentText() or self.model_family + # Primary mode defers to the inherited SAM2 logic; advanced mode rebuilds the size dropdown for + # the current SAM1 family and resolves its model name. + if not self._advanced_active(): + return super()._update_model_type() current_selection = self.model_size_dropdown.currentText() self._get_model_size_options() - # NOTE: We prevent recursive updates while we rebuild the dropdown. self.model_size_dropdown.blockSignals(True) self.model_size_dropdown.clear() self.model_size_dropdown.addItems(self.model_size_options) - - # Restore the previous selection if still valid, else default to the first available size. if current_selection in self.model_size_options: self.model_size = current_selection elif self.model_size_options: self.model_size = self.model_size_options[0] self.model_type = self.model_size_mapping.get(self.model_size) - self.model_size_dropdown.setCurrentText(self.model_size) self.model_size_dropdown.update() self.model_size_dropdown.blockSignals(False) def _validate_model_type_and_custom_weights(self): - # Map the selected family and size to the actual model name. - if self.model_family in self.supported_dropdown_maps: + # Advanced mode (without custom weights): resolve the SAM1 model name from family + size. + if self._advanced_active() and not self.custom_weights: self._get_model_size_options() - if self.model_size in self.model_size_mapping: - self.model_type = self.model_size_mapping[self.model_size] - - # For 'custom_weights', we remove the displayed text on top of the drop-down menu. - if self.custom_weights: - # NOTE: We prevent recursive updates for this step temporarily. - self.model_family_dropdown.blockSignals(True) - self.model_family_dropdown.setCurrentIndex(-1) - self.model_family_dropdown.update() - self.model_family_dropdown.blockSignals(False) + self.model_type = self.model_size_mapping.get(self.model_size, self.model_type) + return + # Primary mode, or custom weights: the inherited logic resolves the type and blanks the dropdown. + super()._validate_model_type_and_custom_weights() # diff --git a/micro_sam/sam_annotator/object_classifier.py b/micro_sam/sam_annotator/object_classifier.py index 311b771b..1c2c6b8d 100644 --- a/micro_sam/sam_annotator/object_classifier.py +++ b/micro_sam/sam_annotator/object_classifier.py @@ -93,6 +93,7 @@ class ObjectClassifier(_ClassifierBase): label_widget_title = "Object label names:" max_components = OBJECT_FEATURES tool_key = "object" + supports_apply_to_volume = False # object classification always runs over the full image/volume def _get_selected_segmentation_layer(self): state = AnnotatorState() diff --git a/micro_sam/v1/util.py b/micro_sam/v1/util.py index 169af6ea..925ce884 100644 --- a/micro_sam/v1/util.py +++ b/micro_sam/v1/util.py @@ -801,30 +801,40 @@ def _compute_tiled_3d(input_, predictor, tile_shape, halo, f, pbar_init, pbar_up def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): + """Validate saved embeddings against the requested configuration. + + Returns True if the saved embeddings are stale and should be recomputed (the model or tiling + configuration changed), False if they can be loaded. Raises if they belong to different image + data (data signature mismatch). + """ # We may have an empty zarr file that was already created to save the embeddings in. # In this case the embeddings will be computed and we don't need to perform any checks. if "input_size" not in f.attrs: - return + return False signature = _get_embedding_signature(input_, predictor, tile_shape, halo) + stale = False for key, val in signature.items(): - # Check whether the key is missing from the attrs or if the value is not matching. - if key not in f.attrs or f.attrs[key] != val: - # These keys were recently added, so we don't want to fail yet if they don't - # match in order to not invalidate previous embedding files. - # Instead we just raise a warning. (For the version we probably also don't want to fail - # i the future since it should not invalidate the embeddings). - if key in ("micro_sam_version", "model_hash", "model_name"): - warnings.warn( - f"The signature for {key} in embeddings file {save_path} has a mismatch: " - f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " - "But please recompute them if model predictions don't look as expected." - ) - else: - raise RuntimeError( - f"Embeddings file {save_path} is invalid due to mismatch in {key}: " - f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." - ) + # A key absent from an older file should not invalidate it (it predates that key). + if key not in f.attrs or f.attrs[key] == val: + continue + # Different image data: surface as an error rather than silently overwriting it. + if key == "data_signature": + raise RuntimeError( + f"Embeddings file {save_path} is invalid due to mismatch in {key}: " + f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." + ) + # A version bump alone does not invalidate the embeddings. + if key == "micro_sam_version": + warnings.warn( + f"The signature for {key} in embeddings file {save_path} has a mismatch: " + f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " + "But please recompute them if model predictions don't look as expected." + ) + continue + # Model or tiling changed: the saved embeddings are stale and must be recomputed. + stale = True + return stale def precompute_image_embeddings( @@ -881,7 +891,9 @@ def precompute_image_embeddings( # check that the saved embeddings in there match the parameters of the function call. elif os.path.exists(save_path): f = zarr.open(save_path, mode="a") - _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) + if _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): + # Stale embeddings (model or tiling changed): truncate and recompute, overwriting them. + f = zarr.open(save_path, mode="w") # We have a save path and it does not exist yet. Create the zarr file to which the # embeddings will then be saved. diff --git a/micro_sam/v2/util.py b/micro_sam/v2/util.py index a3cb776b..6a806f1c 100644 --- a/micro_sam/v2/util.py +++ b/micro_sam/v2/util.py @@ -262,34 +262,44 @@ def get_sam2_model( def _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): + """Validate saved embeddings against the requested configuration. + + Returns True if the saved embeddings are stale and should be recomputed (the model or tiling + configuration changed), False if they can be loaded. Raises if they belong to different image + data (data signature mismatch). + """ # We may have an empty zarr file that was already created to save the embeddings in. # In this case the embeddings will be computed and we don't need to perform any checks. if "input_size" not in f.attrs: - return + return False # Creates all the metadta that is stored along with the embeddings. # TODO: This is currently paired with `micro_sam`-level metadata. Should we get separate for `micro_sam.v2`? from micro_sam.util import _get_embedding_signature signature = _get_embedding_signature(input_, predictor, tile_shape, halo) + stale = False for key, val in signature.items(): - # Check whether the key is missing from the attrs or if the value is not matching. - if key not in f.attrs or f.attrs[key] != val: - # These keys were recently added, so we don't want to fail yet if they don't - # match in order to not invalidate previous embedding files. - # Instead we just raise a warning. (For the version we probably also don't want to fail - # i the future since it should not invalidate the embeddings). - if key in ("micro_sam_version", "model_hash", "model_name"): - warnings.warn( - f"The signature for {key} in embeddings file {save_path} has a mismatch: " - f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " - "But please recompute them if model predictions don't look as expected." - ) - else: - raise RuntimeError( - f"Embeddings file {save_path} is invalid due to mismatch in {key}: " - f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." - ) + # A key absent from an older file should not invalidate it (it predates that key). + if key not in f.attrs or f.attrs[key] == val: + continue + # Different image data: surface as an error rather than silently overwriting it. + if key == "data_signature": + raise RuntimeError( + f"Embeddings file {save_path} is invalid due to mismatch in {key}: " + f"{f.attrs.get(key)} != {val}. Please recompute embeddings in a new file." + ) + # A version bump alone does not invalidate the embeddings. + if key == "micro_sam_version": + warnings.warn( + f"The signature for {key} in embeddings file {save_path} has a mismatch: " + f"{f.attrs.get(key)} != {val}. This key was recently added, so your embeddings are likely correct. " + "But please recompute them if model predictions don't look as expected." + ) + continue + # Model or tiling changed: the saved embeddings are stale and must be recomputed. + stale = True + return stale def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): @@ -696,7 +706,9 @@ def precompute_image_embeddings( # check tha tthe saved embeedidng in there match the parameters of the function call.abs elif os.path.exists(save_path): f = zarr.open(save_path, mode="a") - _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo) + if _check_saved_embeddings(input_, predictor, f, save_path, tile_shape, halo): + # Stale embeddings (model or tiling changed): truncate and recompute, overwriting them. + f = zarr.open(save_path, mode="w") # We have a save path and it does not exist yet. Create the zarr file to which the # embeddings will then be saved.