From 88e7bf5c1fcec11feb7ade2512b3ccf7f6c438b9 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Wed, 27 May 2026 15:15:15 +0000 Subject: [PATCH 01/17] fail if not fit on GPU, so I can test it --- TPTBox/segmentation/VibeSeg/inference_nnunet.py | 2 ++ TPTBox/segmentation/nnUnet_utils/inference_api.py | 5 ++--- TPTBox/segmentation/nnUnet_utils/predictor.py | 4 +++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 26d6472..7aaf8d2 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -94,6 +94,7 @@ def run_inference_on_file( verbose: bool = True, auto_download: bool = False, _key_ResEnc: str = "__nnUNet*ResEnc", + fail_on_missing_memory=False, logger=logger, ) -> tuple[Image_Reference, np.ndarray | None]: """Load a VibeSeg model and run inference on the supplied NIfTI images. @@ -207,6 +208,7 @@ def run_inference_on_file( memory_factor=memory_factor, memory_max=memory_max, wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + fail_on_missing_memory=fail_on_missing_memory, ) # _unets[idx] = nnunet diff --git a/TPTBox/segmentation/nnUnet_utils/inference_api.py b/TPTBox/segmentation/nnUnet_utils/inference_api.py index 4913952..8482904 100755 --- a/TPTBox/segmentation/nnUnet_utils/inference_api.py +++ b/TPTBox/segmentation/nnUnet_utils/inference_api.py @@ -2,13 +2,10 @@ from pathlib import Path -import nibabel as nib import numpy as np -import SimpleITK as sitk # noqa: N813 import torch from TPTBox import NII, Log_Type, No_Logger -from TPTBox.core import sitk_utils from .predictor import nnUNetPredictor @@ -36,6 +33,7 @@ def load_inf_model( memory_factor: int = 160, memory_max: int = 160000, wait_till_gpu_percent_is_free: float = 0.3, + fail_on_missing_memory=False, ) -> nnUNetPredictor: """Load and initialise an nnU-Net model predictor from a trained model folder. @@ -103,6 +101,7 @@ def load_inf_model( memory_factor=memory_factor, memory_max=memory_max, wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + fail_on_missing_memory=fail_on_missing_memory, ) check_name = "checkpoint_final.pth" # if not allow_non_final else "checkpoint_best.pth" try: diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index eb654e9..4e7316c 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -86,6 +86,7 @@ def __init__( memory_base=5000, # Base memory in MB, default is 5GB memory_factor=160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB memory_max: int = 160000, # in MB, default is 160GB + fail_on_missing_memory=False, wait_till_gpu_percent_is_free=0.3, ): self.verbose = verbose @@ -114,6 +115,7 @@ def __init__( perform_everything_on_gpu = False self.device = device self.perform_everything_on_gpu = perform_everything_on_gpu + self.fail_on_missing_memory = fail_on_missing_memory self.memory_base = memory_base self.memory_factor = memory_factor self.memory_max = memory_max @@ -416,7 +418,7 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: in prediction = None self.perform_everything_on_gpu = False empty_cache(self.device) - if attempts == 0: + if attempts == 0 or self.fail_on_missing_memory: raise return self.predict_logits_from_preprocessed_data(data, attempts=attempts - 1) From 4a6ee90c16fc40843f0a4322ee2ec921cb16d967 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Thu, 28 May 2026 12:54:23 +0000 Subject: [PATCH 02/17] update memory requirements --- .../segmentation/VibeSeg/inference_nnunet.py | 11 +++++-- TPTBox/segmentation/VibeSeg/vibeseg.py | 8 +++++ .../nnUnet_utils/inference_api.py | 6 ++-- TPTBox/segmentation/nnUnet_utils/predictor.py | 29 +++++++++++++------ 4 files changed, 39 insertions(+), 15 deletions(-) diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 7aaf8d2..2e6e367 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -87,9 +87,9 @@ def run_inference_on_file( ddevice: Literal["cpu", "cuda", "mps"] = "cuda", model_path=None, step_size: float = 0.5, - memory_base: int = 5000, # Base memory in MB, default is 5GB - memory_factor: int = 160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB - memory_max: int = 160000, # in MB, default is 160GB + memory_base: float | None = None, # Base memory in MB, default is 5GB + memory_factor: float | None = None, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB + memory_max: int = 990000, # in MB, default is 990GB (so it is most likely ignored and replaced by Max Memory of the GPU) wait_till_gpu_percent_is_free: float = 0.1, verbose: bool = True, auto_download: bool = False, @@ -197,6 +197,11 @@ def run_inference_on_file( if "labels" in ds_info2: ds_info["labels_mapping"] = ds_info2["labels"] + if memory_base is None: + memory_base = float(ds_info.get("memory_base", 5000)) + if memory_factor is None: + memory_factor = float(ds_info.get("memory_factor", 160)) + nnunet = load_inf_model( nnunet_path, allow_non_final=True, diff --git a/TPTBox/segmentation/VibeSeg/vibeseg.py b/TPTBox/segmentation/VibeSeg/vibeseg.py index 087702f..04d4615 100644 --- a/TPTBox/segmentation/VibeSeg/vibeseg.py +++ b/TPTBox/segmentation/VibeSeg/vibeseg.py @@ -84,6 +84,10 @@ 72: "bone_other", } +defaults = { + 100: {"memory_base": 5500, "memory_factor": 25}, +} + def run_vibeseg( i: Image_Reference, @@ -113,6 +117,10 @@ def run_vibeseg( Returns: Segmentation ``NII`` saved at *out_seg*. """ + if dataset_id in defaults: + for k, v in defaults[dataset_id].items(): + if k not in args: + args[k] = v return run_inference_on_file( dataset_id, [to_nii(i)] if not isinstance(i, (list, tuple)) else [to_nii(j) for j in i], diff --git a/TPTBox/segmentation/nnUnet_utils/inference_api.py b/TPTBox/segmentation/nnUnet_utils/inference_api.py index 8482904..18dd387 100755 --- a/TPTBox/segmentation/nnUnet_utils/inference_api.py +++ b/TPTBox/segmentation/nnUnet_utils/inference_api.py @@ -29,9 +29,9 @@ def load_inf_model( use_gaussian: bool = True, verbose: bool = False, gpu: int | None = None, - memory_base: int = 5000, - memory_factor: int = 160, - memory_max: int = 160000, + memory_base: float = 5000, + memory_factor: float = 160, + memory_max: float = 160000, wait_till_gpu_percent_is_free: float = 0.3, fail_on_missing_memory=False, ) -> nnUNetPredictor: diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index 4e7316c..c4075e8 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -83,9 +83,9 @@ def __init__( verbose: bool = False, verbose_preprocessing: bool = False, allow_tqdm: bool = True, - memory_base=5000, # Base memory in MB, default is 5GB - memory_factor=160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB - memory_max: int = 160000, # in MB, default is 160GB + memory_base: float = 5000, # Base memory in MB, default is 5GB + memory_factor: float = 160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB + memory_max: float = 160000, # in MB, default is 160GB fail_on_missing_memory=False, wait_till_gpu_percent_is_free=0.3, ): @@ -109,7 +109,8 @@ def __init__( self.use_mirroring = use_mirroring if device.type == "cuda": device = torch.device(type="cuda", index=cuda_id) # set the desired GPU with CUDA_VISIBLE_DEVICES! - self.do_not_use_half_precision = device.type in ["cpu", "mps"] # float16 not supported by cpu + # float16 not supported by cpu + self.do_not_use_half_precision = device.type in ["cpu", "mps"] if device.type != "cuda" and perform_everything_on_gpu: print("perform_everything_on_gpu=True is only supported for cuda devices! Setting this to False") perform_everything_on_gpu = False @@ -154,7 +155,11 @@ def initialize_from_trained_model_folder( ## LOAD NNUNET 1 models from nnunet.training.model_restore import restore_model - pkl_file1 = join(model_training_output_dir, f"fold_{use_folds[0]}", "model_final_checkpoint.model.pkl") # type: ignore + pkl_file1 = join( + model_training_output_dir, + f"fold_{use_folds[0]}", + "model_final_checkpoint.model.pkl", + ) # type: ignore trainer = restore_model(pkl_file1, fp16=True) trainer.output_folder = model_training_output_dir trainer.output_folder_base = model_training_output_dir @@ -302,7 +307,10 @@ def mapp(d: dict): # print(type(self.loaded_networks[0])) def predict_single_npy_array( - self, input_image: np.ndarray, image_properties: dict, save_or_return_probabilities: bool = False + self, + input_image: np.ndarray, + image_properties: dict, + save_or_return_probabilities: bool = False, ) -> np.ndarray: """Run full inference on a single numpy image array. @@ -339,7 +347,10 @@ def predict_single_npy_array( if self.verbose: print("predicting") predicted_logits = self.predict_logits_from_preprocessed_data(dct["data"]) # type: ignore - print("convert_predicted_logits_to_segmentation_with_correct_shape", predicted_logits.shape) + print( + "convert_predicted_logits_to_segmentation_with_correct_shape", + predicted_logits.shape, + ) import time t = time.time() @@ -605,12 +616,12 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ time.sleep(1) def check_mem(shape): - memory = get_gpu_memory_MB(device) + memory = get_gpu_memory_MB(device) * 0.80 max_memory = self.memory_max min_memory = self.memory_base factor = self.memory_factor # print(shape, "usage", np.prod(shape) / 1000000 * factor, max(min(memory, max_memory), min_memory)) - return (np.prod(shape) / 1000000 * factor) + min_memory // 2 < max(min(memory, max_memory), min_memory) + return (np.prod(shape) / 1000000 * factor) + min_memory < max(min(memory, max_memory), min_memory) with tqdm(total=len(slicers), disable=not self.allow_tqdm) as pbar: if not check_mem(shape) or "nnUNetPlans_2d" not in self.configuration_manager.configuration.get("data_identifier", "3D"): From 48a9d2adb3e9608db7f354d902d3caa63d0fad71 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 29 May 2026 08:37:58 +0000 Subject: [PATCH 03/17] fix bug for very elongated segmentations --- TPTBox/registration/_deformable/multilabel_segmentation.py | 7 ++++--- TPTBox/registration/_ridged_intensity/affine_deepali.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/TPTBox/registration/_deformable/multilabel_segmentation.py b/TPTBox/registration/_deformable/multilabel_segmentation.py index 33cadce..46d153c 100644 --- a/TPTBox/registration/_deformable/multilabel_segmentation.py +++ b/TPTBox/registration/_deformable/multilabel_segmentation.py @@ -54,6 +54,7 @@ def __init__( # noqa: C901 poi_target_cms: POI | None = None, max_history=100, change_after_point_reg=lambda x, y, z, w: (x, y, z, w), + tether_distance=1, **args, ): """Initialize a multi-stage registration pipeline from an atlas to a target image. @@ -90,7 +91,7 @@ def __init__( # noqa: C901 "be": ("BSplineBending", {"stride": 1}), "seg": "MSE", "Dice": "Dice", - "Tether": Tether_Seg(delta=5), + "Tether": Tether_Seg(delta=tether_distance), } assert target_seg.seg, target_seg.seg @@ -187,7 +188,7 @@ def __init__( # noqa: C901 poi_cms = poi_cms.resample_from_to(atlas_seg_) self.reg_point = Point_Registration(poi_target, poi_cms, verbose=False) - atlas_reg = self.reg_point.transform_nii(atlas_seg_) + atlas_reg = self.reg_point.transform_nii(atlas_seg_, c_val=0) if not atlas_reg.is_segmentation_in_border(): print("point registration ok") @@ -204,7 +205,7 @@ def __init__( # noqa: C901 target_img = target_img.apply_pad(resize_param) if target_img is not None else None self.reg_point = Point_Registration(poi_target.resample_from_to(target_seg), poi_cms.resample_from_to(atlas_seg)) - atlas_reg = self.reg_point.transform_nii(atlas_seg) + atlas_reg = self.reg_point.transform_nii(atlas_seg, c_val=0) atlas_img_reg = self.reg_point.transform_nii(atlas_img) if atlas_img is not None else None if crop: diff --git a/TPTBox/registration/_ridged_intensity/affine_deepali.py b/TPTBox/registration/_ridged_intensity/affine_deepali.py index 5f1777f..b35a3fb 100644 --- a/TPTBox/registration/_ridged_intensity/affine_deepali.py +++ b/TPTBox/registration/_ridged_intensity/affine_deepali.py @@ -92,14 +92,14 @@ def forward( target: torch.Tensor, # shape: (B, C, X, Y, Z) mask: torch.Tensor | None = None, # noqa: ARG002 ) -> torch.Tensor: - w = max(target.shape[2:]) + w = min(target.shape[2:]) com_fixed = center_of_mass_cc(target) # (B, C, 3) com_warped = center_of_mass_cc(source) # (B, C, 3) l_com = torch.norm(com_fixed - com_warped, dim=-1) / w # (B, C) # Zero out channels with small displacement (<10) or NaNs - l_com = torch.where(l_com < self.delta, torch.zeros_like(l_com), l_com) + l_com = torch.where(l_com * w < self.delta, torch.zeros_like(l_com), l_com) l_com = torch.nan_to_num(l_com, nan=0.0) return l_com.mean() # type: ignore From 41fa83355e37bc57b926ab2553bda650fda95b36 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Fri, 29 May 2026 09:09:29 +0000 Subject: [PATCH 04/17] should not use Runtime Errors --- TPTBox/core/nii_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 041f4d5..cf4dd39 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -2505,7 +2505,7 @@ def to_stl( try: verts, faces, normals, values = marching_cubes(seg_arr, gradient_direction="ascent", step_size=1) except RuntimeError as e: - raise RuntimeError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None + raise IndexError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None # Remove padding offset (since we padded by 1 voxel) verts -= 1 # Apply bounding box offset (still voxel space) From 397cd36603cfeecf2139b4f69260c52fcdeeec90 Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 4 Jun 2026 15:10:00 +0000 Subject: [PATCH 05/17] perf(nnunet): enable cuDNN autotune + TF32 for sliding-window inference All sliding-window tiles share the same patch_size shape, so cuDNN can autotune the fastest conv algorithms once and reuse them across tiles and images. TF32 accelerates fp32 matmul/conv on Ampere+ GPUs with negligible accuracy impact. Gated behind a new fast_perf flag (default True, CUDA only). Co-Authored-By: Claude Opus 4.8 --- .../segmentation/nnUnet_utils/inference_api.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/TPTBox/segmentation/nnUnet_utils/inference_api.py b/TPTBox/segmentation/nnUnet_utils/inference_api.py index 4913952..365a15f 100755 --- a/TPTBox/segmentation/nnUnet_utils/inference_api.py +++ b/TPTBox/segmentation/nnUnet_utils/inference_api.py @@ -31,6 +31,7 @@ def load_inf_model( inference_augmentation: bool = False, use_gaussian: bool = True, verbose: bool = False, + fast_perf: bool = True, gpu: int | None = None, memory_base: int = 5000, memory_factor: int = 160, @@ -55,6 +56,11 @@ def load_inf_model( inference_augmentation: If True, enable test-time mirroring augmentation. use_gaussian: If True, apply Gaussian weighting in the sliding window. verbose: If True, print progress information during model initialisation. + fast_perf: If True (and running on CUDA), enable cuDNN autotuning and TF32 + matmul/conv. Every sliding-window tile has the same ``patch_size`` + shape, so cuDNN can pick the fastest convolution algorithms once and + reuse them. TF32 speeds up fp32 ops on Ampere+ GPUs with negligible + accuracy impact. These are global ``torch.backends`` flags. gpu: GPU device index forwarded to the predictor. ``None`` defaults to 0. memory_base: Base GPU memory reservation in MB (default 5 000 MB = 5 GB). memory_factor: Per-voxel memory scaling factor. The formula is @@ -84,6 +90,16 @@ def load_inf_model( _interop = True except Exception as e: print(e) + if fast_perf: + # All sliding-window tiles share the same (patch_size) shape, so cuDNN can + # autotune the fastest conv algorithms once and reuse them across tiles/images. + # TF32 accelerates fp32 matmul/conv on Ampere+ with negligible accuracy impact. + try: + torch.backends.cudnn.benchmark = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.backends.cudnn.allow_tf32 = True + except Exception as e: + print(e) device = torch.device("cuda") else: device = torch.device("mps") From 065e3d99a48fb71647e847d0c6d36d53af50161e Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 4 Jun 2026 15:10:18 +0000 Subject: [PATCH 06/17] perf(nnunet): use torch.inference_mode() instead of no_grad() for inference inference_mode disables autograd view/version tracking entirely, so it is slightly faster and uses less memory than no_grad. All tensors here are pure inference outputs (moved to CPU / converted to numpy downstream), so the stricter inference-tensor semantics are safe. Co-Authored-By: Claude Opus 4.8 --- TPTBox/segmentation/nnUnet_utils/predictor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index eb654e9..0de4e5f 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -382,7 +382,7 @@ def predict_logits_from_preprocessed_data(self, data: torch.Tensor, attempts: in # things a lot faster for some datasets. original_perform_everything_on_gpu = self.perform_everything_on_gpu assert self.list_of_parameters is not None - with torch.no_grad(): + with torch.inference_mode(): prediction = None try: for idx, params in enumerate(self.list_of_parameters): @@ -563,7 +563,7 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ # is set. Why. (this is why we don't make use of enabled=False) # So autocast will only be active if we have a cuda device. with ( - torch.no_grad(), + torch.inference_mode(), torch.autocast(self.device.type, enabled=True) if self.device.type == "cuda" else dummy_context(), ): assert len(input_image.shape) == 4, "input_image must be a 4D np.ndarray or torch.Tensor (c, x, y, z)" From 271557a27e3315b6e7e66bde8a62a96ac0e8616d Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 4 Jun 2026 15:11:01 +0000 Subject: [PATCH 07/17] perf(nnunet): stop clearing CUDA cache on the per-fold happy path predict_sliding_window_return_logits runs once per fold; calling torch.cuda.empty_cache() at its start and end returned allocator blocks to the driver right before the next fold reallocated them (slow cudaMalloc + sync). The pool is still cleared once per image after fold averaging, and all OOM-recovery empty_cache() calls are untouched. Co-Authored-By: Claude Opus 4.8 --- TPTBox/segmentation/nnUnet_utils/predictor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index 0de4e5f..f23a93e 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -554,7 +554,6 @@ def predict_sliding_window_return_logits(self, input_image: torch.Tensor, networ network = network.to(self.device) # type: ignore assert self.configuration_manager is not None assert self.label_manager is not None - empty_cache(self.device) # Autocast is a little bitch. # If the device_type is 'cpu' then it's slow as heck on some CPUs (no auto bfloat16 support detection) @@ -648,7 +647,9 @@ def check_mem(shape): predicted_logits /= n_predictions del n_predictions predicted_logits = predicted_logits.cpu() - empty_cache(self.device) + # NOTE: do not empty_cache() here. This runs once per fold; releasing the + # allocator pool now just forces the next fold to re-cudaMalloc. The pool is + # cleared once per image in predict_logits_from_preprocessed_data instead. return predicted_logits[(slice(None), *slicer_revert_padding[1:])] def _run_prediction_splits( From 3fbc8424fab761cef6a55bf18d37893b83683222 Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 4 Jun 2026 15:12:39 +0000 Subject: [PATCH 08/17] fix(nnunet): repair fold ensembling broken by loaded_networks cache initialize_from_trained_model_folder built loaded_networks by appending the same self.network object once per fold, so every entry ended up holding the LAST fold's weights. predict_logits_from_preprocessed_data then ran that one weight set N times and averaged it, silently collapsing the N-fold ensemble to a single fold while still paying Nx compute. Separately, cache_state_dicts=False left loaded_networks as [] (not None), causing an IndexError in the predict loop. Now loaded_networks is None unless exactly one fold is present (preloaded for zero per-prediction reloads). For >1 fold it stays None and the predict loop swaps weights per fold via load_state_dict, restoring a correct ensemble using a single network's worth of GPU memory. Device-availability warnings now fire once regardless of fold count. Co-Authored-By: Claude Opus 4.8 --- TPTBox/segmentation/nnUnet_utils/predictor.py | 42 +++++++++++-------- 1 file changed, 25 insertions(+), 17 deletions(-) diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index f23a93e..b30f320 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -281,23 +281,31 @@ def mapp(d: dict): print("compiling network") self.network = torch.compile(self.network) # type: ignore - self.loaded_networks = [] - if cache_state_dicts: - for params in self.list_of_parameters: - if not isinstance(self.network, OptimizedModule): - self.network.load_state_dict(params) # type: ignore - else: - self.network._orig_mod.load_state_dict(params) - if self.device.type == "cuda" and not torch.cuda.is_available(): - Print_Logger().on_warning( - "No CUDA device. If you have a CUDA-able GPU (Nvidia), reinstall pytorch with cuda or for non-cuda devices use ddevice=cpu or ddevice=mps" - ) - if self.device.type == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): - Print_Logger().on_warning("No MPS device found. Use ddevice=cpu or ddevice=mps") - self.network.to(self.device) - self.network.eval() # type: ignore - self.loaded_networks.append(self.network) - # print(type(self.loaded_networks[0])) + # Warn early if the requested device is unavailable (runs once, independent of folds). + if self.device.type == "cuda" and not torch.cuda.is_available(): + Print_Logger().on_warning( + "No CUDA device. If you have a CUDA-able GPU (Nvidia), reinstall pytorch with cuda or for non-cuda devices use ddevice=cpu or ddevice=mps" + ) + if self.device.type == "mps" and not (hasattr(torch.backends, "mps") and torch.backends.mps.is_available()): + Print_Logger().on_warning("No MPS device found. Use ddevice=cpu or ddevice=mps") + + # loaded_networks holds one ready-to-run network per fold (or None to load weights + # lazily per fold). We only cache the single-fold case: previously this loop appended + # the SAME self.network object once per fold, so every entry ended up holding the LAST + # fold's weights. That silently collapsed an N-fold ensemble to a single fold while + # still paying Nx the compute. For >1 fold we keep loaded_networks=None and let + # predict_logits_from_preprocessed_data swap weights per fold via load_state_dict (a + # true ensemble that needs only one network's worth of GPU memory). + self.loaded_networks = None + if cache_state_dicts and len(self.list_of_parameters) == 1: + params = self.list_of_parameters[0] + if not isinstance(self.network, OptimizedModule): + self.network.load_state_dict(params) # type: ignore + else: + self.network._orig_mod.load_state_dict(params) + self.network.to(self.device) + self.network.eval() # type: ignore + self.loaded_networks = [self.network] def predict_single_npy_array( self, input_image: np.ndarray, image_properties: dict, save_or_return_probabilities: bool = False From 93ad675d9ac46ea61aba568561c6676b120f59d0 Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 4 Jun 2026 15:14:13 +0000 Subject: [PATCH 09/17] perf(nnunet): add opt-in persistent model cache (cache_model) Every run_inference_on_file call previously reloaded the predictor from disk and re-uploaded weights to the GPU. With cache_model=True the loaded predictor is kept in a process-wide cache keyed by model path + folds + device/runtime settings, so a loop over many files reuses the in-memory model. When caching, the end-of-call del/empty_cache is skipped so the CUDA allocator pool stays warm between images. Default is False to preserve current memory semantics; the flag forwards through run_VibeSeg/run_vibeseg/run_nnunet via **kwargs. Co-Authored-By: Claude Opus 4.8 --- .../segmentation/VibeSeg/inference_nnunet.py | 64 ++++++++++++++----- 1 file changed, 48 insertions(+), 16 deletions(-) diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 26d6472..77dd435 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -16,6 +16,12 @@ out_base = Path(__file__).parent.parent / "nnUNet/" _model_path_ = out_base / "nnUNet_results" +# Opt-in cache of loaded predictors (enable via cache_model=True). Keyed by model identity plus +# the device/runtime settings that affect the loaded predictor, so repeated inference (e.g. a loop +# over many files with the same model) reuses the in-memory model instead of reloading weights from +# disk and re-uploading them to the GPU on every call. +_model_cache: dict = {} + def get_ds_info(idx: int, _model_path: str | Path | None = None, exit_one_fail: bool = True, logger=logger) -> dict: """Load and return the ``dataset.json`` for the model with the given dataset index. @@ -93,6 +99,7 @@ def run_inference_on_file( wait_till_gpu_percent_is_free: float = 0.1, verbose: bool = True, auto_download: bool = False, + cache_model: bool = False, _key_ResEnc: str = "__nnUNet*ResEnc", logger=logger, ) -> tuple[Image_Reference, np.ndarray | None]: @@ -136,6 +143,13 @@ def run_inference_on_file( wait_till_gpu_percent_is_free: Minimum free GPU fraction to require before starting inference. verbose: Print progress information. + cache_model: If ``True``, keep the loaded predictor in a process-wide + cache and reuse it on subsequent calls with identical model and + device/runtime settings. Avoids reloading weights from disk and + re-uploading them to the GPU when segmenting many files in a loop, at + the cost of holding the model in GPU memory between calls. The GPU + cache is also left warm (no ``empty_cache``) so the allocator can + reuse buffers across images. Returns: A tuple ``(seg_nii, softmax_logits)`` where ``seg_nii`` is the @@ -196,20 +210,36 @@ def run_inference_on_file( if "labels" in ds_info2: ds_info["labels_mapping"] = ds_info2["labels"] - nnunet = load_inf_model( - nnunet_path, - allow_non_final=True, - use_folds=tuple(folds) if len(folds) != 5 else None, - gpu=gpu, - ddevice=ddevice, - step_size=step_size, - memory_base=memory_base, - memory_factor=memory_factor, - memory_max=memory_max, - wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + use_folds_arg = tuple(folds) if len(folds) != 5 else None + # Include every setting that changes the loaded predictor so a cache hit is always equivalent + # to a fresh load; differing settings simply miss the cache and reload. + cache_key = ( + str(nnunet_path), + use_folds_arg, + ddevice, + gpu, + step_size, + memory_base, + memory_factor, + memory_max, + wait_till_gpu_percent_is_free, ) - - # _unets[idx] = nnunet + nnunet = _model_cache.get(cache_key) if cache_model else None + if nnunet is None: + nnunet = load_inf_model( + nnunet_path, + allow_non_final=True, + use_folds=use_folds_arg, + gpu=gpu, + ddevice=ddevice, + step_size=step_size, + memory_base=memory_base, + memory_factor=memory_factor, + memory_max=memory_max, + wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + ) + if cache_model: + _model_cache[cache_key] = nnunet if "orientation" in ds_info: orientation = ds_info["orientation"] @@ -315,9 +345,11 @@ def to_int(a: str, k: None | int = None): seg_nii.map_labels_(mapping) if out_file is not None and (not Path(out_file).exists() or override): seg_nii.set_dtype("smallest_uint").save(out_file) - del nnunet - - torch.cuda.empty_cache() + if not cache_model: + # When caching we keep the predictor alive (it stays referenced by _model_cache, so del + # would not free it anyway) and leave the CUDA cache warm so the next image reuses buffers. + del nnunet + torch.cuda.empty_cache() return seg_nii, softmax_logits From cf37216c1793a5d4fb79613945aba2cbf3f336e8 Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 4 Jun 2026 15:17:57 +0000 Subject: [PATCH 10/17] perf(nnunet): optionally batch sliding-window tiles (tile_batch_size) _run_sub processed one tile per forward pass, leaving the GPU underutilised for small patches. Tiles all share patch_size, so they batch densely: with tile_batch_size>1 they are stacked into a (B, C, *patch) tensor and run in one pass, then scattered back per tile. Mirroring/gaussian/accumulation are unchanged. tile_batch_size=1 (default) takes the original view-based path verbatim, so behaviour and memory are unchanged unless opted in. Threaded through nnUNetPredictor, load_inf_model and run_inference_on_file. Co-Authored-By: Claude Opus 4.8 --- .../segmentation/VibeSeg/inference_nnunet.py | 7 ++++ .../nnUnet_utils/inference_api.py | 5 +++ TPTBox/segmentation/nnUnet_utils/predictor.py | 33 ++++++++++++++----- 3 files changed, 36 insertions(+), 9 deletions(-) diff --git a/TPTBox/segmentation/VibeSeg/inference_nnunet.py b/TPTBox/segmentation/VibeSeg/inference_nnunet.py index 77dd435..6dd3c22 100644 --- a/TPTBox/segmentation/VibeSeg/inference_nnunet.py +++ b/TPTBox/segmentation/VibeSeg/inference_nnunet.py @@ -97,6 +97,7 @@ def run_inference_on_file( memory_factor: int = 160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB memory_max: int = 160000, # in MB, default is 160GB wait_till_gpu_percent_is_free: float = 0.1, + tile_batch_size: int = 1, verbose: bool = True, auto_download: bool = False, cache_model: bool = False, @@ -142,6 +143,10 @@ def run_inference_on_file( memory_max: Hard cap on assumed GPU memory in MB. wait_till_gpu_percent_is_free: Minimum free GPU fraction to require before starting inference. + tile_batch_size: Number of sliding-window tiles to run per network + forward pass. ``1`` (default) keeps the original per-tile behaviour; + larger values batch tiles to better saturate the GPU at the cost of + higher peak memory. verbose: Print progress information. cache_model: If ``True``, keep the loaded predictor in a process-wide cache and reuse it on subsequent calls with identical model and @@ -223,6 +228,7 @@ def run_inference_on_file( memory_factor, memory_max, wait_till_gpu_percent_is_free, + tile_batch_size, ) nnunet = _model_cache.get(cache_key) if cache_model else None if nnunet is None: @@ -237,6 +243,7 @@ def run_inference_on_file( memory_factor=memory_factor, memory_max=memory_max, wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + tile_batch_size=tile_batch_size, ) if cache_model: _model_cache[cache_key] = nnunet diff --git a/TPTBox/segmentation/nnUnet_utils/inference_api.py b/TPTBox/segmentation/nnUnet_utils/inference_api.py index 365a15f..013c9d3 100755 --- a/TPTBox/segmentation/nnUnet_utils/inference_api.py +++ b/TPTBox/segmentation/nnUnet_utils/inference_api.py @@ -37,6 +37,7 @@ def load_inf_model( memory_factor: int = 160, memory_max: int = 160000, wait_till_gpu_percent_is_free: float = 0.3, + tile_batch_size: int = 1, ) -> nnUNetPredictor: """Load and initialise an nnU-Net model predictor from a trained model folder. @@ -69,6 +70,9 @@ def load_inf_model( memory_max: Maximum GPU memory cap in MB (default 160 000 MB = 160 GB). wait_till_gpu_percent_is_free: Fraction of GPU memory that must be free before inference is started. + tile_batch_size: Number of sliding-window tiles per network forward pass. + ``1`` reproduces the original per-tile path; larger values batch + tiles to improve GPU utilisation at higher peak memory. Returns: Initialised ``nnUNetPredictor`` ready for inference. @@ -119,6 +123,7 @@ def load_inf_model( memory_factor=memory_factor, memory_max=memory_max, wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free, + tile_batch_size=tile_batch_size, ) check_name = "checkpoint_final.pth" # if not allow_non_final else "checkpoint_best.pth" try: diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index b30f320..37d6204 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -70,6 +70,10 @@ class nnUNetPredictor: memory_max: Clamp on maximum GPU memory (MB) to assume available. wait_till_gpu_percent_is_free: Fraction of GPU memory that must be free before inference starts. Waits up to 40 minutes. + tile_batch_size: Number of sliding-window tiles to run per network + forward pass. ``1`` (default) reproduces the original per-tile loop + exactly; larger values batch tiles together to improve GPU + utilisation at the cost of higher peak memory. """ def __init__( @@ -87,6 +91,7 @@ def __init__( memory_factor=160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB memory_max: int = 160000, # in MB, default is 160GB wait_till_gpu_percent_is_free=0.3, + tile_batch_size: int = 1, ): self.verbose = verbose self.verbose_preprocessing = verbose_preprocessing @@ -118,6 +123,10 @@ def __init__( self.memory_factor = memory_factor self.memory_max = memory_max self.wait_till_gpu_percent_is_free = wait_till_gpu_percent_is_free + # Number of sliding-window tiles to push through the network in one forward pass. All tiles + # share the same patch_size, so they batch densely. 1 reproduces the original per-tile path + # exactly; larger values raise GPU utilisation (and peak memory) for small patches. + self.tile_batch_size = tile_batch_size def initialize_from_trained_model_folder( self, @@ -758,22 +767,28 @@ def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool return predicted_logits, n_predictions, gaussian, results_device def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: tqdm, addendum: str = ""): - """Iterate over slicers, run inference per tile, and accumulate results.""" + """Iterate over slicers, run inference per tile (optionally batched), and accumulate results.""" try: data = data.to(self.device) # type: ignore predicted_logits, n_predictions, gaussian, results_device = self._allocate(data, results_device, pbar) pbar.desc = f"running prediction {addendum}" prediction = None work_on = None - for sl in slicers: - pbar.update(1) - work_on = data[sl][None] + batch_size = max(1, self.tile_batch_size) + for batch_start in range(0, len(slicers), batch_size): + batch_slicers = slicers[batch_start : batch_start + batch_size] + # batch_size == 1 keeps the original view (no copy); larger batches stack tiles into a + # dense (B, C, *patch) tensor (valid because all tiles share the same patch_size). + work_on = data[batch_slicers[0]][None] if batch_size == 1 else torch.stack([data[sl] for sl in batch_slicers], dim=0) work_on = work_on.to(self.device, non_blocking=False) - prediction = self._internal_maybe_mirror_and_predict(work_on, network=network)[0].to(results_device) - if prediction.shape[0] != predicted_logits.shape[0]: - prediction.squeeze_(0) - predicted_logits[sl] += prediction * gaussian if self.use_gaussian else prediction - n_predictions[sl[1:]] += gaussian if self.use_gaussian else 1 + prediction = self._internal_maybe_mirror_and_predict(work_on, network=network).to(results_device) + for b, sl in enumerate(batch_slicers): + pbar.update(1) + pred = prediction[b] + if pred.shape[0] != predicted_logits.shape[0]: + pred = pred.squeeze(0) + predicted_logits[sl] += pred * gaussian if self.use_gaussian else pred + n_predictions[sl[1:]] += gaussian if self.use_gaussian else 1 return predicted_logits, n_predictions # noqa: TRY300 except RuntimeError: del predicted_logits From 57a8d19bc599a8a4a3f3890a9aa4a2bc6c5f771c Mon Sep 17 00:00:00 2001 From: iback Date: Thu, 4 Jun 2026 18:09:21 +0000 Subject: [PATCH 11/17] bench: add nnU-Net inference timing harness Adds benchmarks/ with a phase-resolved timing tool for run_inference_on_file (load / preprocess / sliding-window predict / postprocess), using CUDA synchronisation, warmup, repeats and peak-memory tracking. It instruments the pipeline by monkeypatching (no library changes) and signature-filters its kwargs so the same file runs against every commit in the optimisation range. - benchmark_nnunet_inference.py: run/compare subcommands; synthetic or real input. - bench_across_commits.sh: replays one config across baseline+each commit and prints per-commit deltas (isolates the always-on changes; the fold fix shows up as the fold_status column flipping from DUPLICATED to lazy-per-fold/distinct). - bench_flag_sweep.sh: sweeps the opt-in flags at HEAD (cache_model needs >1 call, tile_batch_size, max_folds, step_size). - README.md: usage, which tool measures which commit, and caveats. pyproject: exempt benchmarks/ from docstring rules, mirroring the test dirs. Co-Authored-By: Claude Opus 4.8 --- benchmarks/.gitignore | 2 + benchmarks/README.md | 77 +++++ benchmarks/bench_across_commits.sh | 56 ++++ benchmarks/bench_flag_sweep.sh | 43 +++ benchmarks/benchmark_nnunet_inference.py | 387 +++++++++++++++++++++++ pyproject.toml | 2 + 6 files changed, 567 insertions(+) create mode 100644 benchmarks/.gitignore create mode 100644 benchmarks/README.md create mode 100755 benchmarks/bench_across_commits.sh create mode 100755 benchmarks/bench_flag_sweep.sh create mode 100644 benchmarks/benchmark_nnunet_inference.py diff --git a/benchmarks/.gitignore b/benchmarks/.gitignore new file mode 100644 index 0000000..8321588 --- /dev/null +++ b/benchmarks/.gitignore @@ -0,0 +1,2 @@ +results/ +.bench_cache/ diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 0000000..72ccc9b --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,77 @@ +# nnU-Net inference timing harness + +Tools to measure the speed impact of the inference optimisations in +`TPTBox/segmentation/nnUnet_utils/` and `TPTBox/segmentation/VibeSeg/inference_nnunet.py`. + +## Files + +- `benchmark_nnunet_inference.py` — the measurement tool. Splits + `run_inference_on_file` into phases (**load → preprocess → sliding-window + predict → postprocess → other**) with CUDA synchronisation, warmup, repeats + and peak-memory tracking. It instruments the pipeline by monkeypatching (no + library changes) and drops any CLI flag a given commit does not yet support, + so the *same file* runs against every commit. +- `bench_across_commits.sh` — checks out the baseline + each optimisation commit, + runs one fixed config, and prints per-commit deltas. +- `bench_flag_sweep.sh` — stays on the current commit and sweeps the opt-in flags. + +## Requirements + +- An **editable** TPTBox install (`poetry install`) so that checking out a commit + changes the imported code. +- Model weights for the chosen `--dataset-id` (default `100`). The first run may + download them; that happens during warmup and is excluded from the numbers. +- A GPU for `--device cuda` (the default). `--synthetic` removes the need for a + real input image — only the model is required. + +## Quick start + +```bash +# 1) per-commit deltas (always-on changes), synthetic input +benchmarks/bench_across_commits.sh --device cuda --repeats 5 + +# 2) opt-in flag effects at HEAD +benchmarks/bench_flag_sweep.sh --device cuda --repeats 5 + +# 3) one config by hand (e.g. on your real data) +python benchmarks/benchmark_nnunet_inference.py run \ + --dataset-id 100 --input water.nii.gz fat.nii.gz \ + --tile-batch-size 4 --repeats 5 --json /tmp/tb4.json +``` + +## Which tool measures which commit + +| Commit | Change | How it shows up | +|---|---|---| +| `cuDNN/TF32` | autotune + TF32 | `bench_across_commits.sh`: `predict` drops at this commit (after warmup) | +| `inference_mode` | no_grad → inference_mode | `bench_across_commits.sh`: small `predict`/`peak_mem` drop | +| `empty_cache` | drop per-fold cache clears | `bench_across_commits.sh`: `predict` drop, larger with more folds | +| `fold fix` | repair `loaded_networks` | `fold_status` column flips from `DUPLICATED(...)` to `lazy-per-fold`/`distinct`. Time is ~unchanged (correctness fix) — the old code already paid for N passes. | +| `cache_model` | persistent model cache | `bench_flag_sweep.sh` → `sweep_cache_model.json`: `load: first` is large, `load: steady-median` ≈ 0 | +| `tile_batch_size` | batch tiles per forward | `bench_flag_sweep.sh`: `tile_batch_4/8` reduce `forward_calls` and `predict`, raise `peak_mem` | + +Biggest raw-speed levers (quality trade-off): `--max-folds 1` (≈ folds×) and +`--step-size 0.7` (fewer tiles). + +## Reading the output + +- **`total: first` vs `steady-median`** — the first call pays lazy CUDA init, + cuDNN autotuning and (without caching) model load; steady-median is the + representative per-image cost. +- **`forward_calls`** — number of network forward passes = `folds × tiles` (÷ batch). + Drops with `--tile-batch-size`, scales with folds. +- **`fold_status`** — `DUPLICATED(...)` means the ensemble was silently collapsed + to one fold (pre-fix); `lazy-per-fold` (multi-fold) / `distinct` is correct. +- **`peak_mem_mb`** — `torch.cuda.max_memory_allocated`; watch this rise with + `--tile-batch-size`. + +## Notes / caveats + +- Synthetic input is random noise with the model's channel count and a chosen + shape/spacing — fine for *timing* (the network cost is content-independent), not + for assessing segmentation quality. It is cached in a temp dir and reused across + commits so the workload is identical. +- TF32 and the fold fix change the numeric output between commits, so outputs are + not bit-identical across the range — expected, not a harness bug. +- `bench_across_commits.sh` restores your original branch on exit (even on error). +- Results and the synthetic cache (`results/`, `.bench_cache/`) are git-ignored. diff --git a/benchmarks/bench_across_commits.sh b/benchmarks/bench_across_commits.sh new file mode 100755 index 0000000..18e6f2f --- /dev/null +++ b/benchmarks/bench_across_commits.sh @@ -0,0 +1,56 @@ +#!/usr/bin/env bash +# Replay the benchmark across the optimisation commits and tabulate per-commit deltas. +# +# Isolates the ALWAYS-ON changes (cuDNN/TF32, inference_mode, empty_cache, fold fix): each commit +# is checked out, run with one fixed config, and compared. The opt-in commits (cache_model, +# tile_batch_size) are no-ops under the default config and show ~0 here -- use bench_flag_sweep.sh +# for those. +# +# Usage: +# benchmarks/bench_across_commits.sh [extra args forwarded to `benchmark ... run`] +# BASELINE= benchmarks/bench_across_commits.sh --device cuda --repeats 5 +# With no extra args it uses a synthetic input. Requires a clean working tree and an editable +# TPTBox install (so checking out a commit changes the imported code). +set -euo pipefail + +REPO_ROOT="$(git rev-parse --show-toplevel)" +cd "$REPO_ROOT" + +if [ -n "$(git status --porcelain)" ]; then + echo "error: working tree is not clean. Commit or stash before running." >&2 + exit 1 +fi + +BASELINE="${BASELINE:-3906165}" +ORIG_REF="$(git symbolic-ref --short -q HEAD || git rev-parse HEAD)" +TMP_BENCH="$(mktemp -t bench_nnunet_XXXX.py)" +cp "$REPO_ROOT/benchmarks/benchmark_nnunet_inference.py" "$TMP_BENCH" +RESULTS="$REPO_ROOT/benchmarks/results" +mkdir -p "$RESULTS" + +cleanup() { git checkout -q "$ORIG_REF"; rm -f "$TMP_BENCH"; } +trap cleanup EXIT + +# baseline + every commit since it, in chronological order +mapfile -t COMMITS < <(printf '%s\n' "$BASELINE"; git rev-list --reverse "${BASELINE}..${ORIG_REF}") + +if [ $# -eq 0 ]; then RUN_ARGS=(--synthetic); else RUN_ARGS=("$@"); fi + +JSONS=() +i=0 +for sha in "${COMMITS[@]}"; do + short="$(git rev-parse --short "$sha")" + out="$RESULTS/$(printf '%02d' "$i")_${short}.json" + echo "================================================================" + echo "[$i] $short : $(git log -1 --format=%s "$sha")" + echo "================================================================" + git checkout -q "$sha" + python "$TMP_BENCH" run "${RUN_ARGS[@]}" --json "$out" + JSONS+=("$out") + i=$((i + 1)) +done + +git checkout -q "$ORIG_REF" +echo +echo "################ COMPARISON ################" +python "$TMP_BENCH" compare "${JSONS[@]}" diff --git a/benchmarks/bench_flag_sweep.sh b/benchmarks/bench_flag_sweep.sh new file mode 100755 index 0000000..4418e58 --- /dev/null +++ b/benchmarks/bench_flag_sweep.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash +# Measure the OPT-IN optimisations at the current commit by sweeping their flags in one process. +# (cache_model needs >1 call to show amortisation, so it cannot be measured by the cross-commit +# driver -- that is what this script is for.) +# +# Usage: +# benchmarks/bench_flag_sweep.sh [extra args forwarded to `benchmark ... run`] +# benchmarks/bench_flag_sweep.sh --input water.nii.gz fat.nii.gz --device cuda +# With no extra args it uses a synthetic input. +set -euo pipefail + +REPO_ROOT="$(git rev-parse --show-toplevel)" +BENCH="$REPO_ROOT/benchmarks/benchmark_nnunet_inference.py" +RESULTS="$REPO_ROOT/benchmarks/results" +mkdir -p "$RESULTS" + +if [ $# -eq 0 ]; then COMMON=(--synthetic); else COMMON=("$@"); fi + +run() { # name : extra flags... + local name="$1"; shift + echo "================ $name ================" + python "$BENCH" run "${COMMON[@]}" "$@" --json "$RESULTS/sweep_${name}.json" +} + +run baseline +run tile_batch_4 --tile-batch-size 4 +run tile_batch_8 --tile-batch-size 8 +run cache_model --cache-model --repeats 6 +run max_folds_1 --max-folds 1 +run step_0p7 --step-size 0.7 + +echo +echo "################ COMPARISON ################" +python "$BENCH" compare \ + "$RESULTS/sweep_baseline.json" \ + "$RESULTS/sweep_tile_batch_4.json" \ + "$RESULTS/sweep_tile_batch_8.json" \ + "$RESULTS/sweep_cache_model.json" \ + "$RESULTS/sweep_max_folds_1.json" \ + "$RESULTS/sweep_step_0p7.json" +echo +echo "cache_model: compare 'load: first' vs 'load: steady-median' in sweep_cache_model.json --" +echo "the median load should collapse to ~0 once the model is cached." diff --git a/benchmarks/benchmark_nnunet_inference.py b/benchmarks/benchmark_nnunet_inference.py new file mode 100644 index 0000000..7c674e8 --- /dev/null +++ b/benchmarks/benchmark_nnunet_inference.py @@ -0,0 +1,387 @@ +#!/usr/bin/env python +"""Timing harness for the VibeSeg / nnU-Net inference pipeline. + +Measures ``run_inference_on_file`` split into phases (model load, preprocess, +sliding-window forward, postprocess) with correct CUDA synchronisation, warmup, +repeats and peak-memory tracking. It is written so the *same file* runs against +any commit in the optimisation range: arguments that a given commit does not yet +support are silently dropped (see :func:`supported_kwargs`). + +Two ways to use it: + +1. Per-commit deltas (always-on changes: cuDNN/TF32, inference_mode, + empty_cache, fold fix). Drive it with ``bench_across_commits.sh``. + +2. Flag sweep at HEAD (opt-in changes: cache_model, tile_batch_size, max_folds, + step_size). Drive it with ``bench_flag_sweep.sh`` or call ``run`` directly + with the relevant flags. + +Examples:: + + # synthetic input (only the model weights are required), single config + python benchmark_nnunet_inference.py run --dataset-id 100 --synthetic \ + --shape 320 320 96 --repeats 5 --json /tmp/head.json + + # real input(s); multi-channel models take several --input paths + python benchmark_nnunet_inference.py run --dataset-id 100 \ + --input water.nii.gz fat.nii.gz --tile-batch-size 4 --cache-model + + # compare result JSONs produced by several runs/commits + python benchmark_nnunet_inference.py compare results/*.json + +Caveats: TPTBox must be importable from the *working tree* (an editable +``poetry install`` does this), otherwise checking out a commit will not change +the measured code. The first run may download model weights; that happens during +warmup and is excluded from the reported numbers. +""" + +from __future__ import annotations + +import argparse +import functools +import inspect +import json +import statistics +import sys +import time +from pathlib import Path + +import numpy as np +import torch + +# --- phase timers / counters populated by monkeypatching (no library changes) --------------- +TIMINGS: dict[str, float] = {} +COUNTERS: dict[str, object] = {} + + +def _sync(device: str) -> None: + if device == "cuda": + torch.cuda.synchronize() + + +def _reset_collectors() -> None: + TIMINGS.clear() + COUNTERS.clear() + COUNTERS["forward_calls"] = 0 + COUNTERS["tiles"] = 0 + COUNTERS["folds"] = None + COUNTERS["fold_status"] = "n/a" + + +def _timed(name: str, device: str): + """Wrap a callable so its synchronised wall time accumulates into ``TIMINGS[name]``.""" + + def deco(fn): + @functools.wraps(fn) + def wrapper(*a, **k): + _sync(device) + t = time.perf_counter() + try: + return fn(*a, **k) + finally: + _sync(device) + TIMINGS[name] = TIMINGS.get(name, 0.0) + (time.perf_counter() - t) + + return wrapper + + return deco + + +def install_patches(device: str) -> None: + """Monkeypatch the inference stack to record per-phase timings and forward counts. + + Patches the names in the modules where they are actually *called* so the + instrumentation works regardless of how each commit imports them. + """ + from TPTBox.segmentation.nnUnet_utils import inference_api, predictor + + # model load (run_inference_on_file does `from inference_api import load_inf_model` at call + # time, so patching the attribute here is picked up by that local import). + inference_api.load_inf_model = _timed("load", device)(inference_api.load_inf_model) + + # postprocess: convert_predicted_logits_... is called via predictor's own namespace binding. + predictor.convert_predicted_logits_to_segmentation_with_correct_shape = _timed("postprocess", device)( + predictor.convert_predicted_logits_to_segmentation_with_correct_shape + ) + + P = predictor.nnUNetPredictor + P.predict_single_npy_array = _timed("single_npy", device)(P.predict_single_npy_array) + P.predict_logits_from_preprocessed_data = _timed("predict", device)(_fold_probe(P.predict_logits_from_preprocessed_data)) + P._internal_maybe_mirror_and_predict = _forward_counter(P._internal_maybe_mirror_and_predict) + + +def _fold_probe(fn): + """Record fold count and whether the loaded_networks cache is correct (surfaces the #6 bug).""" + + @functools.wraps(fn) + def wrapper(self, *a, **k): + params = getattr(self, "list_of_parameters", None) + COUNTERS["folds"] = len(params) if params is not None else None + loaded = getattr(self, "loaded_networks", None) + if loaded is None: + COUNTERS["fold_status"] = "lazy-per-fold" # correct: weights swapped per fold + else: + ids = {id(n) for n in loaded} + COUNTERS["fold_status"] = "distinct" if len(ids) == len(loaded) else f"DUPLICATED({len(loaded)}->{len(ids)})" + return fn(self, *a, **k) + + return wrapper + + +def _forward_counter(fn): + @functools.wraps(fn) + def wrapper(self, x, *a, **k): + COUNTERS["forward_calls"] = COUNTERS.get("forward_calls", 0) + 1 + COUNTERS["tiles"] = COUNTERS.get("tiles", 0) + int(x.shape[0]) + return fn(self, x, *a, **k) + + return wrapper + + +# --- input handling -------------------------------------------------------------------------- +def make_synthetic(shape: list[int], spacing: list[float], channels: int, seed: int, cache_dir: Path) -> list[str]: + """Create (and cache) ``channels`` random NIfTI volumes with the given shape/spacing.""" + import nibabel as nib + + cache_dir.mkdir(parents=True, exist_ok=True) + rng = np.random.default_rng(seed) + affine = np.diag([*spacing, 1.0]).astype(float) + paths = [] + tag = f"{channels}ch_{'x'.join(map(str, shape))}_sp{'-'.join(str(s) for s in spacing)}_s{seed}" + for c in range(channels): + p = cache_dir / f"synthetic_{tag}_ch{c}.nii.gz" + if not p.exists(): + arr = (rng.standard_normal(tuple(shape)).astype(np.float32) * 200.0) + 100.0 + nib.save(nib.Nifti1Image(arr, affine), str(p)) + paths.append(str(p)) + return paths + + +def resolve_channels(dataset_id: int | None, fallback: int) -> int: + if dataset_id is None: + return fallback + try: + from TPTBox.segmentation.VibeSeg.inference_nnunet import get_ds_info + + ds = get_ds_info(dataset_id, exit_one_fail=False) + if ds and "channel_names" in ds: + return len(ds["channel_names"]) + except Exception as e: # noqa: BLE001 - best effort, fall back to the user value + print(f"[warn] could not read dataset.json for channel count ({e}); using --channels={fallback}") + return fallback + + +# --- core measurement ------------------------------------------------------------------------ +def supported_kwargs(fn, kwargs: dict) -> dict: + """Keep only kwargs the target accepts, so the harness runs on commits that lack newer flags.""" + params = inspect.signature(fn).parameters + if any(p.kind == p.VAR_KEYWORD for p in params.values()): + return dict(kwargs) + keep = {k: v for k, v in kwargs.items() if k in params} + dropped = sorted(set(kwargs) - set(keep)) + if dropped: + print(f"[info] this commit ignores unsupported args: {dropped}") + return keep + + +def run_once(idx, inputs: list[str], call_kwargs: dict, device: str) -> dict: + from TPTBox import to_nii + from TPTBox.segmentation.VibeSeg.inference_nnunet import run_inference_on_file + + _reset_collectors() + niis = [to_nii(p) for p in inputs] # reload each repeat so in-place ops never bleed across runs + if device == "cuda": + torch.cuda.reset_peak_memory_stats() + _sync(device) + t0 = time.perf_counter() + run_inference_on_file(idx, niis, **supported_kwargs(run_inference_on_file, call_kwargs)) + _sync(device) + total = time.perf_counter() - t0 + + phases = { + "load": TIMINGS.get("load", 0.0), + "preprocess": max(0.0, TIMINGS.get("single_npy", 0.0) - TIMINGS.get("predict", 0.0) - TIMINGS.get("postprocess", 0.0)), + "predict": TIMINGS.get("predict", 0.0), + "postprocess": TIMINGS.get("postprocess", 0.0), + } + phases["other"] = max(0.0, total - sum(phases.values())) # NII I/O, reorient, rescale-back, save + return { + "total": total, + "phases": phases, + "peak_mem_mb": (torch.cuda.max_memory_allocated() / 1e6) if device == "cuda" else 0.0, + "forward_calls": COUNTERS["forward_calls"], + "tiles": COUNTERS["tiles"], + "folds": COUNTERS["folds"], + "fold_status": COUNTERS["fold_status"], + } + + +def summarize(repeats: list[dict]) -> dict: + """Median over steady-state repeats (repeat 0 kept separately as the cold/first call).""" + steady = repeats[1:] if len(repeats) > 1 else repeats + med = lambda key: statistics.median(r[key] for r in steady) # noqa: E731 + phase_keys = repeats[0]["phases"].keys() + return { + "total_first": repeats[0]["total"], + "total_median": med("total"), + "phases_median": {k: statistics.median(r["phases"][k] for r in steady) for k in phase_keys}, + "load_first": repeats[0]["phases"]["load"], + "load_median": statistics.median(r["phases"]["load"] for r in steady), + "peak_mem_mb": max(r["peak_mem_mb"] for r in repeats), + "forward_calls": repeats[0]["forward_calls"], + "tiles": repeats[0]["tiles"], + "folds": repeats[0]["folds"], + "fold_status": repeats[0]["fold_status"], + } + + +def commit_hash() -> str: + import subprocess + + try: + return subprocess.check_output(["git", "rev-parse", "--short", "HEAD"], text=True).strip() # noqa: S607 + except Exception: # noqa: BLE001 + return "unknown" + + +# --- subcommands ----------------------------------------------------------------------------- +def cmd_run(args: argparse.Namespace) -> None: + device = args.device + install_patches(device) + + channels = resolve_channels(args.dataset_id if args.model_path is None else None, args.channels) + if args.synthetic: + inputs = make_synthetic(args.shape, args.spacing, channels, args.seed, Path(args.cache_dir)) + else: + if not args.input: + sys.exit("error: provide --input PATH [PATH ...] or use --synthetic") + inputs = args.input + idx = Path(args.model_path) if args.model_path else args.dataset_id + + call_kwargs = { + "out_file": None, + "override": True, + "ddevice": device, + "gpu": args.gpu, + "verbose": False, + "max_folds": args.max_folds, + "step_size": args.step_size, + "tile_batch_size": args.tile_batch_size, + "cache_model": args.cache_model, + "keep_size": args.keep_size, + "padd": args.padd, + } + + print(f"warmup ({args.warmup}) ...") + for _ in range(max(0, args.warmup)): + run_once(idx, inputs, call_kwargs, device) + + repeats = [] + for i in range(args.repeats): + r = run_once(idx, inputs, call_kwargs, device) + repeats.append(r) + print( + f" repeat {i}: total={r['total']:.3f}s predict={r['phases']['predict']:.3f}s " + f"load={r['phases']['load']:.3f}s peak={r['peak_mem_mb']:.0f}MB forwards={r['forward_calls']}" + ) + + summary = summarize(repeats) + result = { + "commit": commit_hash(), + "device": device, + "inputs": inputs, + "config": {k: call_kwargs[k] for k in ("max_folds", "step_size", "tile_batch_size", "cache_model", "keep_size", "padd")}, + "repeats": repeats, + "summary": summary, + } + _print_summary(result) + if args.json: + Path(args.json).parent.mkdir(parents=True, exist_ok=True) + Path(args.json).write_text(json.dumps(result, indent=2)) + print(f"\nwrote {args.json}") + + +def _print_summary(result: dict) -> None: + s = result["summary"] + print(f"\n=== {result['commit']} | device={result['device']} | config={result['config']} ===") + print(f"folds={s['folds']} fold_status={s['fold_status']} forward_calls={s['forward_calls']} tiles={s['tiles']}") + print(f"peak_mem={s['peak_mem_mb']:.0f} MB") + print(f"total: first={s['total_first']:.3f}s steady-median={s['total_median']:.3f}s") + print(f"load: first={s['load_first']:.3f}s steady-median={s['load_median']:.3f}s") + print("phase medians (steady-state):") + for k, v in s["phases_median"].items(): + print(f" {k:<11s} {v:.3f}s") + + +def cmd_compare(args: argparse.Namespace) -> None: + rows = [] + for f in args.files: + d = json.loads(Path(f).read_text()) + s = d["summary"] + rows.append((d["commit"], s, d.get("config", {}))) + + hdr = f"{'commit':<10} {'total_med':>10} {'predict':>9} {'load_1st':>9} {'peak_MB':>9} {'fwd':>6} {'folds':>5} fold_status" + print(hdr) + print("-" * len(hdr)) + base = None + prev = None + for commit, s, _cfg in rows: + tm = s["total_median"] + line = ( + f"{commit:<10} {tm:>10.3f} {s['phases_median']['predict']:>9.3f} {s['load_first']:>9.3f} " + f"{s['peak_mem_mb']:>9.0f} {s['forward_calls']:>6} {s['folds']!s:>5} {s['fold_status']}" + ) + print(line) + base = base if base is not None else tm + if prev is not None: + d_prev = (tm - prev) / prev * 100 + d_base = (tm - base) / base * 100 + print(f"{'':<10} {'':>10} {'':>9} {'':>9} {'':>9} {'':>6} {'':>5} Δprev={d_prev:+.1f}% Δbaseline={d_base:+.1f}%") + prev = tm + print( + "\nNote: opt-in commits (cache_model, tile_batch_size) show ~0 here under the default config; " + "measure them with bench_flag_sweep.sh at HEAD." + ) + + +def build_parser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + sub = p.add_subparsers(dest="cmd", required=True) + + r = sub.add_parser("run", help="benchmark one configuration") + r.add_argument("--dataset-id", type=int, default=100) + r.add_argument("--model-path", default=None, help="explicit model folder; overrides --dataset-id") + r.add_argument("--input", nargs="+", default=None, help="input NIfTI path(s); one per model channel") + r.add_argument("--synthetic", action="store_true", help="generate random input(s) instead of reading files") + r.add_argument("--shape", type=int, nargs=3, default=[320, 320, 96]) + r.add_argument("--spacing", type=float, nargs=3, default=[1.40625, 1.40625, 3.0]) + r.add_argument("--channels", type=int, default=1, help="used only if dataset.json channel count is unavailable") + r.add_argument("--seed", type=int, default=1234) + r.add_argument("--cache-dir", default=str(Path(__file__).parent / ".bench_cache")) + r.add_argument("--device", choices=["cuda", "cpu", "mps"], default="cuda") + r.add_argument("--gpu", type=int, default=0) + r.add_argument("--repeats", type=int, default=5) + r.add_argument("--warmup", type=int, default=1) + # optimisation knobs (dropped automatically on commits that predate them) + r.add_argument("--max-folds", type=int, default=None) + r.add_argument("--step-size", type=float, default=0.5) + r.add_argument("--tile-batch-size", type=int, default=1) + r.add_argument("--cache-model", action="store_true") + r.add_argument("--keep-size", action="store_true") + r.add_argument("--padd", type=int, default=0) + r.add_argument("--json", default=None, help="write structured results to this path") + r.set_defaults(func=cmd_run) + + c = sub.add_parser("compare", help="tabulate result JSONs in the given order") + c.add_argument("files", nargs="+") + c.set_defaults(func=cmd_compare) + return p + + +def main() -> None: + args = build_parser().parse_args() + args.func(args) + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index d6f222e..c5356aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -210,6 +210,8 @@ convention = "google" # Test files: don't require docstrings on test functions / helper utilities "unit_tests/**" = ["D101", "D102", "D103", "D205", "D415", "ANN201"] "TPTBox/tests/**" = ["D101", "D102", "D103", "D205", "D415", "ANN201"] +# Dev tooling (standalone scripts): same relaxed docstring rules, not an importable package +"benchmarks/**" = ["D101", "D102", "D103", "D205", "D415", "ANN201", "INP001"] [tool.ruff.lint.mccabe] # Flag errors (`C901`) whenever the complexity level exceeds 5. From 365ab8b22ba806bff9c15a0d6fbd09e0cc789421 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Mon, 8 Jun 2026 11:39:01 +0000 Subject: [PATCH 12/17] bug fixes --- TPTBox/core/bids_files.py | 6 ++- TPTBox/core/nii_wrapper.py | 37 ++++++++++--------- TPTBox/core/np_utils.py | 24 +++++++++++- .../nnUnet_utils/export_prediction.py | 2 +- 4 files changed, 49 insertions(+), 20 deletions(-) diff --git a/TPTBox/core/bids_files.py b/TPTBox/core/bids_files.py index 4fdd4bf..f459b95 100755 --- a/TPTBox/core/bids_files.py +++ b/TPTBox/core/bids_files.py @@ -486,7 +486,7 @@ def enumerate_subjects(self, sort: bool = False, shuffle: bool = False) -> list[ return s return self.subjects.items() # type: ignore - def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container]]: + def iter_subjects(self, sort: bool = False, shuffle: bool = False) -> list[tuple[str, Subject_Container]]: """Iterate over all subjects (alias for :meth:`enumerate_subjects` without shuffle). Args: @@ -498,6 +498,10 @@ def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container """ if sort: return sorted(self.subjects.items()) + if shuffle: + s = list(self.subjects.items()) + random.shuffle(s) + return s return self.subjects.items() # type: ignore def __len__(self): diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index cf4dd39..7b8a95d 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -1180,33 +1180,36 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN if mapping.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0): log.print(f"resample_from_to skipped; already in space: {self}",verbose=verbose) return self if inplace else self.copy() - m1 = mapping if mapping.orientation == self.orientation else mapping.make_empty_POI().reorient(self.orientation) if m1.assert_affine(self,raise_error=False,origin_tolerance=0.00001,error_tolerance=0.00001,shape_tolerance=0): log.print(f"resample_from_to only need reorientation; {self.orientation}",verbose=verbose) ret = self.reorient(mapping.orientation,inplace=inplace) ret.affine = mapping.affine #remove floating point error return ret - if self.orientation == mapping.orientation and np.allclose(self.zoom , mapping.zoom, atol=1e-6): - shift = (np.array(self.origin) - np.array(m1.origin)) / np.array(m1.zoom) - if np.allclose(shift, np.round(shift), atol=1e-6): - s = self.reorient(mapping.orientation,inplace=inplace) # noqa: PLW0642 - shift = (np.array(self.origin) - np.array(mapping.origin)) / np.array(mapping.zoom) - shift = np.round(shift).astype(int) - dst_shape = np.array(mapping.shape) - src_shape = np.array(s.shape) - # padding before = how much dst starts before src - pad_before = shift - # padding after = remaining dst size after src - pad_after = dst_shape-shift-src_shape - pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after)) - ret = s.apply_pad(pad, mode=mode,inplace=inplace,verbose=verbose) - + if np.allclose(self.zoom, m1.zoom, atol=1e-6): + s = self.reorient(mapping.orientation, inplace=inplace) + # Compute voxel offset directly from the affines after both + # images are in the same orientation. This is robust to axis + # permutations and flips. + voxel_offset = np.linalg.inv(mapping.affine) @ s.affine @ np.array([0, 0, 0, 1]) + shift = np.round(voxel_offset[:3]).astype(int) + + dst_shape = np.array(mapping.shape) + src_shape = np.array(s.shape) + # padding before = how much dst starts before src + pad_before = shift + # padding after = remaining dst size after src + pad_after = dst_shape - shift - src_shape + pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after)) + try: + ret = s.apply_pad(pad,mode=mode,inplace=inplace,verbose=verbose) valid = ret.assert_affine(mapping,raise_error=False,origin_tolerance=0.0001,error_tolerance=0.0001,shape_tolerance=0) if valid: log.print(f"resample_from_to only needs padding/cropping {pad}",verbose=verbose) - ret.affine = mapping.affine #remove floating point error + ret.affine = mapping.affine # remove floating point error return ret + except ValueError as e: + log.warning("Padding failed.",e,verbose=verbose) assert mapping is not None diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 826a02a..cd808be 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -160,7 +160,7 @@ def np_count_nonzero(arr: np.ndarray) -> int: return np.count_nonzero(arr) -def np_unique(arr: np.ndarray) -> list[int]: +def old_np_unique(arr: np.ndarray) -> list[int]: """Returns each existing label in the array (including zero!). Uses cc3d statistics for unsigned-integer arrays for speed, and falls back @@ -181,7 +181,29 @@ def np_unique(arr: np.ndarray) -> list[int]: return list(np.unique(arr)) +def np_unique(arr: np.ndarray) -> list[int]: + if np.issubdtype(arr.dtype, np.unsignedinteger): + # bincount is O(max_val) but ~5-10x faster than np.unique for dense label arrays + max_val = int(arr.max()) + if max_val < 2**20: # ~1M labels threshold — bincount stays fast + counts = np.bincount(arr.ravel()) + return list(np.where(counts > 0)[0]) + # For sparse label spaces fall back to np.unique + return list(np.unique(arr)) + + def np_unique_withoutzero(arr: UINTARRAY) -> list[int]: + if np.issubdtype(arr.dtype, np.unsignedinteger): + max_val = int(arr.max()) + if max_val == 0: + return [] + if max_val < 2**20: + counts = np.bincount(arr.ravel()) + return list(np.where(counts[1:] > 0)[0] + 1) + return [i for i in np.unique(arr) if i != 0] + + +def old_np_unique_withoutzero(arr: UINTARRAY) -> list[int]: """Returns each existing non-zero label in the array (excluding background zero). Args: diff --git a/TPTBox/segmentation/nnUnet_utils/export_prediction.py b/TPTBox/segmentation/nnUnet_utils/export_prediction.py index 36ca786..b49c371 100755 --- a/TPTBox/segmentation/nnUnet_utils/export_prediction.py +++ b/TPTBox/segmentation/nnUnet_utils/export_prediction.py @@ -6,7 +6,7 @@ import numpy as np import torch from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice -from batchgenerators.utilities.file_and_folder_operations import isfile, load_json, save_pickle +from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle from nnunetv2.utilities.label_handling.label_handling import LabelManager from TPTBox.segmentation.nnUnet_utils.plans_handler import ConfigurationManager, PlansManager From a7efb703528003508832f4968f1a486be5772c42 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Tue, 9 Jun 2026 08:59:46 +0000 Subject: [PATCH 13/17] fix bug made by claude --- TPTBox/segmentation/nnUnet_utils/predictor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index a7b6c4b..f926172 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -781,6 +781,7 @@ def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool def _run_sub(self, data: torch.Tensor, network, results_device, slicers, pbar: tqdm, addendum: str = ""): """Iterate over slicers, run inference per tile (optionally batched), and accumulate results.""" + slicers = list(slicers) try: data = data.to(self.device) # type: ignore predicted_logits, n_predictions, gaussian, results_device = self._allocate(data, results_device, pbar) From 7b49086aa174b68123bf287eb6195f170f786a03 Mon Sep 17 00:00:00 2001 From: ga84mun Date: Wed, 10 Jun 2026 08:56:01 +0000 Subject: [PATCH 14/17] speed up argmax. Yes this much code is needed for this. --- .../nnUnet_utils/export_prediction.py | 148 ++++++++++++++++-- TPTBox/segmentation/nnUnet_utils/predictor.py | 7 +- 2 files changed, 142 insertions(+), 13 deletions(-) diff --git a/TPTBox/segmentation/nnUnet_utils/export_prediction.py b/TPTBox/segmentation/nnUnet_utils/export_prediction.py index b49c371..712effe 100755 --- a/TPTBox/segmentation/nnUnet_utils/export_prediction.py +++ b/TPTBox/segmentation/nnUnet_utils/export_prediction.py @@ -8,9 +8,145 @@ from acvl_utils.cropping_and_padding.bounding_boxes import bounding_box_to_slice from batchgenerators.utilities.file_and_folder_operations import load_json, save_pickle from nnunetv2.utilities.label_handling.label_handling import LabelManager +from tqdm import tqdm from TPTBox.segmentation.nnUnet_utils.plans_handler import ConfigurationManager, PlansManager +SAFETY_FACTOR = 0.5 # only use 50% of VRAM + + +def _argmax_with_gpu_fallback(predicted_logits: torch.Tensor | np.ndarray, device: torch.device, chunk_size: int = 64) -> np.ndarray: + """Computes argmax(0). + + Tiered argmax: + 1. Full array on GPU + 2. Chunked on GPU (if full array doesn't fit) + 3. Chunked on CPU (if even a single chunk doesn't fit) + """ + from TPTBox.segmentation.nnUnet_utils.predictor import empty_cache + + empty_cache(device) + + def _get_free_vram(device: torch.device) -> int: + try: + """Returns free VRAM in bytes.""" + free, _ = torch.cuda.mem_get_info(device) + return int(free * SAFETY_FACTOR) + except Exception: + return 0 + + def _array_bytes(shape: tuple, dtype: torch.dtype = torch.float16) -> int: + n_elements = 1 + for s in shape: + n_elements *= s + return n_elements * torch.finfo(dtype).bits // 8 + + def _to_cpu_tensor(arr: torch.Tensor | np.ndarray) -> torch.Tensor: + if isinstance(arr, np.ndarray): + return torch.from_numpy(arr) + return arr.cpu() + + def _chunked_argmax_gpu(t: torch.Tensor, device: torch.device) -> np.ndarray: + X = t.shape[1] + out = np.empty(t.shape[1:], dtype=np.int16) + for x in tqdm(range(0, X, chunk_size), "argmax gpu"): + chunk = t[:, x : x + chunk_size].to(device) + out[x : x + chunk_size] = torch.argmax(chunk, dim=0).cpu().numpy() + del chunk + empty_cache(device) + return out + + def _chunked_argmax_cpu(t: torch.Tensor | np.ndarray) -> np.ndarray: + arr = t.numpy() if isinstance(t, torch.Tensor) else t + if not arr.flags["C_CONTIGUOUS"]: + arr = np.ascontiguousarray(arr) + X = arr.shape[1] + out = np.empty(arr.shape[1:], dtype=np.int16) + for x in tqdm(range(0, X, chunk_size), "argmax cpu"): + out[x : x + chunk_size] = arr[:, x : x + chunk_size].argmax(0) + return out + + t = _to_cpu_tensor(predicted_logits) + + if device is None or not torch.cuda.is_available(): + return _chunked_argmax_cpu(t) + + full_bytes = _array_bytes(t.shape) + free_vram = _get_free_vram(device) + + print(f"[argmax] array: {full_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB") + + # Tier 1: full GPU + if full_bytes <= free_vram or device.type == "mps": + try: + return torch.argmax(t.to(device), dim=0).cpu().numpy().astype(np.int16) + except torch.cuda.OutOfMemoryError: + print("[argmax] full GPU OOM despite estimate, trying chunked GPU") + empty_cache(device) + except Exception as e: + print(e) + empty_cache(device) + + for i in range(10): + chunk_shape = (t.shape[0], min(max(int(chunk_size / 2**i), 1), t.shape[1]), *t.shape[2:]) + chunk_bytes = _array_bytes(chunk_shape) + if chunk_bytes <= free_vram: + chunk_size = max(int(chunk_size / 2**i), 1) + break + print(f"[argmax] array chunk: {chunk_bytes / 1e6:.1f} MB, VRAM: {free_vram / 1e6:.1f} MB, {chunk_size=}") + + # Tier 2: chunked GPU + if chunk_bytes <= free_vram: + print("[argmax] using chunked GPU") + try: + return _chunked_argmax_gpu(t, device) + except torch.cuda.OutOfMemoryError: + print("[argmax] chunked GPU OOM despite estimate, falling back to CPU") + empty_cache(device) + else: + print("[argmax] chunk too large for VRAM, falling back to CPU") + + # Tier 3: chunked CPU + return _chunked_argmax_cpu(t) + + +@torch.inference_mode() +def convert_probabilities_to_segmentation(self, predicted_probabilities: np.ndarray | torch.Tensor, device, chunk_size=64) -> np.ndarray: + """Assumes that inference_nonlinearity was already applied! + + predicted_probabilities has to have shape (c, x, y(, z)) where c is the number of classes/regions + """ + if not isinstance(predicted_probabilities, (np.ndarray, torch.Tensor)): + raise RuntimeError(f"Unexpected input type. Expected np.ndarray or torch.Tensor, got {type(predicted_probabilities)}") # noqa: TRY004 + + if self.has_regions: + assert self.regions_class_order is not None, "if region-based training is requested then you need to define regions_class_order!" + # check correct number of outputs + assert predicted_probabilities.shape[0] == self.num_segmentation_heads, ( + f"unexpected number of channels in predicted_probabilities. Expected {self.num_segmentation_heads}, " + f"got {predicted_probabilities.shape[0]}. Remember that predicted_probabilities should have shape " + f"(c, x, y(, z))." + ) + if self.has_regions: + if isinstance(predicted_probabilities, np.ndarray): + segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.uint16) + else: + # no uint16 in torch + segmentation = torch.zeros( + predicted_probabilities.shape[1:], + dtype=torch.int16, + device=predicted_probabilities.device, + ) + for i, c in enumerate(self.regions_class_order): + segmentation[predicted_probabilities[i] > 0.5] = c + if isinstance(segmentation, torch.Tensor): + segmentation = segmentation.cpu().numpy() + else: + # Issensee is no longer right when saying "numpy is faster than torch" newer torch versions no longer have this issue, on GPU we even get a 20x improvment. :facepalm: + segmentation = _argmax_with_gpu_fallback(predicted_probabilities, device, chunk_size=chunk_size) + + return segmentation + def convert_predicted_logits_to_segmentation_with_correct_shape( predicted_logits: torch.Tensor | np.ndarray, @@ -20,6 +156,7 @@ def convert_predicted_logits_to_segmentation_with_correct_shape( properties_dict: dict, return_probabilities: bool = False, num_threads_torch: int = 8, + device=None, ) -> np.ndarray: """Revert all preprocessing steps and return a segmentation in the original image space. @@ -62,19 +199,12 @@ def convert_predicted_logits_to_segmentation_with_correct_shape( predicted_logits = configuration_manager.resampling_fn_probabilities( predicted_logits, properties_dict["shape_after_cropping_and_before_resampling"], current_spacing, properties_dict["spacing"] ) - # return value of resampling_fn_probabilities can be ndarray or Tensor but that doesnt matter because - # apply_inference_nonlin will covnert to torch - # And this is stupid because convert_probabilities_to_segmentation transforms it back to a numpy... if label_manager.has_regions or return_probabilities: # Softmax does not change when we use argmax in the next step predicted_logits = label_manager.apply_inference_nonlin(predicted_logits) - # segmentation may be torch.Tensor but we continue with numpy - if isinstance(predicted_logits, torch.Tensor): - predicted_logits = predicted_logits.cpu().numpy() - - segmentation: np.ndarray = label_manager.convert_probabilities_to_segmentation(np.ascontiguousarray(predicted_logits)) # type: ignore + # segmentation: np.ndarray = label_manager.convert_probabilities_to_segmentation(predicted_logits) # type: ignore + segmentation: np.ndarray = convert_probabilities_to_segmentation(label_manager, predicted_logits, device) segmentation = segmentation.astype(np.uint8 if len(label_manager.foreground_labels) < 255 else np.uint16) - # if not return_probabilities: del predicted_logits # put segmentation in bbox (revert cropping) segmentation_reverted_cropping = np.zeros( diff --git a/TPTBox/segmentation/nnUnet_utils/predictor.py b/TPTBox/segmentation/nnUnet_utils/predictor.py index f926172..ff9e6e2 100755 --- a/TPTBox/segmentation/nnUnet_utils/predictor.py +++ b/TPTBox/segmentation/nnUnet_utils/predictor.py @@ -378,6 +378,7 @@ def predict_single_npy_array( self.label_manager, dct["data_properites"], return_probabilities=save_or_return_probabilities, + device=self.device, ) print("convert_predicted_logits_to_segmentation_with_correct_shape; Took", time.time() - t, " seconds") @@ -671,15 +672,13 @@ def check_mem(shape): break splits[j] += 1 + predicted_logits, n_predictions = self._run_sub(data, network, device, slicers, pbar) pbar.desc = "finish" pbar.update(0) predicted_logits /= n_predictions del n_predictions predicted_logits = predicted_logits.cpu() - # NOTE: do not empty_cache() here. This runs once per fold; releasing the - # allocator pool now just forces the next fold to re-cudaMalloc. The pool is - # cleared once per image in predict_logits_from_preprocessed_data instead. return predicted_logits[(slice(None), *slicer_revert_padding[1:])] def _run_prediction_splits( @@ -731,7 +730,7 @@ def _run_prediction_splits( predicted_logits /= n_predictions del n_predictions - empty_cache(self.device) + # empty_cache(self.device) return predicted_logits def _allocate(self, data: torch.Tensor, results_device, pbar: tqdm, gauss: bool = True): From afcb40273c4d4f084096423a093dae02efe89e6e Mon Sep 17 00:00:00 2001 From: ga84mun Date: Wed, 10 Jun 2026 09:01:04 +0000 Subject: [PATCH 15/17] ruff --- TPTBox/core/np_utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index cd808be..7b51c0d 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -182,6 +182,18 @@ def old_np_unique(arr: np.ndarray) -> list[int]: def np_unique(arr: np.ndarray) -> list[int]: + """Returns each existing label in the array (including zero!). + + Uses cc3d statistics for unsigned-integer arrays for speed, and falls back + to ``numpy.unique`` for other dtypes. + + Args: + arr (np.ndarray): Input label array. + + Returns: + list[int]: Sorted list of every distinct label value present in ``arr``, + including 0 (background). + """ if np.issubdtype(arr.dtype, np.unsignedinteger): # bincount is O(max_val) but ~5-10x faster than np.unique for dense label arrays max_val = int(arr.max()) @@ -193,6 +205,15 @@ def np_unique(arr: np.ndarray) -> list[int]: def np_unique_withoutzero(arr: UINTARRAY) -> list[int]: + """Returns each existing non-zero label in the array (excluding background zero). + + Args: + arr (UINTARRAY): Input unsigned-integer label array. + + Returns: + list[int]: Sorted list of every distinct label value present in ``arr``, + excluding 0 (background). + """ if np.issubdtype(arr.dtype, np.unsignedinteger): max_val = int(arr.max()) if max_val == 0: From d1bfde2a7e8afbdce9bfbb0b338c91404d00df8a Mon Sep 17 00:00:00 2001 From: ga84mun Date: Wed, 10 Jun 2026 09:10:34 +0000 Subject: [PATCH 16/17] update tests. add ravel --- TPTBox/core/nii_wrapper.py | 14 ++++++++++++++ unit_tests/test_auto_segmentation.py | 10 +++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/TPTBox/core/nii_wrapper.py b/TPTBox/core/nii_wrapper.py index 7b8a95d..c5aa88c 100755 --- a/TPTBox/core/nii_wrapper.py +++ b/TPTBox/core/nii_wrapper.py @@ -2699,6 +2699,20 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum]|None, keep_la if keep_label: seg_arr = seg_arr * self.get_seg_array() return self.set_array(seg_arr,inplace=inplace) + def ravel(self,order:Literal["K", "A", "C", "F"] | None="C")->np.ndarray: + """Return a contiguous flattened array. + + A 1-D array, containing the elements of the input, is returned. A copy is made only if needed. + + As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input) + + Args: + order (Literal["K", "A", "C", "F"] | None, optional): The elements of a are read using this index order. ‘C’ means to index the elements in row-major, C-style order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to index the elements in column-major, Fortran-style order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. ‘A’ means to read the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise. ‘K’ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, ‘C’ index order is used. Defaults to "C". + + Returns: + np.ndarray + """ + return self.get_array().ravel(order=order) def extract_label_(self, label: int | Enum | Sequence[int] | Sequence[Enum], keep_label=False) -> Self: """In-place variant of `extract_label`.""" return self.extract_label(label,keep_label,inplace=True) diff --git a/unit_tests/test_auto_segmentation.py b/unit_tests/test_auto_segmentation.py index 5b31b2b..4f20f6e 100644 --- a/unit_tests/test_auto_segmentation.py +++ b/unit_tests/test_auto_segmentation.py @@ -39,6 +39,14 @@ has_torch = False +try: + import ants + + has_ants = True +except ModuleNotFoundError: + has_ants = False + + class Test_test_samples(unittest.TestCase): # def test_load_ct(self): # ct_nii, subreg_nii, vert_nii, label = get_test_ct() @@ -60,7 +68,7 @@ def test_get_outpaths_spineps(self): assert "out_spine" in out assert "out_vert" in out - @unittest.skipIf(not has_spineps, "requires spineps to be installed") + @unittest.skipIf(not has_spineps or not has_ants, "requires spineps to be installed") def test_spineps(self): tests_path = get_tests_dir() if (tests_path / "derivative").exists(): From a4d1e6d09ff0e8d9fd5abf3f77840e11e957c434 Mon Sep 17 00:00:00 2001 From: iback Date: Wed, 10 Jun 2026 15:23:30 +0000 Subject: [PATCH 17/17] minor fallback plus updated speedtest --- TPTBox/core/np_utils.py | 2 +- TPTBox/tests/speedtests/speedtest_npunique.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/TPTBox/core/np_utils.py b/TPTBox/core/np_utils.py index 7b51c0d..2b3ebbd 100755 --- a/TPTBox/core/np_utils.py +++ b/TPTBox/core/np_utils.py @@ -201,7 +201,7 @@ def np_unique(arr: np.ndarray) -> list[int]: counts = np.bincount(arr.ravel()) return list(np.where(counts > 0)[0]) # For sparse label spaces fall back to np.unique - return list(np.unique(arr)) + return old_np_unique(arr) def np_unique_withoutzero(arr: UINTARRAY) -> list[int]: diff --git a/TPTBox/tests/speedtests/speedtest_npunique.py b/TPTBox/tests/speedtests/speedtest_npunique.py index 484606d..afdc91f 100644 --- a/TPTBox/tests/speedtests/speedtest_npunique.py +++ b/TPTBox/tests/speedtests/speedtest_npunique.py @@ -18,13 +18,14 @@ np_unique, np_unique_withoutzero, np_volume, + old_np_unique, ) from TPTBox.tests.speedtests.speedtest import speed_test from TPTBox.tests.test_utils import get_nii def get_nii_array(): num_points = random.randint(1, 30) - nii, points, orientation, sizes = get_nii(x=(140, 140, 150), num_point=num_points) + nii, points, orientation, sizes = get_nii(x=(400, 400, 400), num_point=num_points) # nii.map_labels_({1: -1}, verbose=False) arr = nii.get_seg_array().astype(np.uint8) # arr[arr == 1] = -1 @@ -34,7 +35,7 @@ def get_nii_array(): speed_test( repeats=50, get_input_func=get_nii_array, - functions=[np_unique, np.unique, np_is_empty, np.max], + functions=[np_unique, old_np_unique, np.unique, np_is_empty, np.max], assert_equal_function=lambda x, y: True, # np.all([x[i] == y[i] for i in range(len(x))]), # noqa: ARG005 # np.all([x[i] == y[i] for i in range(len(x))]) )