diff --git a/TPTBox/core/bids_files.py b/TPTBox/core/bids_files.py index 4fdd4bf..f459b95 100755 --- a/TPTBox/core/bids_files.py +++ b/TPTBox/core/bids_files.py @@ -486,7 +486,7 @@ def enumerate_subjects(self, sort: bool = False, shuffle: bool = False) -> list[ return s return self.subjects.items() # type: ignore - def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container]]: + def iter_subjects(self, sort: bool = False, shuffle: bool = False) -> list[tuple[str, Subject_Container]]: """Iterate over all subjects (alias for :meth:`enumerate_subjects` without shuffle). Args: @@ -498,6 +498,10 @@ def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container """ if sort: return sorted(self.subjects.items()) + if shuffle: + s = list(self.subjects.items()) + random.shuffle(s) + return s return self.subjects.items() # type: ignore def __len__(self): diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 041f4d5..c5aa88c 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -1180,33 +1180,36 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN if mapping.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0): log.print(f"resample_from_to skipped; already in space: {self}",verbose=verbose) return self if inplace else self.copy() - m1 = mapping if mapping.orientation == self.orientation else mapping.make_empty_POI().reorient(self.orientation) if m1.assert_affine(self,raise_error=False,origin_tolerance=0.00001,error_tolerance=0.00001,shape_tolerance=0): log.print(f"resample_from_to only need reorientation; {self.orientation}",verbose=verbose) ret = self.reorient(mapping.orientation,inplace=inplace) ret.affine = mapping.affine #remove floating point error return ret - if self.orientation == mapping.orientation and np.allclose(self.zoom , mapping.zoom, atol=1e-6): - shift = (np.array(self.origin) - np.array(m1.origin)) / np.array(m1.zoom) - if np.allclose(shift, np.round(shift), atol=1e-6): - s = self.reorient(mapping.orientation,inplace=inplace) # noqa: PLW0642 - shift = (np.array(self.origin) - np.array(mapping.origin)) / np.array(mapping.zoom) - shift = np.round(shift).astype(int) - dst_shape = np.array(mapping.shape) - src_shape = np.array(s.shape) - # padding before = how much dst starts before src - pad_before = shift - # padding after = remaining dst size after src - pad_after = dst_shape-shift-src_shape - pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after)) - ret = s.apply_pad(pad, mode=mode,inplace=inplace,verbose=verbose) - + if np.allclose(self.zoom, m1.zoom, atol=1e-6): + s = self.reorient(mapping.orientation, inplace=inplace) + # Compute voxel offset directly from the affines after both + # images are in the same orientation. This is robust to axis + # permutations and flips. + voxel_offset = np.linalg.inv(mapping.affine) @ s.affine @ np.array([0, 0, 0, 1]) + shift = np.round(voxel_offset[:3]).astype(int) + + dst_shape = np.array(mapping.shape) + src_shape = np.array(s.shape) + # padding before = how much dst starts before src + pad_before = shift + # padding after = remaining dst size after src + pad_after = dst_shape - shift - src_shape + pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after)) + try: + ret = s.apply_pad(pad,mode=mode,inplace=inplace,verbose=verbose) valid = ret.assert_affine(mapping,raise_error=False,origin_tolerance=0.0001,error_tolerance=0.0001,shape_tolerance=0) if valid: log.print(f"resample_from_to only needs padding/cropping {pad}",verbose=verbose) - ret.affine = mapping.affine #remove floating point error + ret.affine = mapping.affine # remove floating point error return ret + except ValueError as e: + log.warning("Padding failed.",e,verbose=verbose) assert mapping is not None @@ -2505,7 +2508,7 @@ def to_stl( try: verts, faces, normals, values = marching_cubes(seg_arr, gradient_direction="ascent", step_size=1) except RuntimeError as e: - raise RuntimeError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None + raise IndexError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None # Remove padding offset (since we padded by 1 voxel) verts -= 1 # Apply bounding box offset (still voxel space) @@ -2696,6 +2699,20 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum]|None, keep_la if keep_label: seg_arr = seg_arr * self.get_seg_array() return self.set_array(seg_arr,inplace=inplace) + def ravel(self,order:Literal["K", "A", "C", "F"] | None="C")->np.ndarray: + """Return a contiguous flattened array. + + A 1-D array, containing the elements of the input, is returned. A copy is made only if needed. + + As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input) + + Args: + order (Literal["K", "A", "C", "F"] | None, optional): The elements of a are read using this index order. ‘C’ means to index the elements in row-major, C-style order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to index the elements in column-major, Fortran-style order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. ‘A’ means to read the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise. ‘K’ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, ‘C’ index order is used. Defaults to "C". + + Returns: + np.ndarray + """ + return self.get_array().ravel(order=order) def extract_label_(self, label: int | Enum | Sequence[int] | Sequence[Enum], keep_label=False) -> Self: """In-place variant of `extract_label`.""" return self.extract_label(label,keep_label,inplace=True) diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 826a02a..2b3ebbd 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -160,7 +160,7 @@ def np_count_nonzero(arr: np.ndarray) -> int: return np.count_nonzero(arr) -def np_unique(arr: np.ndarray) -> list[int]: +def old_np_unique(arr: np.ndarray) -> list[int]: """Returns each existing label in the array (including zero!). Uses cc3d statistics for unsigned-integer arrays for speed, and falls back @@ -181,9 +181,52 @@ def np_unique(arr: np.ndarray) -> list[int]: return list(np.unique(arr)) +def np_unique(arr: np.ndarray) -> list[int]: + """Returns each existing label in the array (including zero!). + + Uses cc3d statistics for unsigned-integer arrays for speed, and falls back + to ``numpy.unique`` for other dtypes. + + Args: + arr (np.ndarray): Input label array. + + Returns: + list[int]: Sorted list of every distinct label value present in ``arr``, + including 0 (background). + """ + if np.issubdtype(arr.dtype, np.unsignedinteger): + # bincount is O(max_val) but ~5-10x faster than np.unique for dense label arrays + max_val = int(arr.max()) + if max_val < 2**20: # ~1M labels threshold — bincount stays fast + counts = np.bincount(arr.ravel()) + return list(np.where(counts > 0)[0]) + # For sparse label spaces fall back to np.unique + return old_np_unique(arr) + + def np_unique_withoutzero(arr: UINTARRAY) -> list[int]: """Returns each existing non-zero label in the array (excluding background zero). + Args: + arr (UINTARRAY): Input unsigned-integer label array. + + Returns: + list[int]: Sorted list of every distinct label value present in ``arr``, + excluding 0 (background). + """ + if np.issubdtype(arr.dtype, np.unsignedinteger): + max_val = int(arr.max()) + if max_val == 0: + return [] + if max_val < 2**20: + counts = np.bincount(arr.ravel()) + return list(np.where(counts[1:] > 0)[0] + 1) + return [i for i in np.unique(arr) if i != 0] + + +def old_np_unique_withoutzero(arr: UINTARRAY) -> list[int]: + """Returns each existing non-zero label in the array (excluding background zero). + Args: arr (UINTARRAY): Input unsigned-integer label array. diff --git a/TPTBox/registration/_deformable/multilabel_segmentation.py b/TPTBox/registration/_deformable/multilabel_segmentation.py index 33cadce..46d153c 100644 --- a/TPTBox/registration/_deformable/multilabel_segmentation.py +++ b/TPTBox/registration/_deformable/multilabel_segmentation.py @@ -54,6 +54,7 @@ def __init__( # noqa: C901 poi_target_cms: POI | None = None, max_history=100, change_after_point_reg=lambda x, y, z, w: (x, y, z, w), + tether_distance=1, **args, ): """Initialize a multi-stage registration pipeline from an atlas to a target image. @@ -90,7 +91,7 @@ def __init__( # noqa: C901 "be": ("BSplineBending", {"stride": 1}), "seg": "MSE", "Dice": "Dice", - "Tether": Tether_Seg(delta=5), + "Tether": Tether_Seg(delta=tether_distance), } assert target_seg.seg, target_seg.seg @@ -187,7 +188,7 @@ def __init__( # noqa: C901 poi_cms = poi_cms.resample_from_to(atlas_seg_) self.reg_point = Point_Registration(poi_target, poi_cms, verbose=False) - atlas_reg = self.reg_point.transform_nii(atlas_seg_) + atlas_reg = self.reg_point.transform_nii(atlas_seg_, c_val=0) if not atlas_reg.is_segmentation_in_border(): print("point registration ok") @@ -204,7 +205,7 @@ def __init__( # noqa: C901 target_img = target_img.apply_pad(resize_param) if target_img is not None else None self.reg_point = Point_Registration(poi_target.resample_from_to(target_seg), poi_cms.resample_from_to(atlas_seg)) - atlas_reg = self.reg_point.transform_nii(atlas_seg) + atlas_reg = self.reg_point.transform_nii(atlas_seg, c_val=0) atlas_img_reg = self.reg_point.transform_nii(atlas_img) if atlas_img is not None else None if crop: diff --git a/TPTBox/registration/_ridged_intensity/affine_deepali.py b/TPTBox/registration/_ridged_intensity/affine_deepali.py index 5f1777f..b35a3fb 100644 --- a/TPTBox/registration/_ridged_intensity/affine_deepali.py +++ b/TPTBox/registration/_ridged_intensity/affine_deepali.py @@ -92,14 +92,14 @@ def forward( target: torch.Tensor, # shape: (B, C, X, Y, Z) mask: torch.Tensor | None = None, # noqa: ARG002 ) -> torch.Tensor: - w = max(target.shape[2:]) + w = min(target.shape[2:]) com_fixed = center_of_mass_cc(target) # (B, C, 3) com_warped = center_of_mass_cc(source) # (B, C, 3) l_com = torch.norm(com_fixed - com_warped, dim=-1) / w # (B, C) # Zero out channels with small displacement (<10) or NaNs - l_com = torch.where(l_com < self.delta, torch.zeros_like(l_com), l_com) + l_com = torch.where(l_com * w < self.delta, torch.zeros_like(l_com), l_com) l_com = torch.nan_to_num(l_com, nan=0.0) return l_com.mean() # type: ignore diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 26d6472..1ea4a46 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -16,6 +16,12 @@ out_base = Path(__file__).parent.parent / "nnUNet/" _model_path_ = out_base / "nnUNet_results" +# Opt-in cache of loaded predictors (enable via cache_model=True). Keyed by model identity plus +# the device/runtime settings that affect the loaded predictor, so repeated inference (e.g. a loop +# over many files with the same model) reuses the in-memory model instead of reloading weights from +# disk and re-uploading them to the GPU on every call. +_model_cache: dict = {} + def get_ds_info(idx: int, _model_path: str | Path | None = None, exit_one_fail: bool = True, logger=logger) -> dict: """Load and return the ``dataset.json`` for the model with the given dataset index. @@ -87,13 +93,16 @@ def run_inference_on_file( ddevice: Literal["cpu", "cuda", "mps"] = "cuda", model_path=None, step_size: float = 0.5, - memory_base: int = 5000, # Base memory in MB, default is 5GB - memory_factor: int = 160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB - memory_max: int = 160000, # in MB, default is 160GB + memory_base: float | None = None, # Base memory in MB, default is 5GB + memory_factor: float | None = None, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB + memory_max: int = 990000, # in MB, default is 990GB (so it is most likely ignored and replaced by Max Memory of the GPU) wait_till_gpu_percent_is_free: float = 0.1, + tile_batch_size: int = 1, verbose: bool = True, auto_download: bool = False, + cache_model: bool = False, _key_ResEnc: str = "__nnUNet*ResEnc", + fail_on_missing_memory=False, logger=logger, ) -> tuple[Image_Reference, np.ndarray | None]: """Load a VibeSeg model and run inference on the supplied NIfTI images. @@ -135,7 +144,18 @@ def run_inference_on_file( memory_max: Hard cap on assumed GPU memory in MB. wait_till_gpu_percent_is_free: Minimum free GPU fraction to require before starting inference. + tile_batch_size: Number of sliding-window tiles to run per network + forward pass. ``1`` (default) keeps the original per-tile behaviour; + larger values batch tiles to better saturate the GPU at the cost of + higher peak memory. verbose: Print progress information. + cache_model: If ``True``, keep the loaded predictor in a process-wide + cache and reuse it on subsequent calls with identical model and + device/runtime settings. Avoids reloading weights from disk and + re-uploading them to the GPU when segmenting many files in a loop, at + the cost of holding the model in GPU memory between calls. The GPU + cache is also left warm (no ``empty_cache``) so the allocator can + reuse buffers across images. Returns: A tuple ``(seg_nii, softmax_logits)`` where ``seg_nii`` is the @@ -196,20 +216,44 @@ def run_inference_on_file( if "labels" in ds_info2: ds_info["labels_mapping"] = ds_info2["labels"] - nnunet = load_inf_model( - nnunet_path, - allow_non_final=True, - use_folds=tuple(folds) if len(folds) != 5 else None, - gpu=gpu, - ddevice=ddevice, - step_size=step_size, - memory_base=memory_base, - memory_factor=memory_factor, - memory_max=memory_max, - wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + if memory_base is None: + memory_base = float(ds_info.get("memory_base", 5000)) + if memory_factor is None: + memory_factor = float(ds_info.get("memory_factor", 160)) + + use_folds_arg = tuple(folds) if len(folds) != 5 else None + # Include every setting that changes the loaded predictor so a cache hit is always equivalent + # to a fresh load; differing settings simply miss the cache and reload. + cache_key = ( + str(nnunet_path), + use_folds_arg, + ddevice, + gpu, + step_size, + memory_base, + memory_factor, + memory_max, + wait_till_gpu_percent_is_free, + tile_batch_size, ) - - # _unets[idx] = nnunet + nnunet = _model_cache.get(cache_key) if cache_model else None + if nnunet is None: + nnunet = load_inf_model( + nnunet_path, + allow_non_final=True, + use_folds=use_folds_arg, + gpu=gpu, + ddevice=ddevice, + step_size=step_size, + memory_base=memory_base, + memory_factor=memory_factor, + memory_max=memory_max, + wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + tile_batch_size=tile_batch_size, + fail_on_missing_memory=fail_on_missing_memory, + ) + if cache_model: + _model_cache[cache_key] = nnunet if "orientation" in ds_info: orientation = ds_info["orientation"] @@ -315,9 +359,11 @@ def to_int(a: str, k: None | int = None): seg_nii.map_labels_(mapping) if out_file is not None and (not Path(out_file).exists() or override): seg_nii.set_dtype("smallest_uint").save(out_file) - del nnunet - - torch.cuda.empty_cache() + if not cache_model: + # When caching we keep the predictor alive (it stays referenced by _model_cache, so del + # would not free it anyway) and leave the CUDA cache warm so the next image reuses buffers. + del nnunet + torch.cuda.empty_cache() return seg_nii, softmax_logits diff --git a/TPTBox/segmentation/VibeSeg/vibeseg.py b/TPTBox/segmentation/VibeSeg/vibeseg.py index 087702f..04d4615 100644 --- a/TPTBox/segmentation/VibeSeg/vibeseg.py +++ b/TPTBox/segmentation/VibeSeg/vibeseg.py @@ -84,6 +84,10 @@ 72: "bone_other", } +defaults = { + 100: {"memory_base": 5500, "memory_factor": 25}, +} + def run_vibeseg( i: Image_Reference, @@ -113,6 +117,10 @@ def run_vibeseg( Returns: Segmentation ``NII`` saved at *out_seg*. """ + if dataset_id in defaults: + for k, v in defaults[dataset_id].items(): + if k not in args: + args[k] = v return run_inference_on_file( dataset_id, [to_nii(i)] if not isinstance(i, (list, tuple)) else [to_nii(j) for j in i], diff --git a/TPTBox/segmentation/nnUnet_utils/export_prediction.py b/TPTBox/segmentation/nnUnet_utils/export_prediction.py index 36ca786..712effe 100755 --- a/TPTBox/segmentation/nnUnet_utils/export_prediction.py +++ b/TPTBox/segmentation/nnUnet_utils/export_prediction.py @@ -6,11 +6,147 @@ import numpy as np import torch from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice -from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_pickle +from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle from nnunetv2.utilities.label_handling.label_handling import LabelManager +from tqdm import tqdm from TPTBox.segmentation.nnUnet_utils.plans_handler import ConfigurationManager, PlansManager +SAFETY_FACTOR = 0.5 # only use 50% of VRAM + + +def _argmax_with_gpu_fallback(predicted_logits: torch.Tensor | np.ndarray, device: torch.device, chunk_size: int = 64) -> np.ndarray: + """Computes argmax(0). + + Tiered argmax: + 1. Full array on GPU + 2. Chunked on GPU (if full array doesn't fit) + 3. Chunked on CPU (if even a single chunk doesn't fit) + """ + from TPTBox.segmentation.nnUnet_utils.predictor import empty_cache + + empty_cache(device) + + def _get_free_vram(device: torch.device) -> int: + try: + """Returns free VRAM in bytes.""" + free, _ = torch.cuda.mem_get_info(device) + return int(free * SAFETY_FACTOR) + except Exception: + return 0 + + def _array_bytes(shape: tuple, dtype: torch.dtype = torch.float16) -> int: + n_elements = 1 + for s in shape: + n_elements *= s + return n_elements * torch.finfo(dtype).bits // 8 + + def _to_cpu_tensor(arr: torch.Tensor | np.ndarray) -> torch.Tensor: + if isinstance(arr, np.ndarray): + return torch.from_numpy(arr) + return arr.cpu() + + def _chunked_argmax_gpu(t: torch.Tensor, device: torch.device) -> np.ndarray: + X = t.shape[1] + out = np.empty(t.shape[1:], dtype=np.int16) + for x in tqdm(range(0, X, chunk_size), "argmax gpu"): + chunk = t[:, x : x + chunk_size].to(device) + out[x : x + chunk_size] = torch.argmax(chunk, dim=0).cpu().numpy() + del chunk + empty_cache(device) + return out + + def _chunked_argmax_cpu(t: torch.Tensor | np.ndarray) -> np.ndarray: + arr = t.numpy() if isinstance(t, torch.Tensor) else t + if not arr.flags["C_CONTIGUOUS"]: + arr = np.ascontiguousarray(arr) + X = arr.shape[1] + out = np.empty(arr.shape[1:], dtype=np.int16) + for x in tqdm(range(0, X, chunk_size), "argmax cpu"): + out[x : x + chunk_size] = arr[:, x : x + chunk_size].argmax(0) + return out + + t = _to_cpu_tensor(predicted_logits) + + if device is None or not torch.cuda.is_available(): + return _chunked_argmax_cpu(t) + + full_bytes = _array_bytes(t.shape) + free_vram = _get_free_vram(device) + + print(f"[argmax] array: {full_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB") + + # Tier 1: full GPU + if full_bytes <= free_vram or device.type == "mps": + try: + return torch.argmax(t.to(device), dim=0).cpu().numpy().astype(np.int16) + except torch.cuda.OutOfMemoryError: + print("[argmax] full GPU OOM despite estimate, trying chunked GPU") + empty_cache(device) + except Exception as e: + print(e) + empty_cache(device) + + for i in range(10): + chunk_shape = (t.shape[0], min(max(int(chunk_size / 2**i), 1), t.shape[1]), *t.shape[2:]) + chunk_bytes = _array_bytes(chunk_shape) + if chunk_bytes <= free_vram: + chunk_size = max(int(chunk_size / 2**i), 1) + break + print(f"[argmax] array chunk: {chunk_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB, {chunk_size=}") + + # Tier 2: chunked GPU + if chunk_bytes <= free_vram: + print("[argmax] using chunked GPU") + try: + return _chunked_argmax_gpu(t, device) + except torch.cuda.OutOfMemoryError: + print("[argmax] chunked GPU OOM despite estimate, falling back to CPU") + empty_cache(device) + else: + print("[argmax] chunk too large for VRAM, falling back to CPU") + + # Tier 3: chunked CPU + return _chunked_argmax_cpu(t) + + +@torch.inference_mode() +def convert_probabilities_to_segmentation(self, predicted_probabilities: np.ndarray | torch.Tensor, device, chunk_size=64) -> np.ndarray: + """Assumes that inference_nonlinearity was already applied! + + predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions + """ + if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)): + raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor, got {type(predicted_probabilities)}") # noqa: TRY004 + + if self.has_regions: + assert self.regions_class_order is not None, "if region-based training is requested then you need to define regions_class_order!" + # check correct number of outputs + assert predicted_probabilities.shape[0] == self.num_segmentation_heads, ( + f"unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, " + f"got {predicted_probabilities.shape[0]}. Remember that predicted_probabilities should have shape " + f"(c, x, y(, z))." + ) + if self.has_regions: + if isinstance(predicted_probabilities, np.ndarray): + segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16) + else: + # no uint16 in torch + segmentation = torch.zeros( + predicted_probabilities.shape[1:], + dtype=torch.int16, + device=predicted_probabilities.device, + ) + for i, c in enumerate(self.regions_class_order): + segmentation[predicted_probabilities[i] > 0.5] = c + if isinstance(segmentation, torch.Tensor): + segmentation = segmentation.cpu().numpy() + else: + # Issensee is no longer right when saying "numpy is faster than torch" newer torch versions no longer have this issue, on GPU we even get a 20x improvment. :facepalm: + segmentation = _argmax_with_gpu_fallback(predicted_probabilities, device, chunk_size=chunk_size) + + return segmentation + def convert_predicted_logits_to_segmentation_with_correct_shape( predicted_logits: torch.Tensor | np.ndarray, @@ -20,6 +156,7 @@ def convert_predicted_logits_to_segmentation_with_correct_shape( properties_dict: dict, return_probabilities: bool = False, num_threads_torch: int = 8, + device=None, ) -> np.ndarray: """Revert all preprocessing steps and return a segmentation in the original image space. @@ -62,19 +199,12 @@ def convert_predicted_logits_to_segmentation_with_correct_shape( predicted_logits = configuration_manager.resampling_fn_probabilities( predicted_logits, properties_dict["shape_after_cropping_and_before_resampling"], current_spacing, properties_dict["spacing"] ) - # return value of resampling_fn_probabilities can be ndarray or Tensor but that doesnt matter because - # apply_inference_nonlin will covnert to torch - # And this is stupid because convert_probabilities_to_segmentation transforms it back to a numpy... if label_manager.has_regions or return_probabilities: # Softmax does not change when we use argmax in the next step predicted_logits = label_manager.apply_inference_nonlin(predicted_logits) - # segmentation may be torch.Tensor but we continue with numpy - if isinstance(predicted_logits, torch.Tensor): - predicted_logits = predicted_logits.cpu().numpy() - - segmentation: np.ndarray = label_manager.convert_probabilities_to_segmentation(np.ascontiguousarray(predicted_logits)) # type: ignore + # segmentation: np.ndarray = label_manager.convert_probabilities_to_segmentation(predicted_logits) # type: ignore + segmentation: np.ndarray = convert_probabilities_to_segmentation(label_manager, predicted_logits, device) segmentation = segmentation.astype(np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) - # if not return_probabilities: del predicted_logits # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros( diff --git a/TPTBox/segmentation/nnUnet_utils/inference_api.py b/TPTBox/segmentation/nnUnet_utils/inference_api.py index 4913952..9ddc538 100755 --- a/TPTBox/segmentation/nnUnet_utils/inference_api.py +++ b/TPTBox/segmentation/nnUnet_utils/inference_api.py @@ -2,13 +2,10 @@ from pathlib import Path -import nibabel as nib import numpy as np -import SimpleITK as sitk # noqa: N813 import torch from TPTBox import NII, Log_Type, No_Logger -from TPTBox.core import sitk_utils from .predictor import nnUNetPredictor @@ -31,11 +28,14 @@ def load_inf_model( inference_augmentation: bool = False, use_gaussian: bool = True, verbose: bool = False, + fast_perf: bool = True, gpu: int | None = None, - memory_base: int = 5000, - memory_factor: int = 160, - memory_max: int = 160000, + memory_base: float = 5000, + memory_factor: float = 160, + memory_max: float = 160000, wait_till_gpu_percent_is_free: float = 0.3, + fail_on_missing_memory=False, + tile_batch_size: int = 1, ) -> nnUNetPredictor: """Load and initialise an nnU-Net model predictor from a trained model folder. @@ -55,6 +55,11 @@ def load_inf_model( inference_augmentation: If True, enable test-time mirroring augmentation. use_gaussian: If True, apply Gaussian weighting in the sliding window. verbose: If True, print progress information during model initialisation. + fast_perf: If True (and running on CUDA), enable cuDNN autotuning and TF32 + matmul/conv. Every sliding-window tile has the same ``patch_size`` + shape, so cuDNN can pick the fastest convolution algorithms once and + reuse them. TF32 speeds up fp32 ops on Ampere+ GPUs with negligible + accuracy impact. These are global ``torch.backends`` flags. gpu: GPU device index forwarded to the predictor. ``None`` defaults to 0. memory_base: Base GPU memory reservation in MB (default 5 000 MB = 5 GB). memory_factor: Per-voxel memory scaling factor. The formula is @@ -63,6 +68,9 @@ def load_inf_model( memory_max: Maximum GPU memory cap in MB (default 160 000 MB = 160 GB). wait_till_gpu_percent_is_free: Fraction of GPU memory that must be free before inference is started. + tile_batch_size: Number of sliding-window tiles per network forward pass. + ``1`` reproduces the original per-tile path; larger values batch + tiles to improve GPU utilisation at higher peak memory. Returns: Initialised ``nnUNetPredictor`` ready for inference. @@ -84,6 +92,16 @@ def load_inf_model( _interop = True except Exception as e: print(e) + if fast_perf: + # All sliding-window tiles share the same (patch_size) shape, so cuDNN can + # autotune the fastest conv algorithms once and reuse them across tiles/images. + # TF32 accelerates fp32 matmul/conv on Ampere+ with negligible accuracy impact. + try: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + except Exception as e: + print(e) device = torch.device("cuda") else: device = torch.device("mps") @@ -103,6 +121,8 @@ def load_inf_model( memory_factor=memory_factor, memory_max=memory_max, wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + fail_on_missing_memory=fail_on_missing_memory, + tile_batch_size=tile_batch_size, ) check_name = "checkpoint_final.pth" # if not allow_non_final else "checkpoint_best.pth" try: diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index eb654e9..ff9e6e2 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -70,6 +70,10 @@ class nnUNetPredictor: memory_max: Clamp on maximum GPU memory (MB) to assume available. wait_till_gpu_percent_is_free: Fraction of GPU memory that must be free before inference starts. Waits up to 40 minutes. + tile_batch_size: Number of sliding-window tiles to run per network + forward pass. ``1`` (default) reproduces the original per-tile loop + exactly; larger values batch tiles together to improve GPU + utilisation at the cost of higher peak memory. """ def __init__( @@ -83,10 +87,12 @@ def __init__( verbose: bool = False, verbose_preprocessing: bool = False, allow_tqdm: bool = True, - memory_base=5000, # Base memory in MB, default is 5GB - memory_factor=160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB - memory_max: int = 160000, # in MB, default is 160GB + memory_base: float = 5000, # Base memory in MB, default is 5GB + memory_factor: float = 160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB + memory_max: float = 160000, # in MB, default is 160GB + fail_on_missing_memory=False, wait_till_gpu_percent_is_free=0.3, + tile_batch_size: int = 1, ): self.verbose = verbose self.verbose_preprocessing = verbose_preprocessing @@ -108,16 +114,22 @@ def __init__( self.use_mirroring = use_mirroring if device.type == "cuda": device = torch.device(type="cuda", index=cuda_id) # set the desired GPU with CUDA_VISIBLE_DEVICES! - self.do_not_use_half_precision = device.type in ["cpu", "mps"] # float16 not supported by cpu + # float16 not supported by cpu + self.do_not_use_half_precision = device.type in ["cpu", "mps"] if device.type != "cuda" and perform_everything_on_gpu: print("perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False") perform_everything_on_gpu = False self.device = device self.perform_everything_on_gpu = perform_everything_on_gpu + self.fail_on_missing_memory = fail_on_missing_memory self.memory_base = memory_base self.memory_factor = memory_factor self.memory_max = memory_max self.wait_till_gpu_percent_is_free = wait_till_gpu_percent_is_free + # Number of sliding-window tiles to push through the network in one forward pass. All tiles + # share the same patch_size, so they batch densely. 1 reproduces the original per-tile path + # exactly; larger values raise GPU utilisation (and peak memory) for small patches. + self.tile_batch_size = tile_batch_size def initialize_from_trained_model_folder( self, @@ -152,7 +164,11 @@ def initialize_from_trained_model_folder( ## LOAD NNUNET 1 models from nnunet.training.model_restore import restore_model - pkl_file1 = join(model_training_output_dir, f"fold_{use_folds[0]}", "model_final_checkpoint.model.pkl") # type: ignore + pkl_file1 = join( + model_training_output_dir, + f"fold_{use_folds[0]}", + "model_final_checkpoint.model.pkl", + ) # type: ignore trainer = restore_model(pkl_file1, fp16=True) trainer.output_folder = model_training_output_dir trainer.output_folder_base = model_training_output_dir @@ -281,26 +297,37 @@ def mapp(d: dict): print("compiling network") self.network = torch.compile(self.network) # type: ignore - self.loaded_networks = [] - if cache_state_dicts: - for params in self.list_of_parameters: - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) # type: ignore - else: - self.network._orig_mod.load_state_dict(params) - if self.device.type == "cuda" and not torch.cuda.is_available(): - Print_Logger().on_warning( - "No CUDA device. If you have a CUDA-able GPU (Nvidia), reinstall pytorch with cuda or for non-cuda devices use ddevice=cpu or ddevice=mps" - ) - if self.device.type == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): - Print_Logger().on_warning("No MPS device found. Use ddevice=cpu or ddevice=mps") - self.network.to(self.device) - self.network.eval() # type: ignore - self.loaded_networks.append(self.network) - # print(type(self.loaded_networks[0])) + # Warn early if the requested device is unavailable (runs once, independent of folds). + if self.device.type == "cuda" and not torch.cuda.is_available(): + Print_Logger().on_warning( + "No CUDA device. If you have a CUDA-able GPU (Nvidia), reinstall pytorch with cuda or for non-cuda devices use ddevice=cpu or ddevice=mps" + ) + if self.device.type == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): + Print_Logger().on_warning("No MPS device found. Use ddevice=cpu or ddevice=mps") + + # loaded_networks holds one ready-to-run network per fold (or None to load weights + # lazily per fold). We only cache the single-fold case: previously this loop appended + # the SAME self.network object once per fold, so every entry ended up holding the LAST + # fold's weights. That silently collapsed an N-fold ensemble to a single fold while + # still paying Nx the compute. For >1 fold we keep loaded_networks=None and let + # predict_logits_from_preprocessed_data swap weights per fold via load_state_dict (a + # true ensemble that needs only one network's worth of GPU memory). + self.loaded_networks = None + if cache_state_dicts and len(self.list_of_parameters) == 1: + params = self.list_of_parameters[0] + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) # type: ignore + else: + self.network._orig_mod.load_state_dict(params) + self.network.to(self.device) + self.network.eval() # type: ignore + self.loaded_networks = [self.network] def predict_single_npy_array( - self, input_image: np.ndarray, image_properties: dict, save_or_return_probabilities: bool = False + self, + input_image: np.ndarray, + image_properties: dict, + save_or_return_probabilities: bool = False, ) -> np.ndarray: """Run full inference on a single numpy image array. @@ -337,7 +364,10 @@ def predict_single_npy_array( if self.verbose: print("predicting") predicted_logits = self.predict_logits_from_preprocessed_data(dct["data"]) # type: ignore - print("convert_predicted_logits_to_segmentation_with_correct_shape", predicted_logits.shape) + print( + "convert_predicted_logits_to_segmentation_with_correct_shape", + predicted_logits.shape, + ) import time t = time.time() @@ -348,6 +378,7 @@ def predict_single_npy_array( self.label_manager, dct["data_properites"], return_probabilities=save_or_return_probabilities, + device=self.device, ) print("convert_predicted_logits_to_segmentation_with_correct_shape; Took", time.time() - t, " seconds") @@ -382,7 +413,7 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: in # things a lot faster for some datasets. original_perform_everything_on_gpu = self.perform_everything_on_gpu assert self.list_of_parameters is not None - with torch.no_grad(): + with torch.inference_mode(): prediction = None try: for idx, params in enumerate(self.list_of_parameters): @@ -416,7 +447,7 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: in prediction = None self.perform_everything_on_gpu = False empty_cache(self.device) - if attempts == 0: + if attempts == 0 or self.fail_on_missing_memory: raise return self.predict_logits_from_preprocessed_data(data, attempts=attempts - 1) @@ -554,7 +585,6 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ network = network.to(self.device) # type: ignore assert self.configuration_manager is not None assert self.label_manager is not None - empty_cache(self.device) # Autocast is a little bitch. # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) @@ -563,7 +593,7 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ # is set. Why. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with ( - torch.no_grad(), + torch.inference_mode(), torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context(), ): assert len(input_image.shape) == 4, "input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)" @@ -603,12 +633,12 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ time.sleep(1) def check_mem(shape): - memory = get_gpu_memory_MB(device) + memory = get_gpu_memory_MB(device) * 0.80 max_memory = self.memory_max min_memory = self.memory_base factor = self.memory_factor # print(shape, "usage", np.prod(shape) / 1000000 * factor, max(min(memory, max_memory), min_memory)) - return (np.prod(shape) / 1000000 * factor) + min_memory // 2 < max(min(memory, max_memory), min_memory) + return (np.prod(shape) / 1000000 * factor) + min_memory < max(min(memory, max_memory), min_memory) with tqdm(total=len(slicers), disable=not self.allow_tqdm) as pbar: if not check_mem(shape) or "nnUNetPlans_2d" not in self.configuration_manager.configuration.get("data_identifier", "3D"): @@ -642,13 +672,13 @@ def check_mem(shape): break splits[j] += 1 + predicted_logits, n_predictions = self._run_sub(data, network, device, slicers, pbar) pbar.desc = "finish" pbar.update(0) predicted_logits /= n_predictions del n_predictions predicted_logits = predicted_logits.cpu() - empty_cache(self.device) return predicted_logits[(slice(None), *slicer_revert_padding[1:])] def _run_prediction_splits( @@ -700,7 +730,7 @@ def _run_prediction_splits( predicted_logits /= n_predictions del n_predictions - empty_cache(self.device) + # empty_cache(self.device) return predicted_logits def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool = True): @@ -749,22 +779,29 @@ def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool return predicted_logits, n_predictions, gaussian, results_device def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: tqdm, addendum: str = ""): - """Iterate over slicers, run inference per tile, and accumulate results.""" + """Iterate over slicers, run inference per tile (optionally batched), and accumulate results.""" + slicers = list(slicers) try: data = data.to(self.device) # type: ignore predicted_logits, n_predictions, gaussian, results_device = self._allocate(data, results_device, pbar) pbar.desc = f"running prediction {addendum}" prediction = None work_on = None - for sl in slicers: - pbar.update(1) - work_on = data[sl][None] + batch_size = max(1, self.tile_batch_size) + for batch_start in range(0, len(slicers), batch_size): + batch_slicers = slicers[batch_start : batch_start + batch_size] + # batch_size == 1 keeps the original view (no copy); larger batches stack tiles into a + # dense (B, C, *patch) tensor (valid because all tiles share the same patch_size). + work_on = data[batch_slicers[0]][None] if batch_size == 1 else torch.stack([data[sl] for sl in batch_slicers], dim=0) work_on = work_on.to(self.device, non_blocking=False) - prediction = self._internal_maybe_mirror_and_predict(work_on, network=network)[0].to(results_device) - if prediction.shape[0] != predicted_logits.shape[0]: - prediction.squeeze_(0) - predicted_logits[sl] += prediction * gaussian if self.use_gaussian else prediction - n_predictions[sl[1:]] += gaussian if self.use_gaussian else 1 + prediction = self._internal_maybe_mirror_and_predict(work_on, network=network).to(results_device) + for b, sl in enumerate(batch_slicers): + pbar.update(1) + pred = prediction[b] + if pred.shape[0] != predicted_logits.shape[0]: + pred = pred.squeeze(0) + predicted_logits[sl] += pred * gaussian if self.use_gaussian else pred + n_predictions[sl[1:]] += gaussian if self.use_gaussian else 1 return predicted_logits, n_predictions # noqa: TRY300 except RuntimeError: del predicted_logits diff --git a/TPTBox/tests/speedtests/speedtest_npunique.py b/TPTBox/tests/speedtests/speedtest_npunique.py index 484606d..afdc91f 100644 --- a/TPTBox/tests/speedtests/speedtest_npunique.py +++ b/TPTBox/tests/speedtests/speedtest_npunique.py @@ -18,13 +18,14 @@ np_unique, np_unique_withoutzero, np_volume, + old_np_unique, ) from TPTBox.tests.speedtests.speedtest import speed_test from TPTBox.tests.test_utils import get_nii def get_nii_array(): num_points = random.randint(1, 30) - nii, points, orientation, sizes = get_nii(x=(140, 140, 150), num_point=num_points) + nii, points, orientation, sizes = get_nii(x=(400, 400, 400), num_point=num_points) # nii.map_labels_({1: -1}, verbose=False) arr = nii.get_seg_array().astype(np.uint8) # arr[arr == 1] = -1 @@ -34,7 +35,7 @@ def get_nii_array(): speed_test( repeats=50, get_input_func=get_nii_array, - functions=[np_unique, np.unique, np_is_empty, np.max], + functions=[np_unique, old_np_unique, np.unique, np_is_empty, np.max], assert_equal_function=lambda x, y: True, # np.all([x[i] == y[i] for i in range(len(x))]), # noqa: ARG005 # np.all([x[i] == y[i] for i in range(len(x))]) ) diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore new file mode 100644 index 0000000..8321588 --- /dev/null +++ b/benchmarks/.gitignore @@ -0,0 +1,2 @@ +results/ +.bench_cache/ diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..72ccc9b --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,77 @@ +# nnU-Net inference timing harness + +Tools to measure the speed impact of the inference optimisations in +`TPTBox/segmentation/nnUnet_utils/` and `TPTBox/segmentation/VibeSeg/inference_nnunet.py`. + +## Files + +- `benchmark_nnunet_inference.py` — the measurement tool. Splits + `run_inference_on_file` into phases (**load → preprocess → sliding-window + predict → postprocess → other**) with CUDA synchronisation, warmup, repeats + and peak-memory tracking. It instruments the pipeline by monkeypatching (no + library changes) and drops any CLI flag a given commit does not yet support, + so the *same file* runs against every commit. +- `bench_across_commits.sh` — checks out the baseline + each optimisation commit, + runs one fixed config, and prints per-commit deltas. +- `bench_flag_sweep.sh` — stays on the current commit and sweeps the opt-in flags. + +## Requirements + +- An **editable** TPTBox install (`poetry install`) so that checking out a commit + changes the imported code. +- Model weights for the chosen `--dataset-id` (default `100`). The first run may + download them; that happens during warmup and is excluded from the numbers. +- A GPU for `--device cuda` (the default). `--synthetic` removes the need for a + real input image — only the model is required. + +## Quick start + +```bash +# 1) per-commit deltas (always-on changes), synthetic input +benchmarks/bench_across_commits.sh --device cuda --repeats 5 + +# 2) opt-in flag effects at HEAD +benchmarks/bench_flag_sweep.sh --device cuda --repeats 5 + +# 3) one config by hand (e.g. on your real data) +python benchmarks/benchmark_nnunet_inference.py run \ + --dataset-id 100 --input water.nii.gz fat.nii.gz \ + --tile-batch-size 4 --repeats 5 --json /tmp/tb4.json +``` + +## Which tool measures which commit + +| Commit | Change | How it shows up | +|---|---|---| +| `cuDNN/TF32` | autotune + TF32 | `bench_across_commits.sh`: `predict` drops at this commit (after warmup) | +| `inference_mode` | no_grad → inference_mode | `bench_across_commits.sh`: small `predict`/`peak_mem` drop | +| `empty_cache` | drop per-fold cache clears | `bench_across_commits.sh`: `predict` drop, larger with more folds | +| `fold fix` | repair `loaded_networks` | `fold_status` column flips from `DUPLICATED(...)` to `lazy-per-fold`/`distinct`. Time is ~unchanged (correctness fix) — the old code already paid for N passes. | +| `cache_model` | persistent model cache | `bench_flag_sweep.sh` → `sweep_cache_model.json`: `load: first` is large, `load: steady-median` ≈ 0 | +| `tile_batch_size` | batch tiles per forward | `bench_flag_sweep.sh`: `tile_batch_4/8` reduce `forward_calls` and `predict`, raise `peak_mem` | + +Biggest raw-speed levers (quality trade-off): `--max-folds 1` (≈ folds×) and +`--step-size 0.7` (fewer tiles). + +## Reading the output + +- **`total: first` vs `steady-median`** — the first call pays lazy CUDA init, + cuDNN autotuning and (without caching) model load; steady-median is the + representative per-image cost. +- **`forward_calls`** — number of network forward passes = `folds × tiles` (÷ batch). + Drops with `--tile-batch-size`, scales with folds. +- **`fold_status`** — `DUPLICATED(...)` means the ensemble was silently collapsed + to one fold (pre-fix); `lazy-per-fold` (multi-fold) / `distinct` is correct. +- **`peak_mem_mb`** — `torch.cuda.max_memory_allocated`; watch this rise with + `--tile-batch-size`. + +## Notes / caveats + +- Synthetic input is random noise with the model's channel count and a chosen + shape/spacing — fine for *timing* (the network cost is content-independent), not + for assessing segmentation quality. It is cached in a temp dir and reused across + commits so the workload is identical. +- TF32 and the fold fix change the numeric output between commits, so outputs are + not bit-identical across the range — expected, not a harness bug. +- `bench_across_commits.sh` restores your original branch on exit (even on error). +- Results and the synthetic cache (`results/`, `.bench_cache/`) are git-ignored. diff --git a/benchmarks/bench_across_commits.sh b/benchmarks/bench_across_commits.sh new file mode 100755 index 0000000..18e6f2f --- /dev/null +++ b/benchmarks/bench_across_commits.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +# Replay the benchmark across the optimisation commits and tabulate per-commit deltas. +# +# Isolates the ALWAYS-ON changes (cuDNN/TF32, inference_mode, empty_cache, fold fix): each commit +# is checked out, run with one fixed config, and compared. The opt-in commits (cache_model, +# tile_batch_size) are no-ops under the default config and show ~0 here -- use bench_flag_sweep.sh +# for those. +# +# Usage: +# benchmarks/bench_across_commits.sh [extra args forwarded to `benchmark ... run`] +# BASELINE= benchmarks/bench_across_commits.sh --device cuda --repeats 5 +# With no extra args it uses a synthetic input. Requires a clean working tree and an editable +# TPTBox install (so checking out a commit changes the imported code). +set -euo pipefail + +REPO_ROOT="$(git rev-parse --show-toplevel)" +cd "$REPO_ROOT" + +if [ -n "$(git status --porcelain)" ]; then + echo "error: working tree is not clean. Commit or stash before running." >&2 + exit 1 +fi + +BASELINE="${BASELINE:-3906165}" +ORIG_REF="$(git symbolic-ref --short -q HEAD || git rev-parse HEAD)" +TMP_BENCH="$(mktemp -t bench_nnunet_XXXX.py)" +cp "$REPO_ROOT/benchmarks/benchmark_nnunet_inference.py" "$TMP_BENCH" +RESULTS="$REPO_ROOT/benchmarks/results" +mkdir -p "$RESULTS" + +cleanup() { git checkout -q "$ORIG_REF"; rm -f "$TMP_BENCH"; } +trap cleanup EXIT + +# baseline + every commit since it, in chronological order +mapfile -t COMMITS < <(printf '%s\n' "$BASELINE"; git rev-list --reverse "${BASELINE}..${ORIG_REF}") + +if [ $# -eq 0 ]; then RUN_ARGS=(--synthetic); else RUN_ARGS=("$@"); fi + +JSONS=() +i=0 +for sha in "${COMMITS[@]}"; do + short="$(git rev-parse --short "$sha")" + out="$RESULTS/$(printf '%02d' "$i")_${short}.json" + echo "================================================================" + echo "[$i] $short : $(git log -1 --format=%s "$sha")" + echo "================================================================" + git checkout -q "$sha" + python "$TMP_BENCH" run "${RUN_ARGS[@]}" --json "$out" + JSONS+=("$out") + i=$((i + 1)) +done + +git checkout -q "$ORIG_REF" +echo +echo "################ COMPARISON ################" +python "$TMP_BENCH" compare "${JSONS[@]}" diff --git a/benchmarks/bench_flag_sweep.sh b/benchmarks/bench_flag_sweep.sh new file mode 100755 index 0000000..4418e58 --- /dev/null +++ b/benchmarks/bench_flag_sweep.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Measure the OPT-IN optimisations at the current commit by sweeping their flags in one process. +# (cache_model needs >1 call to show amortisation, so it cannot be measured by the cross-commit +# driver -- that is what this script is for.) +# +# Usage: +# benchmarks/bench_flag_sweep.sh [extra args forwarded to `benchmark ... run`] +# benchmarks/bench_flag_sweep.sh --input water.nii.gz fat.nii.gz --device cuda +# With no extra args it uses a synthetic input. +set -euo pipefail + +REPO_ROOT="$(git rev-parse --show-toplevel)" +BENCH="$REPO_ROOT/benchmarks/benchmark_nnunet_inference.py" +RESULTS="$REPO_ROOT/benchmarks/results" +mkdir -p "$RESULTS" + +if [ $# -eq 0 ]; then COMMON=(--synthetic); else COMMON=("$@"); fi + +run() { # name : extra flags... + local name="$1"; shift + echo "================ $name ================" + python "$BENCH" run "${COMMON[@]}" "$@" --json "$RESULTS/sweep_${name}.json" +} + +run baseline +run tile_batch_4 --tile-batch-size 4 +run tile_batch_8 --tile-batch-size 8 +run cache_model --cache-model --repeats 6 +run max_folds_1 --max-folds 1 +run step_0p7 --step-size 0.7 + +echo +echo "################ COMPARISON ################" +python "$BENCH" compare \ + "$RESULTS/sweep_baseline.json" \ + "$RESULTS/sweep_tile_batch_4.json" \ + "$RESULTS/sweep_tile_batch_8.json" \ + "$RESULTS/sweep_cache_model.json" \ + "$RESULTS/sweep_max_folds_1.json" \ + "$RESULTS/sweep_step_0p7.json" +echo +echo "cache_model: compare 'load: first' vs 'load: steady-median' in sweep_cache_model.json --" +echo "the median load should collapse to ~0 once the model is cached." diff --git a/benchmarks/benchmark_nnunet_inference.py b/benchmarks/benchmark_nnunet_inference.py new file mode 100644 index 0000000..7c674e8 --- /dev/null +++ b/benchmarks/benchmark_nnunet_inference.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python +"""Timing harness for the VibeSeg / nnU-Net inference pipeline. + +Measures ``run_inference_on_file`` split into phases (model load, preprocess, +sliding-window forward, postprocess) with correct CUDA synchronisation, warmup, +repeats and peak-memory tracking. It is written so the *same file* runs against +any commit in the optimisation range: arguments that a given commit does not yet +support are silently dropped (see :func:`supported_kwargs`). + +Two ways to use it: + +1. Per-commit deltas (always-on changes: cuDNN/TF32, inference_mode, + empty_cache, fold fix). Drive it with ``bench_across_commits.sh``. + +2. Flag sweep at HEAD (opt-in changes: cache_model, tile_batch_size, max_folds, + step_size). Drive it with ``bench_flag_sweep.sh`` or call ``run`` directly + with the relevant flags. + +Examples:: + + # synthetic input (only the model weights are required), single config + python benchmark_nnunet_inference.py run --dataset-id 100 --synthetic \ + --shape 320 320 96 --repeats 5 --json /tmp/head.json + + # real input(s); multi-channel models take several --input paths + python benchmark_nnunet_inference.py run --dataset-id 100 \ + --input water.nii.gz fat.nii.gz --tile-batch-size 4 --cache-model + + # compare result JSONs produced by several runs/commits + python benchmark_nnunet_inference.py compare results/*.json + +Caveats: TPTBox must be importable from the *working tree* (an editable +``poetry install`` does this), otherwise checking out a commit will not change +the measured code. The first run may download model weights; that happens during +warmup and is excluded from the reported numbers. +""" + +from __future__ import annotations + +import argparse +import functools +import inspect +import json +import statistics +import sys +import time +from pathlib import Path + +import numpy as np +import torch + +# --- phase timers / counters populated by monkeypatching (no library changes) --------------- +TIMINGS: dict[str, float] = {} +COUNTERS: dict[str, object] = {} + + +def _sync(device: str) -> None: + if device == "cuda": + torch.cuda.synchronize() + + +def _reset_collectors() -> None: + TIMINGS.clear() + COUNTERS.clear() + COUNTERS["forward_calls"] = 0 + COUNTERS["tiles"] = 0 + COUNTERS["folds"] = None + COUNTERS["fold_status"] = "n/a" + + +def _timed(name: str, device: str): + """Wrap a callable so its synchronised wall time accumulates into ``TIMINGS[name]``.""" + + def deco(fn): + @functools.wraps(fn) + def wrapper(*a, **k): + _sync(device) + t = time.perf_counter() + try: + return fn(*a, **k) + finally: + _sync(device) + TIMINGS[name] = TIMINGS.get(name, 0.0) + (time.perf_counter() - t) + + return wrapper + + return deco + + +def install_patches(device: str) -> None: + """Monkeypatch the inference stack to record per-phase timings and forward counts. + + Patches the names in the modules where they are actually *called* so the + instrumentation works regardless of how each commit imports them. + """ + from TPTBox.segmentation.nnUnet_utils import inference_api, predictor + + # model load (run_inference_on_file does `from inference_api import load_inf_model` at call + # time, so patching the attribute here is picked up by that local import). + inference_api.load_inf_model = _timed("load", device)(inference_api.load_inf_model) + + # postprocess: convert_predicted_logits_... is called via predictor's own namespace binding. + predictor.convert_predicted_logits_to_segmentation_with_correct_shape = _timed("postprocess", device)( + predictor.convert_predicted_logits_to_segmentation_with_correct_shape + ) + + P = predictor.nnUNetPredictor + P.predict_single_npy_array = _timed("single_npy", device)(P.predict_single_npy_array) + P.predict_logits_from_preprocessed_data = _timed("predict", device)(_fold_probe(P.predict_logits_from_preprocessed_data)) + P._internal_maybe_mirror_and_predict = _forward_counter(P._internal_maybe_mirror_and_predict) + + +def _fold_probe(fn): + """Record fold count and whether the loaded_networks cache is correct (surfaces the #6 bug).""" + + @functools.wraps(fn) + def wrapper(self, *a, **k): + params = getattr(self, "list_of_parameters", None) + COUNTERS["folds"] = len(params) if params is not None else None + loaded = getattr(self, "loaded_networks", None) + if loaded is None: + COUNTERS["fold_status"] = "lazy-per-fold" # correct: weights swapped per fold + else: + ids = {id(n) for n in loaded} + COUNTERS["fold_status"] = "distinct" if len(ids) == len(loaded) else f"DUPLICATED({len(loaded)}->{len(ids)})" + return fn(self, *a, **k) + + return wrapper + + +def _forward_counter(fn): + @functools.wraps(fn) + def wrapper(self, x, *a, **k): + COUNTERS["forward_calls"] = COUNTERS.get("forward_calls", 0) + 1 + COUNTERS["tiles"] = COUNTERS.get("tiles", 0) + int(x.shape[0]) + return fn(self, x, *a, **k) + + return wrapper + + +# --- input handling -------------------------------------------------------------------------- +def make_synthetic(shape: list[int], spacing: list[float], channels: int, seed: int, cache_dir: Path) -> list[str]: + """Create (and cache) ``channels`` random NIfTI volumes with the given shape/spacing.""" + import nibabel as nib + + cache_dir.mkdir(parents=True, exist_ok=True) + rng = np.random.default_rng(seed) + affine = np.diag([*spacing, 1.0]).astype(float) + paths = [] + tag = f"{channels}ch_{'x'.join(map(str, shape))}_sp{'-'.join(str(s) for s in spacing)}_s{seed}" + for c in range(channels): + p = cache_dir / f"synthetic_{tag}_ch{c}.nii.gz" + if not p.exists(): + arr = (rng.standard_normal(tuple(shape)).astype(np.float32) * 200.0) + 100.0 + nib.save(nib.Nifti1Image(arr, affine), str(p)) + paths.append(str(p)) + return paths + + +def resolve_channels(dataset_id: int | None, fallback: int) -> int: + if dataset_id is None: + return fallback + try: + from TPTBox.segmentation.VibeSeg.inference_nnunet import get_ds_info + + ds = get_ds_info(dataset_id, exit_one_fail=False) + if ds and "channel_names" in ds: + return len(ds["channel_names"]) + except Exception as e: # noqa: BLE001 - best effort, fall back to the user value + print(f"[warn] could not read dataset.json for channel count ({e}); using --channels={fallback}") + return fallback + + +# --- core measurement ------------------------------------------------------------------------ +def supported_kwargs(fn, kwargs: dict) -> dict: + """Keep only kwargs the target accepts, so the harness runs on commits that lack newer flags.""" + params = inspect.signature(fn).parameters + if any(p.kind == p.VAR_KEYWORD for p in params.values()): + return dict(kwargs) + keep = {k: v for k, v in kwargs.items() if k in params} + dropped = sorted(set(kwargs) - set(keep)) + if dropped: + print(f"[info] this commit ignores unsupported args: {dropped}") + return keep + + +def run_once(idx, inputs: list[str], call_kwargs: dict, device: str) -> dict: + from TPTBox import to_nii + from TPTBox.segmentation.VibeSeg.inference_nnunet import run_inference_on_file + + _reset_collectors() + niis = [to_nii(p) for p in inputs] # reload each repeat so in-place ops never bleed across runs + if device == "cuda": + torch.cuda.reset_peak_memory_stats() + _sync(device) + t0 = time.perf_counter() + run_inference_on_file(idx, niis, **supported_kwargs(run_inference_on_file, call_kwargs)) + _sync(device) + total = time.perf_counter() - t0 + + phases = { + "load": TIMINGS.get("load", 0.0), + "preprocess": max(0.0, TIMINGS.get("single_npy", 0.0) - TIMINGS.get("predict", 0.0) - TIMINGS.get("postprocess", 0.0)), + "predict": TIMINGS.get("predict", 0.0), + "postprocess": TIMINGS.get("postprocess", 0.0), + } + phases["other"] = max(0.0, total - sum(phases.values())) # NII I/O, reorient, rescale-back, save + return { + "total": total, + "phases": phases, + "peak_mem_mb": (torch.cuda.max_memory_allocated() / 1e6) if device == "cuda" else 0.0, + "forward_calls": COUNTERS["forward_calls"], + "tiles": COUNTERS["tiles"], + "folds": COUNTERS["folds"], + "fold_status": COUNTERS["fold_status"], + } + + +def summarize(repeats: list[dict]) -> dict: + """Median over steady-state repeats (repeat 0 kept separately as the cold/first call).""" + steady = repeats[1:] if len(repeats) > 1 else repeats + med = lambda key: statistics.median(r[key] for r in steady) # noqa: E731 + phase_keys = repeats[0]["phases"].keys() + return { + "total_first": repeats[0]["total"], + "total_median": med("total"), + "phases_median": {k: statistics.median(r["phases"][k] for r in steady) for k in phase_keys}, + "load_first": repeats[0]["phases"]["load"], + "load_median": statistics.median(r["phases"]["load"] for r in steady), + "peak_mem_mb": max(r["peak_mem_mb"] for r in repeats), + "forward_calls": repeats[0]["forward_calls"], + "tiles": repeats[0]["tiles"], + "folds": repeats[0]["folds"], + "fold_status": repeats[0]["fold_status"], + } + + +def commit_hash() -> str: + import subprocess + + try: + return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip() # noqa: S607 + except Exception: # noqa: BLE001 + return "unknown" + + +# --- subcommands ----------------------------------------------------------------------------- +def cmd_run(args: argparse.Namespace) -> None: + device = args.device + install_patches(device) + + channels = resolve_channels(args.dataset_id if args.model_path is None else None, args.channels) + if args.synthetic: + inputs = make_synthetic(args.shape, args.spacing, channels, args.seed, Path(args.cache_dir)) + else: + if not args.input: + sys.exit("error: provide --input PATH [PATH ...] or use --synthetic") + inputs = args.input + idx = Path(args.model_path) if args.model_path else args.dataset_id + + call_kwargs = { + "out_file": None, + "override": True, + "ddevice": device, + "gpu": args.gpu, + "verbose": False, + "max_folds": args.max_folds, + "step_size": args.step_size, + "tile_batch_size": args.tile_batch_size, + "cache_model": args.cache_model, + "keep_size": args.keep_size, + "padd": args.padd, + } + + print(f"warmup ({args.warmup}) ...") + for _ in range(max(0, args.warmup)): + run_once(idx, inputs, call_kwargs, device) + + repeats = [] + for i in range(args.repeats): + r = run_once(idx, inputs, call_kwargs, device) + repeats.append(r) + print( + f" repeat {i}: total={r['total']:.3f}s predict={r['phases']['predict']:.3f}s " + f"load={r['phases']['load']:.3f}s peak={r['peak_mem_mb']:.0f}MB forwards={r['forward_calls']}" + ) + + summary = summarize(repeats) + result = { + "commit": commit_hash(), + "device": device, + "inputs": inputs, + "config": {k: call_kwargs[k] for k in ("max_folds", "step_size", "tile_batch_size", "cache_model", "keep_size", "padd")}, + "repeats": repeats, + "summary": summary, + } + _print_summary(result) + if args.json: + Path(args.json).parent.mkdir(parents=True, exist_ok=True) + Path(args.json).write_text(json.dumps(result, indent=2)) + print(f"\nwrote {args.json}") + + +def _print_summary(result: dict) -> None: + s = result["summary"] + print(f"\n=== {result['commit']} | device={result['device']} | config={result['config']} ===") + print(f"folds={s['folds']} fold_status={s['fold_status']} forward_calls={s['forward_calls']} tiles={s['tiles']}") + print(f"peak_mem={s['peak_mem_mb']:.0f} MB") + print(f"total: first={s['total_first']:.3f}s steady-median={s['total_median']:.3f}s") + print(f"load: first={s['load_first']:.3f}s steady-median={s['load_median']:.3f}s") + print("phase medians (steady-state):") + for k, v in s["phases_median"].items(): + print(f" {k:<11s} {v:.3f}s") + + +def cmd_compare(args: argparse.Namespace) -> None: + rows = [] + for f in args.files: + d = json.loads(Path(f).read_text()) + s = d["summary"] + rows.append((d["commit"], s, d.get("config", {}))) + + hdr = f"{'commit':<10} {'total_med':>10} {'predict':>9} {'load_1st':>9} {'peak_MB':>9} {'fwd':>6} {'folds':>5} fold_status" + print(hdr) + print("-" * len(hdr)) + base = None + prev = None + for commit, s, _cfg in rows: + tm = s["total_median"] + line = ( + f"{commit:<10} {tm:>10.3f} {s['phases_median']['predict']:>9.3f} {s['load_first']:>9.3f} " + f"{s['peak_mem_mb']:>9.0f} {s['forward_calls']:>6} {s['folds']!s:>5} {s['fold_status']}" + ) + print(line) + base = base if base is not None else tm + if prev is not None: + d_prev = (tm - prev) / prev * 100 + d_base = (tm - base) / base * 100 + print(f"{'':<10} {'':>10} {'':>9} {'':>9} {'':>9} {'':>6} {'':>5} Δprev={d_prev:+.1f}% Δbaseline={d_base:+.1f}%") + prev = tm + print( + "\nNote: opt-in commits (cache_model, tile_batch_size) show ~0 here under the default config; " + "measure them with bench_flag_sweep.sh at HEAD." + ) + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + sub = p.add_subparsers(dest="cmd", required=True) + + r = sub.add_parser("run", help="benchmark one configuration") + r.add_argument("--dataset-id", type=int, default=100) + r.add_argument("--model-path", default=None, help="explicit model folder; overrides --dataset-id") + r.add_argument("--input", nargs="+", default=None, help="input NIfTI path(s); one per model channel") + r.add_argument("--synthetic", action="store_true", help="generate random input(s) instead of reading files") + r.add_argument("--shape", type=int, nargs=3, default=[320, 320, 96]) + r.add_argument("--spacing", type=float, nargs=3, default=[1.40625, 1.40625, 3.0]) + r.add_argument("--channels", type=int, default=1, help="used only if dataset.json channel count is unavailable") + r.add_argument("--seed", type=int, default=1234) + r.add_argument("--cache-dir", default=str(Path(__file__).parent / ".bench_cache")) + r.add_argument("--device", choices=["cuda", "cpu", "mps"], default="cuda") + r.add_argument("--gpu", type=int, default=0) + r.add_argument("--repeats", type=int, default=5) + r.add_argument("--warmup", type=int, default=1) + # optimisation knobs (dropped automatically on commits that predate them) + r.add_argument("--max-folds", type=int, default=None) + r.add_argument("--step-size", type=float, default=0.5) + r.add_argument("--tile-batch-size", type=int, default=1) + r.add_argument("--cache-model", action="store_true") + r.add_argument("--keep-size", action="store_true") + r.add_argument("--padd", type=int, default=0) + r.add_argument("--json", default=None, help="write structured results to this path") + r.set_defaults(func=cmd_run) + + c = sub.add_parser("compare", help="tabulate result JSONs in the given order") + c.add_argument("files", nargs="+") + c.set_defaults(func=cmd_compare) + return p + + +def main() -> None: + args = build_parser().parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index d6f222e..c5356aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -210,6 +210,8 @@ convention = "google" # Test files: don't require docstrings on test functions / helper utilities "unit_tests/**" = ["D101", "D102", "D103", "D205", "D415", "ANN201"] "TPTBox/tests/**" = ["D101", "D102", "D103", "D205", "D415", "ANN201"] +# Dev tooling (standalone scripts): same relaxed docstring rules, not an importable package +"benchmarks/**" = ["D101", "D102", "D103", "D205", "D415", "ANN201", "INP001"] [tool.ruff.lint.mccabe] # Flag errors (`C901`) whenever the complexity level exceeds 5. diff --git a/unit_tests/test_auto_segmentation.py b/unit_tests/test_auto_segmentation.py index 5b31b2b..4f20f6e 100644 --- a/unit_tests/test_auto_segmentation.py +++ b/unit_tests/test_auto_segmentation.py @@ -39,6 +39,14 @@ has_torch = False +try: + import ants + + has_ants = True +except ModuleNotFoundError: + has_ants = False + + class Test_test_samples(unittest.TestCase): # def test_load_ct(self): # ct_nii, subreg_nii, vert_nii, label = get_test_ct() @@ -60,7 +68,7 @@ def test_get_outpaths_spineps(self): assert "out_spine" in out assert "out_vert" in out - @unittest.skipIf(not has_spineps, "requires spineps to be installed") + @unittest.skipIf(not has_spineps or not has_ants, "requires spineps to be installed") def test_spineps(self): tests_path = get_tests_dir() if (tests_path / "derivative").exists():