From b23310e85ab6d97110d6645dc0f5fa3326d80594 Mon Sep 17 00:00:00 2001 From: Kamil Monicz Date: Mon, 2 Mar 2026 01:10:04 +0100 Subject: [PATCH] Improve nearest-neighbor query perf --- pyresample/future/resamplers/nearest.py | 35 ++++++--- .../test/test_resamplers/test_nearest.py | 71 +++++++++++++++++++ 2 files changed, 97 insertions(+), 9 deletions(-) diff --git a/pyresample/future/resamplers/nearest.py b/pyresample/future/resamplers/nearest.py index e48c8b3b..941b6ceb 100644 --- a/pyresample/future/resamplers/nearest.py +++ b/pyresample/future/resamplers/nearest.py @@ -66,22 +66,32 @@ def query_no_distance(target_lons, target_lats, valid_output_index, distance_upper_bound=radius, mask=mask) + if neighbours == 1: + # Nearest-neighbor resampling only consumes one neighbor, so avoid + # building and masking the full (rows, cols, neighbours) array shape. + index_array = np.asarray(index_array, dtype=np.int64) + index_array[index_array >= kdtree.n] = -1 + out = np.full(voi.size, -1, dtype=np.int64) + out[voir] = index_array + return out.reshape(voi.shape + (1,)) + if index_array.ndim == 1: index_array = index_array[:, None] # KDTree query returns out-of-bounds neighbors as `len(arr)` # which is an invalid index, we mask those out so -1 represents # invalid values + # # voi is 2D (trows, tcols) # index_array is 2D (valid output pixels, neighbors) # there are as many Trues in voi as rows in index_array - good_pixels = index_array < kdtree.n - res_ia = np.empty(shape, dtype=int) - mask = np.zeros(shape, dtype=bool) - mask[voi, :] = good_pixels - res_ia[mask] = index_array[good_pixels] - res_ia[~mask] = -1 - return res_ia + # + # Write (valid_output_pixels, neighbours) index array into an output filled with + # -1 and then overwrite out-of-bounds values in-place. + out = np.full(shape, -1, dtype=np.int64) + out[voi, :] = index_array + out[out >= kdtree.n] = -1 + return out def _my_index(index_arr, vii, data_arr, vii_slices=None, ia_slices=None, @@ -144,7 +154,12 @@ def _compute_radius_of_influence(self): logger.warning("Could not calculate destination definition " "resolution") dst_res = np.nan - radius_of_influence = np.nanmax([src_res, dst_res]) + if np.isnan(src_res): + radius_of_influence = dst_res + elif np.isnan(dst_res): + radius_of_influence = src_res + else: + radius_of_influence = max(src_res, dst_res) if np.isnan(radius_of_influence): logger.warning("Could not calculate radius_of_influence, falling " "back to 10000 meters. This may produce lower " @@ -487,7 +502,9 @@ def _verify_input_object_type(self, data): "to dask arrays for computation and then converted back. To " "avoid this warning convert your numpy array before providing " "it to the resampler.", PerformanceWarning, stacklevel=3) - data = data.copy() + # Avoid copying the underlying ndarray; we only need a new wrapper + # object so we can replace `.data` with a dask array. + data = data.copy(deep=False) data.data = da.from_array(data.data, chunks="auto") return data diff --git a/pyresample/test/test_resamplers/test_nearest.py b/pyresample/test/test_resamplers/test_nearest.py index a4a16db5..241386f3 100644 --- a/pyresample/test/test_resamplers/test_nearest.py +++ b/pyresample/test/test_resamplers/test_nearest.py @@ -24,10 +24,13 @@ import numpy as np import pytest import xarray as xr +from pykdtree.kdtree import KDTree from pytest_lazy_fixtures import lf from pyresample.future.geometry import AreaDefinition, SwathDefinition from pyresample.future.resamplers import KDTreeNearestXarrayResampler +from pyresample.future.resamplers._transform_utils import lonlat2xyz +from pyresample.future.resamplers.nearest import query_no_distance from pyresample.test.utils import assert_maximum_dask_computes, assert_warnings_contain, catch_warnings from pyresample.utils.errors import PerformanceWarning @@ -300,3 +303,71 @@ def test_inconsistent_input_shapes(self, src_geom, match, call_precompute, resampler.precompute(mask=data_2d_float32_xarray_dask.notnull()) else: resampler.resample(data_2d_float32_xarray_dask) + + +class TestQueryNoDistance: + """Tests for direct KDTree query index remapping.""" + + def test_unselected_and_oob_are_minus_one(self): + voi = np.array([[True, False], [True, False]]) + tlons = np.array([[0.0, 0.0], [10.0, 0.0]], dtype=np.float64) + tlats = np.zeros_like(tlons) + + src_lons = np.array([0.0], dtype=np.float64) + src_lats = np.array([0.0], dtype=np.float64) + src_xyz = lonlat2xyz(src_lons, src_lats).astype(np.float64, copy=False) + kdtree = KDTree(src_xyz) + + res = query_no_distance( + tlons, + tlats, + voi, + neighbours=1, + epsilon=0.0, + radius=1.0, # meters; only exact match is within this ROI + kdtree=kdtree, + ) + + np.testing.assert_array_equal(res[..., 0], np.array([[0, -1], [-1, -1]])) + + def test_forwards_filtered_source_mask(self): + voi = np.array([[True]]) + + src_lons = np.array([[0.0, 0.0001], [0.0002, 0.0003]], dtype=np.float64) + src_lats = np.zeros_like(src_lons) + valid_input_index = np.array([[True, True], [True, False]]) + + src_xyz = lonlat2xyz(src_lons, src_lats).astype(np.float64, copy=False) + kdtree = KDTree(src_xyz[valid_input_index.ravel()]) + + target_lons = np.array([[0.0]], dtype=np.float64) + target_lats = np.array([[0.0]], dtype=np.float64) + + res_unmasked = query_no_distance( + target_lons, + target_lats, + voi, + neighbours=1, + epsilon=0.0, + radius=1000.0, + kdtree=kdtree, + ) + + # Mask out the nearest source point (after valid_input_index filtering). + source_mask = np.array([[True, False], [False, True]]) + res_masked = query_no_distance( + target_lons, + target_lats, + voi, + mask=source_mask, + valid_input_index=valid_input_index, + neighbours=1, + epsilon=0.0, + radius=1000.0, + kdtree=kdtree, + ) + + assert res_unmasked.shape == (1, 1, 1) + assert res_masked.shape == (1, 1, 1) + assert res_unmasked[0, 0, 0] == 0 + assert res_masked[0, 0, 0] == 1