Skip to content

Commit 2de35c8

Browse files
committed
Optimize Dask EWA persist and prune fornav tasks
1 parent b8c4b74 commit 2de35c8

2 files changed

Lines changed: 446 additions & 43 deletions

File tree

pyresample/ewa/dask_ewa.py

Lines changed: 173 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,79 @@ def _call_mapped_ll2cr(lons, lats, target_geo_def):
7171
return res
7272

7373

74+
def _ll2cr_block_extent(ll2cr_block):
75+
"""Compute row/column bounds for a single ll2cr block.
76+
77+
Args:
78+
ll2cr_block: ll2cr output block as ``(cols, rows)`` arrays, or the
79+
empty sentinel returned by ``_call_ll2cr``.
80+
81+
Returns:
82+
``(row_min, row_max, col_min, col_max)`` as floats, or ``None`` when
83+
the block contains no valid finite coordinates.
84+
"""
85+
# Empty ll2cr results: ((shape, fill, dtype), (shape, fill, dtype))
86+
if isinstance(ll2cr_block[0], tuple):
87+
return None
88+
89+
cols = np.asarray(ll2cr_block[0])
90+
rows = np.asarray(ll2cr_block[1])
91+
valid = np.isfinite(cols) & np.isfinite(rows)
92+
if not np.any(valid):
93+
return None
94+
95+
valid_rows = rows[valid]
96+
valid_cols = cols[valid]
97+
row_min = float(valid_rows.min())
98+
row_max = float(valid_rows.max())
99+
col_min = float(valid_cols.min())
100+
col_max = float(valid_cols.max())
101+
return row_min, row_max, col_min, col_max
102+
103+
104+
def _pad_bounds(bounds, overlap_margin):
105+
"""Pad ll2cr bounds by a constant overlap margin.
106+
107+
Args:
108+
bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``,
109+
or ``None``.
110+
overlap_margin: Non-negative overlap margin in grid cells.
111+
112+
Returns:
113+
Padded bounds tuple, or ``None`` when input bounds is ``None``.
114+
"""
115+
if bounds is None:
116+
return None
117+
row_min, row_max, col_min, col_max = bounds
118+
return (
119+
row_min - overlap_margin,
120+
row_max + overlap_margin,
121+
col_min - overlap_margin,
122+
col_max + overlap_margin,
123+
)
124+
125+
126+
def _chunk_intersects_bounds(bounds, y_slice, x_slice):
127+
"""Check whether a target chunk overlaps pre-padded ll2cr bounds.
128+
129+
Args:
130+
bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``,
131+
already padded for overlap, or ``None``.
132+
y_slice: Output chunk row slice.
133+
x_slice: Output chunk column slice.
134+
135+
Returns:
136+
``True`` if the chunk intersects the bounds.
137+
"""
138+
if bounds is None:
139+
return True
140+
row_min, row_max, col_min, col_max = bounds
141+
return (
142+
y_slice.stop > row_min and y_slice.start <= row_max and
143+
x_slice.stop > col_min and x_slice.start <= col_max
144+
)
145+
146+
74147
def _delayed_fornav(ll2cr_result, target_geo_def, y_slice, x_slice, data, fill_value, kwargs):
75148
# Adjust cols and rows for this sub-area
76149
subdef = target_geo_def[y_slice, x_slice]
@@ -107,6 +180,21 @@ def _chunk_callable(x_chunk, axis, keepdims, **kwargs):
107180
return x_chunk
108181

109182

183+
def _sum_arrays(arrays):
184+
"""Sum arrays with one initial copy and in-place accumulation.
185+
186+
Args:
187+
arrays: Non-empty sequence of NumPy arrays with compatible shapes.
188+
189+
Returns:
190+
Element-wise sum as a NumPy array.
191+
"""
192+
total = arrays[0].copy()
193+
for arr in arrays[1:]:
194+
total += arr
195+
return total
196+
197+
110198
def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
111199
maximum_weight_mode=False):
112200
if computing_meta or _is_empty_chunk(x_chunk):
@@ -126,6 +214,8 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
126214
# split step - return "empty" chunk placeholder
127215
return x_chunk[0]
128216
return np.full(*x_chunk[0][0]), np.full(*x_chunk[0][1])
217+
if len(valid_chunks) == 1:
218+
return valid_chunks[0]
129219
weights = [x[0] for x in valid_chunks]
130220
accums = [x[1] for x in valid_chunks]
131221
if maximum_weight_mode:
@@ -135,9 +225,7 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
135225
weights = np.take_along_axis(weights, max_indexes, axis=0).squeeze(axis=0)
136226
accums = np.take_along_axis(accums, max_indexes, axis=0).squeeze(axis=0)
137227
return weights, accums
138-
# NOTE: We use the builtin "sum" function below because it does not copy
139-
# the numpy arrays. Using numpy.sum would do that.
140-
return sum(weights), sum(accums)
228+
return _sum_arrays(weights), _sum_arrays(accums)
141229

142230

143231
def _is_empty_chunk(x_chunk):
@@ -224,60 +312,79 @@ def _get_rows_per_scan(self, rows_per_scan=None):
224312
rows_per_scan = self.source_geo_def.shape[0]
225313
return rows_per_scan
226314

227-
def _fill_block_cache_with_ll2cr_results(self, ll2cr_result,
228-
num_row_blocks,
229-
num_col_blocks,
230-
persist):
315+
def _ll2cr_cache_matches(self, rows_per_scan, persist):
316+
return (
317+
self.cache.get('rows_per_scan') == rows_per_scan and
318+
self.cache.get('persist') == persist
319+
)
320+
321+
def _get_ll2cr_blocks(self, ll2cr_result, persist):
322+
ll2cr_blocks = []
323+
block_dependencies = None
231324
if persist:
232325
ll2cr_delayeds = ll2cr_result.to_delayed()
233-
ll2cr_delayeds = dask.persist(*ll2cr_delayeds.tolist())
234-
235-
block_cache = {}
236-
for in_row_idx in range(num_row_blocks):
237-
for in_col_idx in range(num_col_blocks):
238-
key = (ll2cr_result.name, in_row_idx, in_col_idx)
239-
if persist:
240-
this_delayed = ll2cr_delayeds[in_row_idx][in_col_idx]
241-
result = dask.compute(this_delayed)[0]
242-
# XXX: Is this optimization lost because the persisted keys
243-
# in `ll2cr_delayeds` are used in future computations?
244-
if not isinstance(result[0], tuple):
245-
block_cache[key] = this_delayed.key
246-
else:
247-
block_cache[key] = key
248-
return block_cache
326+
flat_delayeds = [
327+
(in_row_idx, in_col_idx, delayed_block)
328+
for in_row_idx, delayed_row in enumerate(ll2cr_delayeds)
329+
for in_col_idx, delayed_block in enumerate(delayed_row)
330+
]
331+
block_dependencies = []
332+
persisted_delayeds = dask.persist(
333+
*(delayed for _, _, delayed in flat_delayeds))
334+
# Compute only per-block extents on workers to avoid materializing
335+
# full ll2cr blocks in the client process.
336+
extent_delayeds = [dask.delayed(_ll2cr_block_extent)(d) for d in persisted_delayeds]
337+
computed_extents = dask.compute(*extent_delayeds)
338+
for (in_row_idx, in_col_idx, _), persisted_delayed, extent in zip(
339+
flat_delayeds, persisted_delayeds, computed_extents, strict=True):
340+
if extent is None:
341+
continue
342+
ll2cr_blocks.append((in_row_idx, in_col_idx, persisted_delayed.key, extent))
343+
block_dependencies.append(persisted_delayed)
344+
else:
345+
num_row_blocks, num_col_blocks = ll2cr_result.numblocks[-2:]
346+
for in_row_idx in range(num_row_blocks):
347+
for in_col_idx in range(num_col_blocks):
348+
ll2cr_blocks.append((
349+
in_row_idx,
350+
in_col_idx,
351+
(ll2cr_result.name, in_row_idx, in_col_idx),
352+
None,
353+
))
354+
return ll2cr_blocks, block_dependencies
249355

250356
def precompute(self, cache_dir=None, rows_per_scan=None, persist=False,
251357
**kwargs):
252358
"""Generate row and column arrays and store it for later use."""
253-
if self.cache:
359+
rows_per_scan = self._get_rows_per_scan(rows_per_scan)
360+
if self._ll2cr_cache_matches(rows_per_scan, persist):
254361
# this resampler should be used for one SwathDefinition
255-
# no need to recompute ll2cr output again
362+
# no need to recompute matching ll2cr output again
256363
return None
257364

258365
if kwargs.get('mask') is not None:
259366
logger.warning("'mask' parameter has no affect during EWA "
260367
"resampling")
261368

262-
source_geo_def = self.source_geo_def
263-
target_geo_def = self.target_geo_def
264369
if cache_dir:
265370
logger.warning("'cache_dir' is not used by EWA resampling")
266371

267-
rows_per_scan = self._get_rows_per_scan(rows_per_scan)
268-
new_chunks = self._new_chunks(source_geo_def.lons, rows_per_scan)
269-
lons, lats = source_geo_def.get_lonlats(chunks=new_chunks)
372+
new_chunks = self._new_chunks(self.source_geo_def.lons, rows_per_scan)
373+
lons, lats = self.source_geo_def.get_lonlats(chunks=new_chunks)
270374
# run ll2cr to get column/row indexes
271375
# if chunk does not overlap target area then None is returned
272376
# otherwise a 3D array (2, y, x) of cols, rows are returned
273-
ll2cr_result = _call_mapped_ll2cr(lons, lats, target_geo_def)
274-
block_cache = self._fill_block_cache_with_ll2cr_results(
275-
ll2cr_result, lons.numblocks[0], lons.numblocks[1], persist)
377+
ll2cr_result = _call_mapped_ll2cr(lons, lats, self.target_geo_def)
378+
ll2cr_blocks, block_dependencies = self._get_ll2cr_blocks(
379+
ll2cr_result, persist)
276380

277381
# save the dask arrays in the class instance cache
278382
self.cache = {
279383
'll2cr_result': ll2cr_result,
280-
'll2cr_blocks': block_cache,
384+
'll2cr_blocks': ll2cr_blocks,
385+
'll2cr_block_dependencies': block_dependencies,
386+
'rows_per_scan': rows_per_scan,
387+
'persist': persist,
281388
}
282389
return None
283390

@@ -323,27 +430,46 @@ def _generate_fornav_dask_tasks(out_chunks, ll2cr_blocks, task_name,
323430
input_name, target_geo_def, fill_value, kwargs):
324431
y_start = 0
325432
output_stack = {}
433+
overlap_margin = max(
434+
float(kwargs.get("weight_delta_max", 0.0)),
435+
float(kwargs.get("weight_distance_max", 0.0)),
436+
0.0,
437+
)
438+
indexed_blocks = []
439+
for z_idx, (in_row_idx, in_col_idx, ll2cr_block, block_extent) in enumerate(ll2cr_blocks):
440+
block_bounds = _pad_bounds(block_extent, overlap_margin)
441+
indexed_blocks.append((z_idx, in_row_idx, in_col_idx, ll2cr_block, block_bounds))
326442
for out_row_idx in range(len(out_chunks[0])):
327443
y_end = y_start + out_chunks[0][out_row_idx]
328444
x_start = 0
329445
for out_col_idx in range(len(out_chunks[1])):
330446
x_end = x_start + out_chunks[1][out_col_idx]
331447
y_slice = slice(y_start, y_end)
332448
x_slice = slice(x_start, x_end)
333-
for z_idx, ((_, in_row_idx, in_col_idx), ll2cr_block) in enumerate(ll2cr_blocks):
449+
placeholder = (
450+
((y_end - y_start, x_end - x_start), 0, np.float32),
451+
((y_end - y_start, x_end - x_start), 0, np.float32),
452+
)
453+
for z_idx, in_row_idx, in_col_idx, ll2cr_block, block_bounds in indexed_blocks:
334454
key = (task_name, z_idx, out_row_idx, out_col_idx)
335-
output_stack[key] = (_delayed_fornav,
336-
ll2cr_block,
337-
target_geo_def, y_slice, x_slice,
338-
(input_name, in_row_idx, in_col_idx), fill_value, kwargs)
455+
if _chunk_intersects_bounds(block_bounds, y_slice, x_slice):
456+
output_stack[key] = (_delayed_fornav,
457+
ll2cr_block,
458+
target_geo_def, y_slice, x_slice,
459+
(input_name, in_row_idx, in_col_idx), fill_value, kwargs)
460+
else:
461+
output_stack[key] = placeholder
339462
x_start = x_end
340463
y_start = y_end
341464
return output_stack
342465

343466
def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwargs):
344467
ll2cr_result = self.cache['ll2cr_result']
345-
ll2cr_blocks = self.cache['ll2cr_blocks'].items()
346-
ll2cr_numblocks = ll2cr_result.shape if isinstance(ll2cr_result, np.ndarray) else ll2cr_result.numblocks
468+
ll2cr_blocks = self.cache['ll2cr_blocks']
469+
ll2cr_block_dependencies = self.cache.get('ll2cr_block_dependencies')
470+
if not ll2cr_blocks:
471+
return da.full(target_geo_def.shape, fill_value, dtype=data.dtype,
472+
chunks=out_chunks)
347473
fornav_task_name = f"fornav-{data.name}-{ll2cr_result.name}"
348474
maximum_weight_mode = kwargs.setdefault('maximum_weight_mode', False)
349475
weight_sum_min = kwargs.setdefault('weight_sum_min', -1.0)
@@ -357,8 +483,12 @@ def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwa
357483

358484
dsk_graph = HighLevelGraph.from_collections(fornav_task_name,
359485
output_stack,
360-
dependencies=[data, ll2cr_result])
361-
stack_chunks = ((1,) * (ll2cr_numblocks[0] * ll2cr_numblocks[1]),) + out_chunks
486+
dependencies=(
487+
(data, ll2cr_result)
488+
if ll2cr_block_dependencies is None
489+
else (data, *ll2cr_block_dependencies)
490+
))
491+
stack_chunks = ((1,) * len(ll2cr_blocks),) + out_chunks
362492
out_stack = da.Array(dsk_graph, fornav_task_name, stack_chunks, data.dtype)
363493
combine_fornav_with_kwargs = partial(
364494
_combine_fornav, maximum_weight_mode=maximum_weight_mode)

0 commit comments

Comments
 (0)