Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
88e7bf5
fail if not fit on GPU, so I can test it
robert-graf May 27, 2026
edc368f
Merge branch 'main' into development_robert
robert-graf May 27, 2026
4a6ee90
update memory requirements
robert-graf May 28, 2026
48a9d2a
fix bug for very elongated segmentations
robert-graf May 29, 2026
9211f4f
Merge branch 'development_robert' of github.com:Hendrik-code/TPTBox i…
robert-graf May 29, 2026
41fa833
should not use Runtime Errors
robert-graf May 29, 2026
a5e621c
Merge branch 'development_robert' of github.com:Hendrik-code/TPTBox i…
robert-graf May 29, 2026
397cd36
perf(nnunet): enable cuDNN autotune + TF32 for sliding-window inference
Hendrik-code Jun 4, 2026
065e3d9
perf(nnunet): use torch.inference_mode() instead of no_grad() for inf…
Hendrik-code Jun 4, 2026
271557a
perf(nnunet): stop clearing CUDA cache on the per-fold happy path
Hendrik-code Jun 4, 2026
3fbc842
fix(nnunet): repair fold ensembling broken by loaded_networks cache
Hendrik-code Jun 4, 2026
93ad675
perf(nnunet): add opt-in persistent model cache (cache_model)
Hendrik-code Jun 4, 2026
cf37216
perf(nnunet): optionally batch sliding-window tiles (tile_batch_size)
Hendrik-code Jun 4, 2026
57a8d19
bench: add nnU-Net inference timing harness
Hendrik-code Jun 4, 2026
365ab8b
bug fixes
robert-graf Jun 8, 2026
99170c8
Merge branch 'development_robert' of github.com:Hendrik-code/TPTBox i…
robert-graf Jun 8, 2026
69e872e
Merge branch 'nnunetinference' into development_robert
robert-graf Jun 8, 2026
a7efb70
fix bug made by claude
robert-graf Jun 9, 2026
7b49086
speed up argmax. Yes this much code is needed for this.
robert-graf Jun 10, 2026
afcb402
ruff
robert-graf Jun 10, 2026
d1bfde2
update tests. add ravel
robert-graf Jun 10, 2026
a4d1e6d
minor fallback plus updated speedtest
Hendrik-code Jun 10, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion TPTBox/core/bids_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def enumerate_subjects(self, sort: bool = False, shuffle: bool = False) -> list[
return s
return self.subjects.items() # type: ignore

def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container]]:
def iter_subjects(self, sort: bool = False, shuffle: bool = False) -> list[tuple[str, Subject_Container]]:
"""Iterate over all subjects (alias for :meth:`enumerate_subjects` without shuffle).

Args:
Expand All @@ -498,6 +498,10 @@ def iter_subjects(self, sort: bool = False) -> list[tuple[str, Subject_Container
"""
if sort:
return sorted(self.subjects.items())
if shuffle:
s = list(self.subjects.items())
random.shuffle(s)
return s
return self.subjects.items() # type: ignore

def __len__(self):
Expand Down
53 changes: 35 additions & 18 deletions TPTBox/core/nii_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,33 +1180,36 @@ def resample_from_to(self, to_vox_map:Image_Reference|Has_Grid|tuple[SHAPE,AFFIN
if mapping.assert_affine(self,raise_error=False,origin_tolerance=0.000001,error_tolerance=0.000001,shape_tolerance=0):
log.print(f"resample_from_to skipped; already in space: {self}",verbose=verbose)
return self if inplace else self.copy()

m1 = mapping if mapping.orientation == self.orientation else mapping.make_empty_POI().reorient(self.orientation)
if m1.assert_affine(self,raise_error=False,origin_tolerance=0.00001,error_tolerance=0.00001,shape_tolerance=0):
log.print(f"resample_from_to only need reorientation; {self.orientation}",verbose=verbose)
ret = self.reorient(mapping.orientation,inplace=inplace)
ret.affine = mapping.affine #remove floating point error
return ret
if self.orientation == mapping.orientation and np.allclose(self.zoom , mapping.zoom, atol=1e-6):
shift = (np.array(self.origin) - np.array(m1.origin)) / np.array(m1.zoom)
if np.allclose(shift, np.round(shift), atol=1e-6):
s = self.reorient(mapping.orientation,inplace=inplace) # noqa: PLW0642
shift = (np.array(self.origin) - np.array(mapping.origin)) / np.array(mapping.zoom)
shift = np.round(shift).astype(int)
dst_shape = np.array(mapping.shape)
src_shape = np.array(s.shape)
# padding before = how much dst starts before src
pad_before = shift
# padding after = remaining dst size after src
pad_after = dst_shape-shift-src_shape
pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after))
ret = s.apply_pad(pad, mode=mode,inplace=inplace,verbose=verbose)

if np.allclose(self.zoom, m1.zoom, atol=1e-6):
s = self.reorient(mapping.orientation, inplace=inplace)
# Compute voxel offset directly from the affines after both
# images are in the same orientation. This is robust to axis
# permutations and flips.
voxel_offset = np.linalg.inv(mapping.affine) @ s.affine @ np.array([0, 0, 0, 1])
shift = np.round(voxel_offset[:3]).astype(int)

dst_shape = np.array(mapping.shape)
src_shape = np.array(s.shape)
# padding before = how much dst starts before src
pad_before = shift
# padding after = remaining dst size after src
pad_after = dst_shape - shift - src_shape
pad = tuple((int(b), int(a)) for b, a in zip(pad_before, pad_after))
try:
ret = s.apply_pad(pad,mode=mode,inplace=inplace,verbose=verbose)
valid = ret.assert_affine(mapping,raise_error=False,origin_tolerance=0.0001,error_tolerance=0.0001,shape_tolerance=0)
if valid:
log.print(f"resample_from_to only needs padding/cropping {pad}",verbose=verbose)
ret.affine = mapping.affine #remove floating point error
ret.affine = mapping.affine # remove floating point error
return ret
except ValueError as e:
log.warning("Padding failed.",e,verbose=verbose)


assert mapping is not None
Expand Down Expand Up @@ -2505,7 +2508,7 @@ def to_stl(
try:
verts, faces, normals, values = marching_cubes(seg_arr, gradient_direction="ascent", step_size=1)
except RuntimeError as e:
raise RuntimeError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None
raise IndexError(str(e),f"{label=}, {self.unique()}, {out_path=}") from None
# Remove padding offset (since we padded by 1 voxel)
verts -= 1
# Apply bounding box offset (still voxel space)
Expand Down Expand Up @@ -2696,6 +2699,20 @@ def extract_label(self,label:int|Enum|Sequence[int]|Sequence[Enum]|None, keep_la
if keep_label:
seg_arr = seg_arr * self.get_seg_array()
return self.set_array(seg_arr,inplace=inplace)
def ravel(self,order:Literal["K", "A", "C", "F"] | None="C")->np.ndarray:
"""Return a contiguous flattened array.

A 1-D array, containing the elements of the input, is returned. A copy is made only if needed.

As of NumPy 1.10, the returned array will have the same type as the input array. (for example, a masked array will be returned for a masked array input)

Args:
order (Literal["K", "A", "C", "F"] | None, optional): The elements of a are read using this index order. ‘C’ means to index the elements in row-major, C-style order, with the last axis index changing fastest, back to the first axis index changing slowest. ‘F’ means to index the elements in column-major, Fortran-style order, with the first index changing fastest, and the last index changing slowest. Note that the ‘C’ and ‘F’ options take no account of the memory layout of the underlying array, and only refer to the order of axis indexing. ‘A’ means to read the elements in Fortran-like index order if a is Fortran contiguous in memory, C-like order otherwise. ‘K’ means to read the elements in the order they occur in memory, except for reversing the data when strides are negative. By default, ‘C’ index order is used. Defaults to "C".

Returns:
np.ndarray
"""
return self.get_array().ravel(order=order)
def extract_label_(self, label: int | Enum | Sequence[int] | Sequence[Enum], keep_label=False) -> Self:
"""In-place variant of `extract_label`."""
return self.extract_label(label,keep_label,inplace=True)
Expand Down
45 changes: 44 additions & 1 deletion TPTBox/core/np_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def np_count_nonzero(arr: np.ndarray) -> int:
return np.count_nonzero(arr)


def np_unique(arr: np.ndarray) -> list[int]:
def old_np_unique(arr: np.ndarray) -> list[int]:
"""Returns each existing label in the array (including zero!).

Uses cc3d statistics for unsigned-integer arrays for speed, and falls back
Expand All @@ -181,9 +181,52 @@ def np_unique(arr: np.ndarray) -> list[int]:
return list(np.unique(arr))


def np_unique(arr: np.ndarray) -> list[int]:
"""Returns each existing label in the array (including zero!).

Uses cc3d statistics for unsigned-integer arrays for speed, and falls back
to ``numpy.unique`` for other dtypes.

Args:
arr (np.ndarray): Input label array.

Returns:
list[int]: Sorted list of every distinct label value present in ``arr``,
including 0 (background).
"""
if np.issubdtype(arr.dtype, np.unsignedinteger):
# bincount is O(max_val) but ~5-10x faster than np.unique for dense label arrays
max_val = int(arr.max())
if max_val < 2**20: # ~1M labels threshold — bincount stays fast
counts = np.bincount(arr.ravel())
return list(np.where(counts > 0)[0])
# For sparse label spaces fall back to np.unique
return old_np_unique(arr)


def np_unique_withoutzero(arr: UINTARRAY) -> list[int]:
"""Returns each existing non-zero label in the array (excluding background zero).

Args:
arr (UINTARRAY): Input unsigned-integer label array.

Returns:
list[int]: Sorted list of every distinct label value present in ``arr``,
excluding 0 (background).
"""
if np.issubdtype(arr.dtype, np.unsignedinteger):
max_val = int(arr.max())
if max_val == 0:
return []
if max_val < 2**20:
counts = np.bincount(arr.ravel())
return list(np.where(counts[1:] > 0)[0] + 1)
return [i for i in np.unique(arr) if i != 0]


def old_np_unique_withoutzero(arr: UINTARRAY) -> list[int]:
"""Returns each existing non-zero label in the array (excluding background zero).

Args:
arr (UINTARRAY): Input unsigned-integer label array.

Expand Down
7 changes: 4 additions & 3 deletions TPTBox/registration/_deformable/multilabel_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __init__( # noqa: C901
poi_target_cms: POI | None = None,
max_history=100,
change_after_point_reg=lambda x, y, z, w: (x, y, z, w),
tether_distance=1,
**args,
):
"""Initialize a multi-stage registration pipeline from an atlas to a target image.
Expand Down Expand Up @@ -90,7 +91,7 @@ def __init__( # noqa: C901
"be": ("BSplineBending", {"stride": 1}),
"seg": "MSE",
"Dice": "Dice",
"Tether": Tether_Seg(delta=5),
"Tether": Tether_Seg(delta=tether_distance),
}

assert target_seg.seg, target_seg.seg
Expand Down Expand Up @@ -187,7 +188,7 @@ def __init__( # noqa: C901
poi_cms = poi_cms.resample_from_to(atlas_seg_)

self.reg_point = Point_Registration(poi_target, poi_cms, verbose=False)
atlas_reg = self.reg_point.transform_nii(atlas_seg_)
atlas_reg = self.reg_point.transform_nii(atlas_seg_, c_val=0)

if not atlas_reg.is_segmentation_in_border():
print("point registration ok")
Expand All @@ -204,7 +205,7 @@ def __init__( # noqa: C901
target_img = target_img.apply_pad(resize_param) if target_img is not None else None

self.reg_point = Point_Registration(poi_target.resample_from_to(target_seg), poi_cms.resample_from_to(atlas_seg))
atlas_reg = self.reg_point.transform_nii(atlas_seg)
atlas_reg = self.reg_point.transform_nii(atlas_seg, c_val=0)
atlas_img_reg = self.reg_point.transform_nii(atlas_img) if atlas_img is not None else None

if crop:
Expand Down
4 changes: 2 additions & 2 deletions TPTBox/registration/_ridged_intensity/affine_deepali.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,14 @@ def forward(
target: torch.Tensor, # shape: (B, C, X, Y, Z)
mask: torch.Tensor | None = None, # noqa: ARG002
) -> torch.Tensor:
w = max(target.shape[2:])
w = min(target.shape[2:])
com_fixed = center_of_mass_cc(target) # (B, C, 3)
com_warped = center_of_mass_cc(source) # (B, C, 3)

l_com = torch.norm(com_fixed - com_warped, dim=-1) / w # (B, C)

# Zero out channels with small displacement (<10) or NaNs
l_com = torch.where(l_com < self.delta, torch.zeros_like(l_com), l_com)
l_com = torch.where(l_com * w < self.delta, torch.zeros_like(l_com), l_com)
l_com = torch.nan_to_num(l_com, nan=0.0)

return l_com.mean() # type: ignore
Expand Down
84 changes: 65 additions & 19 deletions TPTBox/segmentation/VibeSeg/inference_nnunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@
out_base = Path(__file__).parent.parent / "nnUNet/"
_model_path_ = out_base / "nnUNet_results"

# Opt-in cache of loaded predictors (enable via cache_model=True). Keyed by model identity plus
# the device/runtime settings that affect the loaded predictor, so repeated inference (e.g. a loop
# over many files with the same model) reuses the in-memory model instead of reloading weights from
# disk and re-uploading them to the GPU on every call.
_model_cache: dict = {}


def get_ds_info(idx: int, _model_path: str | Path | None = None, exit_one_fail: bool = True, logger=logger) -> dict:
"""Load and return the ``dataset.json`` for the model with the given dataset index.
Expand Down Expand Up @@ -87,13 +93,16 @@ def run_inference_on_file(
ddevice: Literal["cpu", "cuda", "mps"] = "cuda",
model_path=None,
step_size: float = 0.5,
memory_base: int = 5000, # Base memory in MB, default is 5GB
memory_factor: int = 160, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB
memory_max: int = 160000, # in MB, default is 160GB
memory_base: float | None = None, # Base memory in MB, default is 5GB
memory_factor: float | None = None, # prod(shape)*memory_factor / 1000, 160 ~> 30 GB
memory_max: int = 990000, # in MB, default is 990GB (so it is most likely ignored and replaced by Max Memory of the GPU)
wait_till_gpu_percent_is_free: float = 0.1,
tile_batch_size: int = 1,
verbose: bool = True,
auto_download: bool = False,
cache_model: bool = False,
_key_ResEnc: str = "__nnUNet*ResEnc",
fail_on_missing_memory=False,
logger=logger,
) -> tuple[Image_Reference, np.ndarray | None]:
"""Load a VibeSeg model and run inference on the supplied NIfTI images.
Expand Down Expand Up @@ -135,7 +144,18 @@ def run_inference_on_file(
memory_max: Hard cap on assumed GPU memory in MB.
wait_till_gpu_percent_is_free: Minimum free GPU fraction to require
before starting inference.
tile_batch_size: Number of sliding-window tiles to run per network
forward pass. ``1`` (default) keeps the original per-tile behaviour;
larger values batch tiles to better saturate the GPU at the cost of
higher peak memory.
verbose: Print progress information.
cache_model: If ``True``, keep the loaded predictor in a process-wide
cache and reuse it on subsequent calls with identical model and
device/runtime settings. Avoids reloading weights from disk and
re-uploading them to the GPU when segmenting many files in a loop, at
the cost of holding the model in GPU memory between calls. The GPU
cache is also left warm (no ``empty_cache``) so the allocator can
reuse buffers across images.

Returns:
A tuple ``(seg_nii, softmax_logits)`` where ``seg_nii`` is the
Expand Down Expand Up @@ -196,20 +216,44 @@ def run_inference_on_file(
if "labels" in ds_info2:
ds_info["labels_mapping"] = ds_info2["labels"]

nnunet = load_inf_model(
nnunet_path,
allow_non_final=True,
use_folds=tuple(folds) if len(folds) != 5 else None,
gpu=gpu,
ddevice=ddevice,
step_size=step_size,
memory_base=memory_base,
memory_factor=memory_factor,
memory_max=memory_max,
wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free,
if memory_base is None:
memory_base = float(ds_info.get("memory_base", 5000))
if memory_factor is None:
memory_factor = float(ds_info.get("memory_factor", 160))

use_folds_arg = tuple(folds) if len(folds) != 5 else None
# Include every setting that changes the loaded predictor so a cache hit is always equivalent
# to a fresh load; differing settings simply miss the cache and reload.
cache_key = (
str(nnunet_path),
use_folds_arg,
ddevice,
gpu,
step_size,
memory_base,
memory_factor,
memory_max,
wait_till_gpu_percent_is_free,
tile_batch_size,
)

# _unets[idx] = nnunet
nnunet = _model_cache.get(cache_key) if cache_model else None
if nnunet is None:
nnunet = load_inf_model(
nnunet_path,
allow_non_final=True,
use_folds=use_folds_arg,
gpu=gpu,
ddevice=ddevice,
step_size=step_size,
memory_base=memory_base,
memory_factor=memory_factor,
memory_max=memory_max,
wait_till_gpu_percent_is_free=wait_till_gpu_percent_is_free,
tile_batch_size=tile_batch_size,
fail_on_missing_memory=fail_on_missing_memory,
)
if cache_model:
_model_cache[cache_key] = nnunet
if "orientation" in ds_info:
orientation = ds_info["orientation"]

Expand Down Expand Up @@ -315,9 +359,11 @@ def to_int(a: str, k: None | int = None):
seg_nii.map_labels_(mapping)
if out_file is not None and (not Path(out_file).exists() or override):
seg_nii.set_dtype("smallest_uint").save(out_file)
del nnunet

torch.cuda.empty_cache()
if not cache_model:
# When caching we keep the predictor alive (it stays referenced by _model_cache, so del
# would not free it anyway) and leave the CUDA cache warm so the next image reuses buffers.
del nnunet
torch.cuda.empty_cache()
return seg_nii, softmax_logits


Expand Down
8 changes: 8 additions & 0 deletions TPTBox/segmentation/VibeSeg/vibeseg.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@
72: "bone_other",
}

defaults = {
100: {"memory_base": 5500, "memory_factor": 25},
}


def run_vibeseg(
i: Image_Reference,
Expand Down Expand Up @@ -113,6 +117,10 @@ def run_vibeseg(
Returns:
Segmentation ``NII`` saved at *out_seg*.
"""
if dataset_id in defaults:
for k, v in defaults[dataset_id].items():
if k not in args:
args[k] = v
return run_inference_on_file(
dataset_id,
[to_nii(i)] if not isinstance(i, (list, tuple)) else [to_nii(j) for j in i],
Expand Down
Loading
Loading