diff --git a/xrspatial/focal.py b/xrspatial/focal.py index f5ee07e8..9e38ce9a 100644 --- a/xrspatial/focal.py +++ b/xrspatial/focal.py @@ -25,8 +25,8 @@ class cupy(object): ndarray = False -from xrspatial.convolution import (_available_memory_bytes, _convolve_2d_cupy, _convolve_2d_numpy, - _promote_float, convolve_2d, custom_kernel) +from xrspatial.convolution import (_available_memory_bytes, _promote_float, convolve_2d, + custom_kernel) from xrspatial.dataset_support import supports_dataset from xrspatial.utils import (ArrayTypeFunctionMapping, _boundary_to_dask, _pad_array, _validate_boundary, _validate_raster, _validate_scalar, cuda_args, @@ -1346,12 +1346,12 @@ def _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum, # large numbers and loses precision in float32 for big rasters / weights. weight_sum = weight_sum.astype(np.float64) sq_weight_sum = sq_weight_sum.astype(np.float64) - numerator = weighted_sum.astype(np.float64) - float(global_mean) * weight_sum + numerator = weighted_sum.astype(np.float64) - global_mean * weight_sum variance_term = (n * sq_weight_sum - weight_sum * weight_sum) / (n - 1) # Guard against tiny negatives from float rounding and the degenerate # single-cell neighborhood (variance_term == 0) before the sqrt. variance_term = np.where(variance_term > 0, variance_term, np.nan) - denominator = float(global_std) * np.sqrt(variance_term) + denominator = global_std * np.sqrt(variance_term) z = numerator / denominator return np.where(np.isfinite(z), z, 0.0).astype(np.float32) @@ -1395,99 +1395,68 @@ def _hotspots_numpy(raster, kernel, boundary='nan'): def _hotspots_dask_numpy(raster, kernel, boundary='nan'): - data = raster.data - if not np.issubdtype(data.dtype, np.floating): - data = data.astype(np.float32) - - # Pass 1: eagerly compute global Gi* terms (three scalars). - # This reads all chunks once, produces a few bytes, then frees all - # intermediate state -- no barrier that would force materialization - # of the full convolution output. - valid = ~da.isnan(data) - global_mean, global_std, n = da.compute( - da.nanmean(data), da.nanstd(data), valid.sum()) - global_mean = np.float32(global_mean) - global_std = np.float32(global_std) - n = int(n) - _gistar_global_stats(global_mean, global_std, n) + # Match the numpy path: compute in float32 so the convolution and the + # float32 map_overlap meta agree regardless of input dtype. + data = raster.data.astype(np.float32) - kernel = kernel.astype(np.float32) - pad_h = kernel.shape[0] // 2 - pad_w = kernel.shape[1] // 2 + # Global Gi* terms stay lazy: 0-d dask arrays that broadcast into the + # z-score below. Nothing is computed during graph construction. + valid = ~da.isnan(data) + global_mean = da.nanmean(data) + global_std = da.nanstd(data) + n = valid.sum() + + # Per-cell Gi* convolution terms via convolve_2d's lazy dask path, + # mirroring _gistar_convolutions_numpy on the single-array backend so + # the dask result matches numpy. NaN cells are excluded via the + # validity mask. + valid_f = valid.astype(np.float32) + filled = da.where(valid, data, np.float32(0.0)) + weighted_sum = convolve_2d(filled, kernel, boundary) + weight_sum = convolve_2d(valid_f, kernel, boundary) + sq_weight_sum = convolve_2d(valid_f, kernel * kernel, boundary) - # Pass 2: fuse the three convolutions + Gi* z-score + classification - # into one map_overlap call. Each chunk reads source + halo, produces - # int8 output, and frees all intermediates immediately. - _func = partial( - _hotspots_chunk, - kernel=kernel, - global_mean=global_mean, - global_std=global_std, - n=n, - ) - out = data.map_overlap( - _func, - depth=(pad_h, pad_w), - boundary=_boundary_to_dask(boundary), - meta=np.array((), dtype=np.int8), - ) + # Gi* z-score via broadcast of the lazy 0-d global terms, then classify + # per block. + z_array = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum, + global_mean, global_std, n) + out = z_array.map_blocks(_calc_hotspots_numpy, + meta=np.array((), dtype=np.int8)) return out -def _hotspots_chunk(chunk, kernel, global_mean, global_std, n): - """Fused per-chunk: convolve Gi* terms -> z-score -> classify.""" - valid = (~np.isnan(chunk)).astype(np.float32) - filled = np.where(valid > 0, chunk, np.float32(0.0)) - weighted_sum = _convolve_2d_numpy(filled, kernel) - weight_sum = _convolve_2d_numpy(valid, kernel) - sq_weight_sum = _convolve_2d_numpy(valid, kernel * kernel) - z = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum, - global_mean, global_std, n) - return _calc_hotspots_numpy(z) +def _calc_hotspots_cupy(z): + """Per-chunk GPU classification of a z-score array.""" + out = cupy.zeros_like(z, dtype=cupy.int8) + griddim, blockdim = cuda_args(z.shape) + _run_gpu_hotspots[griddim, blockdim](z, out) + return out def _hotspots_dask_cupy(raster, kernel, boundary='nan'): - data = raster.data - if not cupy.issubdtype(data.dtype, cupy.floating): - data = data.astype(cupy.float32) - - # Pass 1: global Gi* terms (three scalars, eager) - valid_global = ~da.isnan(data) - global_mean, global_std, n = da.compute( - da.nanmean(data), da.nanstd(data), valid_global.sum()) - global_mean = np.float32(float(global_mean)) - global_std = np.float32(float(global_std)) - n = int(n) - _gistar_global_stats(global_mean, global_std, n) + # Match the numpy path: compute in float32 so the convolution and the + # float32 map_overlap meta agree regardless of input dtype. + data = raster.data.astype(cupy.float32) - kernel = kernel.astype(np.float32) - sq_kernel = kernel * kernel - pad_h = kernel.shape[0] // 2 - pad_w = kernel.shape[1] // 2 + # Global Gi* terms stay lazy: 0-d dask arrays that broadcast into the + # z-score below. Nothing is computed during graph construction. + valid = ~da.isnan(data) + global_mean = da.nanmean(data) + global_std = da.nanstd(data) + n = valid.sum() + + # Per-cell Gi* convolution terms via convolve_2d's lazy dask+cupy path; + # each chunk stays on the device (no host round trip). + valid_f = valid.astype(cupy.float32) + filled = da.where(valid, data, cupy.float32(0.0)) + weighted_sum = convolve_2d(filled, kernel, boundary) + weight_sum = convolve_2d(valid_f, kernel, boundary) + sq_weight_sum = convolve_2d(valid_f, kernel * kernel, boundary) - # Pass 2: fuse the three convolutions + Gi* z-score + classification, - # all on the GPU. Reuse the _run_gpu_hotspots kernel (same as the - # single-GPU path) so each chunk stays on the device -- no host round - # trip per chunk. - def _chunk_fn(chunk): - valid = (~cupy.isnan(chunk)).astype(cupy.float32) - filled = cupy.where(valid > 0, chunk, cupy.float32(0.0)) - weighted_sum = _convolve_2d_cupy(filled, kernel) - weight_sum = _convolve_2d_cupy(valid, kernel) - sq_weight_sum = _convolve_2d_cupy(valid, sq_kernel) - z = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum, - global_mean, global_std, n) - out = cupy.zeros_like(z, dtype=cupy.int8) - griddim, blockdim = cuda_args(z.shape) - _run_gpu_hotspots[griddim, blockdim](z, out) - return out - - out = data.map_overlap( - _chunk_fn, - depth=(pad_h, pad_w), - boundary=_boundary_to_dask(boundary, is_cupy=True), - meta=cupy.array((), dtype=cupy.int8), - ) + z_array = _gistar_zscore(weighted_sum, weight_sum, sq_weight_sum, + global_mean, global_std, n) + out = z_array.map_blocks(_calc_hotspots_cupy, + meta=cupy.array((), dtype=cupy.int8)) return out diff --git a/xrspatial/tests/test_dask_laziness.py b/xrspatial/tests/test_dask_laziness.py index da4216f8..5dff79db 100644 --- a/xrspatial/tests/test_dask_laziness.py +++ b/xrspatial/tests/test_dask_laziness.py @@ -16,6 +16,7 @@ try: import dask import dask.array as da + import dask.callbacks # noqa: F401 (registers dask.callbacks.Callback) except ImportError: da = None @@ -32,6 +33,20 @@ def _is_lazy(result): return False +class _TaskCounter(dask.callbacks.Callback): + """Count dask tasks executed while the context is active. + + Used to assert that a function builds its graph lazily (zero tasks run + during the call) rather than eagerly triggering computation. + """ + + def __init__(self): + self.count = 0 + + def _posttask(self, key, result, dsk, state, worker_id): + self.count += 1 + + # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @@ -117,6 +132,21 @@ def test_hotspots(self, elev): kernel = np.ones((3, 3), dtype=np.float32) / 9 assert _is_lazy(hotspots(elev, kernel)) + def test_hotspots_no_eager_compute(self, elev): + # hotspots() must build the graph without running any dask tasks. + # Previously it eagerly computed global mean/std, executing ~12 + # tasks on the call itself (issue #2772). + from xrspatial.focal import hotspots + kernel = np.ones((3, 3), dtype=np.float32) / 9 + with _TaskCounter() as counter: + result = hotspots(elev, kernel) + assert counter.count == 0, ( + f"hotspots() ran {counter.count} dask tasks on call; expected 0 " + f"(should stay lazy)" + ) + # The graph must still compute to a real result. + assert result.compute().shape == elev.shape + # --------------------------------------------------------------------------- # Classification (partially lazy -- returns dask after computing small stats)