Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions pyresample/future/resamplers/nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,32 @@ def query_no_distance(target_lons, target_lats, valid_output_index,
distance_upper_bound=radius,
mask=mask)

if neighbours == 1:
# Nearest-neighbor resampling only consumes one neighbor, so avoid
# building and masking the full (rows, cols, neighbours) array shape.
index_array = np.asarray(index_array, dtype=np.int64)
index_array[index_array >= kdtree.n] = -1
out = np.full(voi.size, -1, dtype=np.int64)
out[voir] = index_array
return out.reshape(voi.shape + (1,))

if index_array.ndim == 1:
index_array = index_array[:, None]

# KDTree query returns out-of-bounds neighbors as `len(arr)`
# which is an invalid index, we mask those out so -1 represents
# invalid values
#
# voi is 2D (trows, tcols)
# index_array is 2D (valid output pixels, neighbors)
# there are as many Trues in voi as rows in index_array
good_pixels = index_array < kdtree.n
res_ia = np.empty(shape, dtype=int)
mask = np.zeros(shape, dtype=bool)
mask[voi, :] = good_pixels
res_ia[mask] = index_array[good_pixels]
res_ia[~mask] = -1
return res_ia
#
# Write (valid_output_pixels, neighbours) index array into an output filled with
# -1 and then overwrite out-of-bounds values in-place.
out = np.full(shape, -1, dtype=np.int64)
out[voi, :] = index_array
out[out >= kdtree.n] = -1
return out


def _my_index(index_arr, vii, data_arr, vii_slices=None, ia_slices=None,
Expand Down Expand Up @@ -144,7 +154,12 @@ def _compute_radius_of_influence(self):
logger.warning("Could not calculate destination definition "
"resolution")
dst_res = np.nan
radius_of_influence = np.nanmax([src_res, dst_res])
if np.isnan(src_res):
radius_of_influence = dst_res
elif np.isnan(dst_res):
radius_of_influence = src_res
else:
radius_of_influence = max(src_res, dst_res)
Comment on lines +157 to +162
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why this is better? Because it prefers the source resolution?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works just the same and avoids the all-nans warning print, which is quite expensive in Python. So it benefits only the worst-case scenario.

if np.isnan(radius_of_influence):
logger.warning("Could not calculate radius_of_influence, falling "
"back to 10000 meters. This may produce lower "
Expand Down Expand Up @@ -487,7 +502,9 @@ def _verify_input_object_type(self, data):
"to dask arrays for computation and then converted back. To "
"avoid this warning convert your numpy array before providing "
"it to the resampler.", PerformanceWarning, stacklevel=3)
data = data.copy()
# Avoid copying the underlying ndarray; we only need a new wrapper
# object so we can replace `.data` with a dask array.
data = data.copy(deep=False)
data.data = da.from_array(data.data, chunks="auto")
return data

Expand Down
71 changes: 71 additions & 0 deletions pyresample/test/test_resamplers/test_nearest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,13 @@
import numpy as np
import pytest
import xarray as xr
from pykdtree.kdtree import KDTree
from pytest_lazy_fixtures import lf

from pyresample.future.geometry import AreaDefinition, SwathDefinition
from pyresample.future.resamplers import KDTreeNearestXarrayResampler
from pyresample.future.resamplers._transform_utils import lonlat2xyz
from pyresample.future.resamplers.nearest import query_no_distance
from pyresample.test.utils import assert_maximum_dask_computes, assert_warnings_contain, catch_warnings
from pyresample.utils.errors import PerformanceWarning

Expand Down Expand Up @@ -300,3 +303,71 @@ def test_inconsistent_input_shapes(self, src_geom, match, call_precompute,
resampler.precompute(mask=data_2d_float32_xarray_dask.notnull())
else:
resampler.resample(data_2d_float32_xarray_dask)


class TestQueryNoDistance:
"""Tests for direct KDTree query index remapping."""

def test_unselected_and_oob_are_minus_one(self):
voi = np.array([[True, False], [True, False]])
tlons = np.array([[0.0, 0.0], [10.0, 0.0]], dtype=np.float64)
tlats = np.zeros_like(tlons)

src_lons = np.array([0.0], dtype=np.float64)
src_lats = np.array([0.0], dtype=np.float64)
src_xyz = lonlat2xyz(src_lons, src_lats).astype(np.float64, copy=False)
kdtree = KDTree(src_xyz)

res = query_no_distance(
tlons,
tlats,
voi,
neighbours=1,
epsilon=0.0,
radius=1.0, # meters; only exact match is within this ROI
kdtree=kdtree,
)

np.testing.assert_array_equal(res[..., 0], np.array([[0, -1], [-1, -1]]))

def test_forwards_filtered_source_mask(self):
voi = np.array([[True]])

src_lons = np.array([[0.0, 0.0001], [0.0002, 0.0003]], dtype=np.float64)
src_lats = np.zeros_like(src_lons)
valid_input_index = np.array([[True, True], [True, False]])

src_xyz = lonlat2xyz(src_lons, src_lats).astype(np.float64, copy=False)
kdtree = KDTree(src_xyz[valid_input_index.ravel()])

target_lons = np.array([[0.0]], dtype=np.float64)
target_lats = np.array([[0.0]], dtype=np.float64)

res_unmasked = query_no_distance(
target_lons,
target_lats,
voi,
neighbours=1,
epsilon=0.0,
radius=1000.0,
kdtree=kdtree,
)

# Mask out the nearest source point (after valid_input_index filtering).
source_mask = np.array([[True, False], [False, True]])
res_masked = query_no_distance(
target_lons,
target_lats,
voi,
mask=source_mask,
valid_input_index=valid_input_index,
neighbours=1,
epsilon=0.0,
radius=1000.0,
kdtree=kdtree,
)

assert res_unmasked.shape == (1, 1, 1)
assert res_masked.shape == (1, 1, 1)
assert res_unmasked[0, 0, 0] == 0
assert res_masked[0, 0, 0] == 1
Loading