@@ -71,6 +71,79 @@ def _call_mapped_ll2cr(lons, lats, target_geo_def):
7171 return res
7272
7373
74+ def _ll2cr_block_extent (ll2cr_block ):
75+ """Compute row/column bounds for a single ll2cr block.
76+
77+ Args:
78+ ll2cr_block: ll2cr output block as ``(cols, rows)`` arrays, or the
79+ empty sentinel returned by ``_call_ll2cr``.
80+
81+ Returns:
82+ ``(row_min, row_max, col_min, col_max)`` as floats, or ``None`` when
83+ the block contains no valid finite coordinates.
84+ """
85+ # Empty ll2cr results: ((shape, fill, dtype), (shape, fill, dtype))
86+ if isinstance (ll2cr_block [0 ], tuple ):
87+ return None
88+
89+ cols = np .asarray (ll2cr_block [0 ])
90+ rows = np .asarray (ll2cr_block [1 ])
91+ valid = np .isfinite (cols ) & np .isfinite (rows )
92+ if not np .any (valid ):
93+ return None
94+
95+ valid_rows = rows [valid ]
96+ valid_cols = cols [valid ]
97+ row_min = float (valid_rows .min ())
98+ row_max = float (valid_rows .max ())
99+ col_min = float (valid_cols .min ())
100+ col_max = float (valid_cols .max ())
101+ return row_min , row_max , col_min , col_max
102+
103+
104+ def _pad_bounds (bounds , overlap_margin ):
105+ """Pad ll2cr bounds by a constant overlap margin.
106+
107+ Args:
108+ bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``,
109+ or ``None``.
110+ overlap_margin: Non-negative overlap margin in grid cells.
111+
112+ Returns:
113+ Padded bounds tuple, or ``None`` when input bounds is ``None``.
114+ """
115+ if bounds is None :
116+ return None
117+ row_min , row_max , col_min , col_max = bounds
118+ return (
119+ row_min - overlap_margin ,
120+ row_max + overlap_margin ,
121+ col_min - overlap_margin ,
122+ col_max + overlap_margin ,
123+ )
124+
125+
126+ def _chunk_intersects_bounds (bounds , y_slice , x_slice ):
127+ """Check whether a target chunk overlaps pre-padded ll2cr bounds.
128+
129+ Args:
130+ bounds: ll2cr bounds tuple ``(row_min, row_max, col_min, col_max)``,
131+ already padded for overlap, or ``None``.
132+ y_slice: Output chunk row slice.
133+ x_slice: Output chunk column slice.
134+
135+ Returns:
136+ ``True`` if the chunk intersects the bounds.
137+ """
138+ if bounds is None :
139+ return True
140+ row_min , row_max , col_min , col_max = bounds
141+ return (
142+ y_slice .stop > row_min and y_slice .start <= row_max and
143+ x_slice .stop > col_min and x_slice .start <= col_max
144+ )
145+
146+
74147def _delayed_fornav (ll2cr_result , target_geo_def , y_slice , x_slice , data , fill_value , kwargs ):
75148 # Adjust cols and rows for this sub-area
76149 subdef = target_geo_def [y_slice , x_slice ]
@@ -107,6 +180,21 @@ def _chunk_callable(x_chunk, axis, keepdims, **kwargs):
107180 return x_chunk
108181
109182
183+ def _sum_arrays (arrays ):
184+ """Sum arrays with one initial copy and in-place accumulation.
185+
186+ Args:
187+ arrays: Non-empty sequence of NumPy arrays with compatible shapes.
188+
189+ Returns:
190+ Element-wise sum as a NumPy array.
191+ """
192+ total = arrays [0 ].copy ()
193+ for arr in arrays [1 :]:
194+ total += arr
195+ return total
196+
197+
110198def _combine_fornav (x_chunk , axis , keepdims , computing_meta = False ,
111199 maximum_weight_mode = False ):
112200 if computing_meta or _is_empty_chunk (x_chunk ):
@@ -126,6 +214,8 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
126214 # split step - return "empty" chunk placeholder
127215 return x_chunk [0 ]
128216 return np .full (* x_chunk [0 ][0 ]), np .full (* x_chunk [0 ][1 ])
217+ if len (valid_chunks ) == 1 :
218+ return valid_chunks [0 ]
129219 weights = [x [0 ] for x in valid_chunks ]
130220 accums = [x [1 ] for x in valid_chunks ]
131221 if maximum_weight_mode :
@@ -135,9 +225,7 @@ def _combine_fornav(x_chunk, axis, keepdims, computing_meta=False,
135225 weights = np .take_along_axis (weights , max_indexes , axis = 0 ).squeeze (axis = 0 )
136226 accums = np .take_along_axis (accums , max_indexes , axis = 0 ).squeeze (axis = 0 )
137227 return weights , accums
138- # NOTE: We use the builtin "sum" function below because it does not copy
139- # the numpy arrays. Using numpy.sum would do that.
140- return sum (weights ), sum (accums )
228+ return _sum_arrays (weights ), _sum_arrays (accums )
141229
142230
143231def _is_empty_chunk (x_chunk ):
@@ -224,60 +312,79 @@ def _get_rows_per_scan(self, rows_per_scan=None):
224312 rows_per_scan = self .source_geo_def .shape [0 ]
225313 return rows_per_scan
226314
227- def _fill_block_cache_with_ll2cr_results (self , ll2cr_result ,
228- num_row_blocks ,
229- num_col_blocks ,
230- persist ):
315+ def _ll2cr_cache_matches (self , rows_per_scan , persist ):
316+ return (
317+ self .cache .get ('rows_per_scan' ) == rows_per_scan and
318+ self .cache .get ('persist' ) == persist
319+ )
320+
321+ def _get_ll2cr_blocks (self , ll2cr_result , persist ):
322+ ll2cr_blocks = []
323+ block_dependencies = None
231324 if persist :
232325 ll2cr_delayeds = ll2cr_result .to_delayed ()
233- ll2cr_delayeds = dask .persist (* ll2cr_delayeds .tolist ())
234-
235- block_cache = {}
236- for in_row_idx in range (num_row_blocks ):
237- for in_col_idx in range (num_col_blocks ):
238- key = (ll2cr_result .name , in_row_idx , in_col_idx )
239- if persist :
240- this_delayed = ll2cr_delayeds [in_row_idx ][in_col_idx ]
241- result = dask .compute (this_delayed )[0 ]
242- # XXX: Is this optimization lost because the persisted keys
243- # in `ll2cr_delayeds` are used in future computations?
244- if not isinstance (result [0 ], tuple ):
245- block_cache [key ] = this_delayed .key
246- else :
247- block_cache [key ] = key
248- return block_cache
326+ flat_delayeds = [
327+ (in_row_idx , in_col_idx , delayed_block )
328+ for in_row_idx , delayed_row in enumerate (ll2cr_delayeds )
329+ for in_col_idx , delayed_block in enumerate (delayed_row )
330+ ]
331+ block_dependencies = []
332+ persisted_delayeds = dask .persist (
333+ * (delayed for _ , _ , delayed in flat_delayeds ))
334+ # Compute only per-block extents on workers to avoid materializing
335+ # full ll2cr blocks in the client process.
336+ extent_delayeds = [dask .delayed (_ll2cr_block_extent )(d ) for d in persisted_delayeds ]
337+ computed_extents = dask .compute (* extent_delayeds )
338+ for (in_row_idx , in_col_idx , _ ), persisted_delayed , extent in zip (
339+ flat_delayeds , persisted_delayeds , computed_extents , strict = True ):
340+ if extent is None :
341+ continue
342+ ll2cr_blocks .append ((in_row_idx , in_col_idx , persisted_delayed .key , extent ))
343+ block_dependencies .append (persisted_delayed )
344+ else :
345+ num_row_blocks , num_col_blocks = ll2cr_result .numblocks [- 2 :]
346+ for in_row_idx in range (num_row_blocks ):
347+ for in_col_idx in range (num_col_blocks ):
348+ ll2cr_blocks .append ((
349+ in_row_idx ,
350+ in_col_idx ,
351+ (ll2cr_result .name , in_row_idx , in_col_idx ),
352+ None ,
353+ ))
354+ return ll2cr_blocks , block_dependencies
249355
250356 def precompute (self , cache_dir = None , rows_per_scan = None , persist = False ,
251357 ** kwargs ):
252358 """Generate row and column arrays and store it for later use."""
253- if self .cache :
359+ rows_per_scan = self ._get_rows_per_scan (rows_per_scan )
360+ if self ._ll2cr_cache_matches (rows_per_scan , persist ):
254361 # this resampler should be used for one SwathDefinition
255- # no need to recompute ll2cr output again
362+ # no need to recompute matching ll2cr output again
256363 return None
257364
258365 if kwargs .get ('mask' ) is not None :
259366 logger .warning ("'mask' parameter has no affect during EWA "
260367 "resampling" )
261368
262- source_geo_def = self .source_geo_def
263- target_geo_def = self .target_geo_def
264369 if cache_dir :
265370 logger .warning ("'cache_dir' is not used by EWA resampling" )
266371
267- rows_per_scan = self ._get_rows_per_scan (rows_per_scan )
268- new_chunks = self ._new_chunks (source_geo_def .lons , rows_per_scan )
269- lons , lats = source_geo_def .get_lonlats (chunks = new_chunks )
372+ new_chunks = self ._new_chunks (self .source_geo_def .lons , rows_per_scan )
373+ lons , lats = self .source_geo_def .get_lonlats (chunks = new_chunks )
270374 # run ll2cr to get column/row indexes
271375 # if chunk does not overlap target area then None is returned
272376 # otherwise a 3D array (2, y, x) of cols, rows are returned
273- ll2cr_result = _call_mapped_ll2cr (lons , lats , target_geo_def )
274- block_cache = self ._fill_block_cache_with_ll2cr_results (
275- ll2cr_result , lons . numblocks [ 0 ], lons . numblocks [ 1 ], persist )
377+ ll2cr_result = _call_mapped_ll2cr (lons , lats , self . target_geo_def )
378+ ll2cr_blocks , block_dependencies = self ._get_ll2cr_blocks (
379+ ll2cr_result , persist )
276380
277381 # save the dask arrays in the class instance cache
278382 self .cache = {
279383 'll2cr_result' : ll2cr_result ,
280- 'll2cr_blocks' : block_cache ,
384+ 'll2cr_blocks' : ll2cr_blocks ,
385+ 'll2cr_block_dependencies' : block_dependencies ,
386+ 'rows_per_scan' : rows_per_scan ,
387+ 'persist' : persist ,
281388 }
282389 return None
283390
@@ -323,27 +430,46 @@ def _generate_fornav_dask_tasks(out_chunks, ll2cr_blocks, task_name,
323430 input_name , target_geo_def , fill_value , kwargs ):
324431 y_start = 0
325432 output_stack = {}
433+ overlap_margin = max (
434+ float (kwargs .get ("weight_delta_max" , 0.0 )),
435+ float (kwargs .get ("weight_distance_max" , 0.0 )),
436+ 0.0 ,
437+ )
438+ indexed_blocks = []
439+ for z_idx , (in_row_idx , in_col_idx , ll2cr_block , block_extent ) in enumerate (ll2cr_blocks ):
440+ block_bounds = _pad_bounds (block_extent , overlap_margin )
441+ indexed_blocks .append ((z_idx , in_row_idx , in_col_idx , ll2cr_block , block_bounds ))
326442 for out_row_idx in range (len (out_chunks [0 ])):
327443 y_end = y_start + out_chunks [0 ][out_row_idx ]
328444 x_start = 0
329445 for out_col_idx in range (len (out_chunks [1 ])):
330446 x_end = x_start + out_chunks [1 ][out_col_idx ]
331447 y_slice = slice (y_start , y_end )
332448 x_slice = slice (x_start , x_end )
333- for z_idx , ((_ , in_row_idx , in_col_idx ), ll2cr_block ) in enumerate (ll2cr_blocks ):
449+ placeholder = (
450+ ((y_end - y_start , x_end - x_start ), 0 , np .float32 ),
451+ ((y_end - y_start , x_end - x_start ), 0 , np .float32 ),
452+ )
453+ for z_idx , in_row_idx , in_col_idx , ll2cr_block , block_bounds in indexed_blocks :
334454 key = (task_name , z_idx , out_row_idx , out_col_idx )
335- output_stack [key ] = (_delayed_fornav ,
336- ll2cr_block ,
337- target_geo_def , y_slice , x_slice ,
338- (input_name , in_row_idx , in_col_idx ), fill_value , kwargs )
455+ if _chunk_intersects_bounds (block_bounds , y_slice , x_slice ):
456+ output_stack [key ] = (_delayed_fornav ,
457+ ll2cr_block ,
458+ target_geo_def , y_slice , x_slice ,
459+ (input_name , in_row_idx , in_col_idx ), fill_value , kwargs )
460+ else :
461+ output_stack [key ] = placeholder
339462 x_start = x_end
340463 y_start = y_end
341464 return output_stack
342465
343466 def _run_fornav_single (self , data , out_chunks , target_geo_def , fill_value , ** kwargs ):
344467 ll2cr_result = self .cache ['ll2cr_result' ]
345- ll2cr_blocks = self .cache ['ll2cr_blocks' ].items ()
346- ll2cr_numblocks = ll2cr_result .shape if isinstance (ll2cr_result , np .ndarray ) else ll2cr_result .numblocks
468+ ll2cr_blocks = self .cache ['ll2cr_blocks' ]
469+ ll2cr_block_dependencies = self .cache .get ('ll2cr_block_dependencies' )
470+ if not ll2cr_blocks :
471+ return da .full (target_geo_def .shape , fill_value , dtype = data .dtype ,
472+ chunks = out_chunks )
347473 fornav_task_name = f"fornav-{ data .name } -{ ll2cr_result .name } "
348474 maximum_weight_mode = kwargs .setdefault ('maximum_weight_mode' , False )
349475 weight_sum_min = kwargs .setdefault ('weight_sum_min' , - 1.0 )
@@ -357,8 +483,12 @@ def _run_fornav_single(self, data, out_chunks, target_geo_def, fill_value, **kwa
357483
358484 dsk_graph = HighLevelGraph .from_collections (fornav_task_name ,
359485 output_stack ,
360- dependencies = [data , ll2cr_result ])
361- stack_chunks = ((1 ,) * (ll2cr_numblocks [0 ] * ll2cr_numblocks [1 ]),) + out_chunks
486+ dependencies = (
487+ (data , ll2cr_result )
488+ if ll2cr_block_dependencies is None
489+ else (data , * ll2cr_block_dependencies )
490+ ))
491+ stack_chunks = ((1 ,) * len (ll2cr_blocks ),) + out_chunks
362492 out_stack = da .Array (dsk_graph , fornav_task_name , stack_chunks , data .dtype )
363493 combine_fornav_with_kwargs = partial (
364494 _combine_fornav , maximum_weight_mode = maximum_weight_mode )
0 commit comments